Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit c3f59a5

Browse files
authored
Add Mock test for IWSLT2017 dataset (#1598)
1 parent 7b7a90d commit c3f59a5

File tree

2 files changed

+175
-4
lines changed

2 files changed

+175
-4
lines changed

test/datasets/test_iwslt2017.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import os
2+
import random
3+
import shutil
4+
import string
5+
import tarfile
6+
import tempfile
7+
from collections import defaultdict
8+
from unittest.mock import patch
9+
10+
from parameterized import parameterized
11+
from torchtext.datasets.iwslt2017 import DATASET_NAME, IWSLT2017, SUPPORTED_DATASETS, _PATH
12+
from torchtext.data.datasets_utils import _generate_iwslt_files_for_lang_and_split
13+
14+
from ..common.case_utils import zip_equal
15+
from ..common.torchtext_test_case import TorchtextTestCase
16+
17+
SUPPORTED_LANGPAIRS = [(k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v]
18+
19+
20+
def _generate_uncleaned_train():
21+
"""Generate tags files"""
22+
file_contents = []
23+
examples = []
24+
xml_tags = [
25+
'<url', '<keywords', '<talkid', '<description', '<reviewer',
26+
'<translator', '<title', '<speaker', '<doc', '</doc'
27+
]
28+
for i in range(100):
29+
rand_string = " ".join(
30+
random.choice(string.ascii_letters) for i in range(10)
31+
)
32+
# With a 10% change, add one of the XML tags which is cleaned
33+
# to ensure cleaning happens appropriately
34+
if random.random() < 0.1:
35+
open_tag = random.choice(xml_tags) + ">"
36+
close_tag = "</" + open_tag[1:] + ">"
37+
file_contents.append(open_tag + rand_string + close_tag)
38+
else:
39+
examples.append(rand_string + "\n")
40+
file_contents.append(rand_string)
41+
return examples, "\n".join(file_contents)
42+
43+
44+
def _generate_uncleaned_valid():
45+
file_contents = ["<root>"]
46+
examples = []
47+
48+
for doc_id in range(5):
49+
file_contents.append(f'<doc docid="{doc_id}" genre="lectures">')
50+
for seg_id in range(100):
51+
rand_string = " ".join(
52+
random.choice(string.ascii_letters) for i in range(10)
53+
)
54+
examples.append(rand_string)
55+
file_contents.append(f"<seg>{rand_string} </seg>" + "\n")
56+
file_contents.append("</doc>")
57+
file_contents.append("</root>")
58+
return examples, " ".join(file_contents)
59+
60+
61+
def _generate_uncleaned_test():
62+
return _generate_uncleaned_valid()
63+
64+
65+
def _generate_uncleaned_contents(split):
66+
return {
67+
"train": _generate_uncleaned_train(),
68+
"valid": _generate_uncleaned_valid(),
69+
"test": _generate_uncleaned_test(),
70+
}[split]
71+
72+
73+
def _get_mock_dataset(root_dir, split, src, tgt, valid_set, test_set):
74+
"""
75+
root_dir: directory to the mocked dataset
76+
"""
77+
78+
base_dir = os.path.join(root_dir, DATASET_NAME)
79+
temp_dataset_dir = os.path.join(base_dir, 'temp_dataset_dir')
80+
outer_temp_dataset_dir = os.path.join(temp_dataset_dir, "texts/DeEnItNlRo/DeEnItNlRo")
81+
inner_temp_dataset_dir = os.path.join(outer_temp_dataset_dir, "DeEnItNlRo-DeEnItNlRo")
82+
83+
os.makedirs(outer_temp_dataset_dir, exist_ok=True)
84+
os.makedirs(inner_temp_dataset_dir, exist_ok=True)
85+
86+
mocked_data = defaultdict(lambda: defaultdict(list))
87+
88+
cleaned_file_names, uncleaned_file_names = _generate_iwslt_files_for_lang_and_split(17, src, tgt, valid_set, test_set)
89+
uncleaned_src_file = uncleaned_file_names[src][split]
90+
uncleaned_tgt_file = uncleaned_file_names[tgt][split]
91+
92+
cleaned_src_file = cleaned_file_names[src][split]
93+
cleaned_tgt_file = cleaned_file_names[tgt][split]
94+
95+
for (unclean_file_name, clean_file_name) in [
96+
(uncleaned_src_file, cleaned_src_file),
97+
(uncleaned_tgt_file, cleaned_tgt_file)
98+
]:
99+
# Get file extension (i.e., the language) without the . prefix (.en -> en)
100+
lang = os.path.splitext(unclean_file_name)[1][1:]
101+
102+
out_file = os.path.join(inner_temp_dataset_dir, unclean_file_name)
103+
with open(out_file, "w") as f:
104+
mocked_data_for_split, file_contents = _generate_uncleaned_contents(split)
105+
mocked_data[split][lang] = mocked_data_for_split
106+
f.write(file_contents)
107+
108+
inner_compressed_dataset_path = os.path.join(
109+
outer_temp_dataset_dir, "DeEnItNlRo-DeEnItNlRo.tgz"
110+
)
111+
112+
# create tar file from dataset folder
113+
with tarfile.open(inner_compressed_dataset_path, "w:gz") as tar:
114+
tar.add(inner_temp_dataset_dir, arcname="DeEnItNlRo-DeEnItNlRo")
115+
116+
# this is necessary so that the outer tarball only includes the inner tarball
117+
shutil.rmtree(inner_temp_dataset_dir)
118+
119+
outer_temp_dataset_path = os.path.join(base_dir, _PATH)
120+
121+
with tarfile.open(outer_temp_dataset_path, "w:gz") as tar:
122+
tar.add(temp_dataset_dir, arcname=os.path.splitext(_PATH)[0])
123+
124+
return list(zip(mocked_data[split][src], mocked_data[split][tgt]))
125+
126+
127+
class TestIWSLT2017(TorchtextTestCase):
128+
root_dir = None
129+
patcher = None
130+
131+
@classmethod
132+
def setUpClass(cls):
133+
super().setUpClass()
134+
cls.patcher = patch(
135+
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
136+
)
137+
cls.patcher.start()
138+
139+
@classmethod
140+
def tearDownClass(cls):
141+
cls.patcher.stop()
142+
super().tearDownClass()
143+
144+
@parameterized.expand([
145+
(split, src, tgt)
146+
for split in ("train", "valid", "test")
147+
for src, tgt in SUPPORTED_LANGPAIRS
148+
])
149+
def test_iwslt2017(self, split, src, tgt):
150+
151+
with tempfile.TemporaryDirectory() as root_dir:
152+
expected_samples = _get_mock_dataset(root_dir, split, src, tgt, "dev2010", "tst2010")
153+
154+
dataset = IWSLT2017(root=root_dir, split=split, language_pair=(src, tgt))
155+
156+
samples = list(dataset)
157+
158+
for sample, expected_sample in zip_equal(samples, expected_samples):
159+
self.assertEqual(sample, expected_sample)
160+
161+
@parameterized.expand(["train", "valid", "test"])
162+
def test_iwslt2017_split_argument(self, split):
163+
with tempfile.TemporaryDirectory() as root_dir:
164+
language_pair = ("de", "en")
165+
valid_set = "dev2010"
166+
test_set = "tst2010"
167+
_ = _get_mock_dataset(root_dir, split, language_pair[0], language_pair[1], valid_set, test_set)
168+
dataset1 = IWSLT2017(root=root_dir, split=split, language_pair=language_pair)
169+
(dataset2,) = IWSLT2017(root=root_dir, split=(split,), language_pair=language_pair)
170+
171+
for d1, d2 in zip_equal(dataset1, dataset2):
172+
self.assertEqual(d1, d2)

torchtext/datasets/iwslt2017.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,15 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de
194194
)
195195

196196
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar)
197-
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(
198-
lambda x: os.path.basename(inner_iwslt_tar) in x[0])
197+
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar()
199198
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
200199

201200
src_filename = file_path_by_lang_and_split[src_language][split]
202201
uncleaned_src_filename = uncleaned_filenames_by_lang_and_split[src_language][split]
203202

204203
# We create the whole filepath here, but only check for the literal filename in the filter
205204
# because we're lazily extracting from the outer tarfile.
206-
full_src_filepath = os.path.join(root, "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo", src_filename)
205+
full_src_filepath = os.path.join(root, os.path.splitext(_PATH)[0], "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo", src_filename)
207206

208207
cache_inner_src_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_src_filepath,
209208
uncleaned_src_filename)
@@ -213,7 +212,7 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de
213212

214213
# We create the whole filepath here, but only check for the literal filename in the filter
215214
# because we're lazily extracting from the outer tarfile.
216-
full_tgt_filepath = os.path.join(root, "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo", tgt_filename)
215+
full_tgt_filepath = os.path.join(root, os.path.splitext(_PATH)[0], "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo", tgt_filename)
217216

218217
cache_inner_tgt_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_tgt_filepath,
219218
uncleaned_tgt_filename)

0 commit comments

Comments
 (0)