Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7b7a90d

Browse files
authored
IWSLT testing to start from compressed file (#1596)
1 parent 18b61fa commit 7b7a90d

File tree

1 file changed

+139
-43
lines changed

1 file changed

+139
-43
lines changed

test/datasets/test_iwslt2016.py

Lines changed: 139 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,141 @@
11
import os
22
import random
3+
import shutil
34
import string
5+
import tarfile
6+
import itertools
7+
import tempfile
48
from collections import defaultdict
59
from unittest.mock import patch
610

711
from parameterized import parameterized
8-
from torchtext.datasets.iwslt2016 import IWSLT2016
12+
from torchtext.datasets.iwslt2016 import DATASET_NAME, IWSLT2016, SUPPORTED_DATASETS, SET_NOT_EXISTS
913
from torchtext.data.datasets_utils import _generate_iwslt_files_for_lang_and_split
1014

11-
from ..common.case_utils import TempDirMixin, zip_equal
15+
from ..common.case_utils import zip_equal
1216
from ..common.torchtext_test_case import TorchtextTestCase
1317

14-
15-
def _get_mock_dataset(root_dir, split, src, tgt):
18+
SUPPORTED_LANGPAIRS = [(k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v]
19+
SUPPORTED_DEVTEST_SPLITS = SUPPORTED_DATASETS["valid_test"]
20+
DEV_TEST_SPLITS = [(dev, test) for dev, test in itertools.product(SUPPORTED_DEVTEST_SPLITS, repeat=2) if dev != test]
21+
22+
23+
def _generate_uncleaned_train():
24+
"""Generate tags files"""
25+
file_contents = []
26+
examples = []
27+
xml_tags = [
28+
'<url', '<keywords', '<talkid', '<description', '<reviewer',
29+
'<translator', '<title', '<speaker', '<doc', '</doc'
30+
]
31+
for i in range(100):
32+
rand_string = " ".join(
33+
random.choice(string.ascii_letters) for i in range(10)
34+
)
35+
# With a 10% change, add one of the XML tags which is cleaned
36+
# to ensure cleaning happens appropriately
37+
if random.random() < 0.1:
38+
open_tag = random.choice(xml_tags) + ">"
39+
close_tag = "</" + open_tag[1:] + ">"
40+
file_contents.append(open_tag + rand_string + close_tag)
41+
else:
42+
examples.append(rand_string + "\n")
43+
file_contents.append(rand_string)
44+
return examples, "\n".join(file_contents)
45+
46+
47+
def _generate_uncleaned_valid():
48+
file_contents = ["<root>"]
49+
examples = []
50+
51+
for doc_id in range(5):
52+
file_contents.append(f'<doc docid="{doc_id}" genre="lectures">')
53+
for seg_id in range(100):
54+
rand_string = " ".join(
55+
random.choice(string.ascii_letters) for i in range(10)
56+
)
57+
examples.append(rand_string)
58+
file_contents.append(f"<seg>{rand_string} </seg>" + "\n")
59+
file_contents.append("</doc>")
60+
file_contents.append("</root>")
61+
return examples, " ".join(file_contents)
62+
63+
64+
def _generate_uncleaned_test():
65+
return _generate_uncleaned_valid()
66+
67+
68+
def _generate_uncleaned_contents(split):
69+
return {
70+
"train": _generate_uncleaned_train(),
71+
"valid": _generate_uncleaned_valid(),
72+
"test": _generate_uncleaned_test(),
73+
}[split]
74+
75+
76+
def _get_mock_dataset(root_dir, split, src, tgt, valid_set, test_set):
1677
"""
1778
root_dir: directory to the mocked dataset
1879
"""
19-
temp_dataset_dir = os.path.join(root_dir, f"IWSLT2016/2016-01/texts/{src}/{tgt}/{src}-{tgt}/")
20-
os.makedirs(temp_dataset_dir, exist_ok=True)
2180

22-
seed = 1
81+
base_dir = os.path.join(root_dir, DATASET_NAME)
82+
temp_dataset_dir = os.path.join(base_dir, 'temp_dataset_dir')
83+
outer_temp_dataset_dir = os.path.join(temp_dataset_dir, f"texts/{src}/{tgt}/")
84+
inner_temp_dataset_dir = os.path.join(outer_temp_dataset_dir, f"{src}-{tgt}")
85+
86+
os.makedirs(outer_temp_dataset_dir, exist_ok=True)
87+
os.makedirs(inner_temp_dataset_dir, exist_ok=True)
88+
2389
mocked_data = defaultdict(lambda: defaultdict(list))
24-
valid_set = "tst2013"
25-
test_set = "tst2014"
26-
27-
files_for_split, _ = _generate_iwslt_files_for_lang_and_split(16, src, tgt, valid_set, test_set)
28-
src_file = files_for_split[src][split]
29-
tgt_file = files_for_split[tgt][split]
30-
for file_name in (src_file, tgt_file):
31-
txt_file = os.path.join(temp_dataset_dir, file_name)
32-
with open(txt_file, "w") as f:
33-
# Get file extension (i.e., the language) without the . prefix (.en -> en)
34-
lang = os.path.splitext(file_name)[1][1:]
35-
for i in range(5):
36-
rand_string = " ".join(
37-
random.choice(string.ascii_letters) for i in range(seed)
38-
)
39-
dataset_line = f"{rand_string} {rand_string}\n"
40-
# append line to correct dataset split
41-
mocked_data[split][lang].append(dataset_line)
42-
f.write(f'{rand_string} {rand_string}\n')
43-
seed += 1
90+
91+
cleaned_file_names, uncleaned_file_names = _generate_iwslt_files_for_lang_and_split(16, src, tgt, valid_set, test_set)
92+
uncleaned_src_file = uncleaned_file_names[src][split]
93+
uncleaned_tgt_file = uncleaned_file_names[tgt][split]
94+
95+
cleaned_src_file = cleaned_file_names[src][split]
96+
cleaned_tgt_file = cleaned_file_names[tgt][split]
97+
98+
for (unclean_file_name, clean_file_name) in [
99+
(uncleaned_src_file, cleaned_src_file),
100+
(uncleaned_tgt_file, cleaned_tgt_file)
101+
]:
102+
# Get file extension (i.e., the language) without the . prefix (.en -> en)
103+
lang = os.path.splitext(unclean_file_name)[1][1:]
104+
105+
out_file = os.path.join(inner_temp_dataset_dir, unclean_file_name)
106+
with open(out_file, "w") as f:
107+
mocked_data_for_split, file_contents = _generate_uncleaned_contents(split)
108+
mocked_data[split][lang] = mocked_data_for_split
109+
f.write(file_contents)
110+
111+
inner_compressed_dataset_path = os.path.join(
112+
outer_temp_dataset_dir, f"{src}-{tgt}.tgz"
113+
)
114+
115+
# create tar file from dataset folder
116+
with tarfile.open(inner_compressed_dataset_path, "w:gz") as tar:
117+
tar.add(inner_temp_dataset_dir, arcname=f"{src}-{tgt}")
118+
119+
# this is necessary so that the outer tarball only includes the inner tarball
120+
shutil.rmtree(inner_temp_dataset_dir)
121+
122+
outer_temp_dataset_path = os.path.join(base_dir, "2016-01.tgz")
123+
124+
with tarfile.open(outer_temp_dataset_path, "w:gz") as tar:
125+
tar.add(temp_dataset_dir, arcname="2016-01")
44126

45127
return list(zip(mocked_data[split][src], mocked_data[split][tgt]))
46128

47129

48-
class TestIWSLT2016(TempDirMixin, TorchtextTestCase):
130+
class TestIWSLT2016(TorchtextTestCase):
49131
root_dir = None
50132
patcher = None
51133

52134
@classmethod
53135
def setUpClass(cls):
54136
super().setUpClass()
55-
cls.root_dir = cls.get_base_temp_dir()
56137
cls.patcher = patch(
57-
"torchdata.datapipes.iter.util.cacheholder.OnDiskCacheHolderIterDataPipe._cache_check_fn", return_value=True
138+
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
58139
)
59140
cls.patcher.start()
60141

@@ -63,21 +144,36 @@ def tearDownClass(cls):
63144
cls.patcher.stop()
64145
super().tearDownClass()
65146

66-
@parameterized.expand([("train", "de", "en"), ("valid", "de", "en")])
67-
def test_iwslt2016(self, split, src, tgt):
68-
expected_samples = _get_mock_dataset(self.root_dir, split, src, tgt)
147+
@parameterized.expand([
148+
(split, src, tgt, dev_set, test_set)
149+
for split in ("train", "valid", "test")
150+
for dev_set, test_set in DEV_TEST_SPLITS
151+
for src, tgt in SUPPORTED_LANGPAIRS
152+
if (dev_set not in SET_NOT_EXISTS[(src, tgt)] and test_set not in SET_NOT_EXISTS[(src, tgt)])
153+
])
154+
def test_iwslt2016(self, split, src, tgt, dev_set, test_set):
69155

70-
dataset = IWSLT2016(root=self.root_dir, split=split)
156+
with tempfile.TemporaryDirectory() as root_dir:
157+
expected_samples = _get_mock_dataset(root_dir, split, src, tgt, dev_set, test_set)
71158

72-
samples = list(dataset)
159+
dataset = IWSLT2016(
160+
root=root_dir, split=split, language_pair=(src, tgt), valid_set=dev_set, test_set=test_set
161+
)
73162

74-
for sample, expected_sample in zip_equal(samples, expected_samples):
75-
self.assertEqual(sample, expected_sample)
163+
samples = list(dataset)
76164

77-
@parameterized.expand(["train", "valid"])
78-
def test_iwslt2016_split_argument(self, split):
79-
dataset1 = IWSLT2016(root=self.root_dir, split=split)
80-
(dataset2,) = IWSLT2016(root=self.root_dir, split=(split,))
165+
for sample, expected_sample in zip_equal(samples, expected_samples):
166+
self.assertEqual(sample, expected_sample)
81167

82-
for d1, d2 in zip_equal(dataset1, dataset2):
83-
self.assertEqual(d1, d2)
168+
@parameterized.expand(["train", "valid", "test"])
169+
def test_iwslt2016_split_argument(self, split):
170+
with tempfile.TemporaryDirectory() as root_dir:
171+
language_pair = ("de", "en")
172+
valid_set = "tst2013"
173+
test_set = "tst2014"
174+
_ = _get_mock_dataset(root_dir, split, language_pair[0], language_pair[1], valid_set, test_set)
175+
dataset1 = IWSLT2016(root=root_dir, split=split, language_pair=language_pair, valid_set=valid_set, test_set=test_set)
176+
(dataset2,) = IWSLT2016(root=root_dir, split=(split,), language_pair=language_pair, valid_set=valid_set, test_set=test_set)
177+
178+
for d1, d2 in zip_equal(dataset1, dataset2):
179+
self.assertEqual(d1, d2)

0 commit comments

Comments
 (0)