Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e710e3a

Browse files
Nayef211nayef211
andauthored
[FORMATTING] Update formatting for dataset tests (#1601)
* Parameterized amazon dataset tests * Renamed squad test for consistency * Deleted YelpReviewFull test since it's already parameterized * Updated formatting for datasets Co-authored-by: nayef211 <[email protected]>
1 parent 9686e0d commit e710e3a

File tree

10 files changed

+184
-77
lines changed

10 files changed

+184
-77
lines changed

test/datasets/test_cc100.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1+
import lzma
12
import os
23
import random
34
import string
4-
import lzma
5-
from parameterized import parameterized
65
from collections import defaultdict
76
from unittest.mock import patch
87

8+
from parameterized import parameterized
99
from torchtext.datasets import CC100
10+
from torchtext.datasets.cc100 import VALID_CODES
1011

1112
from ..common.case_utils import TempDirMixin, zip_equal
1213
from ..common.torchtext_test_case import TorchtextTestCase
1314

14-
from torchtext.datasets.cc100 import VALID_CODES
15-
1615

1716
def _get_mock_dataset(root_dir):
1817
"""

test/datasets/test_conll2000chunking.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import gzip
12
import os
23
import random
34
import string
4-
import gzip
55
from collections import defaultdict
66
from unittest.mock import patch
77

@@ -27,11 +27,19 @@ def _get_mock_dataset(root_dir):
2727
mocked_lines = mocked_data[os.path.splitext(file_name)[0]]
2828
with open(txt_file, "w") as f:
2929
for i in range(5):
30-
rand_strings = [random.choice(string.ascii_letters) for i in range(seed)]
31-
rand_label_1 = [random.choice(string.ascii_letters) for i in range(seed)]
32-
rand_label_2 = [random.choice(string.ascii_letters) for i in range(seed)]
30+
rand_strings = [
31+
random.choice(string.ascii_letters) for i in range(seed)
32+
]
33+
rand_label_1 = [
34+
random.choice(string.ascii_letters) for i in range(seed)
35+
]
36+
rand_label_2 = [
37+
random.choice(string.ascii_letters) for i in range(seed)
38+
]
3339
# one token per line (each sample ends with an extra \n)
34-
for rand_string, label_1, label_2 in zip(rand_strings, rand_label_1, rand_label_2):
40+
for rand_string, label_1, label_2 in zip(
41+
rand_strings, rand_label_1, rand_label_2
42+
):
3543
f.write(f"{rand_string} {label_1} {label_2}\n")
3644
f.write("\n")
3745
dataset_line = (rand_strings, rand_label_1, rand_label_2)
@@ -41,7 +49,9 @@ def _get_mock_dataset(root_dir):
4149

4250
# create gz file from dataset folder
4351
compressed_dataset_path = os.path.join(base_dir, f"{file_name}.gz")
44-
with gzip.open(compressed_dataset_path, "wb") as gz_file, open(txt_file, "rb") as file_in:
52+
with gzip.open(compressed_dataset_path, "wb") as gz_file, open(
53+
txt_file, "rb"
54+
) as file_in:
4555
gz_file.writelines(file_in)
4656

4757
return mocked_data

test/datasets/test_enwik9.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ def _get_mock_dataset(root_dir):
2424
mocked_data = []
2525
with open(txt_file, "w") as f:
2626
for i in range(5):
27-
rand_string = "<" + " ".join(
28-
random.choice(string.ascii_letters) for i in range(seed)
29-
) + ">"
30-
dataset_line = (f"'{rand_string}'")
27+
rand_string = (
28+
"<"
29+
+ " ".join(random.choice(string.ascii_letters) for i in range(seed))
30+
+ ">"
31+
)
32+
dataset_line = f"'{rand_string}'"
3133
f.write(f"'{rand_string}'\n")
3234

3335
# append line to correct dataset split

test/datasets/test_iwslt2016.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,54 @@
1+
import itertools
12
import os
23
import random
34
import shutil
45
import string
56
import tarfile
6-
import itertools
77
import tempfile
88
from collections import defaultdict
99
from unittest.mock import patch
1010

1111
from parameterized import parameterized
12-
from torchtext.datasets.iwslt2016 import DATASET_NAME, IWSLT2016, SUPPORTED_DATASETS, SET_NOT_EXISTS
1312
from torchtext.data.datasets_utils import _generate_iwslt_files_for_lang_and_split
13+
from torchtext.datasets.iwslt2016 import (
14+
DATASET_NAME,
15+
IWSLT2016,
16+
SUPPORTED_DATASETS,
17+
SET_NOT_EXISTS,
18+
)
1419

1520
from ..common.case_utils import zip_equal
1621
from ..common.torchtext_test_case import TorchtextTestCase
1722

18-
SUPPORTED_LANGPAIRS = [(k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v]
23+
SUPPORTED_LANGPAIRS = [
24+
(k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v
25+
]
1926
SUPPORTED_DEVTEST_SPLITS = SUPPORTED_DATASETS["valid_test"]
20-
DEV_TEST_SPLITS = [(dev, test) for dev, test in itertools.product(SUPPORTED_DEVTEST_SPLITS, repeat=2) if dev != test]
27+
DEV_TEST_SPLITS = [
28+
(dev, test)
29+
for dev, test in itertools.product(SUPPORTED_DEVTEST_SPLITS, repeat=2)
30+
if dev != test
31+
]
2132

2233

2334
def _generate_uncleaned_train():
2435
"""Generate tags files"""
2536
file_contents = []
2637
examples = []
2738
xml_tags = [
28-
'<url', '<keywords', '<talkid', '<description', '<reviewer',
29-
'<translator', '<title', '<speaker', '<doc', '</doc'
39+
"<url",
40+
"<keywords",
41+
"<talkid",
42+
"<description",
43+
"<reviewer",
44+
"<translator",
45+
"<title",
46+
"<speaker",
47+
"<doc",
48+
"</doc",
3049
]
3150
for i in range(100):
32-
rand_string = " ".join(
33-
random.choice(string.ascii_letters) for i in range(10)
34-
)
51+
rand_string = " ".join(random.choice(string.ascii_letters) for i in range(10))
3552
# With a 10% change, add one of the XML tags which is cleaned
3653
# to ensure cleaning happens appropriately
3754
if random.random() < 0.1:
@@ -79,7 +96,7 @@ def _get_mock_dataset(root_dir, split, src, tgt, valid_set, test_set):
7996
"""
8097

8198
base_dir = os.path.join(root_dir, DATASET_NAME)
82-
temp_dataset_dir = os.path.join(base_dir, 'temp_dataset_dir')
99+
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir")
83100
outer_temp_dataset_dir = os.path.join(temp_dataset_dir, f"texts/{src}/{tgt}/")
84101
inner_temp_dataset_dir = os.path.join(outer_temp_dataset_dir, f"{src}-{tgt}")
85102

@@ -88,7 +105,9 @@ def _get_mock_dataset(root_dir, split, src, tgt, valid_set, test_set):
88105

89106
mocked_data = defaultdict(lambda: defaultdict(list))
90107

91-
cleaned_file_names, uncleaned_file_names = _generate_iwslt_files_for_lang_and_split(16, src, tgt, valid_set, test_set)
108+
cleaned_file_names, uncleaned_file_names = _generate_iwslt_files_for_lang_and_split(
109+
16, src, tgt, valid_set, test_set
110+
)
92111
uncleaned_src_file = uncleaned_file_names[src][split]
93112
uncleaned_tgt_file = uncleaned_file_names[tgt][split]
94113

@@ -97,7 +116,7 @@ def _get_mock_dataset(root_dir, split, src, tgt, valid_set, test_set):
97116

98117
for (unclean_file_name, clean_file_name) in [
99118
(uncleaned_src_file, cleaned_src_file),
100-
(uncleaned_tgt_file, cleaned_tgt_file)
119+
(uncleaned_tgt_file, cleaned_tgt_file),
101120
]:
102121
# Get file extension (i.e., the language) without the . prefix (.en -> en)
103122
lang = os.path.splitext(unclean_file_name)[1][1:]
@@ -144,20 +163,31 @@ def tearDownClass(cls):
144163
cls.patcher.stop()
145164
super().tearDownClass()
146165

147-
@parameterized.expand([
148-
(split, src, tgt, dev_set, test_set)
149-
for split in ("train", "valid", "test")
150-
for dev_set, test_set in DEV_TEST_SPLITS
151-
for src, tgt in SUPPORTED_LANGPAIRS
152-
if (dev_set not in SET_NOT_EXISTS[(src, tgt)] and test_set not in SET_NOT_EXISTS[(src, tgt)])
153-
])
166+
@parameterized.expand(
167+
[
168+
(split, src, tgt, dev_set, test_set)
169+
for split in ("train", "valid", "test")
170+
for dev_set, test_set in DEV_TEST_SPLITS
171+
for src, tgt in SUPPORTED_LANGPAIRS
172+
if (
173+
dev_set not in SET_NOT_EXISTS[(src, tgt)]
174+
and test_set not in SET_NOT_EXISTS[(src, tgt)]
175+
)
176+
]
177+
)
154178
def test_iwslt2016(self, split, src, tgt, dev_set, test_set):
155179

156180
with tempfile.TemporaryDirectory() as root_dir:
157-
expected_samples = _get_mock_dataset(root_dir, split, src, tgt, dev_set, test_set)
181+
expected_samples = _get_mock_dataset(
182+
root_dir, split, src, tgt, dev_set, test_set
183+
)
158184

159185
dataset = IWSLT2016(
160-
root=root_dir, split=split, language_pair=(src, tgt), valid_set=dev_set, test_set=test_set
186+
root=root_dir,
187+
split=split,
188+
language_pair=(src, tgt),
189+
valid_set=dev_set,
190+
test_set=test_set,
161191
)
162192

163193
samples = list(dataset)
@@ -171,9 +201,23 @@ def test_iwslt2016_split_argument(self, split):
171201
language_pair = ("de", "en")
172202
valid_set = "tst2013"
173203
test_set = "tst2014"
174-
_ = _get_mock_dataset(root_dir, split, language_pair[0], language_pair[1], valid_set, test_set)
175-
dataset1 = IWSLT2016(root=root_dir, split=split, language_pair=language_pair, valid_set=valid_set, test_set=test_set)
176-
(dataset2,) = IWSLT2016(root=root_dir, split=(split,), language_pair=language_pair, valid_set=valid_set, test_set=test_set)
204+
_ = _get_mock_dataset(
205+
root_dir, split, language_pair[0], language_pair[1], valid_set, test_set
206+
)
207+
dataset1 = IWSLT2016(
208+
root=root_dir,
209+
split=split,
210+
language_pair=language_pair,
211+
valid_set=valid_set,
212+
test_set=test_set,
213+
)
214+
(dataset2,) = IWSLT2016(
215+
root=root_dir,
216+
split=(split,),
217+
language_pair=language_pair,
218+
valid_set=valid_set,
219+
test_set=test_set,
220+
)
177221

178222
for d1, d2 in zip_equal(dataset1, dataset2):
179223
self.assertEqual(d1, d2)

test/datasets/test_iwslt2017.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,40 @@
88
from unittest.mock import patch
99

1010
from parameterized import parameterized
11-
from torchtext.datasets.iwslt2017 import DATASET_NAME, IWSLT2017, SUPPORTED_DATASETS, _PATH
1211
from torchtext.data.datasets_utils import _generate_iwslt_files_for_lang_and_split
12+
from torchtext.datasets.iwslt2017 import (
13+
DATASET_NAME,
14+
IWSLT2017,
15+
SUPPORTED_DATASETS,
16+
_PATH,
17+
)
1318

1419
from ..common.case_utils import zip_equal
1520
from ..common.torchtext_test_case import TorchtextTestCase
1621

17-
SUPPORTED_LANGPAIRS = [(k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v]
22+
SUPPORTED_LANGPAIRS = [
23+
(k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v
24+
]
1825

1926

2027
def _generate_uncleaned_train():
2128
"""Generate tags files"""
2229
file_contents = []
2330
examples = []
2431
xml_tags = [
25-
'<url', '<keywords', '<talkid', '<description', '<reviewer',
26-
'<translator', '<title', '<speaker', '<doc', '</doc'
32+
"<url",
33+
"<keywords",
34+
"<talkid",
35+
"<description",
36+
"<reviewer",
37+
"<translator",
38+
"<title",
39+
"<speaker",
40+
"<doc",
41+
"</doc",
2742
]
2843
for i in range(100):
29-
rand_string = " ".join(
30-
random.choice(string.ascii_letters) for i in range(10)
31-
)
44+
rand_string = " ".join(random.choice(string.ascii_letters) for i in range(10))
3245
# With a 10% change, add one of the XML tags which is cleaned
3346
# to ensure cleaning happens appropriately
3447
if random.random() < 0.1:
@@ -76,16 +89,22 @@ def _get_mock_dataset(root_dir, split, src, tgt, valid_set, test_set):
7689
"""
7790

7891
base_dir = os.path.join(root_dir, DATASET_NAME)
79-
temp_dataset_dir = os.path.join(base_dir, 'temp_dataset_dir')
80-
outer_temp_dataset_dir = os.path.join(temp_dataset_dir, "texts/DeEnItNlRo/DeEnItNlRo")
81-
inner_temp_dataset_dir = os.path.join(outer_temp_dataset_dir, "DeEnItNlRo-DeEnItNlRo")
92+
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir")
93+
outer_temp_dataset_dir = os.path.join(
94+
temp_dataset_dir, "texts/DeEnItNlRo/DeEnItNlRo"
95+
)
96+
inner_temp_dataset_dir = os.path.join(
97+
outer_temp_dataset_dir, "DeEnItNlRo-DeEnItNlRo"
98+
)
8299

83100
os.makedirs(outer_temp_dataset_dir, exist_ok=True)
84101
os.makedirs(inner_temp_dataset_dir, exist_ok=True)
85102

86103
mocked_data = defaultdict(lambda: defaultdict(list))
87104

88-
cleaned_file_names, uncleaned_file_names = _generate_iwslt_files_for_lang_and_split(17, src, tgt, valid_set, test_set)
105+
cleaned_file_names, uncleaned_file_names = _generate_iwslt_files_for_lang_and_split(
106+
17, src, tgt, valid_set, test_set
107+
)
89108
uncleaned_src_file = uncleaned_file_names[src][split]
90109
uncleaned_tgt_file = uncleaned_file_names[tgt][split]
91110

@@ -94,7 +113,7 @@ def _get_mock_dataset(root_dir, split, src, tgt, valid_set, test_set):
94113

95114
for (unclean_file_name, clean_file_name) in [
96115
(uncleaned_src_file, cleaned_src_file),
97-
(uncleaned_tgt_file, cleaned_tgt_file)
116+
(uncleaned_tgt_file, cleaned_tgt_file),
98117
]:
99118
# Get file extension (i.e., the language) without the . prefix (.en -> en)
100119
lang = os.path.splitext(unclean_file_name)[1][1:]
@@ -141,15 +160,19 @@ def tearDownClass(cls):
141160
cls.patcher.stop()
142161
super().tearDownClass()
143162

144-
@parameterized.expand([
145-
(split, src, tgt)
146-
for split in ("train", "valid", "test")
147-
for src, tgt in SUPPORTED_LANGPAIRS
148-
])
163+
@parameterized.expand(
164+
[
165+
(split, src, tgt)
166+
for split in ("train", "valid", "test")
167+
for src, tgt in SUPPORTED_LANGPAIRS
168+
]
169+
)
149170
def test_iwslt2017(self, split, src, tgt):
150171

151172
with tempfile.TemporaryDirectory() as root_dir:
152-
expected_samples = _get_mock_dataset(root_dir, split, src, tgt, "dev2010", "tst2010")
173+
expected_samples = _get_mock_dataset(
174+
root_dir, split, src, tgt, "dev2010", "tst2010"
175+
)
153176

154177
dataset = IWSLT2017(root=root_dir, split=split, language_pair=(src, tgt))
155178

@@ -164,9 +187,15 @@ def test_iwslt2017_split_argument(self, split):
164187
language_pair = ("de", "en")
165188
valid_set = "dev2010"
166189
test_set = "tst2010"
167-
_ = _get_mock_dataset(root_dir, split, language_pair[0], language_pair[1], valid_set, test_set)
168-
dataset1 = IWSLT2017(root=root_dir, split=split, language_pair=language_pair)
169-
(dataset2,) = IWSLT2017(root=root_dir, split=(split,), language_pair=language_pair)
190+
_ = _get_mock_dataset(
191+
root_dir, split, language_pair[0], language_pair[1], valid_set, test_set
192+
)
193+
dataset1 = IWSLT2017(
194+
root=root_dir, split=split, language_pair=language_pair
195+
)
196+
(dataset2,) = IWSLT2017(
197+
root=root_dir, split=(split,), language_pair=language_pair
198+
)
170199

171200
for d1, d2 in zip_equal(dataset1, dataset2):
172201
self.assertEqual(d1, d2)

0 commit comments

Comments
 (0)