diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 0e2d320f42..2c8190c44a 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -162,6 +162,34 @@ def test_imdb(self): self._helper_test_func(len(test_iter), 25000, next(iter(test_iter))[1][:25], 'I love sci-fi and am will') del train_iter, test_iter + def test_iwslt(self): + from torchtext.experimental.datasets import IWSLT + + train_dataset, valid_dataset, test_dataset = IWSLT() + + self.assertEqual(len(train_dataset), 196884) + self.assertEqual(len(valid_dataset), 993) + self.assertEqual(len(test_dataset), 1305) + + de_vocab, en_vocab = train_dataset.get_vocab() + + def assert_nth_pair_is_equal(n, expected_sentence_pair): + de_sentence = [de_vocab.itos[index] for index in train_dataset[n][0]] + en_sentence = [en_vocab.itos[index] for index in train_dataset[n][1]] + expected_de_sentence, expected_en_sentence = expected_sentence_pair + + self.assertEqual(de_sentence, expected_de_sentence) + self.assertEqual(en_sentence, expected_en_sentence) + + assert_nth_pair_is_equal(0, (['David', 'Gallo', ':', 'Das', 'ist', 'Bill', 'Lange', '.', 'Ich', 'bin', 'Dave', 'Gallo', '.', '\n'], + ['David', 'Gallo', ':', 'This', 'is', 'Bill', 'Lange', '.', 'I', "'m", 'Dave', 'Gallo', '.', '\n'])) + assert_nth_pair_is_equal(10, (['Die', 'meisten', 'Tiere', 'leben', 'in', 'den', 'Ozeanen', '.', '\n'], + ['Most', 'of', 'the', 'animals', 'are', 'in', 'the', 'oceans', '.', '\n'])) + assert_nth_pair_is_equal(20, (['Es', 'ist', 'einer', 'meiner', 'Lieblinge', ',', 'weil', 'es', 'alle', 'möglichen', 'Funktionsteile', 'hat', '.', '\n'], + ['It', "'s", 'one', 'of', 'my', 'favorites', ',', 'because', 'it', "'s", 'got', 'all', 'sorts', 'of', 'working', 'parts', '.', '\n'])) + datafile = os.path.join(self.project_root, ".data", "2016-01.tgz") + conditional_remove(datafile) + def test_multi30k(self): from torchtext.experimental.datasets import Multi30k # smoke test to ensure multi30k works properly diff --git a/torchtext/experimental/datasets/raw/translation.py b/torchtext/experimental/datasets/raw/translation.py index 3b74c58421..25960bead8 100644 --- a/torchtext/experimental/datasets/raw/translation.py +++ b/torchtext/experimental/datasets/raw/translation.py @@ -62,7 +62,7 @@ 'WMT14': 'https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8', 'IWSLT': - 'https://wit3.fbk.eu/archive/2016-01//texts/{}/{}/{}.tgz' + 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8' } @@ -125,14 +125,26 @@ def _setup_datasets(dataset_name, src_eval, tgt_eval = valid_filenames src_test, tgt_test = test_filenames - extracted_files = [] + extracted_files = [] # list of paths to the extracted files if isinstance(URLS[dataset_name], list): for idx, f in enumerate(URLS[dataset_name]): - dataset_tar = download_from_url(f, root=root, hash_value=MD5[dataset_name][idx], hash_type='md5') + dataset_tar = download_from_url( + f, root=root, hash_value=MD5[dataset_name][idx], hash_type='md5') extracted_files.extend(extract_archive(dataset_tar)) elif isinstance(URLS[dataset_name], str): dataset_tar = download_from_url(URLS[dataset_name], root=root, hash_value=MD5[dataset_name], hash_type='md5') - extracted_files.extend(extract_archive(dataset_tar)) + extracted_dataset_tar = extract_archive(dataset_tar) + if dataset_name == 'IWSLT': + # IWSLT dataset's url downloads a multilingual tgz. + # We need to take an extra step to pick out the specific language pair from it. + src_language = train_filenames[0].split(".")[-1] + tgt_language = train_filenames[1].split(".")[-1] + languages = "-".join([src_language, tgt_language]) + iwslt_tar = '.data/2016-01/texts/{}/{}/{}.tgz' + iwslt_tar = iwslt_tar.format( + src_language, tgt_language, languages) + extracted_dataset_tar = extract_archive(iwslt_tar) + extracted_files.extend(extracted_dataset_tar) else: raise ValueError( "URLS for {} has to be in a form or list or string".format( @@ -418,10 +430,6 @@ def IWSLT(train_filenames=('train.de-en.de', 'train.de-en.en'), >>> from torchtext.experimental.datasets.raw import IWSLT >>> train_dataset, valid_dataset, test_dataset = IWSLT() """ - src_language = train_filenames[0].split(".")[-1] - tgt_language = train_filenames[1].split(".")[-1] - languages = "-".join([src_language, tgt_language]) - URLS["IWSLT"] = URLS["IWSLT"].format(src_language, tgt_language, languages) return _setup_datasets("IWSLT", train_filenames, valid_filenames, test_filenames, data_select, root) @@ -567,6 +575,6 @@ def WMT14(train_filenames=('train.tok.clean.bpe.32000.de', 'acb5ea26a577ceccfae6337181c31716', '873a377a348713d3ab84db1fb57cdede', '680816e0938fea5cf5331444bc09a4cf'], - 'IWSLT': '6ff9ab8ea16fb352597c2784e0391fa8', + 'IWSLT': 'c393ed3fc2a1b0f004b3331043f615ae', 'WMT14': '874ab6bbfe9c21ec987ed1b9347f95ec' }