Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 08f49f9

Browse files
Nayef211nayef211
andauthored
Parameterize tests for similar datasets (#1600)
* Parameterized amazon dataset tests * Renamed squad test for consistency * Deleted YelpReviewFull test since it's already parameterized Co-authored-by: nayef211 <[email protected]>
1 parent c3f59a5 commit 08f49f9

File tree

4 files changed

+34
-188
lines changed

4 files changed

+34
-188
lines changed

test/datasets/test_amazonreviewfull.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

test/datasets/test_amazonreviewpolarity.py renamed to test/datasets/test_amazonreviews.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,20 @@
55
from collections import defaultdict
66
from unittest.mock import patch
77

8-
from parameterized import parameterized
8+
from torchtext.datasets.amazonreviewfull import AmazonReviewFull
99
from torchtext.datasets.amazonreviewpolarity import AmazonReviewPolarity
1010

1111
from ..common.case_utils import TempDirMixin, zip_equal
12+
from ..common.parameterized_utils import nested_params
1213
from ..common.torchtext_test_case import TorchtextTestCase
1314

1415

15-
def _get_mock_dataset(root_dir):
16+
def _get_mock_dataset(root_dir, base_dir_name):
1617
"""
1718
root_dir: directory to the mocked dataset
19+
base_dir_name: AmazonReviewFull or AmazonReviewPolarity
1820
"""
19-
base_dir = os.path.join(root_dir, "AmazonReviewPolarity")
21+
base_dir = os.path.join(root_dir, base_dir_name)
2022
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir")
2123
os.makedirs(temp_dataset_dir, exist_ok=True)
2224

@@ -26,7 +28,10 @@ def _get_mock_dataset(root_dir):
2628
txt_file = os.path.join(temp_dataset_dir, file_name)
2729
with open(txt_file, "w") as f:
2830
for i in range(5):
29-
label = seed % 2 + 1
31+
if base_dir_name == AmazonReviewFull.__name__:
32+
label = seed % 5 + 1
33+
else:
34+
label = seed % 2 + 1
3035
rand_string = " ".join(
3136
random.choice(string.ascii_letters) for i in range(seed)
3237
)
@@ -36,25 +41,27 @@ def _get_mock_dataset(root_dir):
3641
f.write(f'"{label}","{rand_string}","{rand_string}"\n')
3742
seed += 1
3843

39-
compressed_dataset_path = os.path.join(
40-
base_dir, "amazon_review_polarity_csv.tar.gz"
41-
)
44+
if base_dir_name == AmazonReviewFull.__name__:
45+
archive_file_name = "amazon_review_full_csv"
46+
else:
47+
archive_file_name = "amazon_review_polarity_csv"
48+
49+
compressed_dataset_path = os.path.join(base_dir, f"{archive_file_name}.tar.gz")
4250
# create tar file from dataset folder
4351
with tarfile.open(compressed_dataset_path, "w:gz") as tar:
44-
tar.add(temp_dataset_dir, arcname="amazon_review_polarity_csv")
52+
tar.add(temp_dataset_dir, arcname=archive_file_name)
4553

4654
return mocked_data
4755

4856

49-
class TestAmazonReviewPolarity(TempDirMixin, TorchtextTestCase):
57+
class TestAmazonReviews(TempDirMixin, TorchtextTestCase):
5058
root_dir = None
5159
samples = []
5260

5361
@classmethod
5462
def setUpClass(cls):
5563
super().setUpClass()
5664
cls.root_dir = cls.get_base_temp_dir()
57-
cls.samples = _get_mock_dataset(cls.root_dir)
5865
cls.patcher = patch(
5966
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
6067
)
@@ -65,19 +72,24 @@ def tearDownClass(cls):
6572
cls.patcher.stop()
6673
super().tearDownClass()
6774

68-
@parameterized.expand(["train", "test"])
69-
def test_amazon_review_polarity(self, split):
70-
dataset = AmazonReviewPolarity(root=self.root_dir, split=split)
71-
75+
@nested_params([AmazonReviewFull, AmazonReviewPolarity], ["train", "test"])
76+
def test_amazon_reviews(self, amazon_review_dataset, split):
77+
expected_samples = _get_mock_dataset(
78+
self.root_dir, amazon_review_dataset.__name__
79+
)[split]
80+
dataset = amazon_review_dataset(root=self.root_dir, split=split)
7281
samples = list(dataset)
73-
expected_samples = self.samples[split]
82+
7483
for sample, expected_sample in zip_equal(samples, expected_samples):
7584
self.assertEqual(sample, expected_sample)
7685

77-
@parameterized.expand(["train", "test"])
78-
def test_amazon_review_polarity_split_argument(self, split):
79-
dataset1 = AmazonReviewPolarity(root=self.root_dir, split=split)
80-
(dataset2,) = AmazonReviewPolarity(root=self.root_dir, split=(split,))
86+
@nested_params([AmazonReviewFull, AmazonReviewPolarity], ["train", "test"])
87+
def test_amazon_reviews_split_argument(self, amazon_review_dataset, split):
88+
# call `_get_mock_dataset` to create mock dataset files
89+
_ = _get_mock_dataset(self.root_dir, amazon_review_dataset.__name__)
90+
91+
dataset1 = amazon_review_dataset(root=self.root_dir, split=split)
92+
(dataset2,) = amazon_review_dataset(root=self.root_dir, split=(split,))
8193

8294
for d1, d2 in zip_equal(dataset1, dataset2):
8395
self.assertEqual(d1, d2)

test/datasets/test_squad.py renamed to test/datasets/test_squads.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _get_mock_dataset(root_dir, base_dir_name):
7373
return mocked_data
7474

7575

76-
class TestSQuAD(TempDirMixin, TorchtextTestCase):
76+
class TestSQuADs(TempDirMixin, TorchtextTestCase):
7777
root_dir = None
7878
samples = []
7979

@@ -92,7 +92,7 @@ def tearDownClass(cls):
9292
super().tearDownClass()
9393

9494
@nested_params([SQuAD1, SQuAD2], ["train", "dev"])
95-
def test_squad(self, squad_dataset, split):
95+
def test_squads(self, squad_dataset, split):
9696
expected_samples = _get_mock_dataset(self.root_dir, squad_dataset.__name__)[
9797
split
9898
]
@@ -103,7 +103,7 @@ def test_squad(self, squad_dataset, split):
103103
self.assertEqual(sample, expected_sample)
104104

105105
@nested_params([SQuAD1, SQuAD2], ["train", "dev"])
106-
def test_squad_split_argument(self, squad_dataset, split):
106+
def test_squads_split_argument(self, squad_dataset, split):
107107
# call `_get_mock_dataset` to create mock dataset files
108108
_ = _get_mock_dataset(self.root_dir, squad_dataset.__name__)
109109

test/datasets/test_yelpreviewfull.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

0 commit comments

Comments
 (0)