From 284f131967a8e2408d5a8068665a6ebab835cd53 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Fri, 28 Jan 2022 09:30:24 -0500 Subject: [PATCH 1/8] migrate IWSLT2017 to datapipes. --- torchtext/datasets/iwslt2017.py | 134 ++++++++++++++++---------------- 1 file changed, 66 insertions(+), 68 deletions(-) diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index 051f1377d5..30b12253bc 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -1,19 +1,21 @@ +from torchtext._internal.module_utils import is_module_available + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper, FileLister + import os -from torchtext.utils import (download_from_url, extract_archive) from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _clean_xml_file, _clean_tags_file, - _read_text_iterator, ) from torchtext.data.datasets_utils import _create_dataset_directory -SUPPORTED_DATASETS = { +URL = 'https://drive.google.com/u/0/uc?id=12ycYSzLIG253AFN35Y6qoyf9wtkOjakp' +_PATH = '2017-01-trnmted.tgz' +MD5 = 'aca701032b1c4411afc4d9fa367796ba' - 'URL': 'https://drive.google.com/u/0/uc?id=12ycYSzLIG253AFN35Y6qoyf9wtkOjakp', - '_PATH': '2017-01-trnmted.tgz', - 'MD5': 'aca701032b1c4411afc4d9fa367796ba', +SUPPORTED_DATASETS = { 'valid_test': ['dev2010', 'tst2010'], 'language_pair': { 'en': ['nl', 'de', 'it', 'ro'], @@ -25,9 +27,6 @@ 'year': 17, } -URL = SUPPORTED_DATASETS['URL'] -MD5 = SUPPORTED_DATASETS['MD5'] - NUM_LINES = { 'train': { 'train': { @@ -97,28 +96,11 @@ } } - -def _construct_filenames(filename, languages): - filenames = [] - for lang in languages: - filenames.append(filename + "." + lang) - return filenames - - -def _construct_filepaths(paths, src_filename, tgt_filename): - src_path = None - tgt_path = None - for p in paths: - src_path = p if src_filename in p else src_path - tgt_path = p if tgt_filename in p else tgt_path - return (src_path, tgt_path) - - DATASET_NAME = "IWSLT2017" @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'valid', 'test')) +@_wrap_split_argument(("train", "valid", "test")) def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en')): """IWSLT2017 dataset @@ -154,16 +136,12 @@ def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de >>> src_sentence, tgt_sentence = next(train_iter) """ + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") valid_set = 'dev2010' test_set = 'tst2010' - num_lines_set_identifier = { - 'train': 'train', - 'valid': valid_set, - 'test': test_set - } - if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair))) @@ -190,45 +168,65 @@ def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de src_eval, tgt_eval = valid_filenames src_test, tgt_test = test_filenames - extracted_files = [] # list of paths to the extracted files - dataset_tar = download_from_url(SUPPORTED_DATASETS['URL'], root=root, hash_value=SUPPORTED_DATASETS['MD5'], path=os.path.join(root, SUPPORTED_DATASETS['_PATH']), hash_type='md5') - extracted_dataset_tar = extract_archive(dataset_tar) - # IWSLT dataset's url downloads a multilingual tgz. - # We need to take an extra step to pick out the specific language pair from it. + url_dp = IterableWrapper([URL]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _PATH), + hash_dict={os.path.join(root, _PATH): MD5}, + hash_type="md5" + ) + cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_compressed_dp = FileOpener(cache_compressed_dp, mode="b") src_language = train_filenames[0].split(".")[-1] tgt_language = train_filenames[1].split(".")[-1] - iwslt_tar = os.path.join(root, SUPPORTED_DATASETS['_PATH'].split(".")[0], 'texts/DeEnItNlRo/DeEnItNlRo', 'DeEnItNlRo-DeEnItNlRo.tgz') - extracted_dataset_tar = extract_archive(iwslt_tar) - extracted_files.extend(extracted_dataset_tar) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + # Convert /path/to/downloaded/foo.tgz to /path/to/downloaded/foo/rest/of/path + filepath_fn=lambda x: os.path.join(root, os.path.splitext(_PATH)[0], "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo.tgz") + ) + cache_decompressed_dp = cache_decompressed_dp.read_from_tar() + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - # Clean the xml and tag file in the archives - file_archives = [] - for fname in extracted_files: + def clean_files(fname): if 'xml' in fname: _clean_xml_file(fname) - file_archives.append(os.path.splitext(fname)[0]) + return os.path.splitext(fname)[0] elif "tags" in fname: _clean_tags_file(fname) - file_archives.append(fname.replace('.tags', '')) - else: - file_archives.append(fname) - - data_filenames = { - "train": _construct_filepaths(file_archives, src_train, tgt_train), - "valid": _construct_filepaths(file_archives, src_eval, tgt_eval), - "test": _construct_filepaths(file_archives, src_test, tgt_test) - } - for key in data_filenames: - if len(data_filenames[key]) == 0 or data_filenames[key] is None: - raise FileNotFoundError( - "Files are not found for data type {}".format(key)) - - src_data_iter = _read_text_iterator(data_filenames[split][0]) - tgt_data_iter = _read_text_iterator(data_filenames[split][1]) - - def _iter(src_data_iter, tgt_data_iter): - for item in zip(src_data_iter, tgt_data_iter): - yield item - - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split][num_lines_set_identifier[split]][tuple(sorted(language_pair))], _iter(src_data_iter, tgt_data_iter)) + return fname.replace('.tags', '') + return fname + + cache_decompressed_dp = cache_decompressed_dp.on_disk_cache( + # Convert /path/to/downloaded/foo.tgz to /path/to/downloaded/foo/rest/of/path + filepath_fn=lambda x: os.path.join(root, os.path.splitext(_PATH)[0], "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo/") + ) + + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar() + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=False) + + def get_filepath(split, lang): + return { + src_language: { + "train": src_train, + "valid": src_eval, + "test": src_test, + }, + tgt_language: { + "train": tgt_train, + "valid": tgt_eval, + "test": tgt_test, + } + }[lang][split] + + cache_decompressed_dp = cache_decompressed_dp.flatmap(FileLister) + cleaned_cache_decompressed_dp = cache_decompressed_dp.map(clean_files) + + # Filters out irrelevant file given the filename templates filled with split and src/tgt codes + src_data_dp = cleaned_cache_decompressed_dp.filter(lambda x: get_filepath(split, src_language) in x) + tgt_data_dp = cleaned_cache_decompressed_dp.filter(lambda x: get_filepath(split, tgt_language) in x) + + tgt_data_dp = FileOpener(tgt_data_dp, mode="r") + src_data_dp = FileOpener(src_data_dp, mode="r") + + src_lines = src_data_dp.readlines(return_path=False, strip_newline=False) + tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False) + return src_lines.zip(tgt_lines) From eb1869fe463a4d7d107ba37e721eedc777e566b0 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sun, 30 Jan 2022 08:05:41 -0500 Subject: [PATCH 2/8] refactor IWSLT2017 to use feedback from IWSLT2016. --- torchtext/datasets/iwslt2017.py | 146 +++++++++++++++++++++----------- 1 file changed, 95 insertions(+), 51 deletions(-) diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index 30b12253bc..32fbaea528 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -5,11 +5,10 @@ import os from torchtext.data.datasets_utils import ( + _clean_files, + _create_dataset_directory, _wrap_split_argument, - _clean_xml_file, - _clean_tags_file, ) -from torchtext.data.datasets_utils import _create_dataset_directory URL = 'https://drive.google.com/u/0/uc?id=12ycYSzLIG253AFN35Y6qoyf9wtkOjakp' _PATH = '2017-01-trnmted.tgz' @@ -99,6 +98,17 @@ DATASET_NAME = "IWSLT2017" +def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): + cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath) + cache_inner_decompressed_dp = FileOpener(cache_inner_decompressed_dp, mode="b").read_from_tar() + cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter( + lambda x: os.path.basename(uncleaned_filename) in x[0]) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.map( + lambda x: _clean_files(full_filepath, x[0], x[1])) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + return cache_inner_decompressed_dp + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en')): @@ -168,65 +178,99 @@ def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de src_eval, tgt_eval = valid_filenames src_test, tgt_test = test_filenames + uncleaned_train_filenames = ('train.tags.{}-{}.{}'.format(src_language, tgt_language, src_language), + 'train.tags.{}-{}.{}'.format(src_language, tgt_language, tgt_language)) + uncleaed_valid_filenames = ( + 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, + src_language), + 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, + tgt_language)) + uncleaned_test_filenames = ( + 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, + src_language), + 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, + tgt_language)) + + uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames + uncleaned_src_eval, uncleaned_tgt_eval = uncleaed_valid_filenames + uncleaned_src_test, uncleaned_tgt_test = uncleaned_test_filenames + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" ) - cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_compressed_dp = FileOpener(cache_compressed_dp, mode="b") - src_language = train_filenames[0].split(".")[-1] - tgt_language = train_filenames[1].split(".")[-1] - - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - # Convert /path/to/downloaded/foo.tgz to /path/to/downloaded/foo/rest/of/path - filepath_fn=lambda x: os.path.join(root, os.path.splitext(_PATH)[0], "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo.tgz") + cache_compressed_dp = GDriveReader(cache_compressed_dp) + cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + languages = "-".join([src_language, tgt_language]) + + # We create the whole filepath here, but only check for the literal filename in the filter + # because we're lazily extracting from the outer tarfile. Thus, + # /root/2017-01-trnmted/texts/.../src-tgt.tgz will never be in + # /root/2017-01-trnmted.tgz/texts/.../src-tgt.tgz + inner_iwslt_tar = os.path.join( + root, + os.path.splitext(_PATH)[0], + "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo.tgz" ) - cache_decompressed_dp = cache_decompressed_dp.read_from_tar() + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter( + lambda x: os.path.basename(inner_iwslt_tar) in x[0]) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - def clean_files(fname): - if 'xml' in fname: - _clean_xml_file(fname) - return os.path.splitext(fname)[0] - elif "tags" in fname: - _clean_tags_file(fname) - return fname.replace('.tags', '') - return fname - - cache_decompressed_dp = cache_decompressed_dp.on_disk_cache( - # Convert /path/to/downloaded/foo.tgz to /path/to/downloaded/foo/rest/of/path - filepath_fn=lambda x: os.path.join(root, os.path.splitext(_PATH)[0], "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo/") - ) + file_path_by_lang_and_split = { + src_language: { + "train": src_train, + "valid": src_eval, + "test": src_test, + }, + tgt_language: { + "train": tgt_train, + "valid": tgt_eval, + "test": tgt_test, + } + } - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar() - cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=False) - - def get_filepath(split, lang): - return { - src_language: { - "train": src_train, - "valid": src_eval, - "test": src_test, - }, - tgt_language: { - "train": tgt_train, - "valid": tgt_eval, - "test": tgt_test, - } - }[lang][split] - - cache_decompressed_dp = cache_decompressed_dp.flatmap(FileLister) - cleaned_cache_decompressed_dp = cache_decompressed_dp.map(clean_files) - - # Filters out irrelevant file given the filename templates filled with split and src/tgt codes - src_data_dp = cleaned_cache_decompressed_dp.filter(lambda x: get_filepath(split, src_language) in x) - tgt_data_dp = cleaned_cache_decompressed_dp.filter(lambda x: get_filepath(split, tgt_language) in x) - - tgt_data_dp = FileOpener(tgt_data_dp, mode="r") - src_data_dp = FileOpener(src_data_dp, mode="r") + uncleaned_filenames = { + src_language: { + "train": uncleaned_src_train, + "valid": uncleaned_src_eval, + "test": uncleaned_src_test, + }, + tgt_language: { + "train": uncleaned_tgt_train, + "valid": uncleaned_tgt_eval, + "test": uncleaned_tgt_test, + } + } + + src_filename = file_path_by_lang_and_split[src_language][split] + uncleaned_src_filename = uncleaned_filenames[src_language][split] + + # We create the whole filepath here, but only check for the literal filename in the filter + # because we're lazily extracting from the outer tarfile. + full_src_filepath = os.path.join(root, "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo", src_filename) + + cache_inner_src_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_src_filepath, + uncleaned_src_filename) + + tgt_filename = file_path_by_lang_and_split[tgt_language][split] + uncleaned_tgt_filename = uncleaned_filenames[tgt_language][split] + + # We create the whole filepath here, but only check for the literal filename in the filter + # because we're lazily extracting from the outer tarfile. + full_tgt_filepath = os.path.join(root, "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo", tgt_filename) + + cache_inner_tgt_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_tgt_filepath, + uncleaned_tgt_filename) + + tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="r") + src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="r") src_lines = src_data_dp.readlines(return_path=False, strip_newline=False) tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False) + return src_lines.zip(tgt_lines) From d081f37937a818eec7114dad7d466e5a76f9bca3 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sun, 30 Jan 2022 08:06:50 -0500 Subject: [PATCH 3/8] remove unused import. --- torchtext/datasets/iwslt2017.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index 32fbaea528..a0767246fc 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -1,7 +1,7 @@ from torchtext._internal.module_utils import is_module_available if is_module_available("torchdata"): - from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper, FileLister + from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper import os from torchtext.data.datasets_utils import ( From f211fb094de6c8453938b5f3767033b0df18bf2c Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sun, 30 Jan 2022 09:52:09 -0500 Subject: [PATCH 4/8] fix flake. --- torchtext/datasets/iwslt2017.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index a0767246fc..891ef1962f 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -178,18 +178,18 @@ def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de src_eval, tgt_eval = valid_filenames src_test, tgt_test = test_filenames - uncleaned_train_filenames = ('train.tags.{}-{}.{}'.format(src_language, tgt_language, src_language), - 'train.tags.{}-{}.{}'.format(src_language, tgt_language, tgt_language)) + uncleaned_train_filenames = ( + 'train.tags.{}-{}.{}'.format(src_language, tgt_language, src_language), + 'train.tags.{}-{}.{}'.format(src_language, tgt_language, tgt_language) + ) uncleaed_valid_filenames = ( - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, - src_language), - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, - tgt_language)) + 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), + 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language) + ) uncleaned_test_filenames = ( - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, - src_language), - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, - tgt_language)) + 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), + 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language) + ) uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames uncleaned_src_eval, uncleaned_tgt_eval = uncleaed_valid_filenames @@ -204,8 +204,6 @@ def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de cache_compressed_dp = GDriveReader(cache_compressed_dp) cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) - languages = "-".join([src_language, tgt_language]) - # We create the whole filepath here, but only check for the literal filename in the filter # because we're lazily extracting from the outer tarfile. Thus, # /root/2017-01-trnmted/texts/.../src-tgt.tgz will never be in From 468e91bbddd11d0b604be0938b257a5c8aca1483 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sun, 30 Jan 2022 09:56:13 -0500 Subject: [PATCH 5/8] fix typo in comment. --- torchtext/datasets/iwslt2017.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index 891ef1962f..29e0708700 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -206,8 +206,8 @@ def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de # We create the whole filepath here, but only check for the literal filename in the filter # because we're lazily extracting from the outer tarfile. Thus, - # /root/2017-01-trnmted/texts/.../src-tgt.tgz will never be in - # /root/2017-01-trnmted.tgz/texts/.../src-tgt.tgz + # /root/2017-01-trnmted/texts/.../DeEnItNlRo-DeEnItNlRo.tgz will never be in + # /root/2017-01-trnmted.tgz/texts/.../DeEnItNlRo-DeEnItNlRo.tgz inner_iwslt_tar = os.path.join( root, os.path.splitext(_PATH)[0], From 54a663a8d959676c70626678183629c398551019 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Mon, 31 Jan 2022 07:40:23 -0500 Subject: [PATCH 6/8] add TODOs to IWSLT datasets. --- torchtext/datasets/iwslt2016.py | 2 ++ torchtext/datasets/iwslt2017.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index 4ddaba49ef..6d7a0f9102 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -127,6 +127,8 @@ DATASET_NAME = "IWSLT2016" +# TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to +# avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath) cache_inner_decompressed_dp = FileOpener(cache_inner_decompressed_dp, mode="b").read_from_tar() diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index 29e0708700..67b3345685 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -98,6 +98,8 @@ DATASET_NAME = "IWSLT2017" +# TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to +# avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath) cache_inner_decompressed_dp = FileOpener(cache_inner_decompressed_dp, mode="b").read_from_tar() From f9ea533ae5cf7b44868968e0b363ccc0834fbc27 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Mon, 31 Jan 2022 13:21:50 -0500 Subject: [PATCH 7/8] refactor common code out of IWSLTs and convert single quotes to double. --- torchtext/data/datasets_utils.py | 64 ++++++++ torchtext/datasets/iwslt2016.py | 269 ++++++++++++------------------- torchtext/datasets/iwslt2017.py | 227 ++++++++++---------------- 3 files changed, 260 insertions(+), 300 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index fa363150ac..844e9c9245 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -392,6 +392,70 @@ def __str__(self): return self.description +def _generate_iwslt_files_for_lang_and_split(year, src_language, tgt_language, valid_set, test_set): + train_filenames = ( + "train.{}-{}.{}".format(src_language, tgt_language, src_language), + "train.{}-{}.{}".format(src_language, tgt_language, tgt_language) + ) + valid_filenames = ( + "IWSLT{}.TED.{}.{}-{}.{}".format(year, valid_set, src_language, tgt_language, src_language), + "IWSLT{}.TED.{}.{}-{}.{}".format(year, valid_set, src_language, tgt_language, tgt_language) + ) + test_filenames = ( + "IWSLT{}.TED.{}.{}-{}.{}".format(year, test_set, src_language, tgt_language, src_language), + "IWSLT{}.TED.{}.{}-{}.{}".format(year, test_set, src_language, tgt_language, tgt_language) + ) + + src_train, tgt_train = train_filenames + src_eval, tgt_eval = valid_filenames + src_test, tgt_test = test_filenames + + uncleaned_train_filenames = ( + "train.tags.{}-{}.{}".format(src_language, tgt_language, src_language), + "train.tags.{}-{}.{}".format(src_language, tgt_language, tgt_language) + ) + uncleaed_valid_filenames = ( + "IWSLT{}.TED.{}.{}-{}.{}.xml".format(year, valid_set, src_language, tgt_language, src_language), + "IWSLT{}.TED.{}.{}-{}.{}.xml".format(year, valid_set, src_language, tgt_language, tgt_language) + ) + uncleaned_test_filenames = ( + "IWSLT{}.TED.{}.{}-{}.{}.xml".format(year, test_set, src_language, tgt_language, src_language), + "IWSLT{}.TED.{}.{}-{}.{}.xml".format(year, test_set, src_language, tgt_language, tgt_language) + ) + + uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames + uncleaned_src_eval, uncleaned_tgt_eval = uncleaed_valid_filenames + uncleaned_src_test, uncleaned_tgt_test = uncleaned_test_filenames + + file_path_by_lang_and_split = { + src_language: { + "train": src_train, + "valid": src_eval, + "test": src_test, + }, + tgt_language: { + "train": tgt_train, + "valid": tgt_eval, + "test": tgt_test, + } + } + + uncleaned_filenames_by_lang_and_split = { + src_language: { + "train": uncleaned_src_train, + "valid": uncleaned_src_eval, + "test": uncleaned_src_test, + }, + tgt_language: { + "train": uncleaned_tgt_train, + "valid": uncleaned_tgt_eval, + "test": uncleaned_tgt_test, + } + } + + return file_path_by_lang_and_split, uncleaned_filenames_by_lang_and_split + + @functional_datapipe("read_squad") class _ParseSQuADQAData(IterDataPipe): r"""Iterable DataPipe to parse the contents of a stream of JSON objects diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index 6d7a0f9102..3073a9c8c4 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -8,120 +8,121 @@ _wrap_split_argument, _clean_files, _create_dataset_directory, + _generate_iwslt_files_for_lang_and_split, ) -URL = 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8' +URL = "https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8" -_PATH = '2016-01.tgz' +_PATH = "2016-01.tgz" -MD5 = 'c393ed3fc2a1b0f004b3331043f615ae' +MD5 = "c393ed3fc2a1b0f004b3331043f615ae" SUPPORTED_DATASETS = { - 'valid_test': ['dev2010', 'tst2010', 'tst2011', 'tst2012', 'tst2013', 'tst2014'], - 'language_pair': { - 'en': ['ar', 'de', 'fr', 'cs'], - 'ar': ['en'], - 'fr': ['en'], - 'de': ['en'], - 'cs': ['en'], + "valid_test": ["dev2010", "tst2010", "tst2011", "tst2012", "tst2013", "tst2014"], + "language_pair": { + "en": ["ar", "de", "fr", "cs"], + "ar": ["en"], + "fr": ["en"], + "de": ["en"], + "cs": ["en"], }, - 'year': 16, + "year": 16, } NUM_LINES = { - 'train': { - 'train': { - ('ar', 'en'): 224126, - ('de', 'en'): 196884, - ('en', 'fr'): 220400, - ('cs', 'en'): 114390 + "train": { + "train": { + ("ar", "en"): 224126, + ("de", "en"): 196884, + ("en", "fr"): 220400, + ("cs", "en"): 114390 } }, - 'valid': { - 'dev2010': { - ('ar', 'en'): 887, - ('de', 'en'): 887, - ('en', 'fr'): 887, - ('cs', 'en'): 480 + "valid": { + "dev2010": { + ("ar", "en"): 887, + ("de", "en"): 887, + ("en", "fr"): 887, + ("cs", "en"): 480 }, - 'tst2010': { - ('ar', 'en'): 1569, - ('de', 'en'): 1565, - ('en', 'fr'): 1664, - ('cs', 'en'): 1511 + "tst2010": { + ("ar", "en"): 1569, + ("de", "en"): 1565, + ("en", "fr"): 1664, + ("cs", "en"): 1511 }, - 'tst2011': { - ('ar', 'en'): 1199, - ('de', 'en'): 1433, - ('en', 'fr'): 818, - ('cs', 'en'): 1013 + "tst2011": { + ("ar", "en"): 1199, + ("de", "en"): 1433, + ("en", "fr"): 818, + ("cs", "en"): 1013 }, - 'tst2012': { - ('ar', 'en'): 1702, - ('de', 'en'): 1700, - ('en', 'fr'): 1124, - ('cs', 'en'): 1385 + "tst2012": { + ("ar", "en"): 1702, + ("de", "en"): 1700, + ("en", "fr"): 1124, + ("cs", "en"): 1385 }, - 'tst2013': { - ('ar', 'en'): 1169, - ('de', 'en'): 993, - ('en', 'fr'): 1026, - ('cs', 'en'): 1327 + "tst2013": { + ("ar", "en"): 1169, + ("de", "en"): 993, + ("en", "fr"): 1026, + ("cs", "en"): 1327 }, - 'tst2014': { - ('ar', 'en'): 1107, - ('de', 'en'): 1305, - ('en', 'fr'): 1305 + "tst2014": { + ("ar", "en"): 1107, + ("de", "en"): 1305, + ("en", "fr"): 1305 } }, - 'test': { - 'dev2010': { - ('ar', 'en'): 887, - ('de', 'en'): 887, - ('en', 'fr'): 887, - ('cs', 'en'): 480 + "test": { + "dev2010": { + ("ar", "en"): 887, + ("de", "en"): 887, + ("en", "fr"): 887, + ("cs", "en"): 480 }, - 'tst2010': { - ('ar', 'en'): 1569, - ('de', 'en'): 1565, - ('en', 'fr'): 1664, - ('cs', 'en'): 1511 + "tst2010": { + ("ar", "en"): 1569, + ("de", "en"): 1565, + ("en", "fr"): 1664, + ("cs", "en"): 1511 }, - 'tst2011': { - ('ar', 'en'): 1199, - ('de', 'en'): 1433, - ('en', 'fr'): 818, - ('cs', 'en'): 1013 + "tst2011": { + ("ar", "en"): 1199, + ("de", "en"): 1433, + ("en", "fr"): 818, + ("cs", "en"): 1013 }, - 'tst2012': { - ('ar', 'en'): 1702, - ('de', 'en'): 1700, - ('en', 'fr'): 1124, - ('cs', 'en'): 1385 + "tst2012": { + ("ar", "en"): 1702, + ("de", "en"): 1700, + ("en", "fr"): 1124, + ("cs", "en"): 1385 }, - 'tst2013': { - ('ar', 'en'): 1169, - ('de', 'en'): 993, - ('en', 'fr'): 1026, - ('cs', 'en'): 1327 + "tst2013": { + ("ar", "en"): 1169, + ("de", "en"): 993, + ("en", "fr"): 1026, + ("cs", "en"): 1327 }, - 'tst2014': { - ('ar', 'en'): 1107, - ('de', 'en'): 1305, - ('en', 'fr'): 1305 + "tst2014": { + ("ar", "en"): 1107, + ("de", "en"): 1305, + ("en", "fr"): 1305 } } } SET_NOT_EXISTS = { - ('en', 'ar'): [], - ('en', 'de'): [], - ('en', 'fr'): [], - ('en', 'cs'): ['tst2014'], - ('ar', 'en'): [], - ('fr', 'en'): [], - ('de', 'en'): [], - ('cs', 'en'): ['tst2014'] + ("en", "ar"): [], + ("en", "de"): [], + ("en", "fr"): [], + ("en", "cs"): ["tst2014"], + ("ar", "en"): [], + ("fr", "en"): [], + ("de", "en"): [], + ("cs", "en"): ["tst2014"] } DATASET_NAME = "IWSLT2016" @@ -142,7 +143,7 @@ def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) -def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'): +def IWSLT2016(root=".data", split=("train", "valid", "test"), language_pair=("de", "en"), valid_set="tst2013", test_set="tst2014"): """IWSLT2016 dataset The available datasets include following: @@ -150,20 +151,20 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de **Language pairs**: +-----+-----+-----+-----+-----+-----+ - | |'en' |'fr' |'de' |'cs' |'ar' | + | |"en" |"fr" |"de" |"cs" |"ar" | +-----+-----+-----+-----+-----+-----+ - |'en' | | x | x | x | x | + |"en" | | x | x | x | x | +-----+-----+-----+-----+-----+-----+ - |'fr' | x | | | | | + |"fr" | x | | | | | +-----+-----+-----+-----+-----+-----+ - |'de' | x | | | | | + |"de" | x | | | | | +-----+-----+-----+-----+-----+-----+ - |'cs' | x | | | | | + |"cs" | x | | | | | +-----+-----+-----+-----+-----+-----+ - |'ar' | x | | | | | + |"ar" | x | | | | | +-----+-----+-----+-----+-----+-----+ - **valid/test sets**: ['dev2010', 'tst2010', 'tst2011', 'tst2012', 'tst2013', 'tst2014'] + **valid/test sets**: ["dev2010", "tst2010", "tst2011", "tst2012", "tst2013", "tst2014"] For additional details refer to source website: https://wit3.fbk.eu/2016-01 @@ -186,59 +187,29 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair))) - assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively' + assert (len(language_pair) == 2), "language_pair must contain only 2 elements: src and tgt language respectively" src_language, tgt_language = language_pair[0], language_pair[1] - if src_language not in SUPPORTED_DATASETS['language_pair']: + if src_language not in SUPPORTED_DATASETS["language_pair"]: raise ValueError("src_language '{}' is not valid. Supported source languages are {}". - format(src_language, list(SUPPORTED_DATASETS['language_pair']))) + format(src_language, list(SUPPORTED_DATASETS["language_pair"]))) - if tgt_language not in SUPPORTED_DATASETS['language_pair'][src_language]: + if tgt_language not in SUPPORTED_DATASETS["language_pair"][src_language]: raise ValueError("tgt_language '{}' is not valid for give src_language '{}'. Supported target language are {}". - format(tgt_language, src_language, SUPPORTED_DATASETS['language_pair'][src_language])) + format(tgt_language, src_language, SUPPORTED_DATASETS["language_pair"][src_language])) - if valid_set not in SUPPORTED_DATASETS['valid_test'] or valid_set in SET_NOT_EXISTS[language_pair]: + if valid_set not in SUPPORTED_DATASETS["valid_test"] or valid_set in SET_NOT_EXISTS[language_pair]: raise ValueError("valid_set '{}' is not valid for given language pair {}. Supported validation sets are {}". - format(valid_set, language_pair, [s for s in SUPPORTED_DATASETS['valid_test'] if s not in SET_NOT_EXISTS[language_pair]])) + format(valid_set, language_pair, [s for s in SUPPORTED_DATASETS["valid_test"] if s not in SET_NOT_EXISTS[language_pair]])) - if test_set not in SUPPORTED_DATASETS['valid_test'] or test_set in SET_NOT_EXISTS[language_pair]: + if test_set not in SUPPORTED_DATASETS["valid_test"] or test_set in SET_NOT_EXISTS[language_pair]: raise ValueError("test_set '{}' is not valid for give language pair {}. Supported test sets are {}". - format(valid_set, language_pair, [s for s in SUPPORTED_DATASETS['valid_test'] if s not in SET_NOT_EXISTS[language_pair]])) + format(valid_set, language_pair, [s for s in SUPPORTED_DATASETS["valid_test"] if s not in SET_NOT_EXISTS[language_pair]])) - train_filenames = ( - 'train.{}-{}.{}'.format(src_language, tgt_language, src_language), - 'train.{}-{}.{}'.format(src_language, tgt_language, tgt_language) + file_path_by_lang_and_split, uncleaned_filenames_by_lang_and_split = _generate_iwslt_files_for_lang_and_split( + SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set ) - valid_filenames = ( - 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language) - ) - test_filenames = ( - 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language) - ) - - src_train, tgt_train = train_filenames - src_eval, tgt_eval = valid_filenames - src_test, tgt_test = test_filenames - - uncleaned_train_filenames = ( - 'train.tags.{}-{}.{}'.format(src_language, tgt_language, src_language), - 'train.tags.{}-{}.{}'.format(src_language, tgt_language, tgt_language) - ) - uncleaed_valid_filenames = ( - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language) - ) - uncleaned_test_filenames = ( - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language) - ) - - uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames - uncleaned_src_eval, uncleaned_tgt_eval = uncleaed_valid_filenames - uncleaned_src_test, uncleaned_tgt_test = uncleaned_test_filenames url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( @@ -260,34 +231,8 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: os.path.basename(inner_iwslt_tar) in x[0]) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - file_path_by_lang_and_split = { - src_language: { - "train": src_train, - "valid": src_eval, - "test": src_test, - }, - tgt_language: { - "train": tgt_train, - "valid": tgt_eval, - "test": tgt_test, - } - } - - uncleaned_filenames = { - src_language: { - "train": uncleaned_src_train, - "valid": uncleaned_src_eval, - "test": uncleaned_src_test, - }, - tgt_language: { - "train": uncleaned_tgt_train, - "valid": uncleaned_tgt_eval, - "test": uncleaned_tgt_test, - } - } - src_filename = file_path_by_lang_and_split[src_language][split] - uncleaned_src_filename = uncleaned_filenames[src_language][split] + uncleaned_src_filename = uncleaned_filenames_by_lang_and_split[src_language][split] # We create the whole filepath here, but only check for the literal filename in the filter # because we're lazily extracting from the outer tarfile. @@ -296,7 +241,7 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de cache_inner_src_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_src_filepath, uncleaned_src_filename) tgt_filename = file_path_by_lang_and_split[tgt_language][split] - uncleaned_tgt_filename = uncleaned_filenames[tgt_language][split] + uncleaned_tgt_filename = uncleaned_filenames_by_lang_and_split[tgt_language][split] # We create the whole filepath here, but only check for the literal filename in the filter # because we're lazily extracting from the outer tarfile. diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index 67b3345685..57d325d65e 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -8,89 +8,90 @@ _clean_files, _create_dataset_directory, _wrap_split_argument, + _generate_iwslt_files_for_lang_and_split, ) -URL = 'https://drive.google.com/u/0/uc?id=12ycYSzLIG253AFN35Y6qoyf9wtkOjakp' -_PATH = '2017-01-trnmted.tgz' -MD5 = 'aca701032b1c4411afc4d9fa367796ba' +URL = "https://drive.google.com/u/0/uc?id=12ycYSzLIG253AFN35Y6qoyf9wtkOjakp" +_PATH = "2017-01-trnmted.tgz" +MD5 = "aca701032b1c4411afc4d9fa367796ba" SUPPORTED_DATASETS = { - 'valid_test': ['dev2010', 'tst2010'], - 'language_pair': { - 'en': ['nl', 'de', 'it', 'ro'], - 'ro': ['de', 'en', 'nl', 'it'], - 'de': ['ro', 'en', 'nl', 'it'], - 'it': ['en', 'nl', 'de', 'ro'], - 'nl': ['de', 'en', 'it', 'ro'], + "valid_test": ["dev2010", "tst2010"], + "language_pair": { + "en": ["nl", "de", "it", "ro"], + "ro": ["de", "en", "nl", "it"], + "de": ["ro", "en", "nl", "it"], + "it": ["en", "nl", "de", "ro"], + "nl": ["de", "en", "it", "ro"], }, - 'year': 17, + "year": 17, } NUM_LINES = { - 'train': { - 'train': { - ('en', 'nl'): 237240, - ('de', 'en'): 206112, - ('en', 'it'): 231619, - ('en', 'ro'): 220538, - ('de', 'ro'): 201455, - ('nl', 'ro'): 206920, - ('it', 'ro'): 217551, - ('de', 'nl'): 213628, - ('de', 'it'): 205465, - ('it', 'nl'): 233415 + "train": { + "train": { + ("en", "nl"): 237240, + ("de", "en"): 206112, + ("en", "it"): 231619, + ("en", "ro"): 220538, + ("de", "ro"): 201455, + ("nl", "ro"): 206920, + ("it", "ro"): 217551, + ("de", "nl"): 213628, + ("de", "it"): 205465, + ("it", "nl"): 233415 } }, - 'valid': { - 'dev2010': { - ('en', 'nl'): 1003, - ('de', 'en'): 888, - ('en', 'it'): 929, - ('en', 'ro'): 914, - ('de', 'ro'): 912, - ('nl', 'ro'): 913, - ('it', 'ro'): 914, - ('de', 'nl'): 1001, - ('de', 'it'): 923, - ('it', 'nl'): 1001 + "valid": { + "dev2010": { + ("en", "nl"): 1003, + ("de", "en"): 888, + ("en", "it"): 929, + ("en", "ro"): 914, + ("de", "ro"): 912, + ("nl", "ro"): 913, + ("it", "ro"): 914, + ("de", "nl"): 1001, + ("de", "it"): 923, + ("it", "nl"): 1001 }, - 'tst2010': { - ('en', 'nl'): 1777, - ('de', 'en'): 1568, - ('en', 'it'): 1566, - ('en', 'ro'): 1678, - ('de', 'ro'): 1677, - ('nl', 'ro'): 1680, - ('it', 'ro'): 1643, - ('de', 'nl'): 1779, - ('de', 'it'): 1567, - ('it', 'nl'): 1669 + "tst2010": { + ("en", "nl"): 1777, + ("de", "en"): 1568, + ("en", "it"): 1566, + ("en", "ro"): 1678, + ("de", "ro"): 1677, + ("nl", "ro"): 1680, + ("it", "ro"): 1643, + ("de", "nl"): 1779, + ("de", "it"): 1567, + ("it", "nl"): 1669 } }, - 'test': { - 'dev2010': { - ('en', 'nl'): 1003, - ('de', 'en'): 888, - ('en', 'it'): 929, - ('en', 'ro'): 914, - ('de', 'ro'): 912, - ('nl', 'ro'): 913, - ('it', 'ro'): 914, - ('de', 'nl'): 1001, - ('de', 'it'): 923, - ('it', 'nl'): 1001 + "test": { + "dev2010": { + ("en", "nl"): 1003, + ("de", "en"): 888, + ("en", "it"): 929, + ("en", "ro"): 914, + ("de", "ro"): 912, + ("nl", "ro"): 913, + ("it", "ro"): 914, + ("de", "nl"): 1001, + ("de", "it"): 923, + ("it", "nl"): 1001 }, - 'tst2010': { - ('en', 'nl'): 1777, - ('de', 'en'): 1568, - ('en', 'it'): 1566, - ('en', 'ro'): 1678, - ('de', 'ro'): 1677, - ('nl', 'ro'): 1680, - ('it', 'ro'): 1643, - ('de', 'nl'): 1779, - ('de', 'it'): 1567, - ('it', 'nl'): 1669 + "tst2010": { + ("en", "nl"): 1777, + ("de", "en"): 1568, + ("en", "it"): 1566, + ("en", "ro"): 1678, + ("de", "ro"): 1677, + ("nl", "ro"): 1680, + ("it", "ro"): 1643, + ("de", "nl"): 1779, + ("de", "it"): 1567, + ("it", "nl"): 1669 } } } @@ -113,7 +114,7 @@ def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) -def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en')): +def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de", "en")): """IWSLT2017 dataset The available datasets include following: @@ -121,17 +122,17 @@ def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de **Language pairs**: +-----+-----+-----+-----+-----+-----+ - | |'en' |'nl' |'de' |'it' |'ro' | + | |"en" |"nl" |"de" |"it" |"ro" | +-----+-----+-----+-----+-----+-----+ - |'en' | | x | x | x | x | + |"en" | | x | x | x | x | +-----+-----+-----+-----+-----+-----+ - |'nl' | x | | x | x | x | + |"nl" | x | | x | x | x | +-----+-----+-----+-----+-----+-----+ - |'de' | x | x | | x | x | + |"de" | x | x | | x | x | +-----+-----+-----+-----+-----+-----+ - |'it' | x | x | x | | x | + |"it" | x | x | x | | x | +-----+-----+-----+-----+-----+-----+ - |'ro' | x | x | x | x | | + |"ro" | x | x | x | x | | +-----+-----+-----+-----+-----+-----+ @@ -151,51 +152,27 @@ def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de if not is_module_available("torchdata"): raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") - valid_set = 'dev2010' - test_set = 'tst2010' + valid_set = "dev2010" + test_set = "tst2010" if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair))) - assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively' + assert (len(language_pair) == 2), "language_pair must contain only 2 elements: src and tgt language respectively" src_language, tgt_language = language_pair[0], language_pair[1] - if src_language not in SUPPORTED_DATASETS['language_pair']: + if src_language not in SUPPORTED_DATASETS["language_pair"]: raise ValueError("src_language '{}' is not valid. Supported source languages are {}". - format(src_language, list(SUPPORTED_DATASETS['language_pair']))) + format(src_language, list(SUPPORTED_DATASETS["language_pair"]))) - if tgt_language not in SUPPORTED_DATASETS['language_pair'][src_language]: + if tgt_language not in SUPPORTED_DATASETS["language_pair"][src_language]: raise ValueError("tgt_language '{}' is not valid for give src_language '{}'. Supported target language are {}". - format(tgt_language, src_language, SUPPORTED_DATASETS['language_pair'][src_language])) - - train_filenames = ('train.{}-{}.{}'.format(src_language, tgt_language, src_language), - 'train.{}-{}.{}'.format(src_language, tgt_language, tgt_language)) - valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language)) - test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language)) - - src_train, tgt_train = train_filenames - src_eval, tgt_eval = valid_filenames - src_test, tgt_test = test_filenames - - uncleaned_train_filenames = ( - 'train.tags.{}-{}.{}'.format(src_language, tgt_language, src_language), - 'train.tags.{}-{}.{}'.format(src_language, tgt_language, tgt_language) - ) - uncleaed_valid_filenames = ( - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language) - ) - uncleaned_test_filenames = ( - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language) - ) + format(tgt_language, src_language, SUPPORTED_DATASETS["language_pair"][src_language])) - uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames - uncleaned_src_eval, uncleaned_tgt_eval = uncleaed_valid_filenames - uncleaned_src_test, uncleaned_tgt_test = uncleaned_test_filenames + file_path_by_lang_and_split, uncleaned_filenames_by_lang_and_split = _generate_iwslt_files_for_lang_and_split( + SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set + ) url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( @@ -221,34 +198,8 @@ def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de lambda x: os.path.basename(inner_iwslt_tar) in x[0]) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - file_path_by_lang_and_split = { - src_language: { - "train": src_train, - "valid": src_eval, - "test": src_test, - }, - tgt_language: { - "train": tgt_train, - "valid": tgt_eval, - "test": tgt_test, - } - } - - uncleaned_filenames = { - src_language: { - "train": uncleaned_src_train, - "valid": uncleaned_src_eval, - "test": uncleaned_src_test, - }, - tgt_language: { - "train": uncleaned_tgt_train, - "valid": uncleaned_tgt_eval, - "test": uncleaned_tgt_test, - } - } - src_filename = file_path_by_lang_and_split[src_language][split] - uncleaned_src_filename = uncleaned_filenames[src_language][split] + uncleaned_src_filename = uncleaned_filenames_by_lang_and_split[src_language][split] # We create the whole filepath here, but only check for the literal filename in the filter # because we're lazily extracting from the outer tarfile. @@ -258,7 +209,7 @@ def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de uncleaned_src_filename) tgt_filename = file_path_by_lang_and_split[tgt_language][split] - uncleaned_tgt_filename = uncleaned_filenames[tgt_language][split] + uncleaned_tgt_filename = uncleaned_filenames_by_lang_and_split[tgt_language][split] # We create the whole filepath here, but only check for the literal filename in the filter # because we're lazily extracting from the outer tarfile. From f5c83bddeddfe62bda83de217da41bd49a99cd21 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Mon, 31 Jan 2022 14:03:52 -0500 Subject: [PATCH 8/8] fix typo. --- torchtext/data/datasets_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 844e9c9245..55ca5de78e 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -414,7 +414,7 @@ def _generate_iwslt_files_for_lang_and_split(year, src_language, tgt_language, v "train.tags.{}-{}.{}".format(src_language, tgt_language, src_language), "train.tags.{}-{}.{}".format(src_language, tgt_language, tgt_language) ) - uncleaed_valid_filenames = ( + uncleaned_valid_filenames = ( "IWSLT{}.TED.{}.{}-{}.{}.xml".format(year, valid_set, src_language, tgt_language, src_language), "IWSLT{}.TED.{}.{}-{}.{}.xml".format(year, valid_set, src_language, tgt_language, tgt_language) ) @@ -424,7 +424,7 @@ def _generate_iwslt_files_for_lang_and_split(year, src_language, tgt_language, v ) uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames - uncleaned_src_eval, uncleaned_tgt_eval = uncleaed_valid_filenames + uncleaned_src_eval, uncleaned_tgt_eval = uncleaned_valid_filenames uncleaned_src_test, uncleaned_tgt_test = uncleaned_test_filenames file_path_by_lang_and_split = {