diff --git a/test/datasets/test_squad1.py b/test/datasets/test_squad.py similarity index 69% rename from test/datasets/test_squad1.py rename to test/datasets/test_squad.py index 75f1f61639..d44b6637f1 100644 --- a/test/datasets/test_squad1.py +++ b/test/datasets/test_squad.py @@ -7,11 +7,12 @@ from random import randint from unittest.mock import patch -from parameterized import parameterized from torchtext.data.datasets_utils import _ParseSQuADQAData from torchtext.datasets.squad1 import SQuAD1 +from torchtext.datasets.squad2 import SQuAD2 from ..common.case_utils import TempDirMixin, zip_equal +from ..common.parameterized_utils import nested_params from ..common.torchtext_test_case import TorchtextTestCase @@ -44,15 +45,20 @@ def _get_mock_json_data(): return mock_json_data -def _get_mock_dataset(root_dir): +def _get_mock_dataset(root_dir, base_dir_name): """ root_dir: directory to the mocked dataset """ - base_dir = os.path.join(root_dir, "SQuAD1") + base_dir = os.path.join(root_dir, base_dir_name) os.makedirs(base_dir, exist_ok=True) + if base_dir_name == SQuAD1.__name__: + file_names = ("train-v1.1.json", "dev-v1.1.json") + else: + file_names = ("train-v2.0.json", "dev-v2.0.json") + mocked_data = defaultdict(list) - for file_name in ("train-v1.1.json", "dev-v1.1.json"): + for file_name in file_names: txt_file = os.path.join(base_dir, file_name) with open(txt_file, "w") as f: mock_json_data = _get_mock_json_data() @@ -67,7 +73,7 @@ def _get_mock_dataset(root_dir): return mocked_data -class TestSQuAD1(TempDirMixin, TorchtextTestCase): +class TestSQuAD(TempDirMixin, TorchtextTestCase): root_dir = None samples = [] @@ -75,7 +81,6 @@ class TestSQuAD1(TempDirMixin, TorchtextTestCase): def setUpClass(cls): super().setUpClass() cls.root_dir = cls.get_base_temp_dir() - cls.samples = _get_mock_dataset(cls.root_dir) cls.patcher = patch( "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True ) @@ -86,19 +91,24 @@ def tearDownClass(cls): cls.patcher.stop() super().tearDownClass() - @parameterized.expand(["train", "dev"]) - def test_squad1(self, split): - dataset = SQuAD1(root=self.root_dir, split=split) - + @nested_params([SQuAD1, SQuAD2], ["train", "dev"]) + def test_squad(self, squad_dataset, split): + expected_samples = _get_mock_dataset(self.root_dir, squad_dataset.__name__)[ + split + ] + dataset = squad_dataset(root=self.root_dir, split=split) samples = list(dataset) - expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): self.assertEqual(sample, expected_sample) - @parameterized.expand(["train", "dev"]) - def test_squad1_split_argument(self, split): - dataset1 = SQuAD1(root=self.root_dir, split=split) - (dataset2,) = SQuAD1(root=self.root_dir, split=(split,)) + @nested_params([SQuAD1, SQuAD2], ["train", "dev"]) + def test_squad_split_argument(self, squad_dataset, split): + # call `_get_mock_dataset` to create mock dataset files + _ = _get_mock_dataset(self.root_dir, squad_dataset.__name__) + + dataset1 = squad_dataset(root=self.root_dir, split=split) + (dataset2,) = squad_dataset(root=self.root_dir, split=(split,)) for d1, d2 in zip_equal(dataset1, dataset2): self.assertEqual(d1, d2)