From 2d12abc959268587049b0db1ae06edda9fd37d7d Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Fri, 28 Jan 2022 22:15:27 -0500 Subject: [PATCH 1/2] add test_yelpreviewfull.py to mock YelpReviewFull --- test/datasets/test_yelpreviewfull.py | 85 ++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 test/datasets/test_yelpreviewfull.py diff --git a/test/datasets/test_yelpreviewfull.py b/test/datasets/test_yelpreviewfull.py new file mode 100644 index 0000000000..6351151158 --- /dev/null +++ b/test/datasets/test_yelpreviewfull.py @@ -0,0 +1,85 @@ +import os +import random +import string +import tarfile +from collections import defaultdict +from unittest.mock import patch + +from parameterized import parameterized +from torchtext.datasets.yelpreviewfull import YelpReviewFull + +from ..common.case_utils import TempDirMixin, zip_equal +from ..common.torchtext_test_case import TorchtextTestCase + + +def _get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + base_dir = os.path.join(root_dir, "YelpReviewFull") + temp_dataset_dir = os.path.join(base_dir, "yelp_review_full_csv") + os.makedirs(temp_dataset_dir, exist_ok=True) + + seed = 1 + mocked_data = defaultdict(list) + for file_name in ["train.csv", "test.csv"]: + csv_file = os.path.join(temp_dataset_dir, file_name) + mocked_lines = mocked_data[os.path.splitext(file_name)[0]] + with open(csv_file, "w") as f: + for i in range(5): + label = seed % 11 + rand_string = " ".join( + random.choice(string.ascii_letters) for i in range(seed) + ) + dataset_line = (label, f"{rand_string}") + f.write(f'"{label}","{rand_string}"\n') + + # append line to correct dataset split + mocked_lines.append(dataset_line) + seed += 1 + + compressed_dataset_path = os.path.join(base_dir, "yelp_review_full_csv.tar.gz") + # create gz file from dataset folder + with tarfile.open(compressed_dataset_path, "w:gz") as tar: + for file_name in ("train.csv", "test.csv"): + csv_file = os.path.join(temp_dataset_dir, file_name) + tar.add(csv_file) + + return mocked_data + + +class TestYelpReviewFull(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + cls.samples = _get_mock_dataset(cls.root_dir) + cls.patcher = patch( + "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True + ) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() + + @parameterized.expand(["train", "test"]) + def test_yelpreviewfull(self, split): + dataset = YelpReviewFull(root=self.root_dir, split=split) + + samples = list(dataset) + expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) + + @parameterized.expand(["train", "test"]) + def test_yelpreviewfull_split_argument(self, split): + dataset1 = YelpReviewFull(root=self.root_dir, split=split) + (dataset2,) = YelpReviewFull(root=self.root_dir, split=(split,)) + + for d1, d2 in zip_equal(dataset1, dataset2): + self.assertEqual(d1, d2) From a91e24b6ffcbda69c0f4fc33f24573df72da45ad Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Thu, 3 Feb 2022 16:34:52 -0500 Subject: [PATCH 2/2] correct label + small miscellaneous changes --- test/datasets/test_yelpreviewfull.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/datasets/test_yelpreviewfull.py b/test/datasets/test_yelpreviewfull.py index 6351151158..d0a7e13a7c 100644 --- a/test/datasets/test_yelpreviewfull.py +++ b/test/datasets/test_yelpreviewfull.py @@ -17,17 +17,17 @@ def _get_mock_dataset(root_dir): root_dir: directory to the mocked dataset """ base_dir = os.path.join(root_dir, "YelpReviewFull") - temp_dataset_dir = os.path.join(base_dir, "yelp_review_full_csv") + temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") os.makedirs(temp_dataset_dir, exist_ok=True) seed = 1 mocked_data = defaultdict(list) - for file_name in ["train.csv", "test.csv"]: + for file_name in ("train.csv", "test.csv"): csv_file = os.path.join(temp_dataset_dir, file_name) mocked_lines = mocked_data[os.path.splitext(file_name)[0]] with open(csv_file, "w") as f: for i in range(5): - label = seed % 11 + label = seed % 5 + 1 rand_string = " ".join( random.choice(string.ascii_letters) for i in range(seed) ) @@ -41,9 +41,7 @@ def _get_mock_dataset(root_dir): compressed_dataset_path = os.path.join(base_dir, "yelp_review_full_csv.tar.gz") # create gz file from dataset folder with tarfile.open(compressed_dataset_path, "w:gz") as tar: - for file_name in ("train.csv", "test.csv"): - csv_file = os.path.join(temp_dataset_dir, file_name) - tar.add(csv_file) + tar.add(temp_dataset_dir, arcname="yelp_review_full_csv") return mocked_data