From 2951d42753034a647dfa1e64f8de0cfdc2d5f09d Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 3 Feb 2022 06:16:12 -0800 Subject: [PATCH 1/2] Created squad1 test --- test/datasets/test_squad1.py | 83 ++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 test/datasets/test_squad1.py diff --git a/test/datasets/test_squad1.py b/test/datasets/test_squad1.py new file mode 100644 index 0000000000..9a9a608aaa --- /dev/null +++ b/test/datasets/test_squad1.py @@ -0,0 +1,83 @@ +import os +import random +import string +import tarfile +from collections import defaultdict +from unittest.mock import patch + +from parameterized import parameterized +from torchtext.datasets.squad1 import SQuAD1 + +from ..common.case_utils import TempDirMixin, zip_equal +from ..common.torchtext_test_case import TorchtextTestCase + + +def _get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + base_dir = os.path.join(root_dir, "SQuAD1") + temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") + os.makedirs(temp_dataset_dir, exist_ok=True) + + seed = 1 + mocked_data = defaultdict(list) + for file_name in ("train.csv", "test.csv"): + txt_file = os.path.join(temp_dataset_dir, file_name) + with open(txt_file, "w") as f: + for i in range(5): + label = seed % 2 + 1 + rand_string = " ".join( + random.choice(string.ascii_letters) for i in range(seed) + ) + dataset_line = (label, f"{rand_string} {rand_string}") + # append line to correct dataset split + mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) + f.write(f'"{label}","{rand_string}","{rand_string}"\n') + seed += 1 + + compressed_dataset_path = os.path.join( + base_dir, "amazon_review_polarity_csv.tar.gz" + ) + # create tar file from dataset folder + with tarfile.open(compressed_dataset_path, "w:gz") as tar: + tar.add(temp_dataset_dir, arcname="amazon_review_polarity_csv") + + return mocked_data + + +class TestSQuAD1(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + cls.samples = _get_mock_dataset(cls.root_dir) + cls.patcher = patch( + "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True + ) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() + + @parameterized.expand(["train", "dev"]) + def test_squad1(self, split): + dataset = SQuAD1(root=self.root_dir, split=split) + + samples = list(dataset) + expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) + + @parameterized.expand(["train", "dev"]) + def test_squad1_split_argument(self, split): + dataset1 = SQuAD1(root=self.root_dir, split=split) + (dataset2,) = SQuAD1(root=self.root_dir, split=(split,)) + + for d1, d2 in zip_equal(dataset1, dataset2): + self.assertEqual(d1, d2) From a86223b96edda30e1f73e50f7d3fc170b1ca7857 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 3 Feb 2022 18:00:32 -0800 Subject: [PATCH 2/2] Completed mock tests for squad 1 --- test/datasets/test_squad1.py | 67 +++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 23 deletions(-) diff --git a/test/datasets/test_squad1.py b/test/datasets/test_squad1.py index 9a9a608aaa..75f1f61639 100644 --- a/test/datasets/test_squad1.py +++ b/test/datasets/test_squad1.py @@ -1,47 +1,68 @@ +import json import os import random import string -import tarfile +import uuid from collections import defaultdict +from random import randint from unittest.mock import patch from parameterized import parameterized +from torchtext.data.datasets_utils import _ParseSQuADQAData from torchtext.datasets.squad1 import SQuAD1 from ..common.case_utils import TempDirMixin, zip_equal from ..common.torchtext_test_case import TorchtextTestCase +def _get_mock_json_data(): + rand_string = " ".join(random.choice(string.ascii_letters) for i in range(10)) + mock_json_data = { + "data": [ + { + "title": rand_string, + "paragraphs": [ + { + "context": rand_string, + "qas": [ + { + "answers": [ + { + "answer_start": randint(1, 1000), + "text": rand_string, + } + ], + "question": rand_string, + "id": uuid.uuid1().hex, + }, + ], + } + ], + } + ] + } + return mock_json_data + + def _get_mock_dataset(root_dir): """ root_dir: directory to the mocked dataset """ base_dir = os.path.join(root_dir, "SQuAD1") - temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") - os.makedirs(temp_dataset_dir, exist_ok=True) + os.makedirs(base_dir, exist_ok=True) - seed = 1 mocked_data = defaultdict(list) - for file_name in ("train.csv", "test.csv"): - txt_file = os.path.join(temp_dataset_dir, file_name) + for file_name in ("train-v1.1.json", "dev-v1.1.json"): + txt_file = os.path.join(base_dir, file_name) with open(txt_file, "w") as f: - for i in range(5): - label = seed % 2 + 1 - rand_string = " ".join( - random.choice(string.ascii_letters) for i in range(seed) - ) - dataset_line = (label, f"{rand_string} {rand_string}") - # append line to correct dataset split - mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) - f.write(f'"{label}","{rand_string}","{rand_string}"\n') - seed += 1 - - compressed_dataset_path = os.path.join( - base_dir, "amazon_review_polarity_csv.tar.gz" - ) - # create tar file from dataset folder - with tarfile.open(compressed_dataset_path, "w:gz") as tar: - tar.add(temp_dataset_dir, arcname="amazon_review_polarity_csv") + mock_json_data = _get_mock_json_data() + f.write(json.dumps(mock_json_data)) + + split = "train" if "train" in file_name else "dev" + dataset_line = next( + iter(_ParseSQuADQAData([("file_handle", mock_json_data)])) + ) + mocked_data[split].append(dataset_line) return mocked_data