Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 5df4f7b

Browse files
switch to torch TestCase for build-in dataset (#822)
1 parent de416d7 commit 5df4f7b

File tree

6 files changed

+72
-77
lines changed

6 files changed

+72
-77
lines changed

test/common/torchtext_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
from unittest import TestCase
2+
from torch.testing._internal.common_utils import TestCase
33
import json
44
import logging
55
import os

test/data/test_builtin_datasets.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torchtext.data as data
77
from torchtext.datasets import AG_NEWS
88
import torch
9-
from torch.testing import assert_allclose
109
from ..common.torchtext_test_case import TorchtextTestCase
1110

1211

@@ -99,10 +98,10 @@ def test_text_classification(self):
9998
ag_news_train, ag_news_test = AG_NEWS(root=datadir, ngrams=3)
10099
self.assertEqual(len(ag_news_train), 120000)
101100
self.assertEqual(len(ag_news_test), 7600)
102-
assert_allclose(ag_news_train[-1][1][:10],
103-
torch.tensor([3525, 319, 4053, 34, 5407, 3607, 70, 6798, 10599, 4053]).long())
104-
assert_allclose(ag_news_test[-1][1][:10],
105-
torch.tensor([2351, 758, 96, 38581, 2351, 220, 5, 396, 3, 14786]).long())
101+
self.assertEqual(ag_news_train[-1][1][:10],
102+
torch.tensor([3525, 319, 4053, 34, 5407, 3607, 70, 6798, 10599, 4053]).long())
103+
self.assertEqual(ag_news_test[-1][1][:10],
104+
torch.tensor([2351, 758, 96, 38581, 2351, 220, 5, 396, 3, 14786]).long())
106105

107106
def test_imdb(self):
108107
from torchtext.experimental.datasets import IMDB
@@ -111,14 +110,14 @@ def test_imdb(self):
111110
train_dataset, test_dataset = IMDB()
112111
self.assertEqual(len(train_dataset), 25000)
113112
self.assertEqual(len(test_dataset), 25000)
114-
assert_allclose(train_dataset[0][1][:10],
115-
torch.tensor([13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92]).long())
116-
assert_allclose(train_dataset[-1][1][:10],
117-
torch.tensor([2, 71, 4555, 194, 3328, 15144, 42, 227, 148, 8]).long())
118-
assert_allclose(test_dataset[0][1][:10],
119-
torch.tensor([13, 125, 1051, 5, 246, 1652, 8, 277, 66, 20]).long())
120-
assert_allclose(test_dataset[-1][1][:10],
121-
torch.tensor([13, 1035, 14, 21, 28, 2, 1051, 1275, 1008, 3]).long())
113+
self.assertEqual(train_dataset[0][1][:10],
114+
torch.tensor([13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92]).long())
115+
self.assertEqual(train_dataset[-1][1][:10],
116+
torch.tensor([2, 71, 4555, 194, 3328, 15144, 42, 227, 148, 8]).long())
117+
self.assertEqual(test_dataset[0][1][:10],
118+
torch.tensor([13, 125, 1051, 5, 246, 1652, 8, 277, 66, 20]).long())
119+
self.assertEqual(test_dataset[-1][1][:10],
120+
torch.tensor([13, 1035, 14, 21, 28, 2, 1051, 1275, 1008, 3]).long())
122121

123122
# Test API with a vocab input object
124123
old_vocab = train_dataset.get_vocab()
@@ -164,14 +163,14 @@ def test_squad1(self):
164163
train_dataset, dev_dataset = SQuAD1()
165164
self.assertEqual(len(train_dataset), 87599)
166165
self.assertEqual(len(dev_dataset), 10570)
167-
assert_allclose(train_dataset[100]['question'],
168-
torch.tensor([7, 24, 86, 52, 2, 373, 887, 18, 12797, 11090, 1356, 2, 1788, 3273, 16]).long())
169-
assert_allclose(train_dataset[100]['ans_pos'][0],
170-
torch.tensor([72, 72]).long())
171-
assert_allclose(dev_dataset[100]['question'],
172-
torch.tensor([42, 27, 669, 7438, 17, 2, 1950, 3273, 17252, 389, 16]).long())
173-
assert_allclose(dev_dataset[100]['ans_pos'][0],
174-
torch.tensor([45, 48]).long())
166+
self.assertEqual(train_dataset[100]['question'],
167+
torch.tensor([7, 24, 86, 52, 2, 373, 887, 18, 12797, 11090, 1356, 2, 1788, 3273, 16]).long())
168+
self.assertEqual(train_dataset[100]['ans_pos'][0],
169+
torch.tensor([72, 72]).long())
170+
self.assertEqual(dev_dataset[100]['question'],
171+
torch.tensor([42, 27, 669, 7438, 17, 2, 1950, 3273, 17252, 389, 16]).long())
172+
self.assertEqual(dev_dataset[100]['ans_pos'][0],
173+
torch.tensor([45, 48]).long())
175174

176175
# Test API with a vocab input object
177176
old_vocab = train_dataset.get_vocab()
@@ -185,14 +184,14 @@ def test_squad2(self):
185184
train_dataset, dev_dataset = SQuAD2()
186185
self.assertEqual(len(train_dataset), 130319)
187186
self.assertEqual(len(dev_dataset), 11873)
188-
assert_allclose(train_dataset[200]['question'],
189-
torch.tensor([84, 50, 1421, 12, 5439, 4569, 17, 30, 2, 15202, 4754, 1421, 16]).long())
190-
assert_allclose(train_dataset[200]['ans_pos'][0],
191-
torch.tensor([9, 9]).long())
192-
assert_allclose(dev_dataset[200]['question'],
193-
torch.tensor([41, 29, 2, 66, 17016, 30, 0, 1955, 16]).long())
194-
assert_allclose(dev_dataset[200]['ans_pos'][0],
195-
torch.tensor([40, 46]).long())
187+
self.assertEqual(train_dataset[200]['question'],
188+
torch.tensor([84, 50, 1421, 12, 5439, 4569, 17, 30, 2, 15202, 4754, 1421, 16]).long())
189+
self.assertEqual(train_dataset[200]['ans_pos'][0],
190+
torch.tensor([9, 9]).long())
191+
self.assertEqual(dev_dataset[200]['question'],
192+
torch.tensor([41, 29, 2, 66, 17016, 30, 0, 1955, 16]).long())
193+
self.assertEqual(dev_dataset[200]['ans_pos'][0],
194+
torch.tensor([40, 46]).long())
196195

197196
# Test API with a vocab input object
198197
old_vocab = train_dataset.get_vocab()

test/data/test_field.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from collections import Counter
33
import os
44

5-
from numpy.testing import assert_allclose
65
import torch
76
import torchtext.data as data
87
import pytest
@@ -376,9 +375,9 @@ def test_numerical_features_no_vocab(self):
376375
test_float_data = ["1.1", "0.1", "3.91", "0.2", "10.2"]
377376

378377
numericalized_int = int_field.numericalize(test_int_data)
379-
assert_allclose(numericalized_int.data.numpy(), [1, 0, 1, 3, 19])
378+
self.assertEqual(numericalized_int.data, [1, 0, 1, 3, 19])
380379
numericalized_float = float_field.numericalize(test_float_data)
381-
assert_allclose(numericalized_float.data.numpy(), [1.1, 0.1, 3.91, 0.2, 10.2])
380+
self.assertEqual(numericalized_float.data, [1.1, 0.1, 3.91, 0.2, 10.2])
382381

383382
# Test with postprocessing applied
384383
int_field = data.Field(sequential=False, use_vocab=False,
@@ -396,9 +395,9 @@ def test_numerical_features_no_vocab(self):
396395
test_float_data = ["1.1", "0.1", "3.91", "0.2", "10.2"]
397396

398397
numericalized_int = int_field.numericalize(test_int_data)
399-
assert_allclose(numericalized_int.data.numpy(), [2, 1, 2, 4, 20])
398+
self.assertEqual(numericalized_int.data, [2, 1, 2, 4, 20])
400399
numericalized_float = float_field.numericalize(test_float_data)
401-
assert_allclose(numericalized_float.data.numpy(), [0.55, 0.05, 1.955, 0.1, 5.1])
400+
self.assertEqual(numericalized_float.data, [0.55, 0.05, 1.955, 0.1, 5.1])
402401

403402
def test_errors(self):
404403
# Test that passing a non-tuple (of data and length) to numericalize

test/data/test_metrics.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from torchtext.data.metrics import bleu_score
2-
from torch.testing import assert_allclose
32
from ..common.torchtext_test_case import TorchtextTestCase
43

54

@@ -19,19 +18,19 @@ def test_bleu_score(self):
1918
# Partial match
2019
candidate = [['My', 'full', 'pytorch', 'test']]
2120
refs = [[['My', 'full', 'pytorch', 'test', '!'], ['Different']]]
22-
assert_allclose(bleu_score(candidate, refs), 0.7788007)
21+
self.assertEqual(bleu_score(candidate, refs), 0.7788007)
2322

2423
# Bigrams and unigrams only
2524
candidate = [['My', 'pytorch', 'test']]
2625
refs = [[['My', 'full', 'pytorch', 'test'], ['Different']]]
27-
assert_allclose(bleu_score(candidate, refs, max_n=2,
28-
weights=[0.5, 0.5]), 0.5066641)
26+
self.assertEqual(bleu_score(candidate, refs, max_n=2,
27+
weights=[0.5, 0.5]), 0.5066641)
2928

3029
# Multi-sentence corpus
3130
candidate = [['My', 'full', 'pytorch', 'test'], ['Another', 'Sentence']]
3231
refs = [[['My', 'full', 'pytorch', 'test'], ['Completely', 'Different']],
3332
[['No', 'Match']]]
34-
assert_allclose(bleu_score(candidate, refs), 0.8408964)
33+
self.assertEqual(bleu_score(candidate, refs), 0.8408964)
3534

3635
# Empty input
3736
candidate = [[]]
@@ -52,13 +51,13 @@ def test_bleu_score(self):
5251

5352
# The comments below give the code used to get each hardcoded bleu score
5453
# nltk.translate.bleu_score.corpus_bleu(refs, candidate)
55-
assert_allclose(bleu_score(candidate, refs), 0.4573199)
54+
self.assertEqual(bleu_score(candidate, refs), 0.4573199)
5655
# nltk.translate.bleu_score.corpus_bleu(refs, candidate, weights=[0.33]*3)
57-
assert_allclose(bleu_score(candidate, refs, 3,
58-
weights=[0.33, 0.33, 0.33]), 0.4901113)
56+
self.assertEqual(bleu_score(candidate, refs, 3,
57+
weights=[0.33, 0.33, 0.33]), 0.4901113)
5958
# nltk.translate.bleu_score.corpus_bleu(refs, candidate, weights=[0.5]*2)
60-
assert_allclose(bleu_score(candidate, refs, 2,
61-
weights=[0.5, 0.5]), 0.5119535)
59+
self.assertEqual(bleu_score(candidate, refs, 2,
60+
weights=[0.5, 0.5]), 0.5119535)
6261
# nltk.translate.bleu_score.corpus_bleu(refs, candidate, weights=[1])
63-
assert_allclose(bleu_score(candidate, refs, 1,
64-
weights=[1]), 0.5515605)
62+
self.assertEqual(bleu_score(candidate, refs, 1,
63+
weights=[1]), 0.5515605)

test/test_build.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
from collections import Counter
55

6-
import numpy as np
76
import torch
87
import torchtext.data
98

@@ -130,16 +129,16 @@ def test_vectors_get_vecs(self):
130129
self.assertEqual(vec.vectors.shape[0], len(vec))
131130

132131
tokens = ['chip', 'baby', 'Beautiful']
133-
token_vecs = vec.get_vecs_by_tokens(tokens).numpy()
132+
token_vecs = vec.get_vecs_by_tokens(tokens)
134133
self.assertEqual(token_vecs.shape[0], len(tokens))
135134
self.assertEqual(token_vecs.shape[1], vec.dim)
136-
torch.testing.assert_allclose(vec[tokens[0]].numpy(), token_vecs[0])
137-
torch.testing.assert_allclose(vec[tokens[1]].numpy(), token_vecs[1])
138-
torch.testing.assert_allclose(vec['<unk>'].numpy(), token_vecs[2])
135+
self.assertEqual(vec[tokens[0]], token_vecs[0])
136+
self.assertEqual(vec[tokens[1]], token_vecs[1])
137+
self.assertEqual(vec['<unk>'], token_vecs[2])
139138

140-
token_one_vec = vec.get_vecs_by_tokens(tokens[0], lower_case_backup=True).numpy()
139+
token_one_vec = vec.get_vecs_by_tokens(tokens[0], lower_case_backup=True)
141140
self.assertEqual(token_one_vec.shape[0], vec.dim)
142-
torch.testing.assert_allclose(vec[tokens[0].lower()].numpy(), token_one_vec)
141+
self.assertEqual(vec[tokens[0].lower()], token_one_vec)
143142

144143
def test_download_charngram_vectors(self):
145144
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
@@ -157,7 +156,7 @@ def test_download_charngram_vectors(self):
157156
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
158157
self.assertEqual(v.itos, expected_itos)
159158
self.assertEqual(dict(v.stoi), expected_stoi)
160-
vectors = v.vectors.numpy()
159+
vectors = v.vectors
161160

162161
# The first 5 entries in each vector.
163162
expected_charngram = {
@@ -167,11 +166,11 @@ def test_download_charngram_vectors(self):
167166
}
168167

169168
for word in expected_charngram:
170-
torch.testing.assert_allclose(
169+
self.assertEqual(
171170
vectors[v.stoi[word], :5], expected_charngram[word])
172171

173-
torch.testing.assert_allclose(vectors[v.stoi['<unk>']], np.zeros(100))
174-
torch.testing.assert_allclose(vectors[v.stoi['OOV token']], np.zeros(100))
172+
self.assertEqual(vectors[v.stoi['<unk>']], torch.zeros(100))
173+
self.assertEqual(vectors[v.stoi['OOV token']], torch.zeros(100))
175174

176175
def test_download_custom_vectors(self):
177176
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
@@ -187,7 +186,7 @@ def test_download_custom_vectors(self):
187186

188187
self.assertEqual(v.itos, ['<unk>', '<pad>', '<bos>',
189188
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'])
190-
vectors = v.vectors.numpy()
189+
vectors = v.vectors
191190

192191
# The first 5 entries in each vector.
193192
expected_fasttext_simple_en = {
@@ -196,10 +195,10 @@ def test_download_custom_vectors(self):
196195
}
197196

198197
for word in expected_fasttext_simple_en:
199-
torch.testing.assert_allclose(
198+
self.assertEqual(
200199
vectors[v.stoi[word], :5], expected_fasttext_simple_en[word])
201200

202-
torch.testing.assert_allclose(vectors[v.stoi['<unk>']], np.zeros(300))
201+
self.assertEqual(vectors[v.stoi['<unk>']], torch.zeros(300))
203202

204203
def test_download_fasttext_vectors(self):
205204
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
@@ -219,7 +218,7 @@ def test_download_fasttext_vectors(self):
219218
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
220219
self.assertEqual(v.itos, expected_itos)
221220
self.assertEqual(dict(v.stoi), expected_stoi)
222-
vectors = v.vectors.numpy()
221+
vectors = v.vectors
223222

224223
# The first 5 entries in each vector.
225224
expected_fasttext_simple_en = {
@@ -228,11 +227,11 @@ def test_download_fasttext_vectors(self):
228227
}
229228

230229
for word in expected_fasttext_simple_en:
231-
torch.testing.assert_allclose(
230+
self.assertEqual(
232231
vectors[v.stoi[word], :5], expected_fasttext_simple_en[word])
233232

234-
torch.testing.assert_allclose(vectors[v.stoi['<unk>']], np.zeros(300))
235-
torch.testing.assert_allclose(vectors[v.stoi['OOV token']], np.zeros(300))
233+
self.assertEqual(vectors[v.stoi['<unk>']], torch.zeros(300))
234+
self.assertEqual(vectors[v.stoi['OOV token']], torch.zeros(300))
236235

237236
def test_download_glove_vectors(self):
238237
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
@@ -253,7 +252,7 @@ def test_download_glove_vectors(self):
253252
self.assertEqual(v.itos, expected_itos)
254253
self.assertEqual(dict(v.stoi), expected_stoi)
255254

256-
vectors = v.vectors.numpy()
255+
vectors = v.vectors
257256

258257
# The first 5 entries in each vector.
259258
expected_twitter = {
@@ -262,11 +261,11 @@ def test_download_glove_vectors(self):
262261
}
263262

264263
for word in expected_twitter:
265-
torch.testing.assert_allclose(
264+
self.assertEqual(
266265
vectors[v.stoi[word], :5], expected_twitter[word])
267266

268-
torch.testing.assert_allclose(vectors[v.stoi['<unk>']], np.zeros(25))
269-
torch.testing.assert_allclose(vectors[v.stoi['OOV token']], np.zeros(25))
267+
self.assertEqual(vectors[v.stoi['<unk>']], torch.zeros(25))
268+
self.assertEqual(vectors[v.stoi['OOV token']], torch.zeros(25))
270269

271270
def test_extend(self):
272271
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
@@ -281,7 +280,7 @@ def test_extend(self):
281280

282281
self.assertEqual(v.itos[:6], ['<unk>', '<pad>', '<bos>',
283282
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'])
284-
vectors = v.vectors.numpy()
283+
vectors = v.vectors
285284

286285
# The first 5 entries in each vector.
287286
expected_fasttext_simple_en = {
@@ -290,10 +289,10 @@ def test_extend(self):
290289
}
291290

292291
for word in expected_fasttext_simple_en:
293-
torch.testing.assert_allclose(
292+
self.assertEqual(
294293
vectors[v.stoi[word], :5], expected_fasttext_simple_en[word])
295294

296-
torch.testing.assert_allclose(vectors[v.stoi['<unk>']], np.zeros(300))
295+
self.assertEqual(vectors[v.stoi['<unk>']], torch.zeros(300))
297296

298297
def test_vectors_custom_cache(self):
299298
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
@@ -312,7 +311,7 @@ def test_vectors_custom_cache(self):
312311

313312
self.assertEqual(v.itos, ['<unk>', '<pad>', '<bos>',
314313
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'])
315-
vectors = v.vectors.numpy()
314+
vectors = v.vectors
316315

317316
# The first 5 entries in each vector.
318317
expected_fasttext_simple_en = {
@@ -321,7 +320,7 @@ def test_vectors_custom_cache(self):
321320
}
322321

323322
for word in expected_fasttext_simple_en:
324-
torch.testing.assert_allclose(
323+
self.assertEqual(
325324
vectors[v.stoi[word], :5], expected_fasttext_simple_en[word])
326325

327-
torch.testing.assert_allclose(vectors[v.stoi['<unk>']], np.zeros(300))
326+
self.assertEqual(vectors[v.stoi['<unk>']], torch.zeros(300))

test/test_vocab.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66

77
import numpy as np
8-
from numpy.testing import assert_allclose
98
import torch
109
from torchtext import vocab
1110

@@ -89,7 +88,7 @@ def test_vocab_set_vectors(self):
8988
expected_vectors = np.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0],
9089
[0.0, 0.0], [0.1, 0.2], [0.5, 0.6],
9190
[0.3, 0.4]])
92-
assert_allclose(v.vectors.numpy(), expected_vectors)
91+
self.assertEqual(v.vectors, expected_vectors)
9392

9493
def test_errors(self):
9594
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})

0 commit comments

Comments
 (0)