Skip to content

Commit 5416309

Browse files
authored
updated text sentiment tutorial (#1563)
1 parent 5b39f4b commit 5416309

File tree

1 file changed

+17
-91
lines changed

1 file changed

+17
-91
lines changed

beginner_source/text_sentiment_ngrams_tutorial.py

Lines changed: 17 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -49,32 +49,35 @@
4949
#
5050
# We have revisited the very basic components of the torchtext library, including vocab, word vectors, tokenizer. Those are the basic data processing building blocks for raw text string.
5151
#
52-
# Here is an example for typical NLP data processing with tokenizer and vocabulary. The first step is to build a vocabulary with the raw training dataset. Users can have a customized vocab by setting up arguments in the constructor of the Vocab class. For example, the minimum frequency ``min_freq`` for the tokens to be included.
52+
# Here is an example for typical NLP data processing with tokenizer and vocabulary. The first step is to build a vocabulary with the raw training dataset. Here we use built in
53+
# factory function `build_vocab_from_iterator` which accepts iterator that yield list or iterator of tokens. Users can also pass any special symbols to be added to the
54+
# vocabulary.
5355

5456

5557
from torchtext.data.utils import get_tokenizer
56-
from collections import Counter
57-
from torchtext.vocab import Vocab
58+
from torchtext.vocab import build_vocab_from_iterator
5859

5960
tokenizer = get_tokenizer('basic_english')
6061
train_iter = AG_NEWS(split='train')
61-
counter = Counter()
62-
for (label, line) in train_iter:
63-
counter.update(tokenizer(line))
64-
vocab = Vocab(counter, min_freq=1)
6562

63+
def yield_tokens(data_iter):
64+
for _, text in data_iter:
65+
yield tokenizer(text)
66+
67+
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
68+
vocab.set_default_index(vocab["<unk>"])
6669

6770
######################################################################
6871
# The vocabulary block converts a list of tokens into integers.
6972
#
7073
# ::
7174
#
72-
# [vocab[token] for token in ['here', 'is', 'an', 'example']]
73-
# >>> [476, 22, 31, 5298]
75+
# vocab(['here', 'is', 'an', 'example'])
76+
# >>> [475, 21, 30, 5286]
7477
#
7578
# Prepare the text processing pipeline with the tokenizer and vocabulary. The text and label pipelines will be used to process the raw data strings from the dataset iterators.
7679

77-
text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]
80+
text_pipeline = lambda x: vocab(tokenizer(x))
7881
label_pipeline = lambda x: int(x) - 1
7982

8083

@@ -246,6 +249,7 @@ def evaluate(dataloader):
246249

247250

248251
from torch.utils.data.dataset import random_split
252+
from torchtext.data.functional import to_map_style_dataset
249253
# Hyperparameters
250254
EPOCHS = 10 # epoch
251255
LR = 5 # learning rate
@@ -256,8 +260,8 @@ def evaluate(dataloader):
256260
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
257261
total_accu = None
258262
train_iter, test_iter = AG_NEWS()
259-
train_dataset = list(train_iter)
260-
test_dataset = list(test_iter)
263+
train_dataset = to_map_style_dataset(train_iter)
264+
test_dataset = to_map_style_dataset(test_iter)
261265
num_train = int(len(train_dataset) * 0.95)
262266
split_train_, split_valid_ = \
263267
random_split(train_dataset, [num_train, len(train_dataset) - num_train])
@@ -285,72 +289,6 @@ def evaluate(dataloader):
285289
print('-' * 59)
286290

287291

