Skip to content

Commit 7431f04

Browse files
authored
Merge branch 'master' into 1.9-RC-TEST
2 parents ad99a81 + 973193b commit 7431f04

File tree

3 files changed

+322
-343
lines changed

3 files changed

+322
-343
lines changed

advanced_source/ddp_pipeline.py

Lines changed: 50 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -169,26 +169,24 @@ def run_worker(rank, world_size):
169169
def print_with_rank(msg):
170170
print('[RANK {}]: {}'.format(rank, msg))
171171

172-
import io
173-
from torchtext.utils import download_from_url, extract_archive
172+
from torchtext.datasets import WikiText2
174173
from torchtext.data.utils import get_tokenizer
175174
from torchtext.vocab import build_vocab_from_iterator
176175

177-
url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
178-
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root=".data{}".format(rank)))
176+
train_iter = WikiText2(split='train')
179177
tokenizer = get_tokenizer('basic_english')
180-
vocab = build_vocab_from_iterator(map(tokenizer,
181-
iter(io.open(train_filepath,
182-
encoding="utf8"))))
178+
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
179+
vocab.set_default_index(vocab["<unk>"])
183180

184181
def data_process(raw_text_iter):
185-
data = [torch.tensor([vocab[token] for token in tokenizer(item)],
186-
dtype=torch.long) for item in raw_text_iter]
182+
data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
187183
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
188184

189-
train_data = data_process(iter(io.open(train_filepath, encoding="utf8")))
190-
val_data = data_process(iter(io.open(valid_filepath, encoding="utf8")))
191-
test_data = data_process(iter(io.open(test_filepath, encoding="utf8")))
185+
train_iter, val_iter, test_iter = WikiText2()
186+
train_data = data_process(train_iter)
187+
val_data = data_process(val_iter)
188+
test_data = data_process(test_iter)
189+
192190
device = torch.device(2 * rank)
193191

194192
def batchify(data, bsz, rank, world_size, is_train=False):
@@ -264,7 +262,7 @@ def get_batch(source, i):
264262
# another across GPUs 2 and 3. Both pipes are then replicated using DistributedDataParallel.
265263

266264
# In 'run_worker'
267-
ntokens = len(vocab.stoi) # the size of vocabulary
265+
ntokens = len(vocab) # the size of vocabulary
268266
emsize = 4096 # embedding dimension
269267
nhid = 4096 # the dimension of the feedforward network model in nn.TransformerEncoder
270268
nlayers = 8 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
@@ -361,7 +359,7 @@ def train():
361359
model.train() # Turn on the train mode
362360
total_loss = 0.
363361
start_time = time.time()
364-
ntokens = len(vocab.stoi)
362+
ntokens = len(vocab)
365363

