Skip to content

Commit e87a1e1

Browse files
authored
Update pipeline_tutorial.py
Use new vocab API defined in pytorch/text#1289
1 parent 9259e7b commit e87a1e1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

intermediate_source/pipeline_tutorial.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def get_batch(source, i):
244244
# allows the Pipe to work with only two partitions and avoid any
245245
# cross-partition overheads.
246246

247-
ntokens = len(vocab.stoi) # the size of vocabulary
247+
ntokens = len(vocab.get_stoi()) # the size of vocabulary
248248
emsize = 4096 # embedding dimension
249249
nhid = 4096 # the dimension of the feedforward network model in nn.TransformerEncoder
250250
nlayers = 12 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
@@ -330,7 +330,7 @@ def train():
330330
model.train() # Turn on the train mode
331331
total_loss = 0.
332332
start_time = time.time()
333-
ntokens = len(vocab.stoi)
333+
ntokens = len(vocab.get_stoi())
334334

335335
# Train only for 50 batches to keep script execution time low.
336336
nbatches = min(50 * bptt, train_data.size(0) - 1)
@@ -366,7 +366,7 @@ def train():
366366
def evaluate(eval_model, data_source):
367367
eval_model.eval() # Turn on the evaluation mode
368368
total_loss = 0.
369-
ntokens = len(vocab.stoi)
369+
ntokens = len(vocab.get_stoi())
370370
# Evaluate only for 50 batches to keep script execution time low.
371371
nbatches = min(50 * bptt, data_source.size(0) - 1)
372372
with torch.no_grad():

0 commit comments

Comments
 (0)