66import torchtext .data as data
77from torchtext .datasets import AG_NEWS
88import torch
9- from torch .testing import assert_allclose
109from ..common .torchtext_test_case import TorchtextTestCase
1110
1211
@@ -99,10 +98,10 @@ def test_text_classification(self):
9998 ag_news_train , ag_news_test = AG_NEWS (root = datadir , ngrams = 3 )
10099 self .assertEqual (len (ag_news_train ), 120000 )
101100 self .assertEqual (len (ag_news_test ), 7600 )
102- assert_allclose (ag_news_train [- 1 ][1 ][:10 ],
103- torch .tensor ([3525 , 319 , 4053 , 34 , 5407 , 3607 , 70 , 6798 , 10599 , 4053 ]).long ())
104- assert_allclose (ag_news_test [- 1 ][1 ][:10 ],
105- torch .tensor ([2351 , 758 , 96 , 38581 , 2351 , 220 , 5 , 396 , 3 , 14786 ]).long ())
101+ self . assertEqual (ag_news_train [- 1 ][1 ][:10 ],
102+ torch .tensor ([3525 , 319 , 4053 , 34 , 5407 , 3607 , 70 , 6798 , 10599 , 4053 ]).long ())
103+ self . assertEqual (ag_news_test [- 1 ][1 ][:10 ],
104+ torch .tensor ([2351 , 758 , 96 , 38581 , 2351 , 220 , 5 , 396 , 3 , 14786 ]).long ())
106105
107106 def test_imdb (self ):
108107 from torchtext .experimental .datasets import IMDB
@@ -111,14 +110,14 @@ def test_imdb(self):
111110 train_dataset , test_dataset = IMDB ()
112111 self .assertEqual (len (train_dataset ), 25000 )
113112 self .assertEqual (len (test_dataset ), 25000 )
114- assert_allclose (train_dataset [0 ][1 ][:10 ],
115- torch .tensor ([13 , 1568 , 13 , 246 , 35468 , 43 , 64 , 398 , 1135 , 92 ]).long ())
116- assert_allclose (train_dataset [- 1 ][1 ][:10 ],
117- torch .tensor ([2 , 71 , 4555 , 194 , 3328 , 15144 , 42 , 227 , 148 , 8 ]).long ())
118- assert_allclose (test_dataset [0 ][1 ][:10 ],
119- torch .tensor ([13 , 125 , 1051 , 5 , 246 , 1652 , 8 , 277 , 66 , 20 ]).long ())
120- assert_allclose (test_dataset [- 1 ][1 ][:10 ],
121- torch .tensor ([13 , 1035 , 14 , 21 , 28 , 2 , 1051 , 1275 , 1008 , 3 ]).long ())
113+ self . assertEqual (train_dataset [0 ][1 ][:10 ],
114+ torch .tensor ([13 , 1568 , 13 , 246 , 35468 , 43 , 64 , 398 , 1135 , 92 ]).long ())
115+ self . assertEqual (train_dataset [- 1 ][1 ][:10 ],
116+ torch .tensor ([2 , 71 , 4555 , 194 , 3328 , 15144 , 42 , 227 , 148 , 8 ]).long ())
117+ self . assertEqual (test_dataset [0 ][1 ][:10 ],
118+ torch .tensor ([13 , 125 , 1051 , 5 , 246 , 1652 , 8 , 277 , 66 , 20 ]).long ())
119+ self . assertEqual (test_dataset [- 1 ][1 ][:10 ],
120+ torch .tensor ([13 , 1035 , 14 , 21 , 28 , 2 , 1051 , 1275 , 1008 , 3 ]).long ())
122121
123122 # Test API with a vocab input object
124123 old_vocab = train_dataset .get_vocab ()
@@ -164,14 +163,14 @@ def test_squad1(self):
164163 train_dataset , dev_dataset = SQuAD1 ()
165164 self .assertEqual (len (train_dataset ), 87599 )
166165 self .assertEqual (len (dev_dataset ), 10570 )
167- assert_allclose (train_dataset [100 ]['question' ],
168- torch .tensor ([7 , 24 , 86 , 52 , 2 , 373 , 887 , 18 , 12797 , 11090 , 1356 , 2 , 1788 , 3273 , 16 ]).long ())
169- assert_allclose (train_dataset [100 ]['ans_pos' ][0 ],
170- torch .tensor ([72 , 72 ]).long ())
171- assert_allclose (dev_dataset [100 ]['question' ],
172- torch .tensor ([42 , 27 , 669 , 7438 , 17 , 2 , 1950 , 3273 , 17252 , 389 , 16 ]).long ())
173- assert_allclose (dev_dataset [100 ]['ans_pos' ][0 ],
174- torch .tensor ([45 , 48 ]).long ())
166+ self . assertEqual (train_dataset [100 ]['question' ],
167+ torch .tensor ([7 , 24 , 86 , 52 , 2 , 373 , 887 , 18 , 12797 , 11090 , 1356 , 2 , 1788 , 3273 , 16 ]).long ())
168+ self . assertEqual (train_dataset [100 ]['ans_pos' ][0 ],
169+ torch .tensor ([72 , 72 ]).long ())
170+ self . assertEqual (dev_dataset [100 ]['question' ],
171+ torch .tensor ([42 , 27 , 669 , 7438 , 17 , 2 , 1950 , 3273 , 17252 , 389 , 16 ]).long ())
172+ self . assertEqual (dev_dataset [100 ]['ans_pos' ][0 ],
173+ torch .tensor ([45 , 48 ]).long ())
175174
176175 # Test API with a vocab input object
177176 old_vocab = train_dataset .get_vocab ()
@@ -185,14 +184,14 @@ def test_squad2(self):
185184 train_dataset , dev_dataset = SQuAD2 ()
186185 self .assertEqual (len (train_dataset ), 130319 )
187186 self .assertEqual (len (dev_dataset ), 11873 )
188- assert_allclose (train_dataset [200 ]['question' ],
189- torch .tensor ([84 , 50 , 1421 , 12 , 5439 , 4569 , 17 , 30 , 2 , 15202 , 4754 , 1421 , 16 ]).long ())
190- assert_allclose (train_dataset [200 ]['ans_pos' ][0 ],
191- torch .tensor ([9 , 9 ]).long ())
192- assert_allclose (dev_dataset [200 ]['question' ],
193- torch .tensor ([41 , 29 , 2 , 66 , 17016 , 30 , 0 , 1955 , 16 ]).long ())
194- assert_allclose (dev_dataset [200 ]['ans_pos' ][0 ],
195- torch .tensor ([40 , 46 ]).long ())
187+ self . assertEqual (train_dataset [200 ]['question' ],
188+ torch .tensor ([84 , 50 , 1421 , 12 , 5439 , 4569 , 17 , 30 , 2 , 15202 , 4754 , 1421 , 16 ]).long ())
189+ self . assertEqual (train_dataset [200 ]['ans_pos' ][0 ],
190+ torch .tensor ([9 , 9 ]).long ())
191+ self . assertEqual (dev_dataset [200 ]['question' ],
192+ torch .tensor ([41 , 29 , 2 , 66 , 17016 , 30 , 0 , 1955 , 16 ]).long ())
193+ self . assertEqual (dev_dataset [200 ]['ans_pos' ][0 ],
194+ torch .tensor ([40 , 46 ]).long ())
196195
197196 # Test API with a vocab input object
198197 old_vocab = train_dataset .get_vocab ()
0 commit comments