diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index fa363150ac..55ca5de78e 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -392,6 +392,70 @@ def __str__(self): return self.description +def _generate_iwslt_files_for_lang_and_split(year, src_language, tgt_language, valid_set, test_set): + train_filenames = ( + "train.{}-{}.{}".format(src_language, tgt_language, src_language), + "train.{}-{}.{}".format(src_language, tgt_language, tgt_language) + ) + valid_filenames = ( + "IWSLT{}.TED.{}.{}-{}.{}".format(year, valid_set, src_language, tgt_language, src_language), + "IWSLT{}.TED.{}.{}-{}.{}".format(year, valid_set, src_language, tgt_language, tgt_language) + ) + test_filenames = ( + "IWSLT{}.TED.{}.{}-{}.{}".format(year, test_set, src_language, tgt_language, src_language), + "IWSLT{}.TED.{}.{}-{}.{}".format(year, test_set, src_language, tgt_language, tgt_language) + ) + + src_train, tgt_train = train_filenames + src_eval, tgt_eval = valid_filenames + src_test, tgt_test = test_filenames + + uncleaned_train_filenames = ( + "train.tags.{}-{}.{}".format(src_language, tgt_language, src_language), + "train.tags.{}-{}.{}".format(src_language, tgt_language, tgt_language) + ) + uncleaned_valid_filenames = ( + "IWSLT{}.TED.{}.{}-{}.{}.xml".format(year, valid_set, src_language, tgt_language, src_language), + "IWSLT{}.TED.{}.{}-{}.{}.xml".format(year, valid_set, src_language, tgt_language, tgt_language) + ) + uncleaned_test_filenames = ( + "IWSLT{}.TED.{}.{}-{}.{}.xml".format(year, test_set, src_language, tgt_language, src_language), + "IWSLT{}.TED.{}.{}-{}.{}.xml".format(year, test_set, src_language, tgt_language, tgt_language) + ) + + uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames + uncleaned_src_eval, uncleaned_tgt_eval = uncleaned_valid_filenames + uncleaned_src_test, uncleaned_tgt_test = uncleaned_test_filenames + + file_path_by_lang_and_split = { + src_language: { + "train": src_train, + "valid": src_eval, + "test": src_test, + }, + tgt_language: { + "train": tgt_train, + "valid": tgt_eval, + "test": tgt_test, + } + } + + uncleaned_filenames_by_lang_and_split = { + src_language: { + "train": uncleaned_src_train, + "valid": uncleaned_src_eval, + "test": uncleaned_src_test, + }, + tgt_language: { + "train": uncleaned_tgt_train, + "valid": uncleaned_tgt_eval, + "test": uncleaned_tgt_test, + } + } + + return file_path_by_lang_and_split, uncleaned_filenames_by_lang_and_split + + @functional_datapipe("read_squad") class _ParseSQuADQAData(IterDataPipe): r"""Iterable DataPipe to parse the contents of a stream of JSON objects diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index 4ddaba49ef..3073a9c8c4 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -8,125 +8,128 @@ _wrap_split_argument, _clean_files, _create_dataset_directory, + _generate_iwslt_files_for_lang_and_split, ) -URL = 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8' +URL = "https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8" -_PATH = '2016-01.tgz' +_PATH = "2016-01.tgz" -MD5 = 'c393ed3fc2a1b0f004b3331043f615ae' +MD5 = "c393ed3fc2a1b0f004b3331043f615ae" SUPPORTED_DATASETS = { - 'valid_test': ['dev2010', 'tst2010', 'tst2011', 'tst2012', 'tst2013', 'tst2014'], - 'language_pair': { - 'en': ['ar', 'de', 'fr', 'cs'], - 'ar': ['en'], - 'fr': ['en'], - 'de': ['en'], - 'cs': ['en'], + "valid_test": ["dev2010", "tst2010", "tst2011", "tst2012", "tst2013", "tst2014"], + "language_pair": { + "en": ["ar", "de", "fr", "cs"], + "ar": ["en"], + "fr": ["en"], + "de": ["en"], + "cs": ["en"], }, - 'year': 16, + "year": 16, } NUM_LINES = { - 'train': { - 'train': { - ('ar', 'en'): 224126, - ('de', 'en'): 196884, - ('en', 'fr'): 220400, - ('cs', 'en'): 114390 + "train": { + "train": { + ("ar", "en"): 224126, + ("de", "en"): 196884, + ("en", "fr"): 220400, + ("cs", "en"): 114390 } }, - 'valid': { - 'dev2010': { - ('ar', 'en'): 887, - ('de', 'en'): 887, - ('en', 'fr'): 887, - ('cs', 'en'): 480 + "valid": { + "dev2010": { + ("ar", "en"): 887, + ("de", "en"): 887, + ("en", "fr"): 887, + ("cs", "en"): 480 }, - 'tst2010': { - ('ar', 'en'): 1569, - ('de', 'en'): 1565, - ('en', 'fr'): 1664, - ('cs', 'en'): 1511 + "tst2010": { + ("ar", "en"): 1569, + ("de", "en"): 1565, + ("en", "fr"): 1664, + ("cs", "en"): 1511 }, - 'tst2011': { - ('ar', 'en'): 1199, - ('de', 'en'): 1433, - ('en', 'fr'): 818, - ('cs', 'en'): 1013 + "tst2011": { + ("ar", "en"): 1199, + ("de", "en"): 1433, + ("en", "fr"): 818, + ("cs", "en"): 1013 }, - 'tst2012': { - ('ar', 'en'): 1702, - ('de', 'en'): 1700, - ('en', 'fr'): 1124, - ('cs', 'en'): 1385 + "tst2012": { + ("ar", "en"): 1702, + ("de", "en"): 1700, + ("en", "fr"): 1124, + ("cs", "en"): 1385 }, - 'tst2013': { - ('ar', 'en'): 1169, - ('de', 'en'): 993, - ('en', 'fr'): 1026, - ('cs', 'en'): 1327 + "tst2013": { + ("ar", "en"): 1169, + ("de", "en"): 993, + ("en", "fr"): 1026, + ("cs", "en"): 1327 }, - 'tst2014': { - ('ar', 'en'): 1107, - ('de', 'en'): 1305, - ('en', 'fr'): 1305 + "tst2014": { + ("ar", "en"): 1107, + ("de", "en"): 1305, + ("en", "fr"): 1305 } }, - 'test': { - 'dev2010': { - ('ar', 'en'): 887, - ('de', 'en'): 887, - ('en', 'fr'): 887, - ('cs', 'en'): 480 + "test": { + "dev2010": { + ("ar", "en"): 887, + ("de", "en"): 887, + ("en", "fr"): 887, + ("cs", "en"): 480 }, - 'tst2010': { - ('ar', 'en'): 1569, - ('de', 'en'): 1565, - ('en', 'fr'): 1664, - ('cs', 'en'): 1511 + "tst2010": { + ("ar", "en"): 1569, + ("de", "en"): 1565, + ("en", "fr"): 1664, + ("cs", "en"): 1511 }, - 'tst2011': { - ('ar', 'en'): 1199, - ('de', 'en'): 1433, - ('en', 'fr'): 818, - ('cs', 'en'): 1013 + "tst2011": { + ("ar", "en"): 1199, + ("de", "en"): 1433, + ("en", "fr"): 818, + ("cs", "en"): 1013 }, - 'tst2012': { - ('ar', 'en'): 1702, - ('de', 'en'): 1700, - ('en', 'fr'): 1124, - ('cs', 'en'): 1385 + "tst2012": { + ("ar", "en"): 1702, + ("de", "en"): 1700, + ("en", "fr"): 1124, + ("cs", "en"): 1385 }, - 'tst2013': { - ('ar', 'en'): 1169, - ('de', 'en'): 993, - ('en', 'fr'): 1026, - ('cs', 'en'): 1327 + "tst2013": { + ("ar", "en"): 1169, + ("de", "en"): 993, + ("en", "fr"): 1026, + ("cs", "en"): 1327 }, - 'tst2014': { - ('ar', 'en'): 1107, - ('de', 'en'): 1305, - ('en', 'fr'): 1305 + "tst2014": { + ("ar", "en"): 1107, + ("de", "en"): 1305, + ("en", "fr"): 1305 } } } SET_NOT_EXISTS = { - ('en', 'ar'): [], - ('en', 'de'): [], - ('en', 'fr'): [], - ('en', 'cs'): ['tst2014'], - ('ar', 'en'): [], - ('fr', 'en'): [], - ('de', 'en'): [], - ('cs', 'en'): ['tst2014'] + ("en", "ar"): [], + ("en", "de"): [], + ("en", "fr"): [], + ("en", "cs"): ["tst2014"], + ("ar", "en"): [], + ("fr", "en"): [], + ("de", "en"): [], + ("cs", "en"): ["tst2014"] } DATASET_NAME = "IWSLT2016" +# TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to +# avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath) cache_inner_decompressed_dp = FileOpener(cache_inner_decompressed_dp, mode="b").read_from_tar() @@ -140,7 +143,7 @@ def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) -def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'): +def IWSLT2016(root=".data", split=("train", "valid", "test"), language_pair=("de", "en"), valid_set="tst2013", test_set="tst2014"): """IWSLT2016 dataset The available datasets include following: @@ -148,20 +151,20 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de **Language pairs**: +-----+-----+-----+-----+-----+-----+ - | |'en' |'fr' |'de' |'cs' |'ar' | + | |"en" |"fr" |"de" |"cs" |"ar" | +-----+-----+-----+-----+-----+-----+ - |'en' | | x | x | x | x | + |"en" | | x | x | x | x | +-----+-----+-----+-----+-----+-----+ - |'fr' | x | | | | | + |"fr" | x | | | | | +-----+-----+-----+-----+-----+-----+ - |'de' | x | | | | | + |"de" | x | | | | | +-----+-----+-----+-----+-----+-----+ - |'cs' | x | | | | | + |"cs" | x | | | | | +-----+-----+-----+-----+-----+-----+ - |'ar' | x | | | | | + |"ar" | x | | | | | +-----+-----+-----+-----+-----+-----+ - **valid/test sets**: ['dev2010', 'tst2010', 'tst2011', 'tst2012', 'tst2013', 'tst2014'] + **valid/test sets**: ["dev2010", "tst2010", "tst2011", "tst2012", "tst2013", "tst2014"] For additional details refer to source website: https://wit3.fbk.eu/2016-01 @@ -184,59 +187,29 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair))) - assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively' + assert (len(language_pair) == 2), "language_pair must contain only 2 elements: src and tgt language respectively" src_language, tgt_language = language_pair[0], language_pair[1] - if src_language not in SUPPORTED_DATASETS['language_pair']: + if src_language not in SUPPORTED_DATASETS["language_pair"]: raise ValueError("src_language '{}' is not valid. Supported source languages are {}". - format(src_language, list(SUPPORTED_DATASETS['language_pair']))) + format(src_language, list(SUPPORTED_DATASETS["language_pair"]))) - if tgt_language not in SUPPORTED_DATASETS['language_pair'][src_language]: + if tgt_language not in SUPPORTED_DATASETS["language_pair"][src_language]: raise ValueError("tgt_language '{}' is not valid for give src_language '{}'. Supported target language are {}". - format(tgt_language, src_language, SUPPORTED_DATASETS['language_pair'][src_language])) + format(tgt_language, src_language, SUPPORTED_DATASETS["language_pair"][src_language])) - if valid_set not in SUPPORTED_DATASETS['valid_test'] or valid_set in SET_NOT_EXISTS[language_pair]: + if valid_set not in SUPPORTED_DATASETS["valid_test"] or valid_set in SET_NOT_EXISTS[language_pair]: raise ValueError("valid_set '{}' is not valid for given language pair {}. Supported validation sets are {}". - format(valid_set, language_pair, [s for s in SUPPORTED_DATASETS['valid_test'] if s not in SET_NOT_EXISTS[language_pair]])) + format(valid_set, language_pair, [s for s in SUPPORTED_DATASETS["valid_test"] if s not in SET_NOT_EXISTS[language_pair]])) - if test_set not in SUPPORTED_DATASETS['valid_test'] or test_set in SET_NOT_EXISTS[language_pair]: + if test_set not in SUPPORTED_DATASETS["valid_test"] or test_set in SET_NOT_EXISTS[language_pair]: raise ValueError("test_set '{}' is not valid for give language pair {}. Supported test sets are {}". - format(valid_set, language_pair, [s for s in SUPPORTED_DATASETS['valid_test'] if s not in SET_NOT_EXISTS[language_pair]])) + format(valid_set, language_pair, [s for s in SUPPORTED_DATASETS["valid_test"] if s not in SET_NOT_EXISTS[language_pair]])) - train_filenames = ( - 'train.{}-{}.{}'.format(src_language, tgt_language, src_language), - 'train.{}-{}.{}'.format(src_language, tgt_language, tgt_language) + file_path_by_lang_and_split, uncleaned_filenames_by_lang_and_split = _generate_iwslt_files_for_lang_and_split( + SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set ) - valid_filenames = ( - 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language) - ) - test_filenames = ( - 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language) - ) - - src_train, tgt_train = train_filenames - src_eval, tgt_eval = valid_filenames - src_test, tgt_test = test_filenames - - uncleaned_train_filenames = ( - 'train.tags.{}-{}.{}'.format(src_language, tgt_language, src_language), - 'train.tags.{}-{}.{}'.format(src_language, tgt_language, tgt_language) - ) - uncleaed_valid_filenames = ( - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language) - ) - uncleaned_test_filenames = ( - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language) - ) - - uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames - uncleaned_src_eval, uncleaned_tgt_eval = uncleaed_valid_filenames - uncleaned_src_test, uncleaned_tgt_test = uncleaned_test_filenames url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( @@ -258,34 +231,8 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: os.path.basename(inner_iwslt_tar) in x[0]) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - file_path_by_lang_and_split = { - src_language: { - "train": src_train, - "valid": src_eval, - "test": src_test, - }, - tgt_language: { - "train": tgt_train, - "valid": tgt_eval, - "test": tgt_test, - } - } - - uncleaned_filenames = { - src_language: { - "train": uncleaned_src_train, - "valid": uncleaned_src_eval, - "test": uncleaned_src_test, - }, - tgt_language: { - "train": uncleaned_tgt_train, - "valid": uncleaned_tgt_eval, - "test": uncleaned_tgt_test, - } - } - src_filename = file_path_by_lang_and_split[src_language][split] - uncleaned_src_filename = uncleaned_filenames[src_language][split] + uncleaned_src_filename = uncleaned_filenames_by_lang_and_split[src_language][split] # We create the whole filepath here, but only check for the literal filename in the filter # because we're lazily extracting from the outer tarfile. @@ -294,7 +241,7 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de cache_inner_src_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_src_filepath, uncleaned_src_filename) tgt_filename = file_path_by_lang_and_split[tgt_language][split] - uncleaned_tgt_filename = uncleaned_filenames[tgt_language][split] + uncleaned_tgt_filename = uncleaned_filenames_by_lang_and_split[tgt_language][split] # We create the whole filepath here, but only check for the literal filename in the filter # because we're lazily extracting from the outer tarfile. diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index 051f1377d5..57d325d65e 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -1,125 +1,120 @@ +from torchtext._internal.module_utils import is_module_available + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper + import os -from torchtext.utils import (download_from_url, extract_archive) from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, + _clean_files, + _create_dataset_directory, _wrap_split_argument, - _clean_xml_file, - _clean_tags_file, - _read_text_iterator, + _generate_iwslt_files_for_lang_and_split, ) -from torchtext.data.datasets_utils import _create_dataset_directory -SUPPORTED_DATASETS = { +URL = "https://drive.google.com/u/0/uc?id=12ycYSzLIG253AFN35Y6qoyf9wtkOjakp" +_PATH = "2017-01-trnmted.tgz" +MD5 = "aca701032b1c4411afc4d9fa367796ba" - 'URL': 'https://drive.google.com/u/0/uc?id=12ycYSzLIG253AFN35Y6qoyf9wtkOjakp', - '_PATH': '2017-01-trnmted.tgz', - 'MD5': 'aca701032b1c4411afc4d9fa367796ba', - 'valid_test': ['dev2010', 'tst2010'], - 'language_pair': { - 'en': ['nl', 'de', 'it', 'ro'], - 'ro': ['de', 'en', 'nl', 'it'], - 'de': ['ro', 'en', 'nl', 'it'], - 'it': ['en', 'nl', 'de', 'ro'], - 'nl': ['de', 'en', 'it', 'ro'], +SUPPORTED_DATASETS = { + "valid_test": ["dev2010", "tst2010"], + "language_pair": { + "en": ["nl", "de", "it", "ro"], + "ro": ["de", "en", "nl", "it"], + "de": ["ro", "en", "nl", "it"], + "it": ["en", "nl", "de", "ro"], + "nl": ["de", "en", "it", "ro"], }, - 'year': 17, + "year": 17, } -URL = SUPPORTED_DATASETS['URL'] -MD5 = SUPPORTED_DATASETS['MD5'] - NUM_LINES = { - 'train': { - 'train': { - ('en', 'nl'): 237240, - ('de', 'en'): 206112, - ('en', 'it'): 231619, - ('en', 'ro'): 220538, - ('de', 'ro'): 201455, - ('nl', 'ro'): 206920, - ('it', 'ro'): 217551, - ('de', 'nl'): 213628, - ('de', 'it'): 205465, - ('it', 'nl'): 233415 + "train": { + "train": { + ("en", "nl"): 237240, + ("de", "en"): 206112, + ("en", "it"): 231619, + ("en", "ro"): 220538, + ("de", "ro"): 201455, + ("nl", "ro"): 206920, + ("it", "ro"): 217551, + ("de", "nl"): 213628, + ("de", "it"): 205465, + ("it", "nl"): 233415 } }, - 'valid': { - 'dev2010': { - ('en', 'nl'): 1003, - ('de', 'en'): 888, - ('en', 'it'): 929, - ('en', 'ro'): 914, - ('de', 'ro'): 912, - ('nl', 'ro'): 913, - ('it', 'ro'): 914, - ('de', 'nl'): 1001, - ('de', 'it'): 923, - ('it', 'nl'): 1001 + "valid": { + "dev2010": { + ("en", "nl"): 1003, + ("de", "en"): 888, + ("en", "it"): 929, + ("en", "ro"): 914, + ("de", "ro"): 912, + ("nl", "ro"): 913, + ("it", "ro"): 914, + ("de", "nl"): 1001, + ("de", "it"): 923, + ("it", "nl"): 1001 }, - 'tst2010': { - ('en', 'nl'): 1777, - ('de', 'en'): 1568, - ('en', 'it'): 1566, - ('en', 'ro'): 1678, - ('de', 'ro'): 1677, - ('nl', 'ro'): 1680, - ('it', 'ro'): 1643, - ('de', 'nl'): 1779, - ('de', 'it'): 1567, - ('it', 'nl'): 1669 + "tst2010": { + ("en", "nl"): 1777, + ("de", "en"): 1568, + ("en", "it"): 1566, + ("en", "ro"): 1678, + ("de", "ro"): 1677, + ("nl", "ro"): 1680, + ("it", "ro"): 1643, + ("de", "nl"): 1779, + ("de", "it"): 1567, + ("it", "nl"): 1669 } }, - 'test': { - 'dev2010': { - ('en', 'nl'): 1003, - ('de', 'en'): 888, - ('en', 'it'): 929, - ('en', 'ro'): 914, - ('de', 'ro'): 912, - ('nl', 'ro'): 913, - ('it', 'ro'): 914, - ('de', 'nl'): 1001, - ('de', 'it'): 923, - ('it', 'nl'): 1001 + "test": { + "dev2010": { + ("en", "nl"): 1003, + ("de", "en"): 888, + ("en", "it"): 929, + ("en", "ro"): 914, + ("de", "ro"): 912, + ("nl", "ro"): 913, + ("it", "ro"): 914, + ("de", "nl"): 1001, + ("de", "it"): 923, + ("it", "nl"): 1001 }, - 'tst2010': { - ('en', 'nl'): 1777, - ('de', 'en'): 1568, - ('en', 'it'): 1566, - ('en', 'ro'): 1678, - ('de', 'ro'): 1677, - ('nl', 'ro'): 1680, - ('it', 'ro'): 1643, - ('de', 'nl'): 1779, - ('de', 'it'): 1567, - ('it', 'nl'): 1669 + "tst2010": { + ("en", "nl"): 1777, + ("de", "en"): 1568, + ("en", "it"): 1566, + ("en", "ro"): 1678, + ("de", "ro"): 1677, + ("nl", "ro"): 1680, + ("it", "ro"): 1643, + ("de", "nl"): 1779, + ("de", "it"): 1567, + ("it", "nl"): 1669 } } } - -def _construct_filenames(filename, languages): - filenames = [] - for lang in languages: - filenames.append(filename + "." + lang) - return filenames - - -def _construct_filepaths(paths, src_filename, tgt_filename): - src_path = None - tgt_path = None - for p in paths: - src_path = p if src_filename in p else src_path - tgt_path = p if tgt_filename in p else tgt_path - return (src_path, tgt_path) +DATASET_NAME = "IWSLT2017" -DATASET_NAME = "IWSLT2017" +# TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to +# avoid additional conditional imports. +def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): + cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath) + cache_inner_decompressed_dp = FileOpener(cache_inner_decompressed_dp, mode="b").read_from_tar() + cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter( + lambda x: os.path.basename(uncleaned_filename) in x[0]) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.map( + lambda x: _clean_files(full_filepath, x[0], x[1])) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + return cache_inner_decompressed_dp @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'valid', 'test')) -def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en')): +@_wrap_split_argument(("train", "valid", "test")) +def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de", "en")): """IWSLT2017 dataset The available datasets include following: @@ -127,17 +122,17 @@ def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de **Language pairs**: +-----+-----+-----+-----+-----+-----+ - | |'en' |'nl' |'de' |'it' |'ro' | + | |"en" |"nl" |"de" |"it" |"ro" | +-----+-----+-----+-----+-----+-----+ - |'en' | | x | x | x | x | + |"en" | | x | x | x | x | +-----+-----+-----+-----+-----+-----+ - |'nl' | x | | x | x | x | + |"nl" | x | | x | x | x | +-----+-----+-----+-----+-----+-----+ - |'de' | x | x | | x | x | + |"de" | x | x | | x | x | +-----+-----+-----+-----+-----+-----+ - |'it' | x | x | x | | x | + |"it" | x | x | x | | x | +-----+-----+-----+-----+-----+-----+ - |'ro' | x | x | x | x | | + |"ro" | x | x | x | x | | +-----+-----+-----+-----+-----+-----+ @@ -154,81 +149,79 @@ def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de >>> src_sentence, tgt_sentence = next(train_iter) """ + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") - valid_set = 'dev2010' - test_set = 'tst2010' - - num_lines_set_identifier = { - 'train': 'train', - 'valid': valid_set, - 'test': test_set - } + valid_set = "dev2010" + test_set = "tst2010" if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair))) - assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively' + assert (len(language_pair) == 2), "language_pair must contain only 2 elements: src and tgt language respectively" src_language, tgt_language = language_pair[0], language_pair[1] - if src_language not in SUPPORTED_DATASETS['language_pair']: + if src_language not in SUPPORTED_DATASETS["language_pair"]: raise ValueError("src_language '{}' is not valid. Supported source languages are {}". - format(src_language, list(SUPPORTED_DATASETS['language_pair']))) + format(src_language, list(SUPPORTED_DATASETS["language_pair"]))) - if tgt_language not in SUPPORTED_DATASETS['language_pair'][src_language]: + if tgt_language not in SUPPORTED_DATASETS["language_pair"][src_language]: raise ValueError("tgt_language '{}' is not valid for give src_language '{}'. Supported target language are {}". - format(tgt_language, src_language, SUPPORTED_DATASETS['language_pair'][src_language])) - - train_filenames = ('train.{}-{}.{}'.format(src_language, tgt_language, src_language), - 'train.{}-{}.{}'.format(src_language, tgt_language, tgt_language)) - valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language)) - test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), - 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language)) - - src_train, tgt_train = train_filenames - src_eval, tgt_eval = valid_filenames - src_test, tgt_test = test_filenames - - extracted_files = [] # list of paths to the extracted files - dataset_tar = download_from_url(SUPPORTED_DATASETS['URL'], root=root, hash_value=SUPPORTED_DATASETS['MD5'], path=os.path.join(root, SUPPORTED_DATASETS['_PATH']), hash_type='md5') - extracted_dataset_tar = extract_archive(dataset_tar) - # IWSLT dataset's url downloads a multilingual tgz. - # We need to take an extra step to pick out the specific language pair from it. - src_language = train_filenames[0].split(".")[-1] - tgt_language = train_filenames[1].split(".")[-1] - - iwslt_tar = os.path.join(root, SUPPORTED_DATASETS['_PATH'].split(".")[0], 'texts/DeEnItNlRo/DeEnItNlRo', 'DeEnItNlRo-DeEnItNlRo.tgz') - extracted_dataset_tar = extract_archive(iwslt_tar) - extracted_files.extend(extracted_dataset_tar) - - # Clean the xml and tag file in the archives - file_archives = [] - for fname in extracted_files: - if 'xml' in fname: - _clean_xml_file(fname) - file_archives.append(os.path.splitext(fname)[0]) - elif "tags" in fname: - _clean_tags_file(fname) - file_archives.append(fname.replace('.tags', '')) - else: - file_archives.append(fname) - - data_filenames = { - "train": _construct_filepaths(file_archives, src_train, tgt_train), - "valid": _construct_filepaths(file_archives, src_eval, tgt_eval), - "test": _construct_filepaths(file_archives, src_test, tgt_test) - } - for key in data_filenames: - if len(data_filenames[key]) == 0 or data_filenames[key] is None: - raise FileNotFoundError( - "Files are not found for data type {}".format(key)) - - src_data_iter = _read_text_iterator(data_filenames[split][0]) - tgt_data_iter = _read_text_iterator(data_filenames[split][1]) - - def _iter(src_data_iter, tgt_data_iter): - for item in zip(src_data_iter, tgt_data_iter): - yield item - - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split][num_lines_set_identifier[split]][tuple(sorted(language_pair))], _iter(src_data_iter, tgt_data_iter)) + format(tgt_language, src_language, SUPPORTED_DATASETS["language_pair"][src_language])) + + file_path_by_lang_and_split, uncleaned_filenames_by_lang_and_split = _generate_iwslt_files_for_lang_and_split( + SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set + ) + + url_dp = IterableWrapper([URL]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _PATH), + hash_dict={os.path.join(root, _PATH): MD5}, + hash_type="md5" + ) + cache_compressed_dp = GDriveReader(cache_compressed_dp) + cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + # We create the whole filepath here, but only check for the literal filename in the filter + # because we're lazily extracting from the outer tarfile. Thus, + # /root/2017-01-trnmted/texts/.../DeEnItNlRo-DeEnItNlRo.tgz will never be in + # /root/2017-01-trnmted.tgz/texts/.../DeEnItNlRo-DeEnItNlRo.tgz + inner_iwslt_tar = os.path.join( + root, + os.path.splitext(_PATH)[0], + "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo.tgz" + ) + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter( + lambda x: os.path.basename(inner_iwslt_tar) in x[0]) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + src_filename = file_path_by_lang_and_split[src_language][split] + uncleaned_src_filename = uncleaned_filenames_by_lang_and_split[src_language][split] + + # We create the whole filepath here, but only check for the literal filename in the filter + # because we're lazily extracting from the outer tarfile. + full_src_filepath = os.path.join(root, "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo", src_filename) + + cache_inner_src_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_src_filepath, + uncleaned_src_filename) + + tgt_filename = file_path_by_lang_and_split[tgt_language][split] + uncleaned_tgt_filename = uncleaned_filenames_by_lang_and_split[tgt_language][split] + + # We create the whole filepath here, but only check for the literal filename in the filter + # because we're lazily extracting from the outer tarfile. + full_tgt_filepath = os.path.join(root, "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo", tgt_filename) + + cache_inner_tgt_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_tgt_filepath, + uncleaned_tgt_filename) + + tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="r") + src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="r") + + src_lines = src_data_dp.readlines(return_path=False, strip_newline=False) + tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False) + + return src_lines.zip(tgt_lines)