|
| 1 | +import os |
| 2 | +import random |
| 3 | +import shutil |
| 4 | +import string |
| 5 | +import tarfile |
| 6 | +import tempfile |
| 7 | +from collections import defaultdict |
| 8 | +from unittest.mock import patch |
| 9 | + |
| 10 | +from parameterized import parameterized |
| 11 | +from torchtext.datasets.iwslt2017 import DATASET_NAME, IWSLT2017, SUPPORTED_DATASETS, _PATH |
| 12 | +from torchtext.data.datasets_utils import _generate_iwslt_files_for_lang_and_split |
| 13 | + |
| 14 | +from ..common.case_utils import zip_equal |
| 15 | +from ..common.torchtext_test_case import TorchtextTestCase |
| 16 | + |
| 17 | +SUPPORTED_LANGPAIRS = [(k, e) for k, v in SUPPORTED_DATASETS["language_pair"].items() for e in v] |
| 18 | + |
| 19 | + |
| 20 | +def _generate_uncleaned_train(): |
| 21 | + """Generate tags files""" |
| 22 | + file_contents = [] |
| 23 | + examples = [] |
| 24 | + xml_tags = [ |
| 25 | + '<url', '<keywords', '<talkid', '<description', '<reviewer', |
| 26 | + '<translator', '<title', '<speaker', '<doc', '</doc' |
| 27 | + ] |
| 28 | + for i in range(100): |
| 29 | + rand_string = " ".join( |
| 30 | + random.choice(string.ascii_letters) for i in range(10) |
| 31 | + ) |
| 32 | + # With a 10% change, add one of the XML tags which is cleaned |
| 33 | + # to ensure cleaning happens appropriately |
| 34 | + if random.random() < 0.1: |
| 35 | + open_tag = random.choice(xml_tags) + ">" |
| 36 | + close_tag = "</" + open_tag[1:] + ">" |
| 37 | + file_contents.append(open_tag + rand_string + close_tag) |
| 38 | + else: |
| 39 | + examples.append(rand_string + "\n") |
| 40 | + file_contents.append(rand_string) |
| 41 | + return examples, "\n".join(file_contents) |
| 42 | + |
| 43 | + |
| 44 | +def _generate_uncleaned_valid(): |
| 45 | + file_contents = ["<root>"] |
| 46 | + examples = [] |
| 47 | + |
| 48 | + for doc_id in range(5): |
| 49 | + file_contents.append(f'<doc docid="{doc_id}" genre="lectures">') |
| 50 | + for seg_id in range(100): |
| 51 | + rand_string = " ".join( |
| 52 | + random.choice(string.ascii_letters) for i in range(10) |
| 53 | + ) |
| 54 | + examples.append(rand_string) |
| 55 | + file_contents.append(f"<seg>{rand_string} </seg>" + "\n") |
| 56 | + file_contents.append("</doc>") |
| 57 | + file_contents.append("</root>") |
| 58 | + return examples, " ".join(file_contents) |
| 59 | + |
| 60 | + |
| 61 | +def _generate_uncleaned_test(): |
| 62 | + return _generate_uncleaned_valid() |
| 63 | + |
| 64 | + |
| 65 | +def _generate_uncleaned_contents(split): |
| 66 | + return { |
| 67 | + "train": _generate_uncleaned_train(), |
| 68 | + "valid": _generate_uncleaned_valid(), |
| 69 | + "test": _generate_uncleaned_test(), |
| 70 | + }[split] |
| 71 | + |
| 72 | + |
| 73 | +def _get_mock_dataset(root_dir, split, src, tgt, valid_set, test_set): |
| 74 | + """ |
| 75 | + root_dir: directory to the mocked dataset |
| 76 | + """ |
| 77 | + |
| 78 | + base_dir = os.path.join(root_dir, DATASET_NAME) |
| 79 | + temp_dataset_dir = os.path.join(base_dir, 'temp_dataset_dir') |
| 80 | + outer_temp_dataset_dir = os.path.join(temp_dataset_dir, "texts/DeEnItNlRo/DeEnItNlRo") |
| 81 | + inner_temp_dataset_dir = os.path.join(outer_temp_dataset_dir, "DeEnItNlRo-DeEnItNlRo") |
| 82 | + |
| 83 | + os.makedirs(outer_temp_dataset_dir, exist_ok=True) |
| 84 | + os.makedirs(inner_temp_dataset_dir, exist_ok=True) |
| 85 | + |
| 86 | + mocked_data = defaultdict(lambda: defaultdict(list)) |
| 87 | + |
| 88 | + cleaned_file_names, uncleaned_file_names = _generate_iwslt_files_for_lang_and_split(17, src, tgt, valid_set, test_set) |
| 89 | + uncleaned_src_file = uncleaned_file_names[src][split] |
| 90 | + uncleaned_tgt_file = uncleaned_file_names[tgt][split] |
| 91 | + |
| 92 | + cleaned_src_file = cleaned_file_names[src][split] |
| 93 | + cleaned_tgt_file = cleaned_file_names[tgt][split] |
| 94 | + |
| 95 | + for (unclean_file_name, clean_file_name) in [ |
| 96 | + (uncleaned_src_file, cleaned_src_file), |
| 97 | + (uncleaned_tgt_file, cleaned_tgt_file) |
| 98 | + ]: |
| 99 | + # Get file extension (i.e., the language) without the . prefix (.en -> en) |
| 100 | + lang = os.path.splitext(unclean_file_name)[1][1:] |
| 101 | + |
| 102 | + out_file = os.path.join(inner_temp_dataset_dir, unclean_file_name) |
| 103 | + with open(out_file, "w") as f: |
| 104 | + mocked_data_for_split, file_contents = _generate_uncleaned_contents(split) |
| 105 | + mocked_data[split][lang] = mocked_data_for_split |
| 106 | + f.write(file_contents) |
| 107 | + |
| 108 | + inner_compressed_dataset_path = os.path.join( |
| 109 | + outer_temp_dataset_dir, "DeEnItNlRo-DeEnItNlRo.tgz" |
| 110 | + ) |
| 111 | + |
| 112 | + # create tar file from dataset folder |
| 113 | + with tarfile.open(inner_compressed_dataset_path, "w:gz") as tar: |
| 114 | + tar.add(inner_temp_dataset_dir, arcname="DeEnItNlRo-DeEnItNlRo") |
| 115 | + |
| 116 | + # this is necessary so that the outer tarball only includes the inner tarball |
| 117 | + shutil.rmtree(inner_temp_dataset_dir) |
| 118 | + |
| 119 | + outer_temp_dataset_path = os.path.join(base_dir, _PATH) |
| 120 | + |
| 121 | + with tarfile.open(outer_temp_dataset_path, "w:gz") as tar: |
| 122 | + tar.add(temp_dataset_dir, arcname=os.path.splitext(_PATH)[0]) |
| 123 | + |
| 124 | + return list(zip(mocked_data[split][src], mocked_data[split][tgt])) |
| 125 | + |
| 126 | + |
| 127 | +class TestIWSLT2017(TorchtextTestCase): |
| 128 | + root_dir = None |
| 129 | + patcher = None |
| 130 | + |
| 131 | + @classmethod |
| 132 | + def setUpClass(cls): |
| 133 | + super().setUpClass() |
| 134 | + cls.patcher = patch( |
| 135 | + "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True |
| 136 | + ) |
| 137 | + cls.patcher.start() |
| 138 | + |
| 139 | + @classmethod |
| 140 | + def tearDownClass(cls): |
| 141 | + cls.patcher.stop() |
| 142 | + super().tearDownClass() |
| 143 | + |
| 144 | + @parameterized.expand([ |
| 145 | + (split, src, tgt) |
| 146 | + for split in ("train", "valid", "test") |
| 147 | + for src, tgt in SUPPORTED_LANGPAIRS |
| 148 | + ]) |
| 149 | + def test_iwslt2017(self, split, src, tgt): |
| 150 | + |
| 151 | + with tempfile.TemporaryDirectory() as root_dir: |
| 152 | + expected_samples = _get_mock_dataset(root_dir, split, src, tgt, "dev2010", "tst2010") |
| 153 | + |
| 154 | + dataset = IWSLT2017(root=root_dir, split=split, language_pair=(src, tgt)) |
| 155 | + |
| 156 | + samples = list(dataset) |
| 157 | + |
| 158 | + for sample, expected_sample in zip_equal(samples, expected_samples): |
| 159 | + self.assertEqual(sample, expected_sample) |
| 160 | + |
| 161 | + @parameterized.expand(["train", "valid", "test"]) |
| 162 | + def test_iwslt2017_split_argument(self, split): |
| 163 | + with tempfile.TemporaryDirectory() as root_dir: |
| 164 | + language_pair = ("de", "en") |
| 165 | + valid_set = "dev2010" |
| 166 | + test_set = "tst2010" |
| 167 | + _ = _get_mock_dataset(root_dir, split, language_pair[0], language_pair[1], valid_set, test_set) |
| 168 | + dataset1 = IWSLT2017(root=root_dir, split=split, language_pair=language_pair) |
| 169 | + (dataset2,) = IWSLT2017(root=root_dir, split=(split,), language_pair=language_pair) |
| 170 | + |
| 171 | + for d1, d2 in zip_equal(dataset1, dataset2): |
| 172 | + self.assertEqual(d1, d2) |
0 commit comments