diff --git a/test/datasets/test_cc100.py b/test/datasets/test_cc100.py index ef7731738a..978d447f67 100644 --- a/test/datasets/test_cc100.py +++ b/test/datasets/test_cc100.py @@ -1,18 +1,17 @@ +import lzma import os import random import string -import lzma -from parameterized import parameterized from collections import defaultdict from unittest.mock import patch +from parameterized import parameterized from torchtext.datasets import CC100 +from torchtext.datasets.cc100 import VALID_CODES from ..common.case_utils import TempDirMixin, zip_equal from ..common.torchtext_test_case import TorchtextTestCase -from torchtext.datasets.cc100 import VALID_CODES - def _get_mock_dataset(root_dir): """ diff --git a/test/datasets/test_conll2000chunking.py b/test/datasets/test_conll2000chunking.py index cfbfe730ea..29ae228ba2 100644 --- a/test/datasets/test_conll2000chunking.py +++ b/test/datasets/test_conll2000chunking.py @@ -1,7 +1,7 @@ +import gzip import os import random import string -import gzip from collections import defaultdict from unittest.mock import patch @@ -27,11 +27,19 @@ def _get_mock_dataset(root_dir): mocked_lines = mocked_data[os.path.splitext(file_name)[0]] with open(txt_file, "w") as f: for i in range(5): - rand_strings = [random.choice(string.ascii_letters) for i in range(seed)] - rand_label_1 = [random.choice(string.ascii_letters) for i in range(seed)] - rand_label_2 = [random.choice(string.ascii_letters) for i in range(seed)] + rand_strings = [ + random.choice(string.ascii_letters) for i in range(seed) + ] + rand_label_1 = [ + random.choice(string.ascii_letters) for i in range(seed) + ] + rand_label_2 = [ + random.choice(string.ascii_letters) for i in range(seed) + ] # one token per line (each sample ends with an extra \n) - for rand_string, label_1, label_2 in zip(rand_strings, rand_label_1, rand_label_2): + for rand_string, label_1, label_2 in zip( + rand_strings, rand_label_1, rand_label_2 + ): f.write(f"{rand_string} {label_1} {label_2}\n") f.write("\n") dataset_line = (rand_strings, rand_label_1, rand_label_2) @@ -41,7 +49,9 @@ def _get_mock_dataset(root_dir): # create gz file from dataset folder compressed_dataset_path = os.path.join(base_dir, f"{file_name}.gz") - with gzip.open(compressed_dataset_path, "wb") as gz_file, open(txt_file, "rb") as file_in: + with gzip.open(compressed_dataset_path, "wb") as gz_file, open( + txt_file, "rb" + ) as file_in: gz_file.writelines(file_in) return mocked_data diff --git a/test/datasets/test_enwik9.py b/test/datasets/test_enwik9.py index 16fe349bee..71854c4070 100644 --- a/test/datasets/test_enwik9.py +++ b/test/datasets/test_enwik9.py @@ -24,10 +24,12 @@ def _get_mock_dataset(root_dir): mocked_data = [] with open(txt_file, "w") as f: for i in range(5): - rand_string = "<" + " ".join( - random.choice(string.ascii_letters) for i in range(seed) - ) + ">" - dataset_line = (f"'{rand_string}'") + rand_string = ( + "<" + + " ".join(random.choice(string.ascii_letters) for i in range(seed)) + + ">" + ) + dataset_line = f"'{rand_string}'" f.write(f"'{rand_string}'\n") # append line to correct dataset split diff --git a/test/datasets/test_iwslt2016.py b/test/datasets/test_iwslt2016.py index f2e71e8595..03b681f6f1 100644 --- a/test/datasets/test_iwslt2016.py +++ b/test/datasets/test_iwslt2016.py @@ -1,23 +1,34 @@ +import itertools import os import random import shutil import string import tarfile -import itertools import tempfile from collections import defaultdict from unittest.mock import patch from parameterized import parameterized -from torchtext.datasets.iwslt2016 import DATASET_NAME, IWSLT2016, SUPPORTED_DATASETS, SET_NOT_EXISTS from torchtext.data.datasets_utils import _generate_iwslt_files_for_lang_and_split +from torchtext.datasets.iwslt2016 import ( + DATASET_NAME, + IWSLT2016, + SUPPORTED_DATASETS, + SET_NOT_EXISTS, +) from ..common.case_utils import zip_equal from ..common.torchtext_test_case import TorchtextTestCase -SUPPORTED_LANGPAIRS = [(k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v] +SUPPORTED_LANGPAIRS = [ + (k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v +] SUPPORTED_DEVTEST_SPLITS = SUPPORTED_DATASETS["valid_test"] -DEV_TEST_SPLITS = [(dev, test) for dev, test in itertools.product(SUPPORTED_DEVTEST_SPLITS, repeat=2) if dev != test] +DEV_TEST_SPLITS = [ + (dev, test) + for dev, test in itertools.product(SUPPORTED_DEVTEST_SPLITS, repeat=2) + if dev != test +] def _generate_uncleaned_train(): @@ -25,13 +36,19 @@ def _generate_uncleaned_train(): file_contents = [] examples = [] xml_tags = [ - ' en) lang = os.path.splitext(unclean_file_name)[1][1:] @@ -144,20 +163,31 @@ def tearDownClass(cls): cls.patcher.stop() super().tearDownClass() - @parameterized.expand([ - (split, src, tgt, dev_set, test_set) - for split in ("train", "valid", "test") - for dev_set, test_set in DEV_TEST_SPLITS - for src, tgt in SUPPORTED_LANGPAIRS - if (dev_set not in SET_NOT_EXISTS[(src, tgt)] and test_set not in SET_NOT_EXISTS[(src, tgt)]) - ]) + @parameterized.expand( + [ + (split, src, tgt, dev_set, test_set) + for split in ("train", "valid", "test") + for dev_set, test_set in DEV_TEST_SPLITS + for src, tgt in SUPPORTED_LANGPAIRS + if ( + dev_set not in SET_NOT_EXISTS[(src, tgt)] + and test_set not in SET_NOT_EXISTS[(src, tgt)] + ) + ] + ) def test_iwslt2016(self, split, src, tgt, dev_set, test_set): with tempfile.TemporaryDirectory() as root_dir: - expected_samples = _get_mock_dataset(root_dir, split, src, tgt, dev_set, test_set) + expected_samples = _get_mock_dataset( + root_dir, split, src, tgt, dev_set, test_set + ) dataset = IWSLT2016( - root=root_dir, split=split, language_pair=(src, tgt), valid_set=dev_set, test_set=test_set + root=root_dir, + split=split, + language_pair=(src, tgt), + valid_set=dev_set, + test_set=test_set, ) samples = list(dataset) @@ -171,9 +201,23 @@ def test_iwslt2016_split_argument(self, split): language_pair = ("de", "en") valid_set = "tst2013" test_set = "tst2014" - _ = _get_mock_dataset(root_dir, split, language_pair[0], language_pair[1], valid_set, test_set) - dataset1 = IWSLT2016(root=root_dir, split=split, language_pair=language_pair, valid_set=valid_set, test_set=test_set) - (dataset2,) = IWSLT2016(root=root_dir, split=(split,), language_pair=language_pair, valid_set=valid_set, test_set=test_set) + _ = _get_mock_dataset( + root_dir, split, language_pair[0], language_pair[1], valid_set, test_set + ) + dataset1 = IWSLT2016( + root=root_dir, + split=split, + language_pair=language_pair, + valid_set=valid_set, + test_set=test_set, + ) + (dataset2,) = IWSLT2016( + root=root_dir, + split=(split,), + language_pair=language_pair, + valid_set=valid_set, + test_set=test_set, + ) for d1, d2 in zip_equal(dataset1, dataset2): self.assertEqual(d1, d2) diff --git a/test/datasets/test_iwslt2017.py b/test/datasets/test_iwslt2017.py index e5595821b0..375ca4525d 100644 --- a/test/datasets/test_iwslt2017.py +++ b/test/datasets/test_iwslt2017.py @@ -8,13 +8,20 @@ from unittest.mock import patch from parameterized import parameterized -from torchtext.datasets.iwslt2017 import DATASET_NAME, IWSLT2017, SUPPORTED_DATASETS, _PATH from torchtext.data.datasets_utils import _generate_iwslt_files_for_lang_and_split +from torchtext.datasets.iwslt2017 import ( + DATASET_NAME, + IWSLT2017, + SUPPORTED_DATASETS, + _PATH, +) from ..common.case_utils import zip_equal from ..common.torchtext_test_case import TorchtextTestCase -SUPPORTED_LANGPAIRS = [(k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v] +SUPPORTED_LANGPAIRS = [ + (k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v +] def _generate_uncleaned_train(): @@ -22,13 +29,19 @@ def _generate_uncleaned_train(): file_contents = [] examples = [] xml_tags = [ - ' en) lang = os.path.splitext(unclean_file_name)[1][1:] @@ -141,15 +160,19 @@ def tearDownClass(cls): cls.patcher.stop() super().tearDownClass() - @parameterized.expand([ - (split, src, tgt) - for split in ("train", "valid", "test") - for src, tgt in SUPPORTED_LANGPAIRS - ]) + @parameterized.expand( + [ + (split, src, tgt) + for split in ("train", "valid", "test") + for src, tgt in SUPPORTED_LANGPAIRS + ] + ) def test_iwslt2017(self, split, src, tgt): with tempfile.TemporaryDirectory() as root_dir: - expected_samples = _get_mock_dataset(root_dir, split, src, tgt, "dev2010", "tst2010") + expected_samples = _get_mock_dataset( + root_dir, split, src, tgt, "dev2010", "tst2010" + ) dataset = IWSLT2017(root=root_dir, split=split, language_pair=(src, tgt)) @@ -164,9 +187,15 @@ def test_iwslt2017_split_argument(self, split): language_pair = ("de", "en") valid_set = "dev2010" test_set = "tst2010" - _ = _get_mock_dataset(root_dir, split, language_pair[0], language_pair[1], valid_set, test_set) - dataset1 = IWSLT2017(root=root_dir, split=split, language_pair=language_pair) - (dataset2,) = IWSLT2017(root=root_dir, split=(split,), language_pair=language_pair) + _ = _get_mock_dataset( + root_dir, split, language_pair[0], language_pair[1], valid_set, test_set + ) + dataset1 = IWSLT2017( + root=root_dir, split=split, language_pair=language_pair + ) + (dataset2,) = IWSLT2017( + root=root_dir, split=(split,), language_pair=language_pair + ) for d1, d2 in zip_equal(dataset1, dataset2): self.assertEqual(d1, d2) diff --git a/test/datasets/test_multi30k.py b/test/datasets/test_multi30k.py index 6782a21753..d0e9e96c04 100644 --- a/test/datasets/test_multi30k.py +++ b/test/datasets/test_multi30k.py @@ -5,10 +5,10 @@ from collections import defaultdict from unittest.mock import patch -from ..common.parameterized_utils import nested_params from torchtext.datasets import Multi30k from ..common.case_utils import TempDirMixin, zip_equal +from ..common.parameterized_utils import nested_params from ..common.torchtext_test_case import TorchtextTestCase @@ -68,14 +68,24 @@ def test_multi30k(self, split, language_pair): if split == "valid": split = "val" samples = list(dataset) - expected_samples = [(d1, d2) for d1, d2 in zip(self.samples[f'{split}.{language_pair[0]}'], self.samples[f'{split}.{language_pair[1]}'])] + expected_samples = [ + (d1, d2) + for d1, d2 in zip( + self.samples[f"{split}.{language_pair[0]}"], + self.samples[f"{split}.{language_pair[1]}"], + ) + ] for sample, expected_sample in zip_equal(samples, expected_samples): self.assertEqual(sample, expected_sample) @nested_params(["train", "valid", "test"], [("de", "en"), ("en", "de")]) def test_multi30k_split_argument(self, split, language_pair): - dataset1 = Multi30k(root=self.root_dir, split=split, language_pair=language_pair) - (dataset2,) = Multi30k(root=self.root_dir, split=(split,), language_pair=language_pair) + dataset1 = Multi30k( + root=self.root_dir, split=split, language_pair=language_pair + ) + (dataset2,) = Multi30k( + root=self.root_dir, split=(split,), language_pair=language_pair + ) for d1, d2 in zip_equal(dataset1, dataset2): self.assertEqual(d1, d2) diff --git a/test/datasets/test_sogounews.py b/test/datasets/test_sogounews.py index b0c06b4e35..95b53f87f1 100644 --- a/test/datasets/test_sogounews.py +++ b/test/datasets/test_sogounews.py @@ -36,9 +36,7 @@ def _get_mock_dataset(root_dir): f.write(f'"{label}","{rand_string}","{rand_string}"\n') seed += 1 - compressed_dataset_path = os.path.join( - base_dir, "sogou_news_csv.tar.gz" - ) + compressed_dataset_path = os.path.join(base_dir, "sogou_news_csv.tar.gz") # create tar file from dataset folder with tarfile.open(compressed_dataset_path, "w:gz") as tar: tar.add(temp_dataset_dir, arcname="sogou_news_csv") diff --git a/test/datasets/test_udpos.py b/test/datasets/test_udpos.py index b66a6c681d..455a7cc019 100644 --- a/test/datasets/test_udpos.py +++ b/test/datasets/test_udpos.py @@ -27,11 +27,20 @@ def _get_mock_dataset(root_dir): mocked_lines = mocked_data[os.path.splitext(file_name)[0]] with open(txt_file, "w") as f: for i in range(5): - rand_strings = ["".join(random.sample(string.ascii_letters, random.randint(1, 10))) for i in range(seed)] - rand_label_1 = [random.choice(string.ascii_letters) for i in range(seed)] - rand_label_2 = [random.choice(string.ascii_letters) for i in range(seed)] + rand_strings = [ + "".join(random.sample(string.ascii_letters, random.randint(1, 10))) + for i in range(seed) + ] + rand_label_1 = [ + random.choice(string.ascii_letters) for i in range(seed) + ] + rand_label_2 = [ + random.choice(string.ascii_letters) for i in range(seed) + ] # one token per line (each sample ends with an extra \n) - for rand_string, label_1, label_2 in zip(rand_strings, rand_label_1, rand_label_2): + for rand_string, label_1, label_2 in zip( + rand_strings, rand_label_1, rand_label_2 + ): f.write(f"{rand_string}\t{label_1}\t{label_2}\n") f.write("\n") dataset_line = (rand_strings, rand_label_1, rand_label_2) @@ -73,7 +82,9 @@ def tearDownClass(cls): def test_udpos(self, split): dataset = UDPOS(root=self.root_dir, split=split) samples = list(dataset) - expected_samples = self.samples[split] if split != "valid" else self.samples["dev"] + expected_samples = ( + self.samples[split] if split != "valid" else self.samples["dev"] + ) for sample, expected_sample in zip_equal(samples, expected_samples): self.assertEqual(sample, expected_sample) diff --git a/test/datasets/test_wikitexts.py b/test/datasets/test_wikitexts.py index 616c98f472..36c26db027 100644 --- a/test/datasets/test_wikitexts.py +++ b/test/datasets/test_wikitexts.py @@ -5,11 +5,11 @@ from collections import defaultdict from unittest.mock import patch -from ..common.parameterized_utils import nested_params from torchtext.datasets.wikitext103 import WikiText103 from torchtext.datasets.wikitext2 import WikiText2 from ..common.case_utils import TempDirMixin, zip_equal +from ..common.parameterized_utils import nested_params from ..common.torchtext_test_case import TorchtextTestCase @@ -34,7 +34,7 @@ def _get_mock_dataset(root_dir, base_dir_name): random.choice(string.ascii_letters) for i in range(seed) ) dataset_line = rand_string - f.write(f'{rand_string}\n') + f.write(f"{rand_string}\n") # append line to correct dataset split mocked_lines.append(dataset_line) @@ -75,7 +75,9 @@ def tearDownClass(cls): @nested_params([WikiText103, WikiText2], ["train", "valid", "test"]) def test_wikitexts(self, wikitext_dataset, split): - expected_samples = _get_mock_dataset(self.root_dir, base_dir_name=wikitext_dataset.__name__)[split] + expected_samples = _get_mock_dataset( + self.root_dir, base_dir_name=wikitext_dataset.__name__ + )[split] dataset = wikitext_dataset(root=self.root_dir, split=split) samples = list(dataset) diff --git a/test/datasets/test_yelpreviews.py b/test/datasets/test_yelpreviews.py index 241b52b7fa..3f4ccde54f 100644 --- a/test/datasets/test_yelpreviews.py +++ b/test/datasets/test_yelpreviews.py @@ -5,11 +5,11 @@ from collections import defaultdict from unittest.mock import patch -from ..common.parameterized_utils import nested_params -from torchtext.datasets.yelpreviewpolarity import YelpReviewPolarity from torchtext.datasets.yelpreviewfull import YelpReviewFull +from torchtext.datasets.yelpreviewpolarity import YelpReviewPolarity from ..common.case_utils import TempDirMixin, zip_equal +from ..common.parameterized_utils import nested_params from ..common.torchtext_test_case import TorchtextTestCase @@ -76,7 +76,9 @@ def tearDownClass(cls): @nested_params([YelpReviewPolarity, YelpReviewFull], ["train", "test"]) def test_yelpreviews(self, yelp_dataset, split): - expected_samples = _get_mock_dataset(self.root_dir, base_dir_name=yelp_dataset.__name__)[split] + expected_samples = _get_mock_dataset( + self.root_dir, base_dir_name=yelp_dataset.__name__ + )[split] dataset = yelp_dataset(root=self.root_dir, split=split) samples = list(dataset)