From a0bfbcafa860d8e22ddeafc86ff02a526a73d023 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Fri, 28 Jan 2022 23:06:19 -0500 Subject: [PATCH 1/3] add test_enwik9 to mock EnWik9 data --- test/datasets/test_enwik9.py | 85 ++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 test/datasets/test_enwik9.py diff --git a/test/datasets/test_enwik9.py b/test/datasets/test_enwik9.py new file mode 100644 index 0000000000..75d82eb6dd --- /dev/null +++ b/test/datasets/test_enwik9.py @@ -0,0 +1,85 @@ +import os +import random +import string +import zipfile +from collections import defaultdict +from unittest.mock import patch + +from parameterized import parameterized +from torchtext.datasets.enwik9 import EnWik9 + +from ..common.case_utils import TempDirMixin, zip_equal +from ..common.torchtext_test_case import TorchtextTestCase + + +def _get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + base_dir = os.path.join(root_dir, "EnWik9") + temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") + os.makedirs(temp_dataset_dir, exist_ok=True) + + seed = 1 + mocked_data = defaultdict(list) + file_name = "train" + txt_file = os.path.join(temp_dataset_dir, file_name) + mocked_lines = mocked_data[os.path.splitext(file_name)[0]] + with open(txt_file, "w") as f: + for i in range(5): + label = seed % 2 + rand_string = "<" + " ".join( + random.choice(string.ascii_letters) for i in range(seed) + ) + ">" + dataset_line = (rand_string) + f.write(f"{rand_string}\n") + + # append line to correct dataset split + mocked_lines.append(dataset_line) + seed += 1 + + compressed_dataset_path = os.path.join(base_dir, "enwik9.zip") + # create zip file from dataset folder + with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: + file_name = "train" + txt_file = os.path.join(temp_dataset_dir, file_name) + zip_file.write(txt_file, arcname=os.path.join("EnWik9", file_name)) + + return mocked_data + + +class TestEnWik9(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + cls.samples = _get_mock_dataset(cls.root_dir) + cls.patcher = patch( + "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True + ) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() + + @parameterized.expand(["train"]) + def test_enwik9(self, split): + dataset = EnWik9(root=self.root_dir, split=split) + + samples = list(dataset) + expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) + + @parameterized.expand(["train"]) + def test_enwik9_split_argument(self, split): + dataset1 = EnWik9(root=self.root_dir, split=split) + (dataset2,) = EnWik9(root=self.root_dir, split=(split,)) + + for d1, d2 in zip_equal(dataset1, dataset2): + self.assertEqual(d1, d2) From 77e2c1deb70408781ebd740b7e8edf7bd9de9220 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 2 Feb 2022 23:19:03 -0500 Subject: [PATCH 2/3] changes from comments --- test/datasets/test_enwik9.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/test/datasets/test_enwik9.py b/test/datasets/test_enwik9.py index 75d82eb6dd..fd988f82a0 100644 --- a/test/datasets/test_enwik9.py +++ b/test/datasets/test_enwik9.py @@ -22,28 +22,28 @@ def _get_mock_dataset(root_dir): seed = 1 mocked_data = defaultdict(list) - file_name = "train" + file_name = "enwik9" txt_file = os.path.join(temp_dataset_dir, file_name) - mocked_lines = mocked_data[os.path.splitext(file_name)[0]] + mocked_lines = mocked_data["train"] with open(txt_file, "w") as f: for i in range(5): - label = seed % 2 rand_string = "<" + " ".join( random.choice(string.ascii_letters) for i in range(seed) ) + ">" - dataset_line = (rand_string) - f.write(f"{rand_string}\n") + dataset_line = (f"'{rand_string}'") + f.write(f"'{rand_string}'\n") # append line to correct dataset split mocked_lines.append(dataset_line) seed += 1 + print("base_dir=") + print(base_dir) compressed_dataset_path = os.path.join(base_dir, "enwik9.zip") # create zip file from dataset folder with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: - file_name = "train" txt_file = os.path.join(temp_dataset_dir, file_name) - zip_file.write(txt_file, arcname=os.path.join("EnWik9", file_name)) + zip_file.write(txt_file, arcname="enwik9") return mocked_data @@ -56,6 +56,8 @@ class TestEnWik9(TempDirMixin, TorchtextTestCase): def setUpClass(cls): super().setUpClass() cls.root_dir = cls.get_base_temp_dir() + print("cls.root_dir:") + print(cls.root_dir) cls.samples = _get_mock_dataset(cls.root_dir) cls.patcher = patch( "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True From 4c7fcc47f65d4e34650da8a647735fe1860db204 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Thu, 3 Feb 2022 10:55:22 -0500 Subject: [PATCH 3/3] resolve last issues --- test/datasets/test_enwik9.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/datasets/test_enwik9.py b/test/datasets/test_enwik9.py index fd988f82a0..1b79b201e6 100644 --- a/test/datasets/test_enwik9.py +++ b/test/datasets/test_enwik9.py @@ -37,13 +37,11 @@ def _get_mock_dataset(root_dir): mocked_lines.append(dataset_line) seed += 1 - print("base_dir=") - print(base_dir) compressed_dataset_path = os.path.join(base_dir, "enwik9.zip") # create zip file from dataset folder with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: txt_file = os.path.join(temp_dataset_dir, file_name) - zip_file.write(txt_file, arcname="enwik9") + zip_file.write(txt_file, arcname=file_name) return mocked_data @@ -56,8 +54,6 @@ class TestEnWik9(TempDirMixin, TorchtextTestCase): def setUpClass(cls): super().setUpClass() cls.root_dir = cls.get_base_temp_dir() - print("cls.root_dir:") - print(cls.root_dir) cls.samples = _get_mock_dataset(cls.root_dir) cls.patcher = patch( "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True