Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e22375e

Browse files
Import torchtext from github into fbcode on 1/11/2021
Reviewed By: cpuhrsch Differential Revision: D25873762 fbshipit-source-id: 0d34d36aeb8e7e2ce72fcf345c5e7e713ef3663c
1 parent d1686a9 commit e22375e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+530
-320
lines changed

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Note: we are currently re-designing the torchtext library to make it more compat
1919

2020
pip install --pre torch torchtext -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
2121

22-
For more detail instructions, please refer to `Install PyTorch <https://pytorch.org/get-started/locally/>`_. It should be noted that the new building blocks are still under development, and the APIs have not been solidified.
22+
For more detailed instructions, please refer to `Install PyTorch <https://pytorch.org/get-started/locally/>`_. It should be noted that the new building blocks are still under development, and the APIs have not been solidified.
2323

2424
Installation
2525
============
@@ -81,7 +81,7 @@ To build torchtext from source, you need ``git``, ``CMake`` and C++11 compiler s
8181
**Note**
8282

8383
When building from source, make sure that you have the same C++ compiler as the one used to build PyTorch. A simple way is to build PyTorch from source and use the same environment to build torchtext.
84-
If you are using nightly build of PyTorch, checkout the environment it was built `here (conda) <https://github.com/pytorch/builder/tree/master/conda>`_ and `here (pip) <https://github.com/pytorch/builder/tree/master/manywheel>`_.
84+
If you are using the nightly build of PyTorch, checkout the environment it was built with `conda (here) <https://github.com/pytorch/builder/tree/master/conda>`_ and `pip (here) <https://github.com/pytorch/builder/tree/master/manywheel>`_.
8585

8686
Documentation
8787
=============

examples/text_classification/iterable_train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def train_and_valid(lr_, num_epoch, train_data_, valid_data_):
6060
r"""
6161
Here we use SGD optimizer to train the model.
6262
63-
Arguments:
63+
Args:
6464
lr_: learning rate
6565
num_epoch: the number of epoches for training the model
6666
train_data_: the data used to train the model
@@ -108,7 +108,7 @@ def train_and_valid(lr_, num_epoch, train_data_, valid_data_):
108108

109109
def test(data_):
110110
r"""
111-
Arguments:
111+
Args:
112112
data_: the data used to train the model
113113
"""
114114
data = DataLoader(
@@ -137,7 +137,7 @@ def get_csv_iterator(data_path, ngrams, vocab, start=0, num_lines=None):
137137
Generate an iterator to read CSV file.
138138
The yield values are an integer for the label and a tensor for the text part.
139139
140-
Arguments:
140+
Args:
141141
data_path: a path for the data file.
142142
ngrams: the number used for ngrams.
143143
vocab: a vocab object saving the string-to-index information
@@ -171,7 +171,7 @@ class Dataset(torch.utils.data.IterableDataset):
171171
An iterable dataset to save the data. This dataset supports multi-processing
172172
to load the data.
173173
174-
Arguments:
174+
Args:
175175
iterator: the iterator to read data.
176176
num_lines: the number of lines read by the individual iterator.
177177
"""

examples/text_classification/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def init_weights(self):
3131

