From 8abc6fcd4cbe1e2d436f4988d65db03d531f2de7 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Wed, 19 Jan 2022 15:35:31 -0800 Subject: [PATCH 1/6] First attempt at adding test for amazon review polarity --- test/common/case_utils.py | 37 ++++++- test/datasets/__init__.py | 0 test/datasets/amazonreviewpolarity_test.py | 121 +++++++++++++++++++++ 3 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 test/datasets/__init__.py create mode 100644 test/datasets/amazonreviewpolarity_test.py diff --git a/test/common/case_utils.py b/test/common/case_utils.py index 03eec2627f..99e0290111 100644 --- a/test/common/case_utils.py +++ b/test/common/case_utils.py @@ -1,7 +1,42 @@ +import os.path +import tempfile import unittest + from torchtext._internal.module_utils import is_module_available +class TempDirMixin: + """Mixin to provide easy access to temp dir""" + + temp_dir_ = None + + @classmethod + def get_base_temp_dir(cls): + # If TORCHTEXT_TEST_TEMP_DIR is set, use it instead of temporary directory. + # this is handy for debugging. + key = "TORCHTEXT_TEST_TEMP_DIR" + if key in os.environ: + return os.environ[key] + if cls.temp_dir_ is None: + cls.temp_dir_ = tempfile.TemporaryDirectory() + return cls.temp_dir_.name + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + if cls.temp_dir_ is not None: + cls.temp_dir_.cleanup() + cls.temp_dir_ = None + + def get_temp_path(self, *paths): + temp_dir = os.path.join(self.get_base_temp_dir(), self.id()) + path = os.path.join(temp_dir, *paths) + os.makedirs(os.path.dirname(path), exist_ok=True) + return path + + def skipIfNoModule(module, display_name=None): display_name = display_name or module - return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available') + return unittest.skipIf( + not is_module_available(module), f'"{display_name}" is not available' + ) diff --git a/test/datasets/__init__.py b/test/datasets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/datasets/amazonreviewpolarity_test.py b/test/datasets/amazonreviewpolarity_test.py new file mode 100644 index 0000000000..fafeee49a6 --- /dev/null +++ b/test/datasets/amazonreviewpolarity_test.py @@ -0,0 +1,121 @@ +#!/user/bin/env python3 +# Note that all the tests in this module require dataset (either network access or cached) +import torchtext +import json +from parameterized import parameterized +from ..common.torchtext_test_case import TorchtextTestCase +from ..common.parameterized_utils import load_params +from ..common.case_utils import TempDirMixin +import os.path + +from torchtext.datasets.amazonreviewpolarity import AmazonReviewPolarity + + + +def get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + mocked_data = [] + sample_rate = 16000 + transcript = "This is a test transcript." + + base_dir = os.path.join(root_dir, "ARCTIC", "cmu_us_aew_arctic") + txt_dir = os.path.join(base_dir, "etc") + os.makedirs(txt_dir, exist_ok=True) + txt_file = os.path.join(txt_dir, "txt.done.data") + audio_dir = os.path.join(base_dir, "wav") + os.makedirs(audio_dir, exist_ok=True) + + seed = 42 + with open(txt_file, "w") as txt: + for c in ["a", "b"]: + for i in range(5): + utterance_id = f"arctic_{c}{i:04d}" + path = os.path.join(audio_dir, f"{utterance_id}.wav") + data = get_whitenoise( + sample_rate=sample_rate, + duration=3, + n_channels=1, + dtype="int16", + seed=seed, + ) + save_wav(path, data, sample_rate) + sample = ( + normalize_wav(data), + sample_rate, + transcript, + utterance_id.split("_")[1], + ) + mocked_data.append(sample) + txt.write(f'( {utterance_id} "{transcript}" )\n') + seed += 1 + return mocked_data + + +class TestAmazonReviewPolarity(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.samples = get_mock_dataset(cls.root_dir) + + def _test_amazon_review_polarity(self, dataset): + n_ite = 0 + for i, (waveform, sample_rate, transcript, utterance_id) in enumerate(dataset): + expected_sample = self.samples[i] + assert sample_rate == expected_sample[1] + assert transcript == expected_sample[2] + assert utterance_id == expected_sample[3] + self.assertEqual(expected_sample[0], waveform, atol=5e-5, rtol=1e-8) + n_ite += 1 + assert n_ite == len(self.samples) + + def test_amazon_review_polarity_splits(self, splits): + dataset = AmazonReviewPolarity(root=self.root_dir, split=splits) + self._test_amazon_review_polarity(dataset) + + +# class TestDataset(TorchtextTestCase): +# @classmethod +# def setUpClass(cls): +# check_cache_status() + +# @parameterized.expand( +# load_params('raw_datasets.jsonl'), +# name_func=_raw_text_custom_name_func) +# def test_raw_text_classification(self, info): +# dataset_name = info['dataset_name'] +# split = info['split'] + +# if dataset_name == 'WMT14': +# return +# else: +# data_iter = torchtext.datasets.DATASETS[dataset_name](split=split) +# self.assertEqual(hashlib.md5(json.dumps(next(iter(data_iter)), sort_keys=True).encode('utf-8')).hexdigest(), info['first_line']) +# if dataset_name == "AG_NEWS": +# self.assertEqual(torchtext.datasets.URLS[dataset_name][split], info['URL']) +# self.assertEqual(torchtext.datasets.MD5[dataset_name][split], info['MD5']) +# elif dataset_name == "WMT14": +# return +# else: +# self.assertEqual(torchtext.datasets.URLS[dataset_name], info['URL']) +# self.assertEqual(torchtext.datasets.MD5[dataset_name], info['MD5']) +# del data_iter + +# @parameterized.expand(list(sorted(torchtext.datasets.DATASETS.keys()))) +# def test_raw_datasets_split_argument(self, dataset_name): +# if 'statmt' in torchtext.datasets.URLS[dataset_name]: +# return +# dataset = torchtext.datasets.DATASETS[dataset_name] +# train1 = dataset(split='train') +# train2, = dataset(split=('train',)) +# for d1, d2 in zip(train1, train2): +# self.assertEqual(d1, d2) +# # This test only aims to exercise the argument parsing and uses +# # the first line as a litmus test for correctness. +# break +# # Exercise default constructor +# _ = dataset() From 92e90f81e08d321f3d623ba85ffe7cb36cddef5e Mon Sep 17 00:00:00 2001 From: nayef211 Date: Wed, 19 Jan 2022 23:50:14 -0800 Subject: [PATCH 2/6] Updated dataset to take validate_hash param. Finalized tests --- test/datasets/amazonreviewpolarity_test.py | 153 ++++++++------------- torchtext/datasets/amazonreviewpolarity.py | 8 +- 2 files changed, 63 insertions(+), 98 deletions(-) diff --git a/test/datasets/amazonreviewpolarity_test.py b/test/datasets/amazonreviewpolarity_test.py index fafeee49a6..e444abf862 100644 --- a/test/datasets/amazonreviewpolarity_test.py +++ b/test/datasets/amazonreviewpolarity_test.py @@ -1,54 +1,48 @@ -#!/user/bin/env python3 -# Note that all the tests in this module require dataset (either network access or cached) -import torchtext -import json -from parameterized import parameterized -from ..common.torchtext_test_case import TorchtextTestCase -from ..common.parameterized_utils import load_params -from ..common.case_utils import TempDirMixin import os.path +import random +import string +import tarfile +from collections import defaultdict +from parameterized import parameterized from torchtext.datasets.amazonreviewpolarity import AmazonReviewPolarity +from ..common.case_utils import TempDirMixin +from ..common.torchtext_test_case import TorchtextTestCase def get_mock_dataset(root_dir): """ root_dir: directory to the mocked dataset """ - mocked_data = [] - sample_rate = 16000 - transcript = "This is a test transcript." - - base_dir = os.path.join(root_dir, "ARCTIC", "cmu_us_aew_arctic") - txt_dir = os.path.join(base_dir, "etc") - os.makedirs(txt_dir, exist_ok=True) - txt_file = os.path.join(txt_dir, "txt.done.data") - audio_dir = os.path.join(base_dir, "wav") - os.makedirs(audio_dir, exist_ok=True) - - seed = 42 - with open(txt_file, "w") as txt: - for c in ["a", "b"]: + base_dir = os.path.join(root_dir, "AmazonReviewPolarity") + compressed_dataset_path = os.path.join( + base_dir, "amazon_review_polarity_csv.tar.gz" + ) + uncompressed_dataset_dir = os.path.join(base_dir, "amazon_review_polarity_csv") + os.makedirs(uncompressed_dataset_dir, exist_ok=True) + + # create empty tar file to skip dataset download + with tarfile.open(compressed_dataset_path, "w:gz") as tar: + dummy_file_path = os.path.join(base_dir, "dummy_file.txt") + with open(dummy_file_path, "w") as f: + pass + tar.add(dummy_file_path) + + seed = 1 + mocked_data = defaultdict(list) + for file_name in ("train.csv", "test.csv"): + txt_file = os.path.join(uncompressed_dataset_dir, file_name) + with open(txt_file, "w") as f: for i in range(5): - utterance_id = f"arctic_{c}{i:04d}" - path = os.path.join(audio_dir, f"{utterance_id}.wav") - data = get_whitenoise( - sample_rate=sample_rate, - duration=3, - n_channels=1, - dtype="int16", - seed=seed, + label = seed % 2 + 1 + rand_string = " ".join( + random.choice(string.ascii_letters) for i in range(seed) ) - save_wav(path, data, sample_rate) - sample = ( - normalize_wav(data), - sample_rate, - transcript, - utterance_id.split("_")[1], - ) - mocked_data.append(sample) - txt.write(f'( {utterance_id} "{transcript}" )\n') + dataset_line = (label, f"{rand_string} {rand_string}") + # append line to correct dataset split + mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) + f.write(f'"{label}","{rand_string}","{rand_string}"\n') seed += 1 return mocked_data @@ -62,60 +56,27 @@ def setUpClass(cls): cls.root_dir = cls.get_base_temp_dir() cls.samples = get_mock_dataset(cls.root_dir) - def _test_amazon_review_polarity(self, dataset): - n_ite = 0 - for i, (waveform, sample_rate, transcript, utterance_id) in enumerate(dataset): - expected_sample = self.samples[i] - assert sample_rate == expected_sample[1] - assert transcript == expected_sample[2] - assert utterance_id == expected_sample[3] - self.assertEqual(expected_sample[0], waveform, atol=5e-5, rtol=1e-8) - n_ite += 1 - assert n_ite == len(self.samples) - - def test_amazon_review_polarity_splits(self, splits): - dataset = AmazonReviewPolarity(root=self.root_dir, split=splits) - self._test_amazon_review_polarity(dataset) - - -# class TestDataset(TorchtextTestCase): -# @classmethod -# def setUpClass(cls): -# check_cache_status() - -# @parameterized.expand( -# load_params('raw_datasets.jsonl'), -# name_func=_raw_text_custom_name_func) -# def test_raw_text_classification(self, info): -# dataset_name = info['dataset_name'] -# split = info['split'] - -# if dataset_name == 'WMT14': -# return -# else: -# data_iter = torchtext.datasets.DATASETS[dataset_name](split=split) -# self.assertEqual(hashlib.md5(json.dumps(next(iter(data_iter)), sort_keys=True).encode('utf-8')).hexdigest(), info['first_line']) -# if dataset_name == "AG_NEWS": -# self.assertEqual(torchtext.datasets.URLS[dataset_name][split], info['URL']) -# self.assertEqual(torchtext.datasets.MD5[dataset_name][split], info['MD5']) -# elif dataset_name == "WMT14": -# return -# else: -# self.assertEqual(torchtext.datasets.URLS[dataset_name], info['URL']) -# self.assertEqual(torchtext.datasets.MD5[dataset_name], info['MD5']) -# del data_iter - -# @parameterized.expand(list(sorted(torchtext.datasets.DATASETS.keys()))) -# def test_raw_datasets_split_argument(self, dataset_name): -# if 'statmt' in torchtext.datasets.URLS[dataset_name]: -# return -# dataset = torchtext.datasets.DATASETS[dataset_name] -# train1 = dataset(split='train') -# train2, = dataset(split=('train',)) -# for d1, d2 in zip(train1, train2): -# self.assertEqual(d1, d2) -# # This test only aims to exercise the argument parsing and uses -# # the first line as a litmus test for correctness. -# break -# # Exercise default constructor -# _ = dataset() + @parameterized.expand(["train", "test"]) + def test_amazon_review_polarity(self, split): + dataset = AmazonReviewPolarity( + root=self.root_dir, split=split, validate_hash=False + ) + n_iter = 0 + for i, (label, text) in enumerate(dataset): + expected_sample = self.samples[split][i] + assert label == expected_sample[0] + assert text == expected_sample[1] + n_iter += 1 + assert n_iter == len(self.samples[split]) + + @parameterized.expand([("train", ("train",)), ("test", ("test",))]) + def test_amazon_review_polarity_split_argument(self, split1, split2): + dataset1 = AmazonReviewPolarity( + root=self.root_dir, split=split1, validate_hash=False + ) + (dataset2,) = AmazonReviewPolarity( + root=self.root_dir, split=split2, validate_hash=False + ) + + for d1, d2 in zip(dataset1, dataset2): + self.assertEqual(d1, d2) diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index e585922c19..577e0185d2 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -34,14 +34,18 @@ @_add_docstring_header(num_lines=NUM_LINES, num_classes=2) @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) -def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]): +def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str], validate_hash: bool = True): + # Validate integrity of dataset using md5 checksum + hash_dict = {os.path.join(root, _PATH): MD5} if validate_hash else None + hash_type = "md5" if validate_hash else None + # TODO Remove this after removing conditional dependency if not is_module_available("torchdata"): raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" + filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict=hash_dict, hash_type=hash_type ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) From 249de07f6d1222cc9964400b52320de0b53ab4cf Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 20 Jan 2022 23:06:01 -0800 Subject: [PATCH 3/6] Created non empty tar file --- test/datasets/amazonreviewpolarity_test.py | 29 +++++++++++----------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/test/datasets/amazonreviewpolarity_test.py b/test/datasets/amazonreviewpolarity_test.py index e444abf862..1f1f8db67e 100644 --- a/test/datasets/amazonreviewpolarity_test.py +++ b/test/datasets/amazonreviewpolarity_test.py @@ -1,4 +1,5 @@ -import os.path +# from pathlib import Path +import os import random import string import tarfile @@ -10,29 +11,21 @@ from ..common.case_utils import TempDirMixin from ..common.torchtext_test_case import TorchtextTestCase +# def _create_tar_file(tar_path, data_dir): + def get_mock_dataset(root_dir): """ root_dir: directory to the mocked dataset """ base_dir = os.path.join(root_dir, "AmazonReviewPolarity") - compressed_dataset_path = os.path.join( - base_dir, "amazon_review_polarity_csv.tar.gz" - ) - uncompressed_dataset_dir = os.path.join(base_dir, "amazon_review_polarity_csv") - os.makedirs(uncompressed_dataset_dir, exist_ok=True) - - # create empty tar file to skip dataset download - with tarfile.open(compressed_dataset_path, "w:gz") as tar: - dummy_file_path = os.path.join(base_dir, "dummy_file.txt") - with open(dummy_file_path, "w") as f: - pass - tar.add(dummy_file_path) + temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") + os.makedirs(temp_dataset_dir, exist_ok=True) seed = 1 mocked_data = defaultdict(list) for file_name in ("train.csv", "test.csv"): - txt_file = os.path.join(uncompressed_dataset_dir, file_name) + txt_file = os.path.join(temp_dataset_dir, file_name) with open(txt_file, "w") as f: for i in range(5): label = seed % 2 + 1 @@ -44,6 +37,14 @@ def get_mock_dataset(root_dir): mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) f.write(f'"{label}","{rand_string}","{rand_string}"\n') seed += 1 + + compressed_dataset_path = os.path.join( + base_dir, "amazon_review_polarity_csv.tar.gz" + ) + # create tar file from dataset folder + with tarfile.open(compressed_dataset_path, "w:gz") as tar: + tar.add(temp_dataset_dir, arcname="amazon_review_polarity_csv") + return mocked_data From 09ae6dbbd6dcb076cb568023f97e0136bff967ad Mon Sep 17 00:00:00 2001 From: nayef211 Date: Mon, 24 Jan 2022 13:27:45 -0800 Subject: [PATCH 4/6] Remove formatting. Patch _hash_check method from torchdata during testing --- test/common/case_utils.py | 4 +-- test/datasets/amazonreviewpolarity_test.py | 39 +++++++++++----------- torchtext/datasets/amazonreviewpolarity.py | 8 ++--- 3 files changed, 22 insertions(+), 29 deletions(-) diff --git a/test/common/case_utils.py b/test/common/case_utils.py index 99e0290111..f8803894b0 100644 --- a/test/common/case_utils.py +++ b/test/common/case_utils.py @@ -37,6 +37,4 @@ def get_temp_path(self, *paths): def skipIfNoModule(module, display_name=None): display_name = display_name or module - return unittest.skipIf( - not is_module_available(module), f'"{display_name}" is not available' - ) + return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available') diff --git a/test/datasets/amazonreviewpolarity_test.py b/test/datasets/amazonreviewpolarity_test.py index 1f1f8db67e..3437038c11 100644 --- a/test/datasets/amazonreviewpolarity_test.py +++ b/test/datasets/amazonreviewpolarity_test.py @@ -4,6 +4,7 @@ import string import tarfile from collections import defaultdict +from unittest.mock import patch from parameterized import parameterized from torchtext.datasets.amazonreviewpolarity import AmazonReviewPolarity @@ -11,8 +12,6 @@ from ..common.case_utils import TempDirMixin from ..common.torchtext_test_case import TorchtextTestCase -# def _create_tar_file(tar_path, data_dir): - def get_mock_dataset(root_dir): """ @@ -59,25 +58,25 @@ def setUpClass(cls): @parameterized.expand(["train", "test"]) def test_amazon_review_polarity(self, split): - dataset = AmazonReviewPolarity( - root=self.root_dir, split=split, validate_hash=False - ) - n_iter = 0 - for i, (label, text) in enumerate(dataset): - expected_sample = self.samples[split][i] - assert label == expected_sample[0] - assert text == expected_sample[1] - n_iter += 1 - assert n_iter == len(self.samples[split]) + with patch( + "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True + ): + dataset = AmazonReviewPolarity(root=self.root_dir, split=split) + n_iter = 0 + for i, (label, text) in enumerate(dataset): + expected_sample = self.samples[split][i] + assert label == expected_sample[0] + assert text == expected_sample[1] + n_iter += 1 + assert n_iter == len(self.samples[split]) @parameterized.expand([("train", ("train",)), ("test", ("test",))]) def test_amazon_review_polarity_split_argument(self, split1, split2): - dataset1 = AmazonReviewPolarity( - root=self.root_dir, split=split1, validate_hash=False - ) - (dataset2,) = AmazonReviewPolarity( - root=self.root_dir, split=split2, validate_hash=False - ) + with patch( + "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True + ): + dataset1 = AmazonReviewPolarity(root=self.root_dir, split=split1) + (dataset2,) = AmazonReviewPolarity(root=self.root_dir, split=split2) - for d1, d2 in zip(dataset1, dataset2): - self.assertEqual(d1, d2) + for d1, d2 in zip(dataset1, dataset2): + self.assertEqual(d1, d2) diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index 577e0185d2..e585922c19 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -34,18 +34,14 @@ @_add_docstring_header(num_lines=NUM_LINES, num_classes=2) @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) -def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str], validate_hash: bool = True): - # Validate integrity of dataset using md5 checksum - hash_dict = {os.path.join(root, _PATH): MD5} if validate_hash else None - hash_type = "md5" if validate_hash else None - +def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]): # TODO Remove this after removing conditional dependency if not is_module_available("torchdata"): raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict=hash_dict, hash_type=hash_type + filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) From 263a25779f41b207662f45d10e13201c487634ba Mon Sep 17 00:00:00 2001 From: nayef211 Date: Mon, 24 Jan 2022 13:30:32 -0800 Subject: [PATCH 5/6] Added super().setUpClass() --- test/datasets/amazonreviewpolarity_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/datasets/amazonreviewpolarity_test.py b/test/datasets/amazonreviewpolarity_test.py index 3437038c11..857d2d2710 100644 --- a/test/datasets/amazonreviewpolarity_test.py +++ b/test/datasets/amazonreviewpolarity_test.py @@ -53,6 +53,7 @@ class TestAmazonReviewPolarity(TempDirMixin, TorchtextTestCase): @classmethod def setUpClass(cls): + super().setUpClass() cls.root_dir = cls.get_base_temp_dir() cls.samples = get_mock_dataset(cls.root_dir) From e4eca13476ff4fe2eb51e076c25694a268c451f7 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Mon, 24 Jan 2022 13:33:40 -0800 Subject: [PATCH 6/6] Remove commented import --- test/datasets/amazonreviewpolarity_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/datasets/amazonreviewpolarity_test.py b/test/datasets/amazonreviewpolarity_test.py index 857d2d2710..0d71529ec6 100644 --- a/test/datasets/amazonreviewpolarity_test.py +++ b/test/datasets/amazonreviewpolarity_test.py @@ -1,4 +1,3 @@ -# from pathlib import Path import os import random import string