Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 84b719e

Browse files
authored
Merge YelpReviewPolarity and YelpReviewFull Mocked Unit Tests (#1567)
1 parent 99eb1f8 commit 84b719e

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

test/datasets/test_yelpreviews.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import os
2+
import random
3+
import string
4+
import tarfile
5+
from collections import defaultdict
6+
from unittest.mock import patch
7+
8+
from ..common.parameterized_utils import nested_params
9+
from torchtext.datasets.yelpreviewpolarity import YelpReviewPolarity
10+
from torchtext.datasets.yelpreviewfull import YelpReviewFull
11+
12+
from ..common.case_utils import TempDirMixin, zip_equal
13+
from ..common.torchtext_test_case import TorchtextTestCase
14+
15+
16+
def _get_mock_dataset(root_dir, base_dir_name):
17+
"""
18+
root_dir: directory to the mocked dataset
19+
base_dir_name: YelpReviewPolarity or YelpReviewFull
20+
"""
21+
base_dir = os.path.join(root_dir, base_dir_name)
22+
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir")
23+
os.makedirs(temp_dataset_dir, exist_ok=True)
24+
25+
seed = 1
26+
mocked_data = defaultdict(list)
27+
for file_name in ("train.csv", "test.csv"):
28+
csv_file = os.path.join(temp_dataset_dir, file_name)
29+
mocked_lines = mocked_data[os.path.splitext(file_name)[0]]
30+
with open(csv_file, "w") as f:
31+
for i in range(5):
32+
if base_dir_name == YelpReviewPolarity.__name__:
33+
label = seed % 2 + 1
34+
else:
35+
label = seed % 5 + 1
36+
rand_string = " ".join(
37+
random.choice(string.ascii_letters) for i in range(seed)
38+
)
39+
dataset_line = (label, f"{rand_string}")
40+
f.write(f'"{label}","{rand_string}"\n')
41+
42+
# append line to correct dataset split
43+
mocked_lines.append(dataset_line)
44+
seed += 1
45+
46+
if base_dir_name == YelpReviewPolarity.__name__:
47+
compressed_file = "yelp_review_polarity_csv"
48+
else:
49+
compressed_file = "yelp_review_full_csv"
50+
51+
compressed_dataset_path = os.path.join(base_dir, compressed_file + ".tar.gz")
52+
# create gz file from dataset folder
53+
with tarfile.open(compressed_dataset_path, "w:gz") as tar:
54+
tar.add(temp_dataset_dir, arcname=compressed_file)
55+
56+
return mocked_data
57+
58+
59+
class TestYelpReviews(TempDirMixin, TorchtextTestCase):
60+
root_dir = None
61+
samples = []
62+
63+
@classmethod
64+
def setUpClass(cls):
65+
super().setUpClass()
66+
cls.root_dir = cls.get_base_temp_dir()
67+
cls.patcher = patch(
68+
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
69+
)
70+
cls.patcher.start()
71+
72+
@classmethod
73+
def tearDownClass(cls):
74+
cls.patcher.stop()
75+
super().tearDownClass()
76+
77+
@nested_params([YelpReviewPolarity, YelpReviewFull], ["train", "test"])
78+
def test_yelpreviews(self, yelp_dataset, split):
79+
expected_samples = _get_mock_dataset(self.root_dir, base_dir_name=yelp_dataset.__name__)[split]
80+
81+
dataset = yelp_dataset(root=self.root_dir, split=split)
82+
samples = list(dataset)
83+
for sample, expected_sample in zip_equal(samples, expected_samples):
84+
self.assertEqual(sample, expected_sample)
85+
86+
@nested_params([YelpReviewPolarity, YelpReviewFull], ["train", "test"])
87+
def test_yelpreviews_split_argument(self, yelp_dataset, split):
88+
# call `_get_mock_dataset` to create mock dataset files
89+
_ = _get_mock_dataset(self.root_dir, yelp_dataset.__name__)
90+
91+
dataset1 = yelp_dataset(root=self.root_dir, split=split)
92+
(dataset2,) = yelp_dataset(root=self.root_dir, split=(split,))
93+
94+
for d1, d2 in zip_equal(dataset1, dataset2):
95+
self.assertEqual(d1, d2)

0 commit comments

Comments
 (0)