3232
def forward(self, text, offsets):
3333
r"""
34-
Arguments:
34+
Args:
3535
text: 1-D tensor representing a bag of text tensors
3636
offsets: a list of offsets to delimit the 1-D text tensor
3737
into the individual sequences.

examples/text_classification/predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def predict(text, model, dictionary, ngrams):
1111
The input text is numericalized with the vocab and then sent to
1212
the model for inference.
1313
14-
Arguments:
14+
Args:
1515
text: a sample text string
1616
model: the trained model
1717
dictionary: a vocab object for the information of string-to-index

examples/text_classification/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def train_and_valid(lr_, sub_train_, sub_valid_):
5656
We use a SGD optimizer to train the model here and the learning rate
5757
decreases linearly with the progress of the training process.
5858
59-
Arguments:
59+
Args:
6060
lr_: learning rate
6161
sub_train_: the data used to train the model
6262
sub_valid_: the data used for validation
@@ -94,7 +94,7 @@ def train_and_valid(lr_, sub_train_, sub_valid_):
9494

9595
def test(data_):
9696
r"""
97-
Arguments:
97+
Args:
9898
data_: the data used to train the model
9999
"""
100100
data = DataLoader(data_, batch_size=batch_size, collate_fn=generate_batch)

packaging/torchtext/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ test:
3838

3939
requires:
4040
- pytest
41+
- cpuonly
4142

4243
about:
4344
home: https://github.com/pytorch/text

test/data/test_batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_batch_iter(self):
3737
batch = next(iter(itr))
3838
(x1, x2), y = batch
3939
x = (x1, x2)[fld_order.index("float")]
40-
self.assertEquals(y.data[0], 1)
41-
self.assertEquals(y.data[1], 12)
40+
self.assertEqual(y.data[0], 1)
41+
self.assertEqual(y.data[1], 12)
4242
self.assertAlmostEqual(x.data[0], 0.1, places=4)
4343
self.assertAlmostEqual(x.data[1], 0.5, places=4)

test/data/test_builtin_datasets.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,40 @@ def test_imdb(self):
162162
self._helper_test_func(len(test_iter), 25000, next(iter(test_iter))[1][:25], 'I love sci-fi and am will')
163163
del train_iter, test_iter
164164

165+
def test_iwslt(self):
166+
from torchtext.experimental.datasets import IWSLT
167+
168+
train_dataset, valid_dataset, test_dataset = IWSLT()
169+
170+
self.assertEqual(len(train_dataset), 196884)
171+
self.assertEqual(len(valid_dataset), 993)
172+
self.assertEqual(len(test_dataset), 1305)
173+
174+
de_vocab, en_vocab = train_dataset.get_vocab()
175+
176+
def assert_nth_pair_is_equal(n, expected_sentence_pair):
177+
de_sentence = [de_vocab.itos[index] for index in train_dataset[n][0]]
178+
en_sentence = [en_vocab.itos[index] for index in train_dataset[n][1]]
179+
expected_de_sentence, expected_en_sentence = expected_sentence_pair
180+
181+
self.assertEqual(de_sentence, expected_de_sentence)
182+
self.assertEqual(en_sentence, expected_en_sentence)
183+
184+
assert_nth_pair_is_equal(0, (['David', 'Gallo', ':', 'Das', 'ist', 'Bill', 'Lange',
185+
'.', 'Ich', 'bin', 'Dave', 'Gallo', '.', '\n'],
186+
['David', 'Gallo', ':', 'This', 'is', 'Bill', 'Lange',
187+
'.', 'I', "'m", 'Dave', 'Gallo', '.', '\n']))
188+
assert_nth_pair_is_equal(10, (['Die', 'meisten', 'Tiere', 'leben', 'in',
189+
'den', 'Ozeanen', '.', '\n'],
190+
['Most', 'of', 'the', 'animals', 'are', 'in',
191+
'the', 'oceans', '.', '\n']))
192+
assert_nth_pair_is_equal(20, (['Es', 'ist', 'einer', 'meiner', 'Lieblinge', ',', 'weil', 'es',
193+
'alle', 'möglichen', 'Funktionsteile', 'hat', '.', '\n'],
194+
['It', "'s", 'one', 'of', 'my', 'favorites', ',', 'because', 'it', "'s",
195+
'got', 'all', 'sorts', 'of', 'working', 'parts', '.', '\n']))
196+
datafile = os.path.join(self.project_root, ".data", "2016-01.tgz")
197+
conditional_remove(datafile)
198+
165199
def test_multi30k(self):
166200
from torchtext.experimental.datasets import Multi30k
167201
# smoke test to ensure multi30k works properly

test/data/test_functional.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,24 @@ def test_BasicEnglishNormalize(self):
107107
self.assertEqual(eager_tokens, ref_results)
108108
self.assertEqual(experimental_eager_tokens, ref_results)
109109

110-
# test load and save
111-
save_path = os.path.join(self.test_dir, 'basic_english_normalize.pt')
112-
torch.save(basic_eng_norm.to_ivalue(), save_path)
113-
loaded_basic_eng_norm = torch.load(save_path)
110+
def test_basicEnglishNormalize_load_and_save(self):
111+
test_sample = '\'".<br />,()!?;: Basic English Normalization for a Line of Text \'".<br />,()!?;:'
112+
ref_results = ["'", '.', ',', '(', ')', '!', '?', 'basic', 'english', 'normalization',
113+
'for', 'a', 'line', 'of', 'text', "'", '.', ',', '(', ')', '!', '?']
114114

115-
loaded_eager_tokens = loaded_basic_eng_norm(test_sample)
116-
self.assertEqual(loaded_eager_tokens, ref_results)
115+
with self.subTest('pybind'):
116+
save_path = os.path.join(self.test_dir, 'ben_pybind.pt')
117+
ben = basic_english_normalize()
118+
torch.save(ben, save_path)
119+
loaded_ben = torch.load(save_path)
120+
self.assertEqual(loaded_ben(test_sample), ref_results)
121+
122+
with self.subTest('torchscript'):
123+
save_path = os.path.join(self.test_dir, 'ben_torchscrip.pt')
124+
ben = basic_english_normalize().to_ivalue()
125+
torch.save(ben, save_path)
126+
loaded_ben = torch.load(save_path)
127+
self.assertEqual(loaded_ben(test_sample), ref_results)
117128

118129
# TODO(Nayef211): remove decorator once https://github.com/pytorch/pytorch/issues/38207 is closed
119130
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
@@ -147,13 +158,39 @@ def test_RegexTokenizer(self):
147158
self.assertEqual(eager_tokens, ref_results)
148159
self.assertEqual(jit_tokens, ref_results)
149160

150-
# test load and save
151-
save_path = os.path.join(self.test_dir, 'regex.pt')
152-
torch.save(r_tokenizer.to_ivalue(), save_path)
153-
loaded_r_tokenizer = torch.load(save_path)
161+
def test_load_and_save(self):
162+
test_sample = '\'".<br />,()!?;: Basic Regex Tokenization for a Line of Text \'".<br />,()!?;:'
163+
ref_results = ["'", '.', ',', '(', ')', '!', '?', 'Basic', 'Regex', 'Tokenization',
164+
'for', 'a', 'Line', 'of', 'Text', "'", '.', ',', '(', ')', '!', '?']
165+
patterns_list = [
166+
(r'\'', ' \' '),
167+
(r'\"', ''),
168+
(r'\.', ' . '),
169+
(r'<br \/>', ' '),
170+
(r',', ' , '),
171+
(r'\(', ' ( '),
172+
(r'\)', ' ) '),
173+
(r'\!', ' ! '),
174+
(r'\?', ' ? '),
175+
(r'\;', ' '),
176+
(r'\:', ' '),
177+
(r'\s+', ' ')]
154178

155-
loaded_eager_tokens = loaded_r_tokenizer(test_sample)
156-
self.assertEqual(loaded_eager_tokens, ref_results)
179+
with self.subTest('pybind'):
180+
save_path = os.path.join(self.test_dir, 'regex_pybind.pt')
181+
tokenizer = regex_tokenizer(patterns_list)
182+
torch.save(tokenizer, save_path)
183+
loaded_tokenizer = torch.load(save_path)
184+
results = loaded_tokenizer(test_sample)
185+
self.assertEqual(results, ref_results)
186+
187+
with self.subTest('torchscript'):
188+
save_path = os.path.join(self.test_dir, 'regex_torchscript.pt')
189+
tokenizer = regex_tokenizer(patterns_list).to_ivalue()
190+
torch.save(tokenizer, save_path)
191+
loaded_tokenizer = torch.load(save_path)
192+
results = loaded_tokenizer(test_sample)
193+
self.assertEqual(results, ref_results)
157194

158195
def test_custom_replace(self):
159196
custom_replace_transform = custom_replace([(r'S', 's'), (r'\s+', ' ')])

test/experimental/test_transforms.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,27 @@ def test_vector_transform(self):
5454
[-0.32423, -0.098845, -0.0073467]])
5555
self.assertEqual(vector_transform(['the', 'world'])[:, 0:3], expected_fasttext_simple_en)
5656
self.assertEqual(jit_vector_transform(['the', 'world'])[:, 0:3], expected_fasttext_simple_en)
57+
58+
def test_sentencepiece_load_and_save(self):
59+
model_path = get_asset_path('spm_example.model')
60+
input = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
61+
expected = [
62+
'▁Sent', 'ence', 'P', 'ie', 'ce', '▁is',
63+
'▁an', '▁un', 'super', 'vis', 'ed', '▁text',
64+
'▁to', 'ken', 'izer', '▁and',
65+
'▁de', 'to', 'ken', 'izer',
66+
]
67+
68+
with self.subTest('pybind'):
69+
save_path = os.path.join(self.test_dir, 'spm_pybind.pt')
70+
spm = sentencepiece_tokenizer((model_path))
71+
torch.save(spm, save_path)
72+
loaded_spm = torch.load(save_path)
73+
self.assertEqual(expected, loaded_spm(input))
74+
75+
with self.subTest('torchscript'):
76+
save_path = os.path.join(self.test_dir, 'spm_torchscript.pt')
77+
spm = sentencepiece_tokenizer((model_path)).to_ivalue()
78+
torch.save(spm, save_path)
79+
loaded_spm = torch.load(save_path)
80+
self.assertEqual(expected, loaded_spm(input))

0 commit comments

Comments
 (0)