diff --git a/test/common/case_utils.py b/test/common/case_utils.py index f8803894b0..9a9340fbff 100644 --- a/test/common/case_utils.py +++ b/test/common/case_utils.py @@ -1,6 +1,7 @@ import os.path import tempfile import unittest +from itertools import zip_longest from torchtext._internal.module_utils import is_module_available @@ -37,4 +38,18 @@ def get_temp_path(self, *paths): def skipIfNoModule(module, display_name=None): display_name = display_name or module - return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available') + return unittest.skipIf( + not is_module_available(module), f'"{display_name}" is not available' + ) + + +def zip_equal(*iterables): + """With the regular Python `zip` function, if one iterable is longer than the other, + the remainder portions are ignored.This is resolved in Python 3.10 where we can use + `strict=True` in the `zip` function + """ + sentinel = object() + for combo in zip_longest(*iterables, fillvalue=sentinel): + if sentinel in combo: + raise ValueError("Iterables have different lengths") + yield combo diff --git a/test/datasets/test_sst2.py b/test/datasets/test_sst2.py new file mode 100644 index 0000000000..29fdb6fbed --- /dev/null +++ b/test/datasets/test_sst2.py @@ -0,0 +1,92 @@ +import os +import random +import string +import zipfile +from collections import defaultdict +from unittest.mock import patch + +from parameterized import parameterized +from torchtext.datasets.sst2 import SST2 + +from ..common.case_utils import TempDirMixin, zip_equal +from ..common.torchtext_test_case import TorchtextTestCase + + +def _get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + base_dir = os.path.join(root_dir, "SST2") + temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") + os.makedirs(temp_dataset_dir, exist_ok=True) + + seed = 1 + mocked_data = defaultdict(list) + for file_name, (col1_name, col2_name) in zip( + ("train.tsv", "test.tsv", "dev.tsv"), + ((("sentence", "label"), ("sentence", "label"), ("index", "sentence"))), + ): + txt_file = os.path.join(temp_dataset_dir, file_name) + with open(txt_file, "w") as f: + f.write(f"{col1_name}\t{col2_name}\n") + for i in range(5): + label = seed % 2 + rand_string = " ".join( + random.choice(string.ascii_letters) for i in range(seed) + ) + if file_name == "test.tsv": + dataset_line = (f"{rand_string} .",) + f.write(f"{i}\t{rand_string} .\n") + else: + dataset_line = (f"{rand_string} .", label) + f.write(f"{rand_string} .\t{label}\n") + + # append line to correct dataset split + mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) + seed += 1 + + compressed_dataset_path = os.path.join(base_dir, "SST-2.zip") + # create zip file from dataset folder + with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: + for file_name in ("train.tsv", "test.tsv", "dev.tsv"): + txt_file = os.path.join(temp_dataset_dir, file_name) + zip_file.write(txt_file, arcname=os.path.join("SST-2", file_name)) + + return mocked_data + + +class TestSST2(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + cls.samples = _get_mock_dataset(cls.root_dir) + cls.patcher = patch( + "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True + ) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() + + @parameterized.expand(["train", "test", "dev"]) + def test_sst2(self, split): + dataset = SST2(root=self.root_dir, split=split) + + samples = list(dataset) + expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) + + @parameterized.expand(["train", "test", "dev"]) + def test_sst2_split_argument(self, split): + dataset1 = SST2(root=self.root_dir, split=split) + (dataset2,) = SST2(root=self.root_dir, split=(split,)) + + for d1, d2 in zip_equal(dataset1, dataset2): + self.assertEqual(d1, d2)