From 47e51f3cc4809705bf33cba7069e15243111b1b2 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 10 Feb 2022 14:53:15 -0800 Subject: [PATCH 1/4] Parameterized amazon dataset tests --- test/datasets/test_amazonreviewfull.py | 83 ------------------- ...eviewpolarity.py => test_amazonreviews.py} | 50 ++++++----- 2 files changed, 31 insertions(+), 102 deletions(-) delete mode 100644 test/datasets/test_amazonreviewfull.py rename test/datasets/{test_amazonreviewpolarity.py => test_amazonreviews.py} (55%) diff --git a/test/datasets/test_amazonreviewfull.py b/test/datasets/test_amazonreviewfull.py deleted file mode 100644 index 909c32ab59..0000000000 --- a/test/datasets/test_amazonreviewfull.py +++ /dev/null @@ -1,83 +0,0 @@ -import os -import random -import string -import tarfile -from collections import defaultdict -from unittest.mock import patch - -from parameterized import parameterized -from torchtext.datasets.amazonreviewfull import AmazonReviewFull - -from ..common.case_utils import TempDirMixin, zip_equal -from ..common.torchtext_test_case import TorchtextTestCase - - -def _get_mock_dataset(root_dir): - """ - root_dir: directory to the mocked dataset - """ - base_dir = os.path.join(root_dir, "AmazonReviewFull") - temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") - os.makedirs(temp_dataset_dir, exist_ok=True) - - seed = 1 - mocked_data = defaultdict(list) - for file_name in ("train.csv", "test.csv"): - txt_file = os.path.join(temp_dataset_dir, file_name) - with open(txt_file, "w") as f: - for i in range(5): - label = seed % 5 + 1 - rand_string = " ".join( - random.choice(string.ascii_letters) for i in range(seed) - ) - dataset_line = (label, f"{rand_string} {rand_string}") - # append line to correct dataset split - mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) - f.write(f'"{label}","{rand_string}","{rand_string}"\n') - seed += 1 - - compressed_dataset_path = os.path.join( - base_dir, "amazon_review_full_csv.tar.gz" - ) - # create tar file from dataset folder - with tarfile.open(compressed_dataset_path, "w:gz") as tar: - tar.add(temp_dataset_dir, arcname="amazon_review_full_csv") - - return mocked_data - - -class TestAmazonReviewFull(TempDirMixin, TorchtextTestCase): - root_dir = None - samples = [] - - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.root_dir = cls.get_base_temp_dir() - cls.samples = _get_mock_dataset(cls.root_dir) - cls.patcher = patch( - "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True - ) - cls.patcher.start() - - @classmethod - def tearDownClass(cls): - cls.patcher.stop() - super().tearDownClass() - - @parameterized.expand(["train", "test"]) - def test_amazon_review_full(self, split): - dataset = AmazonReviewFull(root=self.root_dir, split=split) - - samples = list(dataset) - expected_samples = self.samples[split] - for sample, expected_sample in zip_equal(samples, expected_samples): - self.assertEqual(sample, expected_sample) - - @parameterized.expand(["train", "test"]) - def test_amazon_review_full_split_argument(self, split): - dataset1 = AmazonReviewFull(root=self.root_dir, split=split) - (dataset2,) = AmazonReviewFull(root=self.root_dir, split=(split,)) - - for d1, d2 in zip_equal(dataset1, dataset2): - self.assertEqual(d1, d2) diff --git a/test/datasets/test_amazonreviewpolarity.py b/test/datasets/test_amazonreviews.py similarity index 55% rename from test/datasets/test_amazonreviewpolarity.py rename to test/datasets/test_amazonreviews.py index 11c95ae785..87dd6f9952 100644 --- a/test/datasets/test_amazonreviewpolarity.py +++ b/test/datasets/test_amazonreviews.py @@ -5,18 +5,20 @@ from collections import defaultdict from unittest.mock import patch -from parameterized import parameterized +from torchtext.datasets.amazonreviewfull import AmazonReviewFull from torchtext.datasets.amazonreviewpolarity import AmazonReviewPolarity from ..common.case_utils import TempDirMixin, zip_equal +from ..common.parameterized_utils import nested_params from ..common.torchtext_test_case import TorchtextTestCase -def _get_mock_dataset(root_dir): +def _get_mock_dataset(root_dir, base_dir_name): """ root_dir: directory to the mocked dataset + base_dir_name: AmazonReviewFull or AmazonReviewPolarity """ - base_dir = os.path.join(root_dir, "AmazonReviewPolarity") + base_dir = os.path.join(root_dir, base_dir_name) temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") os.makedirs(temp_dataset_dir, exist_ok=True) @@ -26,7 +28,10 @@ def _get_mock_dataset(root_dir): txt_file = os.path.join(temp_dataset_dir, file_name) with open(txt_file, "w") as f: for i in range(5): - label = seed % 2 + 1 + if base_dir_name == AmazonReviewFull.__name__: + label = seed % 5 + 1 + else: + label = seed % 2 + 1 rand_string = " ".join( random.choice(string.ascii_letters) for i in range(seed) ) @@ -36,17 +41,20 @@ def _get_mock_dataset(root_dir): f.write(f'"{label}","{rand_string}","{rand_string}"\n') seed += 1 - compressed_dataset_path = os.path.join( - base_dir, "amazon_review_polarity_csv.tar.gz" - ) + if base_dir_name == AmazonReviewFull.__name__: + archive_file_name = "amazon_review_full_csv" + else: + archive_file_name = "amazon_review_polarity_csv" + + compressed_dataset_path = os.path.join(base_dir, f"{archive_file_name}.tar.gz") # create tar file from dataset folder with tarfile.open(compressed_dataset_path, "w:gz") as tar: - tar.add(temp_dataset_dir, arcname="amazon_review_polarity_csv") + tar.add(temp_dataset_dir, arcname=archive_file_name) return mocked_data -class TestAmazonReviewPolarity(TempDirMixin, TorchtextTestCase): +class TestAmazonReviews(TempDirMixin, TorchtextTestCase): root_dir = None samples = [] @@ -54,7 +62,6 @@ class TestAmazonReviewPolarity(TempDirMixin, TorchtextTestCase): def setUpClass(cls): super().setUpClass() cls.root_dir = cls.get_base_temp_dir() - cls.samples = _get_mock_dataset(cls.root_dir) cls.patcher = patch( "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True ) @@ -65,19 +72,24 @@ def tearDownClass(cls): cls.patcher.stop() super().tearDownClass() - @parameterized.expand(["train", "test"]) - def test_amazon_review_polarity(self, split): - dataset = AmazonReviewPolarity(root=self.root_dir, split=split) - + @nested_params([AmazonReviewFull, AmazonReviewPolarity], ["train", "test"]) + def test_amazon_reviews(self, amazon_review_dataset, split): + expected_samples = _get_mock_dataset( + self.root_dir, amazon_review_dataset.__name__ + )[split] + dataset = amazon_review_dataset(root=self.root_dir, split=split) samples = list(dataset) - expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): self.assertEqual(sample, expected_sample) - @parameterized.expand(["train", "test"]) - def test_amazon_review_polarity_split_argument(self, split): - dataset1 = AmazonReviewPolarity(root=self.root_dir, split=split) - (dataset2,) = AmazonReviewPolarity(root=self.root_dir, split=(split,)) + @nested_params([AmazonReviewFull, AmazonReviewPolarity], ["train", "test"]) + def test_amazon_reviews_split_argument(self, amazon_review_dataset, split): + # call `_get_mock_dataset` to create mock dataset files + _ = _get_mock_dataset(self.root_dir, amazon_review_dataset.__name__) + + dataset1 = amazon_review_dataset(root=self.root_dir, split=split) + (dataset2,) = amazon_review_dataset(root=self.root_dir, split=(split,)) for d1, d2 in zip_equal(dataset1, dataset2): self.assertEqual(d1, d2) From b066cde94be23557313566d575ce9157b74df33e Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 10 Feb 2022 14:54:16 -0800 Subject: [PATCH 2/4] Renamed squad test for consistency --- test/datasets/{test_squad.py => test_squads.py} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename test/datasets/{test_squad.py => test_squads.py} (95%) diff --git a/test/datasets/test_squad.py b/test/datasets/test_squads.py similarity index 95% rename from test/datasets/test_squad.py rename to test/datasets/test_squads.py index d44b6637f1..eb1abd19be 100644 --- a/test/datasets/test_squad.py +++ b/test/datasets/test_squads.py @@ -73,7 +73,7 @@ def _get_mock_dataset(root_dir, base_dir_name): return mocked_data -class TestSQuAD(TempDirMixin, TorchtextTestCase): +class TestSQuADs(TempDirMixin, TorchtextTestCase): root_dir = None samples = [] @@ -92,7 +92,7 @@ def tearDownClass(cls): super().tearDownClass() @nested_params([SQuAD1, SQuAD2], ["train", "dev"]) - def test_squad(self, squad_dataset, split): + def test_squads(self, squad_dataset, split): expected_samples = _get_mock_dataset(self.root_dir, squad_dataset.__name__)[ split ] @@ -103,7 +103,7 @@ def test_squad(self, squad_dataset, split): self.assertEqual(sample, expected_sample) @nested_params([SQuAD1, SQuAD2], ["train", "dev"]) - def test_squad_split_argument(self, squad_dataset, split): + def test_squads_split_argument(self, squad_dataset, split): # call `_get_mock_dataset` to create mock dataset files _ = _get_mock_dataset(self.root_dir, squad_dataset.__name__) From 6fe8cfc4af613c31e8894f41738bb3f199348aa6 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 10 Feb 2022 14:55:48 -0800 Subject: [PATCH 3/4] Deleted YelpReviewFull test since it's already parameterized --- test/datasets/test_yelpreviewfull.py | 83 ---------------------------- 1 file changed, 83 deletions(-) delete mode 100644 test/datasets/test_yelpreviewfull.py diff --git a/test/datasets/test_yelpreviewfull.py b/test/datasets/test_yelpreviewfull.py deleted file mode 100644 index d0a7e13a7c..0000000000 --- a/test/datasets/test_yelpreviewfull.py +++ /dev/null @@ -1,83 +0,0 @@ -import os -import random -import string -import tarfile -from collections import defaultdict -from unittest.mock import patch - -from parameterized import parameterized -from torchtext.datasets.yelpreviewfull import YelpReviewFull - -from ..common.case_utils import TempDirMixin, zip_equal -from ..common.torchtext_test_case import TorchtextTestCase - - -def _get_mock_dataset(root_dir): - """ - root_dir: directory to the mocked dataset - """ - base_dir = os.path.join(root_dir, "YelpReviewFull") - temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") - os.makedirs(temp_dataset_dir, exist_ok=True) - - seed = 1 - mocked_data = defaultdict(list) - for file_name in ("train.csv", "test.csv"): - csv_file = os.path.join(temp_dataset_dir, file_name) - mocked_lines = mocked_data[os.path.splitext(file_name)[0]] - with open(csv_file, "w") as f: - for i in range(5): - label = seed % 5 + 1 - rand_string = " ".join( - random.choice(string.ascii_letters) for i in range(seed) - ) - dataset_line = (label, f"{rand_string}") - f.write(f'"{label}","{rand_string}"\n') - - # append line to correct dataset split - mocked_lines.append(dataset_line) - seed += 1 - - compressed_dataset_path = os.path.join(base_dir, "yelp_review_full_csv.tar.gz") - # create gz file from dataset folder - with tarfile.open(compressed_dataset_path, "w:gz") as tar: - tar.add(temp_dataset_dir, arcname="yelp_review_full_csv") - - return mocked_data - - -class TestYelpReviewFull(TempDirMixin, TorchtextTestCase): - root_dir = None - samples = [] - - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.root_dir = cls.get_base_temp_dir() - cls.samples = _get_mock_dataset(cls.root_dir) - cls.patcher = patch( - "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True - ) - cls.patcher.start() - - @classmethod - def tearDownClass(cls): - cls.patcher.stop() - super().tearDownClass() - - @parameterized.expand(["train", "test"]) - def test_yelpreviewfull(self, split): - dataset = YelpReviewFull(root=self.root_dir, split=split) - - samples = list(dataset) - expected_samples = self.samples[split] - for sample, expected_sample in zip_equal(samples, expected_samples): - self.assertEqual(sample, expected_sample) - - @parameterized.expand(["train", "test"]) - def test_yelpreviewfull_split_argument(self, split): - dataset1 = YelpReviewFull(root=self.root_dir, split=split) - (dataset2,) = YelpReviewFull(root=self.root_dir, split=(split,)) - - for d1, d2 in zip_equal(dataset1, dataset2): - self.assertEqual(d1, d2) From 2237a9127ed09925da27266e6b66ca1a5d5ac82c Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 10 Feb 2022 15:13:57 -0800 Subject: [PATCH 4/4] Updated formatting for datasets --- test/datasets/test_cc100.py | 7 +- test/datasets/test_conll2000chunking.py | 22 ++++-- test/datasets/test_enwik9.py | 10 +-- test/datasets/test_iwslt2016.py | 92 ++++++++++++++++++------- test/datasets/test_iwslt2017.py | 71 +++++++++++++------ test/datasets/test_multi30k.py | 18 +++-- test/datasets/test_sogounews.py | 4 +- test/datasets/test_udpos.py | 21 ++++-- test/datasets/test_wikitexts.py | 8 ++- test/datasets/test_yelpreviews.py | 8 ++- 10 files changed, 184 insertions(+), 77 deletions(-) diff --git a/test/datasets/test_cc100.py b/test/datasets/test_cc100.py index ef7731738a..978d447f67 100644 --- a/test/datasets/test_cc100.py +++ b/test/datasets/test_cc100.py @@ -1,18 +1,17 @@ +import lzma import os import random import string -import lzma -from parameterized import parameterized from collections import defaultdict from unittest.mock import patch +from parameterized import parameterized from torchtext.datasets import CC100 +from torchtext.datasets.cc100 import VALID_CODES from ..common.case_utils import TempDirMixin, zip_equal from ..common.torchtext_test_case import TorchtextTestCase -from torchtext.datasets.cc100 import VALID_CODES - def _get_mock_dataset(root_dir): """ diff --git a/test/datasets/test_conll2000chunking.py b/test/datasets/test_conll2000chunking.py index cfbfe730ea..29ae228ba2 100644 --- a/test/datasets/test_conll2000chunking.py +++ b/test/datasets/test_conll2000chunking.py @@ -1,7 +1,7 @@ +import gzip import os import random import string -import gzip from collections import defaultdict from unittest.mock import patch @@ -27,11 +27,19 @@ def _get_mock_dataset(root_dir): mocked_lines = mocked_data[os.path.splitext(file_name)[0]] with open(txt_file, "w") as f: for i in range(5): - rand_strings = [random.choice(string.ascii_letters) for i in range(seed)] - rand_label_1 = [random.choice(string.ascii_letters) for i in range(seed)] - rand_label_2 = [random.choice(string.ascii_letters) for i in range(seed)] + rand_strings = [ + random.choice(string.ascii_letters) for i in range(seed) + ] + rand_label_1 = [ + random.choice(string.ascii_letters) for i in range(seed) + ] + rand_label_2 = [ + random.choice(string.ascii_letters) for i in range(seed) + ] # one token per line (each sample ends with an extra \n) - for rand_string, label_1, label_2 in zip(rand_strings, rand_label_1, rand_label_2): + for rand_string, label_1, label_2 in zip( + rand_strings, rand_label_1, rand_label_2 + ): f.write(f"{rand_string} {label_1} {label_2}\n") f.write("\n") dataset_line = (rand_strings, rand_label_1, rand_label_2) @@ -41,7 +49,9 @@ def _get_mock_dataset(root_dir): # create gz file from dataset folder compressed_dataset_path = os.path.join(base_dir, f"{file_name}.gz") - with gzip.open(compressed_dataset_path, "wb") as gz_file, open(txt_file, "rb") as file_in: + with gzip.open(compressed_dataset_path, "wb") as gz_file, open( + txt_file, "rb" + ) as file_in: gz_file.writelines(file_in) return mocked_data diff --git a/test/datasets/test_enwik9.py b/test/datasets/test_enwik9.py index 16fe349bee..71854c4070 100644 --- a/test/datasets/test_enwik9.py +++ b/test/datasets/test_enwik9.py @@ -24,10 +24,12 @@ def _get_mock_dataset(root_dir): mocked_data = [] with open(txt_file, "w") as f: for i in range(5): - rand_string = "<" + " ".join( - random.choice(string.ascii_letters) for i in range(seed) - ) + ">" - dataset_line = (f"'{rand_string}'") + rand_string = ( + "<" + + " ".join(random.choice(string.ascii_letters) for i in range(seed)) + + ">" + ) + dataset_line = f"'{rand_string}'" f.write(f"'{rand_string}'\n") # append line to correct dataset split diff --git a/test/datasets/test_iwslt2016.py b/test/datasets/test_iwslt2016.py index f2e71e8595..03b681f6f1 100644 --- a/test/datasets/test_iwslt2016.py +++ b/test/datasets/test_iwslt2016.py @@ -1,23 +1,34 @@ +import itertools import os import random import shutil import string import tarfile -import itertools import tempfile from collections import defaultdict from unittest.mock import patch from parameterized import parameterized -from torchtext.datasets.iwslt2016 import DATASET_NAME, IWSLT2016, SUPPORTED_DATASETS, SET_NOT_EXISTS from torchtext.data.datasets_utils import _generate_iwslt_files_for_lang_and_split +from torchtext.datasets.iwslt2016 import ( + DATASET_NAME, + IWSLT2016, + SUPPORTED_DATASETS, + SET_NOT_EXISTS, +) from ..common.case_utils import zip_equal from ..common.torchtext_test_case import TorchtextTestCase -SUPPORTED_LANGPAIRS = [(k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v] +SUPPORTED_LANGPAIRS = [ + (k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v +] SUPPORTED_DEVTEST_SPLITS = SUPPORTED_DATASETS["valid_test"] -DEV_TEST_SPLITS = [(dev, test) for dev, test in itertools.product(SUPPORTED_DEVTEST_SPLITS, repeat=2) if dev != test] +DEV_TEST_SPLITS = [ + (dev, test) + for dev, test in itertools.product(SUPPORTED_DEVTEST_SPLITS, repeat=2) + if dev != test +] def _generate_uncleaned_train(): @@ -25,13 +36,19 @@ def _generate_uncleaned_train(): file_contents = [] examples = [] xml_tags = [ - ' en) lang = os.path.splitext(unclean_file_name)[1][1:] @@ -144,20 +163,31 @@ def tearDownClass(cls): cls.patcher.stop() super().tearDownClass() - @parameterized.expand([ - (split, src, tgt, dev_set, test_set) - for split in ("train", "valid", "test") - for dev_set, test_set in DEV_TEST_SPLITS - for src, tgt in SUPPORTED_LANGPAIRS - if (dev_set not in SET_NOT_EXISTS[(src, tgt)] and test_set not in SET_NOT_EXISTS[(src, tgt)]) - ]) + @parameterized.expand( + [ + (split, src, tgt, dev_set, test_set) + for split in ("train", "valid", "test") + for dev_set, test_set in DEV_TEST_SPLITS + for src, tgt in SUPPORTED_LANGPAIRS + if ( + dev_set not in SET_NOT_EXISTS[(src, tgt)] + and test_set not in SET_NOT_EXISTS[(src, tgt)] + ) + ] + ) def test_iwslt2016(self, split, src, tgt, dev_set, test_set): with tempfile.TemporaryDirectory() as root_dir: - expected_samples = _get_mock_dataset(root_dir, split, src, tgt, dev_set, test_set) + expected_samples = _get_mock_dataset( + root_dir, split, src, tgt, dev_set, test_set + ) dataset = IWSLT2016( - root=root_dir, split=split, language_pair=(src, tgt), valid_set=dev_set, test_set=test_set + root=root_dir, + split=split, + language_pair=(src, tgt), + valid_set=dev_set, + test_set=test_set, ) samples = list(dataset) @@ -171,9 +201,23 @@ def test_iwslt2016_split_argument(self, split): language_pair = ("de", "en") valid_set = "tst2013" test_set = "tst2014" - _ = _get_mock_dataset(root_dir, split, language_pair[0], language_pair[1], valid_set, test_set) - dataset1 = IWSLT2016(root=root_dir, split=split, language_pair=language_pair, valid_set=valid_set, test_set=test_set) - (dataset2,) = IWSLT2016(root=root_dir, split=(split,), language_pair=language_pair, valid_set=valid_set, test_set=test_set) + _ = _get_mock_dataset( + root_dir, split, language_pair[0], language_pair[1], valid_set, test_set + ) + dataset1 = IWSLT2016( + root=root_dir, + split=split, + language_pair=language_pair, + valid_set=valid_set, + test_set=test_set, + ) + (dataset2,) = IWSLT2016( + root=root_dir, + split=(split,), + language_pair=language_pair, + valid_set=valid_set, + test_set=test_set, + ) for d1, d2 in zip_equal(dataset1, dataset2): self.assertEqual(d1, d2) diff --git a/test/datasets/test_iwslt2017.py b/test/datasets/test_iwslt2017.py index e5595821b0..375ca4525d 100644 --- a/test/datasets/test_iwslt2017.py +++ b/test/datasets/test_iwslt2017.py @@ -8,13 +8,20 @@ from unittest.mock import patch from parameterized import parameterized -from torchtext.datasets.iwslt2017 import DATASET_NAME, IWSLT2017, SUPPORTED_DATASETS, _PATH from torchtext.data.datasets_utils import _generate_iwslt_files_for_lang_and_split +from torchtext.datasets.iwslt2017 import ( + DATASET_NAME, + IWSLT2017, + SUPPORTED_DATASETS, + _PATH, +) from ..common.case_utils import zip_equal from ..common.torchtext_test_case import TorchtextTestCase -SUPPORTED_LANGPAIRS = [(k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v] +SUPPORTED_LANGPAIRS = [ + (k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v +] def _generate_uncleaned_train(): @@ -22,13 +29,19 @@ def _generate_uncleaned_train(): file_contents = [] examples = [] xml_tags = [ - ' en) lang = os.path.splitext(unclean_file_name)[1][1:] @@ -141,15 +160,19 @@ def tearDownClass(cls): cls.patcher.stop() super().tearDownClass() - @parameterized.expand([ - (split, src, tgt) - for split in ("train", "valid", "test") - for src, tgt in SUPPORTED_LANGPAIRS - ]) + @parameterized.expand( + [ + (split, src, tgt) + for split in ("train", "valid", "test") + for src, tgt in SUPPORTED_LANGPAIRS + ] + ) def test_iwslt2017(self, split, src, tgt): with tempfile.TemporaryDirectory() as root_dir: - expected_samples = _get_mock_dataset(root_dir, split, src, tgt, "dev2010", "tst2010") + expected_samples = _get_mock_dataset( + root_dir, split, src, tgt, "dev2010", "tst2010" + ) dataset = IWSLT2017(root=root_dir, split=split, language_pair=(src, tgt)) @@ -164,9 +187,15 @@ def test_iwslt2017_split_argument(self, split): language_pair = ("de", "en") valid_set = "dev2010" test_set = "tst2010" - _ = _get_mock_dataset(root_dir, split, language_pair[0], language_pair[1], valid_set, test_set) - dataset1 = IWSLT2017(root=root_dir, split=split, language_pair=language_pair) - (dataset2,) = IWSLT2017(root=root_dir, split=(split,), language_pair=language_pair) + _ = _get_mock_dataset( + root_dir, split, language_pair[0], language_pair[1], valid_set, test_set + ) + dataset1 = IWSLT2017( + root=root_dir, split=split, language_pair=language_pair + ) + (dataset2,) = IWSLT2017( + root=root_dir, split=(split,), language_pair=language_pair + ) for d1, d2 in zip_equal(dataset1, dataset2): self.assertEqual(d1, d2) diff --git a/test/datasets/test_multi30k.py b/test/datasets/test_multi30k.py index 6782a21753..d0e9e96c04 100644 --- a/test/datasets/test_multi30k.py +++ b/test/datasets/test_multi30k.py @@ -5,10 +5,10 @@ from collections import defaultdict from unittest.mock import patch -from ..common.parameterized_utils import nested_params from torchtext.datasets import Multi30k from ..common.case_utils import TempDirMixin, zip_equal +from ..common.parameterized_utils import nested_params from ..common.torchtext_test_case import TorchtextTestCase @@ -68,14 +68,24 @@ def test_multi30k(self, split, language_pair): if split == "valid": split = "val" samples = list(dataset) - expected_samples = [(d1, d2) for d1, d2 in zip(self.samples[f'{split}.{language_pair[0]}'], self.samples[f'{split}.{language_pair[1]}'])] + expected_samples = [ + (d1, d2) + for d1, d2 in zip( + self.samples[f"{split}.{language_pair[0]}"], + self.samples[f"{split}.{language_pair[1]}"], + ) + ] for sample, expected_sample in zip_equal(samples, expected_samples): self.assertEqual(sample, expected_sample) @nested_params(["train", "valid", "test"], [("de", "en"), ("en", "de")]) def test_multi30k_split_argument(self, split, language_pair): - dataset1 = Multi30k(root=self.root_dir, split=split, language_pair=language_pair) - (dataset2,) = Multi30k(root=self.root_dir, split=(split,), language_pair=language_pair) + dataset1 = Multi30k( + root=self.root_dir, split=split, language_pair=language_pair + ) + (dataset2,) = Multi30k( + root=self.root_dir, split=(split,), language_pair=language_pair + ) for d1, d2 in zip_equal(dataset1, dataset2): self.assertEqual(d1, d2) diff --git a/test/datasets/test_sogounews.py b/test/datasets/test_sogounews.py index b0c06b4e35..95b53f87f1 100644 --- a/test/datasets/test_sogounews.py +++ b/test/datasets/test_sogounews.py @@ -36,9 +36,7 @@ def _get_mock_dataset(root_dir): f.write(f'"{label}","{rand_string}","{rand_string}"\n') seed += 1 - compressed_dataset_path = os.path.join( - base_dir, "sogou_news_csv.tar.gz" - ) + compressed_dataset_path = os.path.join(base_dir, "sogou_news_csv.tar.gz") # create tar file from dataset folder with tarfile.open(compressed_dataset_path, "w:gz") as tar: tar.add(temp_dataset_dir, arcname="sogou_news_csv") diff --git a/test/datasets/test_udpos.py b/test/datasets/test_udpos.py index b66a6c681d..455a7cc019 100644 --- a/test/datasets/test_udpos.py +++ b/test/datasets/test_udpos.py @@ -27,11 +27,20 @@ def _get_mock_dataset(root_dir): mocked_lines = mocked_data[os.path.splitext(file_name)[0]] with open(txt_file, "w") as f: for i in range(5): - rand_strings = ["".join(random.sample(string.ascii_letters, random.randint(1, 10))) for i in range(seed)] - rand_label_1 = [random.choice(string.ascii_letters) for i in range(seed)] - rand_label_2 = [random.choice(string.ascii_letters) for i in range(seed)] + rand_strings = [ + "".join(random.sample(string.ascii_letters, random.randint(1, 10))) + for i in range(seed) + ] + rand_label_1 = [ + random.choice(string.ascii_letters) for i in range(seed) + ] + rand_label_2 = [ + random.choice(string.ascii_letters) for i in range(seed) + ] # one token per line (each sample ends with an extra \n) - for rand_string, label_1, label_2 in zip(rand_strings, rand_label_1, rand_label_2): + for rand_string, label_1, label_2 in zip( + rand_strings, rand_label_1, rand_label_2 + ): f.write(f"{rand_string}\t{label_1}\t{label_2}\n") f.write("\n") dataset_line = (rand_strings, rand_label_1, rand_label_2) @@ -73,7 +82,9 @@ def tearDownClass(cls): def test_udpos(self, split): dataset = UDPOS(root=self.root_dir, split=split) samples = list(dataset) - expected_samples = self.samples[split] if split != "valid" else self.samples["dev"] + expected_samples = ( + self.samples[split] if split != "valid" else self.samples["dev"] + ) for sample, expected_sample in zip_equal(samples, expected_samples): self.assertEqual(sample, expected_sample) diff --git a/test/datasets/test_wikitexts.py b/test/datasets/test_wikitexts.py index 616c98f472..36c26db027 100644 --- a/test/datasets/test_wikitexts.py +++ b/test/datasets/test_wikitexts.py @@ -5,11 +5,11 @@ from collections import defaultdict from unittest.mock import patch -from ..common.parameterized_utils import nested_params from torchtext.datasets.wikitext103 import WikiText103 from torchtext.datasets.wikitext2 import WikiText2 from ..common.case_utils import TempDirMixin, zip_equal +from ..common.parameterized_utils import nested_params from ..common.torchtext_test_case import TorchtextTestCase @@ -34,7 +34,7 @@ def _get_mock_dataset(root_dir, base_dir_name): random.choice(string.ascii_letters) for i in range(seed) ) dataset_line = rand_string - f.write(f'{rand_string}\n') + f.write(f"{rand_string}\n") # append line to correct dataset split mocked_lines.append(dataset_line) @@ -75,7 +75,9 @@ def tearDownClass(cls): @nested_params([WikiText103, WikiText2], ["train", "valid", "test"]) def test_wikitexts(self, wikitext_dataset, split): - expected_samples = _get_mock_dataset(self.root_dir, base_dir_name=wikitext_dataset.__name__)[split] + expected_samples = _get_mock_dataset( + self.root_dir, base_dir_name=wikitext_dataset.__name__ + )[split] dataset = wikitext_dataset(root=self.root_dir, split=split) samples = list(dataset) diff --git a/test/datasets/test_yelpreviews.py b/test/datasets/test_yelpreviews.py index 241b52b7fa..3f4ccde54f 100644 --- a/test/datasets/test_yelpreviews.py +++ b/test/datasets/test_yelpreviews.py @@ -5,11 +5,11 @@ from collections import defaultdict from unittest.mock import patch -from ..common.parameterized_utils import nested_params -from torchtext.datasets.yelpreviewpolarity import YelpReviewPolarity from torchtext.datasets.yelpreviewfull import YelpReviewFull +from torchtext.datasets.yelpreviewpolarity import YelpReviewPolarity from ..common.case_utils import TempDirMixin, zip_equal +from ..common.parameterized_utils import nested_params from ..common.torchtext_test_case import TorchtextTestCase @@ -76,7 +76,9 @@ def tearDownClass(cls): @nested_params([YelpReviewPolarity, YelpReviewFull], ["train", "test"]) def test_yelpreviews(self, yelp_dataset, split): - expected_samples = _get_mock_dataset(self.root_dir, base_dir_name=yelp_dataset.__name__)[split] + expected_samples = _get_mock_dataset( + self.root_dir, base_dir_name=yelp_dataset.__name__ + )[split] dataset = yelp_dataset(root=self.root_dir, split=split) samples = list(dataset)