288-
######################################################################
289-
# Running the model on GPU with the following printout:
290-
#
291-
# ::
292-
#
293-
# | epoch 1 | 500/ 1782 batches | accuracy 0.684
294-
# | epoch 1 | 1000/ 1782 batches | accuracy 0.852
295-
# | epoch 1 | 1500/ 1782 batches | accuracy 0.877
296-
# -----------------------------------------------------------
297-
# | end of epoch 1 | time: 8.33s | valid accuracy 0.867
298-
# -----------------------------------------------------------
299-
# | epoch 2 | 500/ 1782 batches | accuracy 0.895
300-
# | epoch 2 | 1000/ 1782 batches | accuracy 0.900
301-
# | epoch 2 | 1500/ 1782 batches | accuracy 0.903
302-
# -----------------------------------------------------------
303-
# | end of epoch 2 | time: 8.18s | valid accuracy 0.890
304-
# -----------------------------------------------------------
305-
# | epoch 3 | 500/ 1782 batches | accuracy 0.914
306-
# | epoch 3 | 1000/ 1782 batches | accuracy 0.914
307-
# | epoch 3 | 1500/ 1782 batches | accuracy 0.916
308-
# -----------------------------------------------------------
309-
# | end of epoch 3 | time: 8.20s | valid accuracy 0.897
310-
# -----------------------------------------------------------
311-
# | epoch 4 | 500/ 1782 batches | accuracy 0.926
312-
# | epoch 4 | 1000/ 1782 batches | accuracy 0.924
313-
# | epoch 4 | 1500/ 1782 batches | accuracy 0.921
314-
# -----------------------------------------------------------
315-
# | end of epoch 4 | time: 8.18s | valid accuracy 0.895
316-
# -----------------------------------------------------------
317-
# | epoch 5 | 500/ 1782 batches | accuracy 0.938
318-
# | epoch 5 | 1000/ 1782 batches | accuracy 0.935
319-
# | epoch 5 | 1500/ 1782 batches | accuracy 0.937
320-
# -----------------------------------------------------------
321-
# | end of epoch 5 | time: 8.16s | valid accuracy 0.902
322-
# -----------------------------------------------------------
323-
# | epoch 6 | 500/ 1782 batches | accuracy 0.939
324-
# | epoch 6 | 1000/ 1782 batches | accuracy 0.939
325-
# | epoch 6 | 1500/ 1782 batches | accuracy 0.938
326-
# -----------------------------------------------------------
327-
# | end of epoch 6 | time: 8.16s | valid accuracy 0.906
328-
# -----------------------------------------------------------
329-
# | epoch 7 | 500/ 1782 batches | accuracy 0.941
330-
# | epoch 7 | 1000/ 1782 batches | accuracy 0.939
331-
# | epoch 7 | 1500/ 1782 batches | accuracy 0.939
332-
# -----------------------------------------------------------
333-
# | end of epoch 7 | time: 8.19s | valid accuracy 0.903
334-
# -----------------------------------------------------------
335-
# | epoch 8 | 500/ 1782 batches | accuracy 0.942
336-
# | epoch 8 | 1000/ 1782 batches | accuracy 0.941
337-
# | epoch 8 | 1500/ 1782 batches | accuracy 0.942
338-
# -----------------------------------------------------------
339-
# | end of epoch 8 | time: 8.16s | valid accuracy 0.904
340-
# -----------------------------------------------------------
341-
# | epoch 9 | 500/ 1782 batches | accuracy 0.942
342-
# | epoch 9 | 1000/ 1782 batches | accuracy 0.941
343-
# | epoch 9 | 1500/ 1782 batches | accuracy 0.942
344-
# -----------------------------------------------------------
345-
# end of epoch 9 | time: 8.16s | valid accuracy 0.904
346-
# -----------------------------------------------------------
347-
# | epoch 10 | 500/ 1782 batches | accuracy 0.940
348-
# | epoch 10 | 1000/ 1782 batches | accuracy 0.942
349-
# | epoch 10 | 1500/ 1782 batches | accuracy 0.942
350-
# -----------------------------------------------------------
351-
# | end of epoch 10 | time: 8.15s | valid accuracy 0.904
352-
# -----------------------------------------------------------
353-
354292

355293
######################################################################
356294
# Evaluate the model with test dataset
@@ -366,12 +304,7 @@ def evaluate(dataloader):
366304
accu_test = evaluate(test_dataloader)
367305
print('test accuracy {:8.3f}'.format(accu_test))
368306

369-
################################################
370-
#
371-
# ::
372-
#
373-
# test accuracy 0.906
374-
#
307+
375308

376309

377310
######################################################################
@@ -409,10 +342,3 @@ def predict(text, text_pipeline):
409342

410343
print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)])
411344

412-
413-
################################################
414-
#
415-
# ::
416-
#
417-
# This is a Sports news
418-
#

0 commit comments

Comments
 (0)