From 857d82d5ddfb833af649278d88be88ee2784eea6 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Thu, 27 Jan 2022 19:12:10 -0500 Subject: [PATCH 01/15] migrate IWSLT2016 to datapipes. --- torchtext/datasets/iwslt2016.py | 116 +++++++++++++++++++------------- 1 file changed, 71 insertions(+), 45 deletions(-) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index 6083502079..9518ee32ec 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -1,3 +1,9 @@ +from torchtext._internal.module_utils import is_module_available +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper, FileLister + import os from torchtext.utils import (download_from_url, extract_archive) from torchtext.data.datasets_utils import ( @@ -9,11 +15,14 @@ ) from torchtext.data.datasets_utils import _create_dataset_directory +URL = 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8' + +_PATH = '2016-01.tgz' + +MD5 = 'c393ed3fc2a1b0f004b3331043f615ae' SUPPORTED_DATASETS = { - 'URL': 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8', - '_PATH': '2016-01.tgz', - 'MD5': 'c393ed3fc2a1b0f004b3331043f615ae', + 'valid_test': ['dev2010', 'tst2010', 'tst2011', 'tst2012', 'tst2013', 'tst2014'], 'language_pair': { 'en': ['ar', 'de', 'fr', 'cs'], @@ -26,9 +35,6 @@ } -URL = SUPPORTED_DATASETS['URL'] -MD5 = SUPPORTED_DATASETS['MD5'] - NUM_LINES = { 'train': { 'train': { @@ -133,21 +139,28 @@ def _construct_filenames(filename, languages): return filenames +def _construct_filepath(path, src_filename, tgt_filename): + src_path = None + tgt_path = None + src_path = path if src_filename in path else src_path + tgt_path = path if tgt_filename in path else tgt_path + return src_path, tgt_path + + def _construct_filepaths(paths, src_filename, tgt_filename): src_path = None tgt_path = None for p in paths: - src_path = p if src_filename in p else src_path - tgt_path = p if tgt_filename in p else tgt_path - return (src_path, tgt_path) + src_path, tgt_path = _construct_filepath(p, src_filename, tgt_filename) + return src_path, tgt_path DATASET_NAME = "IWSLT2016" @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'valid', 'test')) -def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'): +@_wrap_split_argument(("train", "valid", "test")) +def IWSLT2016(root = '.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'): """IWSLT2016 dataset The available datasets include following: @@ -191,6 +204,9 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de 'test': test_set } + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair))) @@ -225,50 +241,60 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de src_eval, tgt_eval = valid_filenames src_test, tgt_test = test_filenames - extracted_files = [] # list of paths to the extracted files - dataset_tar = download_from_url(SUPPORTED_DATASETS['URL'], root=root, hash_value=SUPPORTED_DATASETS['MD5'], - path=os.path.join(root, SUPPORTED_DATASETS['_PATH']), hash_type='md5') - extracted_dataset_tar = extract_archive(dataset_tar) - # IWSLT dataset's url downloads a multilingual tgz. - # We need to take an extra step to pick out the specific language pair from it. + url_dp = IterableWrapper([URL]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _PATH), + hash_dict={os.path.join(root, _PATH): MD5}, + hash_type="md5" + ) + cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_compressed_dp = FileOpener(cache_compressed_dp, mode="b") src_language = train_filenames[0].split(".")[-1] tgt_language = train_filenames[1].split(".")[-1] languages = "-".join([src_language, tgt_language]) - iwslt_tar = '{}/{}/texts/{}/{}/{}.tgz' - iwslt_tar = iwslt_tar.format( - root, SUPPORTED_DATASETS['_PATH'].split(".")[0], src_language, tgt_language, languages) - extracted_dataset_tar = extract_archive(iwslt_tar) - extracted_files.extend(extracted_dataset_tar) + iwslt_tar = os.path.join( + "texts", src_language, tgt_language, languages + ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(os.path.splitext(x[0])[0], iwslt_tar) + ) + cache_decompressed_dp = cache_decompressed_dp.read_from_tar() + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb") - # Clean the xml and tag file in the archives - file_archives = [] - for fname in extracted_files: + def clean_files(fname): if 'xml' in fname: _clean_xml_file(fname) - file_archives.append(os.path.splitext(fname)[0]) + return os.path.splitext(fname)[0] elif "tags" in fname: _clean_tags_file(fname) - file_archives.append(fname.replace('.tags', '')) - else: - file_archives.append(fname) - - data_filenames = { - "train": _construct_filepaths(file_archives, src_train, tgt_train), - "valid": _construct_filepaths(file_archives, src_eval, tgt_eval), - "test": _construct_filepaths(file_archives, src_test, tgt_test) - } + return fname.replace('.tags', '') + return fname + + cache_decompressed_dp = cache_decompressed_dp.flatmap(FileLister) + + def get_filepath(f): + src_file, tgt_file = { + "train": _construct_filepath(f, src_train, tgt_train), + "valid": _construct_filepath(f, src_eval, tgt_eval), + "test": _construct_filepath(f, src_test, tgt_test) + }[split] + + return src_file, tgt_file + + cleaned_cache_decompressed_dp = cache_decompressed_dp.map(clean_files).map(get_filepath) - for key in data_filenames.keys(): - if len(data_filenames[key]) == 0 or data_filenames[key] is None: - raise FileNotFoundError( - "Files are not found for data type {}".format(key)) + # pairs of filenames are either both None or one of src/tgt is None. + # filter out both None since they're not relevant + cleaned_cache_decompressed_dp = cleaned_cache_decompressed_dp.filter(lambda x: x != (None, None)) - src_data_iter = _read_text_iterator(data_filenames[split][0]) - tgt_data_iter = _read_text_iterator(data_filenames[split][1]) + # (None, tgt) => 1, (src, None) => 0 + tgt_data_dp, src_data_dp = cleaned_cache_decompressed_dp.demux(2, lambda x: x.index(None)) - def _iter(src_data_iter, tgt_data_iter): - for item in zip(src_data_iter, tgt_data_iter): - yield item + # Pull out the non-None element (i.e., filename) from the tuple + tgt_data_dp = FileOpener(tgt_data_dp.map(lambda x: x[1]), mode="r") + src_data_dp = FileOpener(src_data_dp.map(lambda x: x[0]), mode="r") - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split][num_lines_set_identifier[split]][tuple(sorted(language_pair))], _iter(src_data_iter, tgt_data_iter)) + src_lines = src_data_dp.readlines(return_path=False, strip_newline=False) + tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False) + return src_lines.zip(tgt_lines) From b4507be3b2ee024e404f411d587242352ee8de0b Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Thu, 27 Jan 2022 19:58:21 -0500 Subject: [PATCH 02/15] try to fix style. --- torchtext/datasets/iwslt2016.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index 9518ee32ec..c8daef3766 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -1,17 +1,13 @@ from torchtext._internal.module_utils import is_module_available -from typing import Union, Tuple if is_module_available("torchdata"): from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper, FileLister import os -from torchtext.utils import (download_from_url, extract_archive) from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _clean_xml_file, _clean_tags_file, - _read_text_iterator, ) from torchtext.data.datasets_utils import _create_dataset_directory @@ -160,7 +156,7 @@ def _construct_filepaths(paths, src_filename, tgt_filename): @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) -def IWSLT2016(root = '.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'): +def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'): """IWSLT2016 dataset The available datasets include following: @@ -198,12 +194,6 @@ def IWSLT2016(root = '.data', split=('train', 'valid', 'test'), language_pair=(' >>> src_sentence, tgt_sentence = next(train_iter) """ - num_lines_set_identifier = { - 'train': 'train', - 'valid': valid_set, - 'test': test_set - } - if not is_module_available("torchdata"): raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") From b31bf25c4b8460c4c39ecd98c598096e0c502f4e Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Fri, 28 Jan 2022 08:34:54 -0500 Subject: [PATCH 03/15] clean up logic --- torchtext/datasets/iwslt2016.py | 70 +++++++++++---------------------- 1 file changed, 23 insertions(+), 47 deletions(-) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index c8daef3766..c2bd7f7a03 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -18,7 +18,6 @@ MD5 = 'c393ed3fc2a1b0f004b3331043f615ae' SUPPORTED_DATASETS = { - 'valid_test': ['dev2010', 'tst2010', 'tst2011', 'tst2012', 'tst2013', 'tst2014'], 'language_pair': { 'en': ['ar', 'de', 'fr', 'cs'], @@ -28,7 +27,6 @@ 'cs': ['en'], }, 'year': 16, - } NUM_LINES = { @@ -127,30 +125,6 @@ ('cs', 'en'): ['tst2014'] } - -def _construct_filenames(filename, languages): - filenames = [] - for lang in languages: - filenames.append(filename + "." + lang) - return filenames - - -def _construct_filepath(path, src_filename, tgt_filename): - src_path = None - tgt_path = None - src_path = path if src_filename in path else src_path - tgt_path = path if tgt_filename in path else tgt_path - return src_path, tgt_path - - -def _construct_filepaths(paths, src_filename, tgt_filename): - src_path = None - tgt_path = None - for p in paths: - src_path, tgt_path = _construct_filepath(p, src_filename, tgt_filename) - return src_path, tgt_path - - DATASET_NAME = "IWSLT2016" @@ -247,6 +221,7 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de "texts", src_language, tgt_language, languages ) cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + # Convert /path/to/downloaded/foo.tgz to /path/to/downloaded/foo/rest/of/path filepath_fn=lambda x: os.path.join(os.path.splitext(x[0])[0], iwslt_tar) ) cache_decompressed_dp = cache_decompressed_dp.read_from_tar() @@ -263,27 +238,28 @@ def clean_files(fname): cache_decompressed_dp = cache_decompressed_dp.flatmap(FileLister) - def get_filepath(f): - src_file, tgt_file = { - "train": _construct_filepath(f, src_train, tgt_train), - "valid": _construct_filepath(f, src_eval, tgt_eval), - "test": _construct_filepath(f, src_test, tgt_test) - }[split] - - return src_file, tgt_file - - cleaned_cache_decompressed_dp = cache_decompressed_dp.map(clean_files).map(get_filepath) - - # pairs of filenames are either both None or one of src/tgt is None. - # filter out both None since they're not relevant - cleaned_cache_decompressed_dp = cleaned_cache_decompressed_dp.filter(lambda x: x != (None, None)) - - # (None, tgt) => 1, (src, None) => 0 - tgt_data_dp, src_data_dp = cleaned_cache_decompressed_dp.demux(2, lambda x: x.index(None)) - - # Pull out the non-None element (i.e., filename) from the tuple - tgt_data_dp = FileOpener(tgt_data_dp.map(lambda x: x[1]), mode="r") - src_data_dp = FileOpener(src_data_dp.map(lambda x: x[0]), mode="r") + def get_filepath(split, lang): + return { + src_language: { + "train": src_train, + "valid": src_eval, + "test": src_test, + }, + tgt_language: { + "train": tgt_train, + "valid": tgt_eval, + "test": tgt_test, + } + }[lang][split] + + cleaned_cache_decompressed_dp = cache_decompressed_dp.map(clean_files) + + # Filters out irrelevant file given the filename templates filled with split and src/tgt codes + src_data_dp = cleaned_cache_decompressed_dp.filter(lambda x: get_filepath(split, src_language) in x) + tgt_data_dp = cleaned_cache_decompressed_dp.filter(lambda x: get_filepath(split, tgt_language) in x) + + tgt_data_dp = FileOpener(tgt_data_dp, mode="r") + src_data_dp = FileOpener(src_data_dp, mode="r") src_lines = src_data_dp.readlines(return_path=False, strip_newline=False) tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False) From c565d1bfdce8f15245b4636e8f92e45344c9877f Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Fri, 28 Jan 2022 15:26:15 -0500 Subject: [PATCH 04/15] add inner-tar cleaners. --- torchtext/data/datasets_utils.py | 40 ++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index d6daf3a38f..47882f67fa 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -33,6 +33,18 @@ def _clean_xml_file(f_xml): fd_txt.write(e.text.strip() + '\n') +def _clean_inner_xml_file(f_xml, base, stream): + f_txt = os.path.basename(os.path.splitext(f_xml)[0]) + os.makedirs(base, exist_ok=True) + out_file = os.path.join(base, f_txt) + with codecs.open(out_file, mode='w', encoding='utf-8') as fd_txt: + root = ET.fromstring(stream.read().decode("utf-8"))[0] + for doc in root.findall('doc'): + for e in doc.findall('seg'): + fd_txt.write(e.text.strip() + '\n') + return os.path.join(base, f_txt) + + def _clean_tags_file(f_orig): xml_tags = [ ' Date: Fri, 28 Jan 2022 15:26:59 -0500 Subject: [PATCH 05/15] simplify logic for IWSLT2016. --- torchtext/datasets/iwslt2016.py | 65 +++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index c2bd7f7a03..b3e57cdead 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -6,10 +6,11 @@ import os from torchtext.data.datasets_utils import ( _wrap_split_argument, - _clean_xml_file, - _clean_tags_file, + _clean_inner_xml_file, + _clean_inner_tags_file, + _create_dataset_directory, + _rewrite_text_file, ) -from torchtext.data.datasets_utils import _create_dataset_directory URL = 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8' @@ -211,32 +212,27 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" ) - cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_compressed_dp = FileOpener(cache_compressed_dp, mode="b") - src_language = train_filenames[0].split(".")[-1] - tgt_language = train_filenames[1].split(".")[-1] + cache_compressed_dp = GDriveReader(cache_compressed_dp) + cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) + languages = "-".join([src_language, tgt_language]) iwslt_tar = os.path.join( "texts", src_language, tgt_language, languages - ) + ) + ".tgz" + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - # Convert /path/to/downloaded/foo.tgz to /path/to/downloaded/foo/rest/of/path - filepath_fn=lambda x: os.path.join(os.path.splitext(x[0])[0], iwslt_tar) + filepath_fn=lambda x: os.path.join(root, os.path.splitext(_PATH)[0], iwslt_tar) ) - cache_decompressed_dp = cache_decompressed_dp.read_from_tar() - cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb") + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: iwslt_tar in x[0]) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - def clean_files(fname): + def clean_files(fname, base, stream): if 'xml' in fname: - _clean_xml_file(fname) - return os.path.splitext(fname)[0] + return _clean_inner_xml_file(fname, base, stream) elif "tags" in fname: - _clean_tags_file(fname) - return fname.replace('.tags', '') - return fname - - cache_decompressed_dp = cache_decompressed_dp.flatmap(FileLister) + return _clean_inner_tags_file(fname, base, stream) + return _rewrite_text_file(fname, base, stream) def get_filepath(split, lang): return { @@ -252,15 +248,28 @@ def get_filepath(split, lang): } }[lang][split] - cleaned_cache_decompressed_dp = cache_decompressed_dp.map(clean_files) + cache_inner_src_decompressed_dp = cache_decompressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, get_filepath(split, src_language)) + ) + cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b").read_from_tar() + cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1])) + cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: get_filepath(split, src_language) in x) + cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b") + cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + cache_inner_tgt_decompressed_dp = cache_decompressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, get_filepath(split, tgt_language)) + ) + cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b").read_from_tar() + cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1])) + cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: get_filepath(split, tgt_language) in x) + cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b") + cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - # Filters out irrelevant file given the filename templates filled with split and src/tgt codes - src_data_dp = cleaned_cache_decompressed_dp.filter(lambda x: get_filepath(split, src_language) in x) - tgt_data_dp = cleaned_cache_decompressed_dp.filter(lambda x: get_filepath(split, tgt_language) in x) + tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="r") + src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="r") - tgt_data_dp = FileOpener(tgt_data_dp, mode="r") - src_data_dp = FileOpener(src_data_dp, mode="r") + src_lines = src_data_dp.readlines(return_path=False, strip_newline=False, decode=True) + tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False, decode=True) - src_lines = src_data_dp.readlines(return_path=False, strip_newline=False) - tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False) return src_lines.zip(tgt_lines) From 9381a88e95c401639380e13d5907b1109ac9b2cb Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Fri, 28 Jan 2022 15:32:22 -0500 Subject: [PATCH 06/15] add missing base path to output file so files are not written to cwd --- torchtext/data/datasets_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 47882f67fa..3c4702deaa 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -67,7 +67,7 @@ def _clean_inner_tags_file(f_orig, base, stream): ' Date: Fri, 28 Jan 2022 15:40:49 -0500 Subject: [PATCH 07/15] refactor for consistency. --- torchtext/data/datasets_utils.py | 6 ++++ torchtext/datasets/iwslt2016.py | 53 +++++++++++++++----------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 3c4702deaa..d27685b78c 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -89,6 +89,12 @@ def _rewrite_text_file(file, base, stream): f.write(line.decode("utf-8")) return out_file +def _clean_files(fname, base, stream): + if 'xml' in fname: + return _clean_inner_xml_file(fname, base, stream) + elif "tags" in fname: + return _clean_inner_tags_file(fname, base, stream) + return _rewrite_text_file(fname, base, stream) def _create_data_from_json(data_path): with open(data_path) as json_file: diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index b3e57cdead..ff3bfbf186 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -6,10 +6,8 @@ import os from torchtext.data.datasets_utils import ( _wrap_split_argument, - _clean_inner_xml_file, - _clean_inner_tags_file, + _clean_files, _create_dataset_directory, - _rewrite_text_file, ) URL = 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8' @@ -227,42 +225,39 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: iwslt_tar in x[0]) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - def clean_files(fname, base, stream): - if 'xml' in fname: - return _clean_inner_xml_file(fname, base, stream) - elif "tags" in fname: - return _clean_inner_tags_file(fname, base, stream) - return _rewrite_text_file(fname, base, stream) - - def get_filepath(split, lang): - return { - src_language: { - "train": src_train, - "valid": src_eval, - "test": src_test, - }, - tgt_language: { - "train": tgt_train, - "valid": tgt_eval, - "test": tgt_test, - } - }[lang][split] + file_path_by_lang_and_split = { + src_language: { + "train": src_train, + "valid": src_eval, + "test": src_test, + }, + tgt_language: { + "train": tgt_train, + "valid": tgt_eval, + "test": tgt_test, + } + } + + src_filepath = file_path_by_lang_and_split[src_language][split] cache_inner_src_decompressed_dp = cache_decompressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, get_filepath(split, src_language)) + filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, src_filepath) ) cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b").read_from_tar() - cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1])) - cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: get_filepath(split, src_language) in x) + cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: _clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1])) + cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: src_filepath in x) cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b") cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + tgt_filepath = file_path_by_lang_and_split[tgt_language][split] + + cache_inner_tgt_decompressed_dp = cache_decompressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, get_filepath(split, tgt_language)) + filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, tgt_filepath) ) cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b").read_from_tar() - cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1])) - cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: get_filepath(split, tgt_language) in x) + cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: _clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1])) + cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: tgt_filepath in x) cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b") cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) From 77839ce0b4c99a0ee1335fef8229ec54c98a87dc Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sat, 29 Jan 2022 13:09:43 -0500 Subject: [PATCH 08/15] address initial style reviews. --- torchtext/datasets/iwslt2016.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index ff3bfbf186..4516161315 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -1,7 +1,7 @@ from torchtext._internal.module_utils import is_module_available if is_module_available("torchdata"): - from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper, FileLister + from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper import os from torchtext.data.datasets_utils import ( @@ -215,14 +215,14 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de languages = "-".join([src_language, tgt_language]) - iwslt_tar = os.path.join( - "texts", src_language, tgt_language, languages + inner_iwslt_tar = os.path.join( + root, os.path.splitext(_PATH)[0], "texts", src_language, tgt_language, languages ) + ".tgz" cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.splitext(_PATH)[0], iwslt_tar) + filepath_fn=lambda x: inner_iwslt_tar ) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: iwslt_tar in x[0]) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: inner_iwslt_tar in x[0]) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) file_path_by_lang_and_split = { @@ -238,26 +238,25 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de } } - src_filepath = file_path_by_lang_and_split[src_language][split] + src_filename = file_path_by_lang_and_split[src_language][split] cache_inner_src_decompressed_dp = cache_decompressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, src_filepath) + filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, src_filename) ) cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b").read_from_tar() cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: _clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1])) - cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: src_filepath in x) + cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: src_filename in x) cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b") cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - tgt_filepath = file_path_by_lang_and_split[tgt_language][split] - + tgt_filename = file_path_by_lang_and_split[tgt_language][split] cache_inner_tgt_decompressed_dp = cache_decompressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, tgt_filepath) + filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, tgt_filename) ) cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b").read_from_tar() cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: _clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1])) - cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: tgt_filepath in x) + cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: tgt_filename in x) cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b") cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) From a2dc38ef42480a8676031acfc11f295f98362baa Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sat, 29 Jan 2022 13:14:21 -0500 Subject: [PATCH 09/15] pull out common absolute paths for filtering src/tgt files. --- torchtext/datasets/iwslt2016.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index 4516161315..6070fe3dc9 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -239,24 +239,22 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de } src_filename = file_path_by_lang_and_split[src_language][split] + full_src_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, src_filename) - cache_inner_src_decompressed_dp = cache_decompressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, src_filename) - ) + cache_inner_src_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_src_filepath) cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b").read_from_tar() cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: _clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1])) - cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: src_filename in x) + cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: full_src_filepath in x) cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b") cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) tgt_filename = file_path_by_lang_and_split[tgt_language][split] + full_tgt_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, tgt_filename) - cache_inner_tgt_decompressed_dp = cache_decompressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, tgt_filename) - ) + cache_inner_tgt_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_tgt_filepath) cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b").read_from_tar() cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: _clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1])) - cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: tgt_filename in x) + cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: full_tgt_filepath in x) cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b") cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) From b668624c801cfe2ec1f8ff2423c6f5985d20ddd2 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sat, 29 Jan 2022 13:19:20 -0500 Subject: [PATCH 10/15] add docstring for new functions. --- torchtext/data/datasets_utils.py | 35 ++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index d27685b78c..2a41298d50 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -34,6 +34,17 @@ def _clean_xml_file(f_xml): def _clean_inner_xml_file(f_xml, base, stream): + """Accepts an XML filename within a tarball and a stream of the byte contents + within that file and writes the cleaned contents to a new, untarred file + found in the provided base directory. + + Args: + f_orig: the full path of the XML file in the archive + base: the directory to which the new file should be written + stream: the byte datapipe of the contents of f_orig + + Returns: the path to the newly-written file + """ f_txt = os.path.basename(os.path.splitext(f_xml)[0]) os.makedirs(base, exist_ok=True) out_file = os.path.join(base, f_txt) @@ -63,6 +74,17 @@ def _clean_tags_file(f_orig): def _clean_inner_tags_file(f_orig, base, stream): + """Accepts a tags filename within a tarball and a stream of the byte contents + within that file and writes the cleaned contents to a new, untarred file + found in the provided base directory. + + Args: + f_orig: the full path of the tags file in the archive + base: the directory to which the new file should be written + stream: the byte datapipe of the contents of f_orig + + Returns: the path to the newly-written file + """ xml_tags = [ ' Date: Sat, 29 Jan 2022 16:10:45 -0500 Subject: [PATCH 11/15] refactors some of the caching logic and cleaners --- torchtext/data/datasets_utils.py | 69 ++++++++++++++------------------ torchtext/datasets/iwslt2016.py | 59 ++++++++++++++++++++------- 2 files changed, 74 insertions(+), 54 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 2a41298d50..14b9cd803a 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -33,27 +33,23 @@ def _clean_xml_file(f_xml): fd_txt.write(e.text.strip() + '\n') -def _clean_inner_xml_file(f_xml, base, stream): - """Accepts an XML filename within a tarball and a stream of the byte contents - within that file and writes the cleaned contents to a new, untarred file - found in the provided base directory. +def _clean_inner_xml_file(outfile, stream): + """Accepts an output filename and a stream of the byte contents of an XML file + within a tarball and writes the cleaned contents to a new, untarred file. Args: - f_orig: the full path of the XML file in the archive - base: the directory to which the new file should be written - stream: the byte datapipe of the contents of f_orig + outfile: the path to which the modified stream should be written + stream: the byte datapipe of the contents of the archived XML file Returns: the path to the newly-written file """ - f_txt = os.path.basename(os.path.splitext(f_xml)[0]) - os.makedirs(base, exist_ok=True) - out_file = os.path.join(base, f_txt) - with codecs.open(out_file, mode='w', encoding='utf-8') as fd_txt: + os.makedirs(os.path.dirname(outfile), exist_ok=True) + with codecs.open(outfile, mode='w', encoding='utf-8') as fd_txt: root = ET.fromstring(stream.read().decode("utf-8"))[0] for doc in root.findall('doc'): for e in doc.findall('seg'): fd_txt.write(e.text.strip() + '\n') - return os.path.join(base, f_txt) + return outfile def _clean_tags_file(f_orig): @@ -73,15 +69,13 @@ def _clean_tags_file(f_orig): fd_txt.write(line.strip() + '\n') -def _clean_inner_tags_file(f_orig, base, stream): - """Accepts a tags filename within a tarball and a stream of the byte contents - within that file and writes the cleaned contents to a new, untarred file - found in the provided base directory. +def _clean_inner_tags_file(outfile, stream): + """Accepts an output filename and a stream of the byte contents of a tags file + within a tarball and writes the cleaned contents to a new, untarred file. Args: - f_orig: the full path of the tags file in the archive - base: the directory to which the new file should be written - stream: the byte datapipe of the contents of f_orig + outfile: the path to which the modified stream should be written + stream: the byte datapipe of the contents of the archived tags file Returns: the path to the newly-written file """ @@ -89,9 +83,8 @@ def _clean_inner_tags_file(f_orig, base, stream): '>> from torchtext.datasets import IWSLT2016 >>> train_iter, valid_iter, test_iter = IWSLT2016() - >>> src_sentence, tgt_sentence = next(train_iter) + >>> src_sentence, tgt_sentence = next(iter(train_iter)) """ if not is_module_available("torchdata"): @@ -204,6 +204,17 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de src_eval, tgt_eval = valid_filenames src_test, tgt_test = test_filenames + uncleaned_train_filenames = ('train.tags.{}-{}.{}'.format(src_language, tgt_language, src_language), + 'train.tags.{}-{}.{}'.format(src_language, tgt_language, tgt_language)) + uncleaed_valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), + 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language)) + uncleaned_test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), + 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language)) + + uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames + uncleaned_src_eval, uncleaned_tgt_eval = uncleaed_valid_filenames + uncleaned_src_test, uncleaned_tgt_test = uncleaned_test_filenames + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( filepath_fn=lambda x: os.path.join(root, _PATH), @@ -215,14 +226,13 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de languages = "-".join([src_language, tgt_language]) - inner_iwslt_tar = os.path.join( - root, os.path.splitext(_PATH)[0], "texts", src_language, tgt_language, languages - ) + ".tgz" + # We create the whole filepath here, but only check for the literal filename in the filter + # because we're lazily extracting from the outer tarfile. Thus, + # /root/2016-01/texts/.../src-tgt.tgz will never be in /root/2016-01.tgz/texts/.../src-tgt.tgz + inner_iwslt_tar = os.path.join(root, os.path.splitext(_PATH)[0], "texts", src_language, tgt_language, languages) + ".tgz" - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: inner_iwslt_tar - ) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: inner_iwslt_tar in x[0]) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: os.path.basename(inner_iwslt_tar) in x[0]) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) file_path_by_lang_and_split = { @@ -238,28 +248,49 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de } } + uncleaned_filenames = { + src_language: { + "train": uncleaned_src_train, + "valid": uncleaned_src_eval, + "test": uncleaned_src_test, + }, + tgt_language: { + "train": uncleaned_tgt_train, + "valid": uncleaned_tgt_eval, + "test": uncleaned_tgt_test, + } + } + src_filename = file_path_by_lang_and_split[src_language][split] + uncleaned_src_filename = uncleaned_filenames[src_language][split] + + # We create the whole filepath here, but only check for the literal filename in the filter + # because we're lazily extracting from the outer tarfile. full_src_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, src_filename) cache_inner_src_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_src_filepath) cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b").read_from_tar() - cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: _clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1])) - cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: full_src_filepath in x) + cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: os.path.basename(uncleaned_src_filename) in x[0]) + cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: _clean_files(full_src_filepath, x[0], x[1])) cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b") cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) tgt_filename = file_path_by_lang_and_split[tgt_language][split] + uncleaned_tgt_filename = uncleaned_filenames[tgt_language][split] + + # We create the whole filepath here, but only check for the literal filename in the filter + # because we're lazily extracting from the outer tarfile. full_tgt_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, tgt_filename) cache_inner_tgt_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_tgt_filepath) cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b").read_from_tar() - cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: _clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1])) - cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: full_tgt_filepath in x) + cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: os.path.basename(uncleaned_tgt_filename) in x[0]) + cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: _clean_files(full_tgt_filepath, x[0], x[1])) cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b") cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="r") - src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="r") + tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b") + src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b") src_lines = src_data_dp.readlines(return_path=False, strip_newline=False, decode=True) tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False, decode=True) From 9b25db021331d0b17632525b945a7ab275d256e7 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sat, 29 Jan 2022 16:20:19 -0500 Subject: [PATCH 12/15] fix caching by returning the new StreamWrapper for newly-created cached file. --- torchtext/data/datasets_utils.py | 13 +++++++------ torchtext/datasets/iwslt2016.py | 2 -- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 14b9cd803a..650da8722c 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -11,6 +11,7 @@ unicode_csv_reader, ) from torch.utils.data import functional_datapipe, IterDataPipe +from torch.utils.data.datapipes.utils.common import StreamWrapper import codecs try: import defusedxml.ElementTree as ET @@ -41,7 +42,7 @@ def _clean_inner_xml_file(outfile, stream): outfile: the path to which the modified stream should be written stream: the byte datapipe of the contents of the archived XML file - Returns: the path to the newly-written file + Returns: the path to the newly-written file and the new StreamWrapper for appropriate caching """ os.makedirs(os.path.dirname(outfile), exist_ok=True) with codecs.open(outfile, mode='w', encoding='utf-8') as fd_txt: @@ -49,7 +50,7 @@ def _clean_inner_xml_file(outfile, stream): for doc in root.findall('doc'): for e in doc.findall('seg'): fd_txt.write(e.text.strip() + '\n') - return outfile + return outfile, StreamWrapper(open(outfile, "rb")) def _clean_tags_file(f_orig): @@ -77,7 +78,7 @@ def _clean_inner_tags_file(outfile, stream): outfile: the path to which the modified stream should be written stream: the byte datapipe of the contents of the archived tags file - Returns: the path to the newly-written file + Returns: the path to the newly-written file and the new StreamWrapper for appropriate caching """ xml_tags = [ ' Date: Sat, 29 Jan 2022 16:23:24 -0500 Subject: [PATCH 13/15] no need to re-decode when we can read in text mode. :-) --- torchtext/datasets/iwslt2016.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index ae952b9cd9..60029bd06c 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -287,10 +287,10 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: _clean_files(full_tgt_filepath, x[0], x[1])) cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b") - src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b") + tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="r") + src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="r") - src_lines = src_data_dp.readlines(return_path=False, strip_newline=False, decode=True) - tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False, decode=True) + src_lines = src_data_dp.readlines(return_path=False, strip_newline=False) + tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False) return src_lines.zip(tgt_lines) From 700ff18bc0292e608da88a2e1369acb85506ec56 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sat, 29 Jan 2022 18:01:18 -0500 Subject: [PATCH 14/15] DRY up the inner caching logic. --- torchtext/datasets/iwslt2016.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index 60029bd06c..29239494ef 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -127,6 +127,17 @@ DATASET_NAME = "IWSLT2016" +def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): + cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath) + cache_inner_decompressed_dp = FileOpener(cache_inner_decompressed_dp, mode="b").read_from_tar() + cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter( + lambda x: os.path.basename(uncleaned_filename) in x[0]) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.map( + lambda x: _clean_files(full_filepath, x[0], x[1])) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + return cache_inner_decompressed_dp + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'): @@ -268,11 +279,7 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de # because we're lazily extracting from the outer tarfile. full_src_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, src_filename) - cache_inner_src_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_src_filepath) - cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b").read_from_tar() - cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: os.path.basename(uncleaned_src_filename) in x[0]) - cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: _clean_files(full_src_filepath, x[0], x[1])) - cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + cache_inner_src_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_src_filepath, uncleaned_src_filename) tgt_filename = file_path_by_lang_and_split[tgt_language][split] uncleaned_tgt_filename = uncleaned_filenames[tgt_language][split] @@ -281,11 +288,7 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de # because we're lazily extracting from the outer tarfile. full_tgt_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, tgt_filename) - cache_inner_tgt_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_tgt_filepath) - cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b").read_from_tar() - cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: os.path.basename(uncleaned_tgt_filename) in x[0]) - cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: _clean_files(full_tgt_filepath, x[0], x[1])) - cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + cache_inner_tgt_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_tgt_filepath, uncleaned_tgt_filename) tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="r") src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="r") From 60f863ad748a58ea1d19a2b0af8c1c1ee3f4945b Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sat, 29 Jan 2022 23:20:10 -0500 Subject: [PATCH 15/15] fix docstring. --- torchtext/data/datasets_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 650da8722c..fa363150ac 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -36,11 +36,11 @@ def _clean_xml_file(f_xml): def _clean_inner_xml_file(outfile, stream): """Accepts an output filename and a stream of the byte contents of an XML file - within a tarball and writes the cleaned contents to a new, untarred file. + and writes the cleaned contents to a new file on disk. Args: outfile: the path to which the modified stream should be written - stream: the byte datapipe of the contents of the archived XML file + stream: the byte datapipe of the contents of the XML file Returns: the path to the newly-written file and the new StreamWrapper for appropriate caching """ @@ -72,11 +72,11 @@ def _clean_tags_file(f_orig): def _clean_inner_tags_file(outfile, stream): """Accepts an output filename and a stream of the byte contents of a tags file - within a tarball and writes the cleaned contents to a new, untarred file. + and writes the cleaned contents to a new file on disk. Args: outfile: the path to which the modified stream should be written - stream: the byte datapipe of the contents of the archived tags file + stream: the byte datapipe of the contents of the tags file Returns: the path to the newly-written file and the new StreamWrapper for appropriate caching """ @@ -98,11 +98,11 @@ def _clean_inner_tags_file(outfile, stream): def _rewrite_text_file(outfile, stream): """Accepts an output filename and a stream of the byte contents of a text file - within a tarball and writes the cleaned contents to a new, untarred file. + and writes the cleaned contents to a new file on disk. Args: outfile: the path to which the modified stream should be written - stream: the byte datapipe of the contents of the archived text file + stream: the byte datapipe of the contents of the text file Returns: the path to the newly-written file and the new StreamWrapper for appropriate caching """