Skip to content

Commit 7915fef

Browse files
authored
updated transformer tutorial (#1565)
1 parent 5416309 commit 7915fef

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

beginner_source/transformer_tutorial.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
"""
2-
Sequence-to-Sequence Modeling with nn.Transformer and TorchText
2+
Language Modeling with nn.Transformer and TorchText
33
===============================================================
44
55
This is a tutorial on how to train a sequence-to-sequence model
66
that uses the
7-
`nn.Transformer <https://pytorch.org/docs/master/nn.html?highlight=nn%20transformer#torch.nn.Transformer>`__ module.
7+
`nn.Transformer <https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`__ module.
88
99
PyTorch 1.2 release includes a standard transformer module based on the
1010
paper `Attention is All You
1111
Need <https://arxiv.org/pdf/1706.03762.pdf>`__. The transformer model
1212
has been proved to be superior in quality for many sequence-to-sequence
1313
problems while being more parallelizable. The ``nn.Transformer`` module
1414
relies entirely on an attention mechanism (another module recently
15-
implemented as `nn.MultiheadAttention <https://pytorch.org/docs/master/nn.html?highlight=multiheadattention#torch.nn.MultiheadAttention>`__) to draw global dependencies
15+
implemented as `nn.MultiheadAttention <https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html>`__) to draw global dependencies
1616
between input and output. The ``nn.Transformer`` module is now highly
17-
modularized such that a single component (like `nn.TransformerEncoder <https://pytorch.org/docs/master/nn.html?highlight=nn%20transformerencoder#torch.nn.TransformerEncoder>`__
17+
modularized such that a single component (like `nn.TransformerEncoder <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html>`__
1818
in this tutorial) can be easily adapted/composed.
1919
2020
.. image:: ../_static/img/transformer_architecture.jpg
@@ -35,7 +35,7 @@
3535
# layer first, followed by a positional encoding layer to account for the order
3636
# of the word (see the next paragraph for more details). The
3737
# ``nn.TransformerEncoder`` consists of multiple layers of
38-
# `nn.TransformerEncoderLayer <https://pytorch.org/docs/master/nn.html?highlight=transformerencoderlayer#torch.nn.TransformerEncoderLayer>`__. Along with the input sequence, a square
38+
# `nn.TransformerEncoderLayer <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html>`__. Along with the input sequence, a square
3939
# attention mask is required because the self-attention layers in
4040
# ``nn.TransformerEncoder`` are only allowed to attend the earlier positions in
4141
# the sequence. For the language modeling task, any tokens on the future
@@ -144,23 +144,18 @@ def forward(self, x):
144144
# efficient batch processing.
145145
#
146146

147-
import io
148147
import torch
149148
from torchtext.datasets import WikiText2
150149
from torchtext.data.utils import get_tokenizer
151-
from collections import Counter
152-
from torchtext.vocab import Vocab
150+
from torchtext.vocab import build_vocab_from_iterator
153151

154152
train_iter = WikiText2(split='train')
155153
tokenizer = get_tokenizer('basic_english')
156-
counter = Counter()
157-
for line in train_iter:
158-
counter.update(tokenizer(line))
159-
vocab = Vocab(counter)
154+
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
155+
vocab.set_default_index(vocab["<unk>"])
160156

161157
def data_process(raw_text_iter):
162-
data = [torch.tensor([vocab[token] for token in tokenizer(item)],
163-
dtype=torch.long) for item in raw_text_iter]
158+
data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
164159
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
165160

166161
train_iter, val_iter, test_iter = WikiText2()
@@ -225,7 +220,7 @@ def get_batch(source, i):
225220
# equal to the length of the vocab object.
226221
#
227222

228-
ntokens = len(vocab.stoi) # the size of vocabulary
223+
ntokens = len(vocab) # the size of vocabulary
229224
emsize = 200 # embedding dimension
230225
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
231226
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder

0 commit comments

Comments
 (0)