Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8eee23c

Browse files
authored
IWSLT dataset with new url (#1115)
1 parent adc489b commit 8eee23c

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

test/data/test_builtin_datasets.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,34 @@ def test_imdb(self):
162162
self._helper_test_func(len(test_iter), 25000, next(iter(test_iter))[1][:25], 'I love sci-fi and am will')
163163
del train_iter, test_iter
164164

165+
def test_iwslt(self):
166+
from torchtext.experimental.datasets import IWSLT
167+
168+
train_dataset, valid_dataset, test_dataset = IWSLT()
169+
170+
self.assertEqual(len(train_dataset), 196884)
171+
self.assertEqual(len(valid_dataset), 993)
172+
self.assertEqual(len(test_dataset), 1305)
173+
174+
de_vocab, en_vocab = train_dataset.get_vocab()
175+
176+
def assert_nth_pair_is_equal(n, expected_sentence_pair):
177+
de_sentence = [de_vocab.itos[index] for index in train_dataset[n][0]]
178+
en_sentence = [en_vocab.itos[index] for index in train_dataset[n][1]]
179+
expected_de_sentence, expected_en_sentence = expected_sentence_pair
180+
181+
self.assertEqual(de_sentence, expected_de_sentence)
182+
self.assertEqual(en_sentence, expected_en_sentence)
183+
184+
assert_nth_pair_is_equal(0, (['David', 'Gallo', ':', 'Das', 'ist', 'Bill', 'Lange', '.', 'Ich', 'bin', 'Dave', 'Gallo', '.', '\n'],
185+
['David', 'Gallo', ':', 'This', 'is', 'Bill', 'Lange', '.', 'I', "'m", 'Dave', 'Gallo', '.', '\n']))
186+
assert_nth_pair_is_equal(10, (['Die', 'meisten', 'Tiere', 'leben', 'in', 'den', 'Ozeanen', '.', '\n'],
187+
['Most', 'of', 'the', 'animals', 'are', 'in', 'the', 'oceans', '.', '\n']))
188+
assert_nth_pair_is_equal(20, (['Es', 'ist', 'einer', 'meiner', 'Lieblinge', ',', 'weil', 'es', 'alle', 'möglichen', 'Funktionsteile', 'hat', '.', '\n'],
189+
['It', "'s", 'one', 'of', 'my', 'favorites', ',', 'because', 'it', "'s", 'got', 'all', 'sorts', 'of', 'working', 'parts', '.', '\n']))
190+
datafile = os.path.join(self.project_root, ".data", "2016-01.tgz")
191+
conditional_remove(datafile)
192+
165193
def test_multi30k(self):
166194
from torchtext.experimental.datasets import Multi30k
167195
# smoke test to ensure multi30k works properly

torchtext/experimental/datasets/raw/translation.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
'WMT14':
6363
'https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8',
6464
'IWSLT':
65-
'https://wit3.fbk.eu/archive/2016-01//texts/{}/{}/{}.tgz'
65+
'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8'
6666
}
6767

6868

@@ -125,14 +125,26 @@ def _setup_datasets(dataset_name,
125125
src_eval, tgt_eval = valid_filenames
126126
src_test, tgt_test = test_filenames
127127

128-
extracted_files = []
128+
extracted_files = [] # list of paths to the extracted files
129129
if isinstance(URLS[dataset_name], list):
130130
for idx, f in enumerate(URLS[dataset_name]):
131-
dataset_tar = download_from_url(f, root=root, hash_value=MD5[dataset_name][idx], hash_type='md5')
131+
dataset_tar = download_from_url(
132+
f, root=root, hash_value=MD5[dataset_name][idx], hash_type='md5')
132133
extracted_files.extend(extract_archive(dataset_tar))
133134
elif isinstance(URLS[dataset_name], str):
134135
dataset_tar = download_from_url(URLS[dataset_name], root=root, hash_value=MD5[dataset_name], hash_type='md5')
135-
extracted_files.extend(extract_archive(dataset_tar))
136+
extracted_dataset_tar = extract_archive(dataset_tar)
137+
if dataset_name == 'IWSLT':
138+
# IWSLT dataset's url downloads a multilingual tgz.
139+
# We need to take an extra step to pick out the specific language pair from it.
140+
src_language = train_filenames[0].split(".")[-1]
141+
tgt_language = train_filenames[1].split(".")[-1]
142+
languages = "-".join([src_language, tgt_language])
143+
iwslt_tar = '.data/2016-01/texts/{}/{}/{}.tgz'
144+
iwslt_tar = iwslt_tar.format(
145+
src_language, tgt_language, languages)
146+
extracted_dataset_tar = extract_archive(iwslt_tar)
147+
extracted_files.extend(extracted_dataset_tar)
136148
else:
137149
raise ValueError(
138150
"URLS for {} has to be in a form or list or string".format(
@@ -418,10 +430,6 @@ def IWSLT(train_filenames=('train.de-en.de', 'train.de-en.en'),
418430
>>> from torchtext.experimental.datasets.raw import IWSLT
419431
>>> train_dataset, valid_dataset, test_dataset = IWSLT()
420432
"""
421-
src_language = train_filenames[0].split(".")[-1]
422-
tgt_language = train_filenames[1].split(".")[-1]
423-
languages = "-".join([src_language, tgt_language])
424-
URLS["IWSLT"] = URLS["IWSLT"].format(src_language, tgt_language, languages)
425433
return _setup_datasets("IWSLT", train_filenames, valid_filenames, test_filenames, data_select, root)
426434

427435

@@ -567,6 +575,6 @@ def WMT14(train_filenames=('train.tok.clean.bpe.32000.de',
567575
'acb5ea26a577ceccfae6337181c31716',
568576
'873a377a348713d3ab84db1fb57cdede',
569577
'680816e0938fea5cf5331444bc09a4cf'],
570-
'IWSLT': '6ff9ab8ea16fb352597c2784e0391fa8',
578+
'IWSLT': 'c393ed3fc2a1b0f004b3331043f615ae',
571579
'WMT14': '874ab6bbfe9c21ec987ed1b9347f95ec'
572580
}

0 commit comments

Comments
 (0)