@@ -148,27 +148,24 @@ def forward(self, x):
148148# efficient batch processing.
149149#
150150
151- import io
152151import torch
153- from torchtext .utils import download_from_url , extract_archive
152+ from torchtext .datasets import WikiText2
154153from torchtext .data .utils import get_tokenizer
155154from torchtext .vocab import build_vocab_from_iterator
156155
157- url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
158- test_filepath , valid_filepath , train_filepath = extract_archive (download_from_url (url ))
156+ train_iter = WikiText2 (split = 'train' )
159157tokenizer = get_tokenizer ('basic_english' )
160- vocab = build_vocab_from_iterator (map (tokenizer ,
161- iter (io .open (train_filepath ,
162- encoding = "utf8" ))))
158+ vocab = build_vocab_from_iterator (map (tokenizer , train_iter ), specials = ["<unk>" ])
159+ vocab .set_default_index (vocab ["<unk>" ])
163160
164161def data_process (raw_text_iter ):
165- data = [torch .tensor ([vocab [token ] for token in tokenizer (item )],
166- dtype = torch .long ) for item in raw_text_iter ]
162+ data = [torch .tensor (vocab (tokenizer (item )), dtype = torch .long ) for item in raw_text_iter ]
167163 return torch .cat (tuple (filter (lambda t : t .numel () > 0 , data )))
168164
169- train_data = data_process (iter (io .open (train_filepath , encoding = "utf8" )))
170- val_data = data_process (iter (io .open (valid_filepath , encoding = "utf8" )))
171- test_data = data_process (iter (io .open (test_filepath , encoding = "utf8" )))
165+ train_iter , val_iter , test_iter = WikiText2 ()
166+ train_data = data_process (train_iter )
167+ val_data = data_process (val_iter )
168+ test_data = data_process (test_iter )
172169
173170device = torch .device ("cuda" )
174171
@@ -244,7 +241,7 @@ def get_batch(source, i):
244241# allows the Pipe to work with only two partitions and avoid any
245242# cross-partition overheads.
246243
247- ntokens = len (vocab . stoi ) # the size of vocabulary
244+ ntokens = len (vocab ) # the size of vocabulary
248245emsize = 4096 # embedding dimension
249246nhid = 4096 # the dimension of the feedforward network model in nn.TransformerEncoder
250247nlayers = 12 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
@@ -330,7 +327,7 @@ def train():
330327 model .train () # Turn on the train mode
331328 total_loss = 0.
332329 start_time = time .time ()
333- ntokens = len (vocab . stoi )
330+ ntokens = len (vocab )
334331
335332 # Train only for 50 batches to keep script execution time low.
336333 nbatches = min (50 * bptt , train_data .size (0 ) - 1 )
@@ -366,7 +363,7 @@ def train():
366363def evaluate (eval_model , data_source ):
367364 eval_model .eval () # Turn on the evaluation mode
368365 total_loss = 0.
369- ntokens = len (vocab . stoi )
366+ ntokens = len (vocab )
370367 # Evaluate only for 50 batches to keep script execution time low.
371368 nbatches = min (50 * bptt , data_source .size (0 ) - 1 )
372369 with torch .no_grad ():
@@ -418,39 +415,3 @@ def evaluate(eval_model, data_source):
418415print ('| End of training | test loss {:5.2f} | test ppl {:8.2f}' .format (
419416 test_loss , math .exp (test_loss )))
420417print ('=' * 89 )
421-
422-
423- ######################################################################
424- # Output
425- # ------
426- #
427-
428-
429- ######################################################################
430- #.. code-block:: py
431- #
432- # Total parameters in model: 1,847,087,215
433- # | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 2387.45 | loss 42.16 | ppl 2036775646369743616.00
434- # | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 2150.93 | loss 48.24 | ppl 891334049215401558016.00
435- # | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 2155.23 | loss 34.66 | ppl 1125676483188404.62
436- # | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 2158.42 | loss 38.87 | ppl 76287208340888368.00
437- # -----------------------------------------------------------------------------------------
438- # | end of epoch 1 | time: 119.65s | valid loss 2.95 | valid ppl 19.15
439- # -----------------------------------------------------------------------------------------
440- # | epoch 2 | 10/ 50 batches | lr 4.51 | ms/batch 2376.16 | loss 34.92 | ppl 1458001430957104.00
441- # | epoch 2 | 20/ 50 batches | lr 4.51 | ms/batch 2160.96 | loss 34.75 | ppl 1232463826541886.50
442- # | epoch 2 | 30/ 50 batches | lr 4.51 | ms/batch 2160.66 | loss 28.10 | ppl 1599598251136.51
443- # | epoch 2 | 40/ 50 batches | lr 4.51 | ms/batch 2160.07 | loss 20.25 | ppl 621174306.77
444- # -----------------------------------------------------------------------------------------
445- # | end of epoch 2 | time: 119.76s | valid loss 0.87 | valid ppl 2.38
446- # -----------------------------------------------------------------------------------------
447- # | epoch 3 | 10/ 50 batches | lr 4.29 | ms/batch 2376.49 | loss 13.20 | ppl 537727.23
448- # | epoch 3 | 20/ 50 batches | lr 4.29 | ms/batch 2160.12 | loss 10.98 | ppl 58548.58
449- # | epoch 3 | 30/ 50 batches | lr 4.29 | ms/batch 2160.05 | loss 12.01 | ppl 164152.79
450- # | epoch 3 | 40/ 50 batches | lr 4.29 | ms/batch 2160.03 | loss 10.63 | ppl 41348.00
451- # -----------------------------------------------------------------------------------------
452- # | end of epoch 3 | time: 119.76s | valid loss 0.78 | valid ppl 2.17
453- # -----------------------------------------------------------------------------------------
454- # =========================================================================================
455- # | End of training | test loss 0.69 | test ppl 1.99
456- # =========================================================================================
0 commit comments