@@ -169,26 +169,24 @@ def run_worker(rank, world_size):
169169 def print_with_rank (msg ):
170170 print ('[RANK {}]: {}' .format (rank , msg ))
171171
172- import io
173- from torchtext .utils import download_from_url , extract_archive
172+ from torchtext .datasets import WikiText2
174173 from torchtext .data .utils import get_tokenizer
175174 from torchtext .vocab import build_vocab_from_iterator
176175
177- url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
178- test_filepath , valid_filepath , train_filepath = extract_archive (download_from_url (url , root = ".data{}" .format (rank )))
176+ train_iter = WikiText2 (split = 'train' )
179177 tokenizer = get_tokenizer ('basic_english' )
180- vocab = build_vocab_from_iterator (map (tokenizer ,
181- iter (io .open (train_filepath ,
182- encoding = "utf8" ))))
178+ vocab = build_vocab_from_iterator (map (tokenizer , train_iter ), specials = ["<unk>" ])
179+ vocab .set_default_index (vocab ["<unk>" ])
183180
184181 def data_process (raw_text_iter ):
185- data = [torch .tensor ([vocab [token ] for token in tokenizer (item )],
186- dtype = torch .long ) for item in raw_text_iter ]
182+ data = [torch .tensor (vocab (tokenizer (item )), dtype = torch .long ) for item in raw_text_iter ]
187183 return torch .cat (tuple (filter (lambda t : t .numel () > 0 , data )))
188184
189- train_data = data_process (iter (io .open (train_filepath , encoding = "utf8" )))
190- val_data = data_process (iter (io .open (valid_filepath , encoding = "utf8" )))
191- test_data = data_process (iter (io .open (test_filepath , encoding = "utf8" )))
185+ train_iter , val_iter , test_iter = WikiText2 ()
186+ train_data = data_process (train_iter )
187+ val_data = data_process (val_iter )
188+ test_data = data_process (test_iter )
189+
192190 device = torch .device (2 * rank )
193191
194192 def batchify (data , bsz , rank , world_size , is_train = False ):
@@ -264,7 +262,7 @@ def get_batch(source, i):
264262# another across GPUs 2 and 3. Both pipes are then replicated using DistributedDataParallel.
265263
266264# In 'run_worker'
267- ntokens = len (vocab . stoi ) # the size of vocabulary
265+ ntokens = len (vocab ) # the size of vocabulary
268266 emsize = 4096 # embedding dimension
269267 nhid = 4096 # the dimension of the feedforward network model in nn.TransformerEncoder
270268 nlayers = 8 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
@@ -361,7 +359,7 @@ def train():
361359 model .train () # Turn on the train mode
362360 total_loss = 0.
363361 start_time = time .time ()
364- ntokens = len (vocab . stoi )
362+ ntokens = len (vocab )
365363
366364 # Train only for 50 batches to keep script execution time low.
367365 nbatches = min (50 * bptt , train_data .size (0 ) - 1 )
@@ -388,7 +386,7 @@ def train():
388386 print_with_rank ('| epoch {:3d} | {:5d}/{:5d} batches | '
389387 'lr {:02.2f} | ms/batch {:5.2f} | '
390388 'loss {:5.2f} | ppl {:8.2f}' .format (
391- epoch , batch , nbatches // bptt , scheduler .get_lr ()[0 ],
389+ epoch , batch , nbatches // bptt , scheduler .get_last_lr ()[0 ],
392390 elapsed * 1000 / log_interval ,
393391 cur_loss , math .exp (cur_loss )))
394392 total_loss = 0
@@ -397,7 +395,7 @@ def train():
397395 def evaluate (eval_model , data_source ):
398396 eval_model .eval () # Turn on the evaluation mode
399397 total_loss = 0.
400- ntokens = len (vocab . stoi )
398+ ntokens = len (vocab )
401399 # Evaluate only for 50 batches to keep script execution time low.
402400 nbatches = min (50 * bptt , data_source .size (0 ) - 1 )
403401 with torch .no_grad ():
@@ -455,8 +453,6 @@ def evaluate(eval_model, data_source):
455453if __name__ == "__main__" :
456454 world_size = 2
457455 mp .spawn (run_worker , args = (world_size , ), nprocs = world_size , join = True )
458-
459-
460456######################################################################
461457# Output
462458# ------
@@ -466,52 +462,52 @@ def evaluate(eval_model, data_source):
466462######################################################################
467463#.. code-block:: py
468464#
469- # [RANK 1]: Total parameters in model: 1,041,453,167
470- # [RANK 0]: Total parameters in model: 1,041,453,167
471- # [RANK 0]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 1414.18 | loss 48.70 | ppl 1406154472673147092992.00
472- # [RANK 1]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 1414.42 | loss 48.49 | ppl 1146707511057334927360.00
473- # [RANK 0]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 1260.76 | loss 42.74 | ppl 3648812398518492672.00
474- # [RANK 1]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 1260.76 | loss 41.51 | ppl 1064844757565813248.00
475- # [RANK 0]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 1246.80 | loss 41.85 | ppl 1497706388552644096.00
476- # [RANK 1]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 1246.80 | loss 40.46 | ppl 373830103285747072.00
477- # [RANK 0]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 1246.69 | loss 39.76 | ppl 185159839078666368.00
478- # [RANK 1]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 1246.69 | loss 39.89 | ppl 211756997625874912.00
479- # [RANK 0]: -----------------------------------------------------------------------------------------
480- # [RANK 0]: | end of epoch 1 | time: 69.37s | valid loss 2.92 | valid ppl 18.46
481- # [RANK 0]: -----------------------------------------------------------------------------------------
465+ # [RANK 0]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 778.97 | loss 43.31 | ppl 6432469059895903232.00
466+ # [RANK 1]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 778.90 | loss 44.50 | ppl 21245447128217366528.00
467+ # [RANK 0]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 699.89 | loss 44.50 | ppl 21176949187407757312.00
468+ # [RANK 1]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 699.87 | loss 44.62 | ppl 23975861229620961280.00
469+ # [RANK 0]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 698.86 | loss 41.62 | ppl 1193312915629888256.00
470+ # [RANK 1]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 698.87 | loss 40.69 | ppl 471605759847546240.00
471+ # [RANK 0]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 698.34 | loss 45.20 | ppl 42812308420836458496.00
472+ # [RANK 1]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 698.33 | loss 45.68 | ppl 68839569686012223488.00
482473# [RANK 1]: -----------------------------------------------------------------------------------------
483- # [RANK 1]: | end of epoch 1 | time: 69.39s | valid loss 2.92 | valid ppl 18.46
474+ # [RANK 1]: | end of epoch 1 | time: 40.08s | valid loss 0.80 | valid ppl 2.22
484475# [RANK 1]: -----------------------------------------------------------------------------------------
485- # [RANK 1]: | epoch 2 | 10/ 50 batches | lr 4.51 | ms/batch 1373.91 | loss 39.77 | ppl 187532281612905856.00
486- # [RANK 0]: | epoch 2 | 10/ 50 batches | lr 4.51 | ms/batch 1375.62 | loss 39.05 | ppl 91344349371016336.00
487- # [RANK 0]: | epoch 2 | 20/ 50 batches | lr 4.51 | ms/batch 1250.33 | loss 30.62 | ppl 19917977906884.78
488- # [RANK 1]: | epoch 2 | 20/ 50 batches | lr 4.51 | ms/batch 1250.33 | loss 30.48 | ppl 17250186491252.32
489- # [RANK 1]: | epoch 2 | 30/ 50 batches | lr 4.51 | ms/batch 1250.73 | loss 29.14 | ppl 4534527326854.47
490- # [RANK 0]: | epoch 2 | 30/ 50 batches | lr 4.51 | ms/batch 1250.73 | loss 29.43 | ppl 6035762659681.65
491- # [RANK 0]: | epoch 2 | 40/ 50 batches | lr 4.51 | ms/batch 1249.54 | loss 23.11 | ppl 10869828323.89
492- # [RANK 1]: | epoch 2 | 40/ 50 batches | lr 4.51 | ms/batch 1249.55 | loss 22.90 | ppl 8785318464.24
493476# [RANK 0]: -----------------------------------------------------------------------------------------
494- # [RANK 0]: | end of epoch 2 | time: 69.02s | valid loss 0.94 | valid ppl 2.55
477+ # [RANK 0]: | end of epoch 1 | time: 40.09s | valid loss 0.80 | valid ppl 2.22
478+ # [RANK 0]: -----------------------------------------------------------------------------------------
479+ # [RANK 0]: | epoch 2 | 10/ 50 batches | lr 4.75 | ms/batch 768.51 | loss 36.34 | ppl 6063529544668166.00
480+ # [RANK 1]: | epoch 2 | 10/ 50 batches | lr 4.75 | ms/batch 769.23 | loss 37.41 | ppl 17651211266236086.00
481+ # [RANK 0]: | epoch 2 | 20/ 50 batches | lr 4.75 | ms/batch 699.57 | loss 28.97 | ppl 3798441739584.11
482+ # [RANK 1]: | epoch 2 | 20/ 50 batches | lr 4.75 | ms/batch 699.56 | loss 29.28 | ppl 5203636967575.47
483+ # [RANK 0]: | epoch 2 | 30/ 50 batches | lr 4.75 | ms/batch 699.04 | loss 28.43 | ppl 2212498693571.25
484+ # [RANK 1]: | epoch 2 | 30/ 50 batches | lr 4.75 | ms/batch 699.05 | loss 28.33 | ppl 2015144761281.48
485+ # [RANK 0]: | epoch 2 | 40/ 50 batches | lr 4.75 | ms/batch 699.10 | loss 23.30 | ppl 13121380184.92
486+ # [RANK 1]: | epoch 2 | 40/ 50 batches | lr 4.75 | ms/batch 699.09 | loss 23.41 | ppl 14653799192.87
487+ # [RANK 0]: -----------------------------------------------------------------------------------------
488+ # [RANK 0]: | end of epoch 2 | time: 39.97s | valid loss 0.24 | valid ppl 1.27
495489# [RANK 0]: -----------------------------------------------------------------------------------------
496490# [RANK 1]: -----------------------------------------------------------------------------------------
497- # [RANK 1]: | end of epoch 2 | time: 69.05s | valid loss 0.94 | valid ppl 2.55
491+ # [RANK 1]: | end of epoch 2 | time: 39.98s | valid loss 0.24 | valid ppl 1.27
498492# [RANK 1]: -----------------------------------------------------------------------------------------
499- # [RANK 0]: | epoch 3 | 10/ 50 batches | lr 4.29 | ms/batch 1380.66 | loss 12.98 | ppl 434052.59
500- # [RANK 1]: | epoch 3 | 10/ 50 batches | lr 4.29 | ms/batch 1376.47 | loss 12.92 | ppl 410203.33
501- # [RANK 1 ]: | epoch 3 | 20/ 50 batches | lr 4.29 | ms/batch 1250.88 | loss 9.80 | ppl 18034.58
502- # [RANK 0 ]: | epoch 3 | 20/ 50 batches | lr 4.29 | ms/batch 1250.88 | loss 9.78 | ppl 17741.88
503- # [RANK 0]: | epoch 3 | 30/ 50 batches | lr 4.29 | ms/batch 1251.89 | loss 10.37 | ppl 32016.45
504- # [RANK 1]: | epoch 3 | 30/ 50 batches | lr 4.29 | ms/batch 1251.90 | loss 10.46 | ppl 34735.08
505- # [RANK 0]: | epoch 3 | 40/ 50 batches | lr 4.29 | ms/batch 1250.70 | loss 10.09 | ppl 24147.61
506- # [RANK 1]: | epoch 3 | 40/ 50 batches | lr 4.29 | ms/batch 1250.71 | loss 10.08 | ppl 23748.31
493+ # [RANK 0]: | epoch 3 | 10/ 50 batches | lr 4.51 | ms/batch 769.36 | loss 12.80 | ppl 361681.11
494+ # [RANK 1]: | epoch 3 | 10/ 50 batches | lr 4.51 | ms/batch 768.97 | loss 12.57 | ppl 287876.61
495+ # [RANK 0 ]: | epoch 3 | 20/ 50 batches | lr 4.51 | ms/batch 698.27 | loss 12.01 | ppl 164364.60
496+ # [RANK 1 ]: | epoch 3 | 20/ 50 batches | lr 4.51 | ms/batch 698.30 | loss 11.98 | ppl 159095.89
497+ # [RANK 0]: | epoch 3 | 30/ 50 batches | lr 4.51 | ms/batch 697.75 | loss 10.90 | ppl 54261.91
498+ # [RANK 1]: | epoch 3 | 30/ 50 batches | lr 4.51 | ms/batch 697.72 | loss 10.89 | ppl 53372.39
499+ # [RANK 0]: | epoch 3 | 40/ 50 batches | lr 4.51 | ms/batch 699.49 | loss 10.78 | ppl 47948.35
500+ # [RANK 1]: | epoch 3 | 40/ 50 batches | lr 4.51 | ms/batch 699.50 | loss 10.79 | ppl 48664.42
507501# [RANK 0]: -----------------------------------------------------------------------------------------
508- # [RANK 0]: | end of epoch 3 | time: 69.12s | valid loss 0.69 | valid ppl 2.00
502+ # [RANK 0]: | end of epoch 3 | time: 39.96s | valid loss 0.38 | valid ppl 1.46
509503# [RANK 0]: -----------------------------------------------------------------------------------------
510504# [RANK 1]: -----------------------------------------------------------------------------------------
511- # [RANK 1]: | end of epoch 3 | time: 69.12s | valid loss 0.69 | valid ppl 2.00
505+ # [RANK 1]: | end of epoch 3 | time: 39.96s | valid loss 0.38 | valid ppl 1.46
512506# [RANK 1]: -----------------------------------------------------------------------------------------
513507# [RANK 0]: =========================================================================================
514- # [RANK 0]: | End of training | test loss 0.60 | test ppl 1.83
508+ # [RANK 0]: | End of training | test loss 0.33 | test ppl 1.39
515509# [RANK 0]: =========================================================================================
516510# [RANK 1]: =========================================================================================
517- # [RANK 1]: | End of training | test loss 0.60 | test ppl 1.83
511+ # [RANK 1]: | End of training | test loss 0.33 | test ppl 1.39
512+ # [RANK 1]: =========================================================================================
513+ #
0 commit comments