From 3bbb0208d30902f3bb4cdba28def816ab2df465f Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Jun 2020 15:31:51 -0400 Subject: [PATCH 01/10] Use torch.testing._internal.common_utils.TestCase --- test/common/torchtext_test_case.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/common/torchtext_test_case.py b/test/common/torchtext_test_case.py index 249e5aa0c9..21ccc4d785 100644 --- a/test/common/torchtext_test_case.py +++ b/test/common/torchtext_test_case.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from unittest import TestCase +from torch.testing._internal.common_utils import TestCase import json import logging import os From cdc509e7f0ab35f4a69d235f4508b9552ab94e5c Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Jun 2020 18:45:50 -0400 Subject: [PATCH 02/10] Checkpoint --- test/test_build.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/test/test_build.py b/test/test_build.py index f47d535204..449a0b4aaf 100644 --- a/test/test_build.py +++ b/test/test_build.py @@ -133,13 +133,13 @@ def test_vectors_get_vecs(self): token_vecs = vec.get_vecs_by_tokens(tokens).numpy() self.assertEqual(token_vecs.shape[0], len(tokens)) self.assertEqual(token_vecs.shape[1], vec.dim) - torch.testing.assert_allclose(vec[tokens[0]].numpy(), token_vecs[0]) - torch.testing.assert_allclose(vec[tokens[1]].numpy(), token_vecs[1]) - torch.testing.assert_allclose(vec[''].numpy(), token_vecs[2]) + self.assertEqual(vec[tokens[0]].numpy(), token_vecs[0]) + self.assertEqual(vec[tokens[1]].numpy(), token_vecs[1]) + self.assertEqual(vec[''].numpy(), token_vecs[2]) token_one_vec = vec.get_vecs_by_tokens(tokens[0], lower_case_backup=True).numpy() self.assertEqual(token_one_vec.shape[0], vec.dim) - torch.testing.assert_allclose(vec[tokens[0].lower()].numpy(), token_one_vec) + self.assertEqual(vec[tokens[0].lower()].numpy(), token_one_vec) def test_download_charngram_vectors(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -167,11 +167,11 @@ def test_download_charngram_vectors(self): } for word in expected_charngram: - torch.testing.assert_allclose( + self.assertEqual( vectors[v.stoi[word], :5], expected_charngram[word]) - torch.testing.assert_allclose(vectors[v.stoi['']], np.zeros(100)) - torch.testing.assert_allclose(vectors[v.stoi['OOV token']], np.zeros(100)) + self.assertEqual(vectors[v.stoi['']], np.zeros(100)) + self.assertEqual(vectors[v.stoi['OOV token']], np.zeros(100)) def test_download_custom_vectors(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -196,10 +196,10 @@ def test_download_custom_vectors(self): } for word in expected_fasttext_simple_en: - torch.testing.assert_allclose( + self.assertEqual( vectors[v.stoi[word], :5], expected_fasttext_simple_en[word]) - torch.testing.assert_allclose(vectors[v.stoi['']], np.zeros(300)) + self.assertEqual(vectors[v.stoi['']], np.zeros(300)) def test_download_fasttext_vectors(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -228,11 +228,11 @@ def test_download_fasttext_vectors(self): } for word in expected_fasttext_simple_en: - torch.testing.assert_allclose( + self.assertEqual( vectors[v.stoi[word], :5], expected_fasttext_simple_en[word]) - torch.testing.assert_allclose(vectors[v.stoi['']], np.zeros(300)) - torch.testing.assert_allclose(vectors[v.stoi['OOV token']], np.zeros(300)) + self.assertEqual(vectors[v.stoi['']], np.zeros(300)) + self.assertEqual(vectors[v.stoi['OOV token']], np.zeros(300)) def test_download_glove_vectors(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -262,11 +262,11 @@ def test_download_glove_vectors(self): } for word in expected_twitter: - torch.testing.assert_allclose( + self.assertEqual( vectors[v.stoi[word], :5], expected_twitter[word]) - torch.testing.assert_allclose(vectors[v.stoi['']], np.zeros(25)) - torch.testing.assert_allclose(vectors[v.stoi['OOV token']], np.zeros(25)) + self.assertEqual(vectors[v.stoi['']], np.zeros(25)) + self.assertEqual(vectors[v.stoi['OOV token']], np.zeros(25)) def test_extend(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -290,10 +290,10 @@ def test_extend(self): } for word in expected_fasttext_simple_en: - torch.testing.assert_allclose( + self.assertEqual( vectors[v.stoi[word], :5], expected_fasttext_simple_en[word]) - torch.testing.assert_allclose(vectors[v.stoi['']], np.zeros(300)) + self.assertEqual(vectors[v.stoi['']], np.zeros(300)) def test_vectors_custom_cache(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -321,7 +321,7 @@ def test_vectors_custom_cache(self): } for word in expected_fasttext_simple_en: - torch.testing.assert_allclose( + self.assertEqual( vectors[v.stoi[word], :5], expected_fasttext_simple_en[word]) - torch.testing.assert_allclose(vectors[v.stoi['']], np.zeros(300)) + self.assertEqual(vectors[v.stoi['']], np.zeros(300)) From 43887edc6d6a62ddae38dba624a77b1036a4b4b7 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Jun 2020 18:46:24 -0400 Subject: [PATCH 03/10] Checkpoint --- test/test_build.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/test_build.py b/test/test_build.py index 449a0b4aaf..196fbac20f 100644 --- a/test/test_build.py +++ b/test/test_build.py @@ -170,8 +170,8 @@ def test_download_charngram_vectors(self): self.assertEqual( vectors[v.stoi[word], :5], expected_charngram[word]) - self.assertEqual(vectors[v.stoi['']], np.zeros(100)) - self.assertEqual(vectors[v.stoi['OOV token']], np.zeros(100)) + self.assertEqual(vectors[v.stoi['']], torch.zeros(100)) + self.assertEqual(vectors[v.stoi['OOV token']], torch.zeros(100)) def test_download_custom_vectors(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -199,7 +199,7 @@ def test_download_custom_vectors(self): self.assertEqual( vectors[v.stoi[word], :5], expected_fasttext_simple_en[word]) - self.assertEqual(vectors[v.stoi['']], np.zeros(300)) + self.assertEqual(vectors[v.stoi['']], torch.zeros(300)) def test_download_fasttext_vectors(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -231,8 +231,8 @@ def test_download_fasttext_vectors(self): self.assertEqual( vectors[v.stoi[word], :5], expected_fasttext_simple_en[word]) - self.assertEqual(vectors[v.stoi['']], np.zeros(300)) - self.assertEqual(vectors[v.stoi['OOV token']], np.zeros(300)) + self.assertEqual(vectors[v.stoi['']], torch.zeros(300)) + self.assertEqual(vectors[v.stoi['OOV token']], torch.zeros(300)) def test_download_glove_vectors(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -265,8 +265,8 @@ def test_download_glove_vectors(self): self.assertEqual( vectors[v.stoi[word], :5], expected_twitter[word]) - self.assertEqual(vectors[v.stoi['']], np.zeros(25)) - self.assertEqual(vectors[v.stoi['OOV token']], np.zeros(25)) + self.assertEqual(vectors[v.stoi['']], torch.zeros(25)) + self.assertEqual(vectors[v.stoi['OOV token']], torch.zeros(25)) def test_extend(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -293,7 +293,7 @@ def test_extend(self): self.assertEqual( vectors[v.stoi[word], :5], expected_fasttext_simple_en[word]) - self.assertEqual(vectors[v.stoi['']], np.zeros(300)) + self.assertEqual(vectors[v.stoi['']], torch.zeros(300)) def test_vectors_custom_cache(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -324,4 +324,4 @@ def test_vectors_custom_cache(self): self.assertEqual( vectors[v.stoi[word], :5], expected_fasttext_simple_en[word]) - self.assertEqual(vectors[v.stoi['']], np.zeros(300)) + self.assertEqual(vectors[v.stoi['']], torch.zeros(300)) From 3721c24749388dc988a842dc72fb71b2605d76d5 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Jun 2020 18:47:03 -0400 Subject: [PATCH 04/10] Checkpoint --- test/test_build.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/test_build.py b/test/test_build.py index 196fbac20f..cc89fd9883 100644 --- a/test/test_build.py +++ b/test/test_build.py @@ -130,16 +130,16 @@ def test_vectors_get_vecs(self): self.assertEqual(vec.vectors.shape[0], len(vec)) tokens = ['chip', 'baby', 'Beautiful'] - token_vecs = vec.get_vecs_by_tokens(tokens).numpy() + token_vecs = vec.get_vecs_by_tokens(tokens) self.assertEqual(token_vecs.shape[0], len(tokens)) self.assertEqual(token_vecs.shape[1], vec.dim) - self.assertEqual(vec[tokens[0]].numpy(), token_vecs[0]) - self.assertEqual(vec[tokens[1]].numpy(), token_vecs[1]) - self.assertEqual(vec[''].numpy(), token_vecs[2]) + self.assertEqual(vec[tokens[0]], token_vecs[0]) + self.assertEqual(vec[tokens[1]], token_vecs[1]) + self.assertEqual(vec[''], token_vecs[2]) - token_one_vec = vec.get_vecs_by_tokens(tokens[0], lower_case_backup=True).numpy() + token_one_vec = vec.get_vecs_by_tokens(tokens[0], lower_case_backup=True) self.assertEqual(token_one_vec.shape[0], vec.dim) - self.assertEqual(vec[tokens[0].lower()].numpy(), token_one_vec) + self.assertEqual(vec[tokens[0].lower()], token_one_vec) def test_download_charngram_vectors(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) @@ -157,7 +157,7 @@ def test_download_charngram_vectors(self): expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.itos, expected_itos) self.assertEqual(dict(v.stoi), expected_stoi) - vectors = v.vectors.numpy() + vectors = v.vectors # The first 5 entries in each vector. expected_charngram = { @@ -187,7 +187,7 @@ def test_download_custom_vectors(self): self.assertEqual(v.itos, ['', '', '', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']) - vectors = v.vectors.numpy() + vectors = v.vectors # The first 5 entries in each vector. expected_fasttext_simple_en = { @@ -219,7 +219,7 @@ def test_download_fasttext_vectors(self): expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.itos, expected_itos) self.assertEqual(dict(v.stoi), expected_stoi) - vectors = v.vectors.numpy() + vectors = v.vectors # The first 5 entries in each vector. expected_fasttext_simple_en = { @@ -253,7 +253,7 @@ def test_download_glove_vectors(self): self.assertEqual(v.itos, expected_itos) self.assertEqual(dict(v.stoi), expected_stoi) - vectors = v.vectors.numpy() + vectors = v.vectors # The first 5 entries in each vector. expected_twitter = { @@ -281,7 +281,7 @@ def test_extend(self): self.assertEqual(v.itos[:6], ['', '', '', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']) - vectors = v.vectors.numpy() + vectors = v.vectors # The first 5 entries in each vector. expected_fasttext_simple_en = { @@ -312,7 +312,7 @@ def test_vectors_custom_cache(self): self.assertEqual(v.itos, ['', '', '', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']) - vectors = v.vectors.numpy() + vectors = v.vectors # The first 5 entries in each vector. expected_fasttext_simple_en = { From 8855a9088361e817e92b93e2d7f9158439e159a2 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Jun 2020 18:48:18 -0400 Subject: [PATCH 05/10] Checkpoint --- test/data/test_field.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/data/test_field.py b/test/data/test_field.py index 11b98bde68..aa76c38567 100644 --- a/test/data/test_field.py +++ b/test/data/test_field.py @@ -2,7 +2,6 @@ from collections import Counter import os -from numpy.testing import assert_allclose import torch import torchtext.data as data import pytest @@ -376,9 +375,9 @@ def test_numerical_features_no_vocab(self): test_float_data = ["1.1", "0.1", "3.91", "0.2", "10.2"] numericalized_int = int_field.numericalize(test_int_data) - assert_allclose(numericalized_int.data.numpy(), [1, 0, 1, 3, 19]) + self.assertEqual(numericalized_int.data, [1, 0, 1, 3, 19]) numericalized_float = float_field.numericalize(test_float_data) - assert_allclose(numericalized_float.data.numpy(), [1.1, 0.1, 3.91, 0.2, 10.2]) + self.assertEqual(numericalized_float.data, [1.1, 0.1, 3.91, 0.2, 10.2]) # Test with postprocessing applied int_field = data.Field(sequential=False, use_vocab=False, @@ -396,9 +395,9 @@ def test_numerical_features_no_vocab(self): test_float_data = ["1.1", "0.1", "3.91", "0.2", "10.2"] numericalized_int = int_field.numericalize(test_int_data) - assert_allclose(numericalized_int.data.numpy(), [2, 1, 2, 4, 20]) + self.assertEqual(numericalized_int.data, [2, 1, 2, 4, 20]) numericalized_float = float_field.numericalize(test_float_data) - assert_allclose(numericalized_float.data.numpy(), [0.55, 0.05, 1.955, 0.1, 5.1]) + self.assertEqual(numericalized_float.data, [0.55, 0.05, 1.955, 0.1, 5.1]) def test_errors(self): # Test that passing a non-tuple (of data and length) to numericalize From db49199f341d5dbbc598183c77083109417f63dd Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Jun 2020 18:48:51 -0400 Subject: [PATCH 06/10] Checkpoint --- test/test_vocab.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_vocab.py b/test/test_vocab.py index c09bff69f2..c9b863d37a 100644 --- a/test/test_vocab.py +++ b/test/test_vocab.py @@ -5,7 +5,6 @@ import numpy as np -from numpy.testing import assert_allclose import torch from torchtext import vocab @@ -89,7 +88,7 @@ def test_vocab_set_vectors(self): expected_vectors = np.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.1, 0.2], [0.5, 0.6], [0.3, 0.4]]) - assert_allclose(v.vectors.numpy(), expected_vectors) + self.assertEqual(v.vectors, expected_vectors) def test_errors(self): c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) From 117d2f87ccb29af77da3919c51b2f8fd9d19ad86 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Jun 2020 18:49:32 -0400 Subject: [PATCH 07/10] Checkpoint --- test/data/test_metrics.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/test/data/test_metrics.py b/test/data/test_metrics.py index 1d69a20bae..b91393fb11 100644 --- a/test/data/test_metrics.py +++ b/test/data/test_metrics.py @@ -1,5 +1,4 @@ from torchtext.data.metrics import bleu_score -from torch.testing import assert_allclose from ..common.torchtext_test_case import TorchtextTestCase @@ -19,19 +18,19 @@ def test_bleu_score(self): # Partial match candidate = [['My', 'full', 'pytorch', 'test']] refs = [[['My', 'full', 'pytorch', 'test', '!'], ['Different']]] - assert_allclose(bleu_score(candidate, refs), 0.7788007) + self.assertEqual(bleu_score(candidate, refs), 0.7788007) # Bigrams and unigrams only candidate = [['My', 'pytorch', 'test']] refs = [[['My', 'full', 'pytorch', 'test'], ['Different']]] - assert_allclose(bleu_score(candidate, refs, max_n=2, + self.assertEqual(bleu_score(candidate, refs, max_n=2, weights=[0.5, 0.5]), 0.5066641) # Multi-sentence corpus candidate = [['My', 'full', 'pytorch', 'test'], ['Another', 'Sentence']] refs = [[['My', 'full', 'pytorch', 'test'], ['Completely', 'Different']], [['No', 'Match']]] - assert_allclose(bleu_score(candidate, refs), 0.8408964) + self.assertEqual(bleu_score(candidate, refs), 0.8408964) # Empty input candidate = [[]] @@ -52,13 +51,13 @@ def test_bleu_score(self): # The comments below give the code used to get each hardcoded bleu score # nltk.translate.bleu_score.corpus_bleu(refs, candidate) - assert_allclose(bleu_score(candidate, refs), 0.4573199) + self.assertEqual(bleu_score(candidate, refs), 0.4573199) # nltk.translate.bleu_score.corpus_bleu(refs, candidate, weights=[0.33]*3) - assert_allclose(bleu_score(candidate, refs, 3, + self.assertEqual(bleu_score(candidate, refs, 3, weights=[0.33, 0.33, 0.33]), 0.4901113) # nltk.translate.bleu_score.corpus_bleu(refs, candidate, weights=[0.5]*2) - assert_allclose(bleu_score(candidate, refs, 2, + self.assertEqual(bleu_score(candidate, refs, 2, weights=[0.5, 0.5]), 0.5119535) # nltk.translate.bleu_score.corpus_bleu(refs, candidate, weights=[1]) - assert_allclose(bleu_score(candidate, refs, 1, + self.assertEqual(bleu_score(candidate, refs, 1, weights=[1]), 0.5515605) From 16653f09a4ee9b0d2221669d89a9c41566066bd5 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Jun 2020 18:50:12 -0400 Subject: [PATCH 08/10] Checkpoint --- test/data/test_builtin_datasets.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 6acc9e6d0f..6129cbf80e 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -6,7 +6,6 @@ import torchtext.data as data from torchtext.datasets import AG_NEWS import torch -from torch.testing import assert_allclose from ..common.torchtext_test_case import TorchtextTestCase @@ -99,9 +98,9 @@ def test_text_classification(self): ag_news_train, ag_news_test = AG_NEWS(root=datadir, ngrams=3) self.assertEqual(len(ag_news_train), 120000) self.assertEqual(len(ag_news_test), 7600) - assert_allclose(ag_news_train[-1][1][:10], + self.assertEqual(ag_news_train[-1][1][:10], torch.tensor([3525, 319, 4053, 34, 5407, 3607, 70, 6798, 10599, 4053]).long()) - assert_allclose(ag_news_test[-1][1][:10], + self.assertEqual(ag_news_test[-1][1][:10], torch.tensor([2351, 758, 96, 38581, 2351, 220, 5, 396, 3, 14786]).long()) def test_imdb(self): @@ -111,13 +110,13 @@ def test_imdb(self): train_dataset, test_dataset = IMDB() self.assertEqual(len(train_dataset), 25000) self.assertEqual(len(test_dataset), 25000) - assert_allclose(train_dataset[0][1][:10], + self.assertEqual(train_dataset[0][1][:10], torch.tensor([13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92]).long()) - assert_allclose(train_dataset[-1][1][:10], + self.assertEqual(train_dataset[-1][1][:10], torch.tensor([2, 71, 4555, 194, 3328, 15144, 42, 227, 148, 8]).long()) - assert_allclose(test_dataset[0][1][:10], + self.assertEqual(test_dataset[0][1][:10], torch.tensor([13, 125, 1051, 5, 246, 1652, 8, 277, 66, 20]).long()) - assert_allclose(test_dataset[-1][1][:10], + self.assertEqual(test_dataset[-1][1][:10], torch.tensor([13, 1035, 14, 21, 28, 2, 1051, 1275, 1008, 3]).long()) # Test API with a vocab input object @@ -164,13 +163,13 @@ def test_squad1(self): train_dataset, dev_dataset = SQuAD1() self.assertEqual(len(train_dataset), 87599) self.assertEqual(len(dev_dataset), 10570) - assert_allclose(train_dataset[100]['question'], + self.assertEqual(train_dataset[100]['question'], torch.tensor([7, 24, 86, 52, 2, 373, 887, 18, 12797, 11090, 1356, 2, 1788, 3273, 16]).long()) - assert_allclose(train_dataset[100]['ans_pos'][0], + self.assertEqual(train_dataset[100]['ans_pos'][0], torch.tensor([72, 72]).long()) - assert_allclose(dev_dataset[100]['question'], + self.assertEqual(dev_dataset[100]['question'], torch.tensor([42, 27, 669, 7438, 17, 2, 1950, 3273, 17252, 389, 16]).long()) - assert_allclose(dev_dataset[100]['ans_pos'][0], + self.assertEqual(dev_dataset[100]['ans_pos'][0], torch.tensor([45, 48]).long()) # Test API with a vocab input object @@ -185,13 +184,13 @@ def test_squad2(self): train_dataset, dev_dataset = SQuAD2() self.assertEqual(len(train_dataset), 130319) self.assertEqual(len(dev_dataset), 11873) - assert_allclose(train_dataset[200]['question'], + self.assertEqual(train_dataset[200]['question'], torch.tensor([84, 50, 1421, 12, 5439, 4569, 17, 30, 2, 15202, 4754, 1421, 16]).long()) - assert_allclose(train_dataset[200]['ans_pos'][0], + self.assertEqual(train_dataset[200]['ans_pos'][0], torch.tensor([9, 9]).long()) - assert_allclose(dev_dataset[200]['question'], + self.assertEqual(dev_dataset[200]['question'], torch.tensor([41, 29, 2, 66, 17016, 30, 0, 1955, 16]).long()) - assert_allclose(dev_dataset[200]['ans_pos'][0], + self.assertEqual(dev_dataset[200]['ans_pos'][0], torch.tensor([40, 46]).long()) # Test API with a vocab input object From 0d804dcb16cec30e6404d11181a30c75e3d089af Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 11 Jun 2020 16:46:43 -0700 Subject: [PATCH 09/10] flake8 --- test/data/test_builtin_datasets.py | 28 ++++++++++++++-------------- test/data/test_metrics.py | 8 ++++---- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 6129cbf80e..3fa8baeb1b 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -99,9 +99,9 @@ def test_text_classification(self): self.assertEqual(len(ag_news_train), 120000) self.assertEqual(len(ag_news_test), 7600) self.assertEqual(ag_news_train[-1][1][:10], - torch.tensor([3525, 319, 4053, 34, 5407, 3607, 70, 6798, 10599, 4053]).long()) + torch.tensor([3525, 319, 4053, 34, 5407, 3607, 70, 6798, 10599, 4053]).long()) self.assertEqual(ag_news_test[-1][1][:10], - torch.tensor([2351, 758, 96, 38581, 2351, 220, 5, 396, 3, 14786]).long()) + torch.tensor([2351, 758, 96, 38581, 2351, 220, 5, 396, 3, 14786]).long()) def test_imdb(self): from torchtext.experimental.datasets import IMDB @@ -111,13 +111,13 @@ def test_imdb(self): self.assertEqual(len(train_dataset), 25000) self.assertEqual(len(test_dataset), 25000) self.assertEqual(train_dataset[0][1][:10], - torch.tensor([13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92]).long()) + torch.tensor([13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92]).long()) self.assertEqual(train_dataset[-1][1][:10], - torch.tensor([2, 71, 4555, 194, 3328, 15144, 42, 227, 148, 8]).long()) + torch.tensor([2, 71, 4555, 194, 3328, 15144, 42, 227, 148, 8]).long()) self.assertEqual(test_dataset[0][1][:10], - torch.tensor([13, 125, 1051, 5, 246, 1652, 8, 277, 66, 20]).long()) + torch.tensor([13, 125, 1051, 5, 246, 1652, 8, 277, 66, 20]).long()) self.assertEqual(test_dataset[-1][1][:10], - torch.tensor([13, 1035, 14, 21, 28, 2, 1051, 1275, 1008, 3]).long()) + torch.tensor([13, 1035, 14, 21, 28, 2, 1051, 1275, 1008, 3]).long()) # Test API with a vocab input object old_vocab = train_dataset.get_vocab() @@ -164,13 +164,13 @@ def test_squad1(self): self.assertEqual(len(train_dataset), 87599) self.assertEqual(len(dev_dataset), 10570) self.assertEqual(train_dataset[100]['question'], - torch.tensor([7, 24, 86, 52, 2, 373, 887, 18, 12797, 11090, 1356, 2, 1788, 3273, 16]).long()) + torch.tensor([7, 24, 86, 52, 2, 373, 887, 18, 12797, 11090, 1356, 2, 1788, 3273, 16]).long()) self.assertEqual(train_dataset[100]['ans_pos'][0], - torch.tensor([72, 72]).long()) + torch.tensor([72, 72]).long()) self.assertEqual(dev_dataset[100]['question'], - torch.tensor([42, 27, 669, 7438, 17, 2, 1950, 3273, 17252, 389, 16]).long()) + torch.tensor([42, 27, 669, 7438, 17, 2, 1950, 3273, 17252, 389, 16]).long()) self.assertEqual(dev_dataset[100]['ans_pos'][0], - torch.tensor([45, 48]).long()) + torch.tensor([45, 48]).long()) # Test API with a vocab input object old_vocab = train_dataset.get_vocab() @@ -185,13 +185,13 @@ def test_squad2(self): self.assertEqual(len(train_dataset), 130319) self.assertEqual(len(dev_dataset), 11873) self.assertEqual(train_dataset[200]['question'], - torch.tensor([84, 50, 1421, 12, 5439, 4569, 17, 30, 2, 15202, 4754, 1421, 16]).long()) + torch.tensor([84, 50, 1421, 12, 5439, 4569, 17, 30, 2, 15202, 4754, 1421, 16]).long()) self.assertEqual(train_dataset[200]['ans_pos'][0], - torch.tensor([9, 9]).long()) + torch.tensor([9, 9]).long()) self.assertEqual(dev_dataset[200]['question'], - torch.tensor([41, 29, 2, 66, 17016, 30, 0, 1955, 16]).long()) + torch.tensor([41, 29, 2, 66, 17016, 30, 0, 1955, 16]).long()) self.assertEqual(dev_dataset[200]['ans_pos'][0], - torch.tensor([40, 46]).long()) + torch.tensor([40, 46]).long()) # Test API with a vocab input object old_vocab = train_dataset.get_vocab() diff --git a/test/data/test_metrics.py b/test/data/test_metrics.py index b91393fb11..af0ae77273 100644 --- a/test/data/test_metrics.py +++ b/test/data/test_metrics.py @@ -24,7 +24,7 @@ def test_bleu_score(self): candidate = [['My', 'pytorch', 'test']] refs = [[['My', 'full', 'pytorch', 'test'], ['Different']]] self.assertEqual(bleu_score(candidate, refs, max_n=2, - weights=[0.5, 0.5]), 0.5066641) + weights=[0.5, 0.5]), 0.5066641) # Multi-sentence corpus candidate = [['My', 'full', 'pytorch', 'test'], ['Another', 'Sentence']] @@ -54,10 +54,10 @@ def test_bleu_score(self): self.assertEqual(bleu_score(candidate, refs), 0.4573199) # nltk.translate.bleu_score.corpus_bleu(refs, candidate, weights=[0.33]*3) self.assertEqual(bleu_score(candidate, refs, 3, - weights=[0.33, 0.33, 0.33]), 0.4901113) + weights=[0.33, 0.33, 0.33]), 0.4901113) # nltk.translate.bleu_score.corpus_bleu(refs, candidate, weights=[0.5]*2) self.assertEqual(bleu_score(candidate, refs, 2, - weights=[0.5, 0.5]), 0.5119535) + weights=[0.5, 0.5]), 0.5119535) # nltk.translate.bleu_score.corpus_bleu(refs, candidate, weights=[1]) self.assertEqual(bleu_score(candidate, refs, 1, - weights=[1]), 0.5515605) + weights=[1]), 0.5515605) From ecfe77017a90eec1111b125c2ebd75dafbf469ba Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 11 Jun 2020 16:55:16 -0700 Subject: [PATCH 10/10] flake8 --- test/test_build.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_build.py b/test/test_build.py index cc89fd9883..d61e844280 100644 --- a/test/test_build.py +++ b/test/test_build.py @@ -3,7 +3,6 @@ import os from collections import Counter -import numpy as np import torch import torchtext.data