366364
# Train only for 50 batches to keep script execution time low.
367365
nbatches = min(50 * bptt, train_data.size(0) - 1)
@@ -388,7 +386,7 @@ def train():
388386
print_with_rank('| epoch {:3d} | {:5d}/{:5d} batches | '
389387
'lr {:02.2f} | ms/batch {:5.2f} | '
390388
'loss {:5.2f} | ppl {:8.2f}'.format(
391-
epoch, batch, nbatches // bptt, scheduler.get_lr()[0],
389+
epoch, batch, nbatches // bptt, scheduler.get_last_lr()[0],
392390
elapsed * 1000 / log_interval,
393391
cur_loss, math.exp(cur_loss)))
394392
total_loss = 0
@@ -397,7 +395,7 @@ def train():
397395
def evaluate(eval_model, data_source):
398396
eval_model.eval() # Turn on the evaluation mode
399397
total_loss = 0.
400-
ntokens = len(vocab.stoi)
398+
ntokens = len(vocab)
401399
# Evaluate only for 50 batches to keep script execution time low.
402400
nbatches = min(50 * bptt, data_source.size(0) - 1)
403401
with torch.no_grad():
@@ -455,8 +453,6 @@ def evaluate(eval_model, data_source):
455453
if __name__=="__main__":
456454
world_size = 2
457455
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)
458-
459-
460456
######################################################################
461457
# Output
462458
# ------
@@ -466,52 +462,52 @@ def evaluate(eval_model, data_source):
466462
######################################################################
467463
#.. code-block:: py
468464
#
469-
# [RANK 1]: Total parameters in model: 1,041,453,167
470-
# [RANK 0]: Total parameters in model: 1,041,453,167
471-
# [RANK 0]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 1414.18 | loss 48.70 | ppl 1406154472673147092992.00
472-
# [RANK 1]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 1414.42 | loss 48.49 | ppl 1146707511057334927360.00
473-
# [RANK 0]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 1260.76 | loss 42.74 | ppl 3648812398518492672.00
474-
# [RANK 1]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 1260.76 | loss 41.51 | ppl 1064844757565813248.00
475-
# [RANK 0]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 1246.80 | loss 41.85 | ppl 1497706388552644096.00
476-
# [RANK 1]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 1246.80 | loss 40.46 | ppl 373830103285747072.00
477-
# [RANK 0]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 1246.69 | loss 39.76 | ppl 185159839078666368.00
478-
# [RANK 1]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 1246.69 | loss 39.89 | ppl 211756997625874912.00
479-
# [RANK 0]: -----------------------------------------------------------------------------------------
480-
# [RANK 0]: | end of epoch 1 | time: 69.37s | valid loss 2.92 | valid ppl 18.46
481-
# [RANK 0]: -----------------------------------------------------------------------------------------
465+
# [RANK 0]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 778.97 | loss 43.31 | ppl 6432469059895903232.00
466+
# [RANK 1]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 778.90 | loss 44.50 | ppl 21245447128217366528.00
467+
# [RANK 0]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 699.89 | loss 44.50 | ppl 21176949187407757312.00
468+
# [RANK 1]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 699.87 | loss 44.62 | ppl 23975861229620961280.00
469+
# [RANK 0]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 698.86 | loss 41.62 | ppl 1193312915629888256.00
470+
# [RANK 1]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 698.87 | loss 40.69 | ppl 471605759847546240.00
471+
# [RANK 0]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 698.34 | loss 45.20 | ppl 42812308420836458496.00
472+
# [RANK 1]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 698.33 | loss 45.68 | ppl 68839569686012223488.00
482473
# [RANK 1]: -----------------------------------------------------------------------------------------
483-
# [RANK 1]: | end of epoch 1 | time: 69.39s | valid loss 2.92 | valid ppl 18.46
474+
# [RANK 1]: | end of epoch 1 | time: 40.08s | valid loss 0.80 | valid ppl 2.22
484475
# [RANK 1]: -----------------------------------------------------------------------------------------
485-
# [RANK 1]: | epoch 2 | 10/ 50 batches | lr 4.51 | ms/batch 1373.91 | loss 39.77 | ppl 187532281612905856.00
486-
# [RANK 0]: | epoch 2 | 10/ 50 batches | lr 4.51 | ms/batch 1375.62 | loss 39.05 | ppl 91344349371016336.00
487-
# [RANK 0]: | epoch 2 | 20/ 50 batches | lr 4.51 | ms/batch 1250.33 | loss 30.62 | ppl 19917977906884.78
488-
# [RANK 1]: | epoch 2 | 20/ 50 batches | lr 4.51 | ms/batch 1250.33 | loss 30.48 | ppl 17250186491252.32
489-
# [RANK 1]: | epoch 2 | 30/ 50 batches | lr 4.51 | ms/batch 1250.73 | loss 29.14 | ppl 4534527326854.47
490-
# [RANK 0]: | epoch 2 | 30/ 50 batches | lr 4.51 | ms/batch 1250.73 | loss 29.43 | ppl 6035762659681.65
491-
# [RANK 0]: | epoch 2 | 40/ 50 batches | lr 4.51 | ms/batch 1249.54 | loss 23.11 | ppl 10869828323.89
492-
# [RANK 1]: | epoch 2 | 40/ 50 batches | lr 4.51 | ms/batch 1249.55 | loss 22.90 | ppl 8785318464.24
493476
# [RANK 0]: -----------------------------------------------------------------------------------------
494-
# [RANK 0]: | end of epoch 2 | time: 69.02s | valid loss 0.94 | valid ppl 2.55
477+
# [RANK 0]: | end of epoch 1 | time: 40.09s | valid loss 0.80 | valid ppl 2.22
478+
# [RANK 0]: -----------------------------------------------------------------------------------------
479+
# [RANK 0]: | epoch 2 | 10/ 50 batches | lr 4.75 | ms/batch 768.51 | loss 36.34 | ppl 6063529544668166.00
480+
# [RANK 1]: | epoch 2 | 10/ 50 batches | lr 4.75 | ms/batch 769.23 | loss 37.41 | ppl 17651211266236086.00
481+
# [RANK 0]: | epoch 2 | 20/ 50 batches | lr 4.75 | ms/batch 699.57 | loss 28.97 | ppl 3798441739584.11
482+
# [RANK 1]: | epoch 2 | 20/ 50 batches | lr 4.75 | ms/batch 699.56 | loss 29.28 | ppl 5203636967575.47
483+
# [RANK 0]: | epoch 2 | 30/ 50 batches | lr 4.75 | ms/batch 699.04 | loss 28.43 | ppl 2212498693571.25
484+
# [RANK 1]: | epoch 2 | 30/ 50 batches | lr 4.75 | ms/batch 699.05 | loss 28.33 | ppl 2015144761281.48
485+
# [RANK 0]: | epoch 2 | 40/ 50 batches | lr 4.75 | ms/batch 699.10 | loss 23.30 | ppl 13121380184.92
486+
# [RANK 1]: | epoch 2 | 40/ 50 batches | lr 4.75 | ms/batch 699.09 | loss 23.41 | ppl 14653799192.87
487+
# [RANK 0]: -----------------------------------------------------------------------------------------
488+
# [RANK 0]: | end of epoch 2 | time: 39.97s | valid loss 0.24 | valid ppl 1.27
495489
# [RANK 0]: -----------------------------------------------------------------------------------------
496490
# [RANK 1]: -----------------------------------------------------------------------------------------
497-
# [RANK 1]: | end of epoch 2 | time: 69.05s | valid loss 0.94 | valid ppl 2.55
491+
# [RANK 1]: | end of epoch 2 | time: 39.98s | valid loss 0.24 | valid ppl 1.27
498492
# [RANK 1]: -----------------------------------------------------------------------------------------
499-
# [RANK 0]: | epoch 3 | 10/ 50 batches | lr 4.29 | ms/batch 1380.66 | loss 12.98 | ppl 434052.59
500-
# [RANK 1]: | epoch 3 | 10/ 50 batches | lr 4.29 | ms/batch 1376.47 | loss 12.92 | ppl 410203.33
501-
# [RANK 1]: | epoch 3 | 20/ 50 batches | lr 4.29 | ms/batch 1250.88 | loss 9.80 | ppl 18034.58
502-
# [RANK 0]: | epoch 3 | 20/ 50 batches | lr 4.29 | ms/batch 1250.88 | loss 9.78 | ppl 17741.88
503-
# [RANK 0]: | epoch 3 | 30/ 50 batches | lr 4.29 | ms/batch 1251.89 | loss 10.37 | ppl 32016.45
504-
# [RANK 1]: | epoch 3 | 30/ 50 batches | lr 4.29 | ms/batch 1251.90 | loss 10.46 | ppl 34735.08
505-
# [RANK 0]: | epoch 3 | 40/ 50 batches | lr 4.29 | ms/batch 1250.70 | loss 10.09 | ppl 24147.61
506-
# [RANK 1]: | epoch 3 | 40/ 50 batches | lr 4.29 | ms/batch 1250.71 | loss 10.08 | ppl 23748.31
493+
# [RANK 0]: | epoch 3 | 10/ 50 batches | lr 4.51 | ms/batch 769.36 | loss 12.80 | ppl 361681.11
494+
# [RANK 1]: | epoch 3 | 10/ 50 batches | lr 4.51 | ms/batch 768.97 | loss 12.57 | ppl 287876.61
495+
# [RANK 0]: | epoch 3 | 20/ 50 batches | lr 4.51 | ms/batch 698.27 | loss 12.01 | ppl 164364.60
496+
# [RANK 1]: | epoch 3 | 20/ 50 batches | lr 4.51 | ms/batch 698.30 | loss 11.98 | ppl 159095.89
497+
# [RANK 0]: | epoch 3 | 30/ 50 batches | lr 4.51 | ms/batch 697.75 | loss 10.90 | ppl 54261.91
498+
# [RANK 1]: | epoch 3 | 30/ 50 batches | lr 4.51 | ms/batch 697.72 | loss 10.89 | ppl 53372.39
499+
# [RANK 0]: | epoch 3 | 40/ 50 batches | lr 4.51 | ms/batch 699.49 | loss 10.78 | ppl 47948.35
500+
# [RANK 1]: | epoch 3 | 40/ 50 batches | lr 4.51 | ms/batch 699.50 | loss 10.79 | ppl 48664.42
507501
# [RANK 0]: -----------------------------------------------------------------------------------------
508-
# [RANK 0]: | end of epoch 3 | time: 69.12s | valid loss 0.69 | valid ppl 2.00
502+
# [RANK 0]: | end of epoch 3 | time: 39.96s | valid loss 0.38 | valid ppl 1.46
509503
# [RANK 0]: -----------------------------------------------------------------------------------------
510504
# [RANK 1]: -----------------------------------------------------------------------------------------
511-
# [RANK 1]: | end of epoch 3 | time: 69.12s | valid loss 0.69 | valid ppl 2.00
505+
# [RANK 1]: | end of epoch 3 | time: 39.96s | valid loss 0.38 | valid ppl 1.46
512506
# [RANK 1]: -----------------------------------------------------------------------------------------
513507
# [RANK 0]: =========================================================================================
514-
# [RANK 0]: | End of training | test loss 0.60 | test ppl 1.83
508+
# [RANK 0]: | End of training | test loss 0.33 | test ppl 1.39
515509
# [RANK 0]: =========================================================================================
516510
# [RANK 1]: =========================================================================================
517-
# [RANK 1]: | End of training | test loss 0.60 | test ppl 1.83
511+
# [RANK 1]: | End of training | test loss 0.33 | test ppl 1.39
512+
# [RANK 1]: =========================================================================================
513+
#

0 commit comments

Comments
 (0)