From c418136fc45c2c7debef5eeda1d0a74ba401d10c Mon Sep 17 00:00:00 2001 From: Gary Lai Date: Tue, 29 Dec 2020 01:45:50 +0800 Subject: [PATCH 1/5] tweaks code to support new IWSLT url & adds IWSLT test --- test/data/test_builtin_datasets.py | 11 ++++++++ .../experimental/datasets/raw/translation.py | 26 ++++++++++++------- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 0e2d320f42..5a9f2e2a7d 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -162,6 +162,17 @@ def test_imdb(self): self._helper_test_func(len(test_iter), 25000, next(iter(test_iter))[1][:25], 'I love sci-fi and am will') del train_iter, test_iter + def test_iwslt(self): + from torchtext.experimental.datasets import IWSLT + from torchtext.data.utils import get_tokenizer + + src_tokenizer = get_tokenizer("spacy", language='de') + tgt_tokenizer = get_tokenizer("basic_english") + train_dataset, valid_dataset, test_dataset = IWSLT(tokenizer=(src_tokenizer, tgt_tokenizer)) + self.assertEqual(len(train_dataset),196884) + self.assertEqual(len(valid_dataset),993) + self.assertEqual(len(test_dataset),1305) + def test_multi30k(self): from torchtext.experimental.datasets import Multi30k # smoke test to ensure multi30k works properly diff --git a/torchtext/experimental/datasets/raw/translation.py b/torchtext/experimental/datasets/raw/translation.py index 3b74c58421..d6f5458cb2 100644 --- a/torchtext/experimental/datasets/raw/translation.py +++ b/torchtext/experimental/datasets/raw/translation.py @@ -62,7 +62,7 @@ 'WMT14': 'https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8', 'IWSLT': - 'https://wit3.fbk.eu/archive/2016-01//texts/{}/{}/{}.tgz' + 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8' } @@ -125,14 +125,26 @@ def _setup_datasets(dataset_name, src_eval, tgt_eval = valid_filenames src_test, tgt_test = test_filenames - extracted_files = [] + extracted_files = [] # list of paths to the extracted files if isinstance(URLS[dataset_name], list): for idx, f in enumerate(URLS[dataset_name]): - dataset_tar = download_from_url(f, root=root, hash_value=MD5[dataset_name][idx], hash_type='md5') + dataset_tar = download_from_url( + f, root=root, hash_value=MD5[dataset_name][idx], hash_type='md5') extracted_files.extend(extract_archive(dataset_tar)) elif isinstance(URLS[dataset_name], str): - dataset_tar = download_from_url(URLS[dataset_name], root=root, hash_value=MD5[dataset_name], hash_type='md5') - extracted_files.extend(extract_archive(dataset_tar)) + dataset_tar = download_from_url(URLS[dataset_name]) + extracted_dataset_tar = extract_archive(dataset_tar) + if dataset_name == 'IWSLT': + # IWSLT dataset's url downloads a multilingual tgz. + # We need to take an extra step to pick out the specific language pair from it. + src_language = train_filenames[0].split(".")[-1] + tgt_language = train_filenames[1].split(".")[-1] + languages = "-".join([src_language, tgt_language]) + iwslt_tar = '.data/2016-01/texts/{}/{}/{}.tgz' + iwslt_tar = iwslt_tar.format( + src_language, tgt_language, languages) + extracted_dataset_tar = extract_archive(iwslt_tar) + extracted_files.extend(extracted_dataset_tar) else: raise ValueError( "URLS for {} has to be in a form or list or string".format( @@ -418,10 +430,6 @@ def IWSLT(train_filenames=('train.de-en.de', 'train.de-en.en'), >>> from torchtext.experimental.datasets.raw import IWSLT >>> train_dataset, valid_dataset, test_dataset = IWSLT() """ - src_language = train_filenames[0].split(".")[-1] - tgt_language = train_filenames[1].split(".")[-1] - languages = "-".join([src_language, tgt_language]) - URLS["IWSLT"] = URLS["IWSLT"].format(src_language, tgt_language, languages) return _setup_datasets("IWSLT", train_filenames, valid_filenames, test_filenames, data_select, root) From 5867bdfad9a1aa2c6542b06613f5a590ad68171b Mon Sep 17 00:00:00 2001 From: Gary Lai Date: Wed, 30 Dec 2020 00:59:24 +0800 Subject: [PATCH 2/5] adds iwslt ci test --- test/data/test_builtin_datasets.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 5a9f2e2a7d..960d0e63dd 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -169,10 +169,33 @@ def test_iwslt(self): src_tokenizer = get_tokenizer("spacy", language='de') tgt_tokenizer = get_tokenizer("basic_english") train_dataset, valid_dataset, test_dataset = IWSLT(tokenizer=(src_tokenizer, tgt_tokenizer)) + self.assertEqual(len(train_dataset),196884) self.assertEqual(len(valid_dataset),993) self.assertEqual(len(test_dataset),1305) + de_vocab, en_vocab = train_dataset.get_vocab() + + def assert_nth_pair_is_equal(n, expected_sentence_pair): + de_sentence = [de_vocab.itos[index] for index in train_dataset[n][0]] + en_sentence = [en_vocab.itos[index] for index in train_dataset[n][1]] + expected_de_sentence, expected_en_sentence = expected_sentence_pair + + self.assertEqual(de_sentence, expected_de_sentence) + self.assertEqual(en_sentence,expected_en_sentence) + + assert_nth_pair_is_equal(0, (['David', 'Gallo', ':', 'Das', 'ist', 'Bill', 'Lange', '.', 'Ich', 'bin', 'Dave', 'Gallo', '.', '\n'], + ['david', 'gallo', 'this', 'is', 'bill', 'lange', '.', 'i', "'", 'm', 'dave', 'gallo', '.'])) + assert_nth_pair_is_equal(10, (['Die', 'meisten', 'Tiere', 'leben', 'in', 'den', 'Ozeanen', '.', '\n'], + ['most', 'of', 'the', 'animals', 'are', 'in', 'the', 'oceans', '.'])) + assert_nth_pair_is_equal(20, (['Es', 'ist', 'einer', 'meiner', 'Lieblinge', ',', 'weil', 'es', 'alle', 'möglichen', 'Funktionsteile', 'hat', '.', '\n'], + ['it', "'", 's', 'one', 'of', 'my', 'favorites', ',', 'because', 'it', "'", 's', 'got', 'all', 'sorts', 'of', 'working', 'parts', '.'])) + + datafile = os.path.join(self.project_root, ".data", "2016-01.tgz") + conditional_remove(datafile) + + + def test_multi30k(self): from torchtext.experimental.datasets import Multi30k # smoke test to ensure multi30k works properly From 5ac9a0e17f80802b42cdf5dd864add79086af672 Mon Sep 17 00:00:00 2001 From: Gary Lai Date: Wed, 30 Dec 2020 01:45:17 +0800 Subject: [PATCH 3/5] switches to the default tokenizer in test instead of spacy --- test/data/test_builtin_datasets.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 960d0e63dd..a5dc70c724 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -164,11 +164,8 @@ def test_imdb(self): def test_iwslt(self): from torchtext.experimental.datasets import IWSLT - from torchtext.data.utils import get_tokenizer - src_tokenizer = get_tokenizer("spacy", language='de') - tgt_tokenizer = get_tokenizer("basic_english") - train_dataset, valid_dataset, test_dataset = IWSLT(tokenizer=(src_tokenizer, tgt_tokenizer)) + train_dataset, valid_dataset, test_dataset = IWSLT() self.assertEqual(len(train_dataset),196884) self.assertEqual(len(valid_dataset),993) @@ -185,12 +182,11 @@ def assert_nth_pair_is_equal(n, expected_sentence_pair): self.assertEqual(en_sentence,expected_en_sentence) assert_nth_pair_is_equal(0, (['David', 'Gallo', ':', 'Das', 'ist', 'Bill', 'Lange', '.', 'Ich', 'bin', 'Dave', 'Gallo', '.', '\n'], - ['david', 'gallo', 'this', 'is', 'bill', 'lange', '.', 'i', "'", 'm', 'dave', 'gallo', '.'])) + ['David', 'Gallo', ':', 'This', 'is', 'Bill', 'Lange', '.', 'I', "'m", 'Dave', 'Gallo', '.', '\n'])) assert_nth_pair_is_equal(10, (['Die', 'meisten', 'Tiere', 'leben', 'in', 'den', 'Ozeanen', '.', '\n'], - ['most', 'of', 'the', 'animals', 'are', 'in', 'the', 'oceans', '.'])) + ['Most', 'of', 'the', 'animals', 'are', 'in', 'the', 'oceans', '.', '\n'])) assert_nth_pair_is_equal(20, (['Es', 'ist', 'einer', 'meiner', 'Lieblinge', ',', 'weil', 'es', 'alle', 'möglichen', 'Funktionsteile', 'hat', '.', '\n'], - ['it', "'", 's', 'one', 'of', 'my', 'favorites', ',', 'because', 'it', "'", 's', 'got', 'all', 'sorts', 'of', 'working', 'parts', '.'])) - + ['It', "'s", 'one', 'of', 'my', 'favorites', ',', 'because', 'it', "'s", 'got', 'all', 'sorts', 'of', 'working', 'parts', '.', '\n'])) datafile = os.path.join(self.project_root, ".data", "2016-01.tgz") conditional_remove(datafile) From 9e7121eeb17ff38f176cc425a0f1de1b42d92696 Mon Sep 17 00:00:00 2001 From: Gary Lai Date: Wed, 30 Dec 2020 02:16:07 +0800 Subject: [PATCH 4/5] fixes style errors from flake8 --- test/data/test_builtin_datasets.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index a5dc70c724..2c8190c44a 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -167,9 +167,9 @@ def test_iwslt(self): train_dataset, valid_dataset, test_dataset = IWSLT() - self.assertEqual(len(train_dataset),196884) - self.assertEqual(len(valid_dataset),993) - self.assertEqual(len(test_dataset),1305) + self.assertEqual(len(train_dataset), 196884) + self.assertEqual(len(valid_dataset), 993) + self.assertEqual(len(test_dataset), 1305) de_vocab, en_vocab = train_dataset.get_vocab() @@ -179,19 +179,17 @@ def assert_nth_pair_is_equal(n, expected_sentence_pair): expected_de_sentence, expected_en_sentence = expected_sentence_pair self.assertEqual(de_sentence, expected_de_sentence) - self.assertEqual(en_sentence,expected_en_sentence) + self.assertEqual(en_sentence, expected_en_sentence) assert_nth_pair_is_equal(0, (['David', 'Gallo', ':', 'Das', 'ist', 'Bill', 'Lange', '.', 'Ich', 'bin', 'Dave', 'Gallo', '.', '\n'], - ['David', 'Gallo', ':', 'This', 'is', 'Bill', 'Lange', '.', 'I', "'m", 'Dave', 'Gallo', '.', '\n'])) + ['David', 'Gallo', ':', 'This', 'is', 'Bill', 'Lange', '.', 'I', "'m", 'Dave', 'Gallo', '.', '\n'])) assert_nth_pair_is_equal(10, (['Die', 'meisten', 'Tiere', 'leben', 'in', 'den', 'Ozeanen', '.', '\n'], - ['Most', 'of', 'the', 'animals', 'are', 'in', 'the', 'oceans', '.', '\n'])) + ['Most', 'of', 'the', 'animals', 'are', 'in', 'the', 'oceans', '.', '\n'])) assert_nth_pair_is_equal(20, (['Es', 'ist', 'einer', 'meiner', 'Lieblinge', ',', 'weil', 'es', 'alle', 'möglichen', 'Funktionsteile', 'hat', '.', '\n'], - ['It', "'s", 'one', 'of', 'my', 'favorites', ',', 'because', 'it', "'s", 'got', 'all', 'sorts', 'of', 'working', 'parts', '.', '\n'])) + ['It', "'s", 'one', 'of', 'my', 'favorites', ',', 'because', 'it', "'s", 'got', 'all', 'sorts', 'of', 'working', 'parts', '.', '\n'])) datafile = os.path.join(self.project_root, ".data", "2016-01.tgz") conditional_remove(datafile) - - def test_multi30k(self): from torchtext.experimental.datasets import Multi30k # smoke test to ensure multi30k works properly From 75ac292a43ed1a6a980834d03f359a15d69666a7 Mon Sep 17 00:00:00 2001 From: Gary Lai Date: Wed, 30 Dec 2020 02:47:38 +0800 Subject: [PATCH 5/5] updates md5 of iwslt and fixes download_from_url hash check --- torchtext/experimental/datasets/raw/translation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/experimental/datasets/raw/translation.py b/torchtext/experimental/datasets/raw/translation.py index d6f5458cb2..25960bead8 100644 --- a/torchtext/experimental/datasets/raw/translation.py +++ b/torchtext/experimental/datasets/raw/translation.py @@ -132,7 +132,7 @@ def _setup_datasets(dataset_name, f, root=root, hash_value=MD5[dataset_name][idx], hash_type='md5') extracted_files.extend(extract_archive(dataset_tar)) elif isinstance(URLS[dataset_name], str): - dataset_tar = download_from_url(URLS[dataset_name]) + dataset_tar = download_from_url(URLS[dataset_name], root=root, hash_value=MD5[dataset_name], hash_type='md5') extracted_dataset_tar = extract_archive(dataset_tar) if dataset_name == 'IWSLT': # IWSLT dataset's url downloads a multilingual tgz. @@ -575,6 +575,6 @@ def WMT14(train_filenames=('train.tok.clean.bpe.32000.de', 'acb5ea26a577ceccfae6337181c31716', '873a377a348713d3ab84db1fb57cdede', '680816e0938fea5cf5331444bc09a4cf'], - 'IWSLT': '6ff9ab8ea16fb352597c2784e0391fa8', + 'IWSLT': 'c393ed3fc2a1b0f004b3331043f615ae', 'WMT14': '874ab6bbfe9c21ec987ed1b9347f95ec' }