diff --git a/torchtext/datasets/ag_news.py b/torchtext/datasets/ag_news.py index 17e941681c..4b3533fa08 100644 --- a/torchtext/datasets/ag_news.py +++ b/torchtext/datasets/ag_news.py @@ -52,14 +52,20 @@ def AG_NEWS(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, split + ".csv") + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL[split]]) cache_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, split + ".csv"), - hash_dict={os.path.join(root, split + ".csv"): MD5[split]}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp) cache_dp = cache_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_dp, encoding="utf-8") - return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(fn=_modify_res) diff --git a/torchtext/datasets/amazonreviewfull.py b/torchtext/datasets/amazonreviewfull.py index d490ba7463..3a57db391a 100644 --- a/torchtext/datasets/amazonreviewfull.py +++ b/torchtext/datasets/amazonreviewfull.py @@ -58,21 +58,29 @@ def AmazonReviewFull(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(fn=_modify_res) diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index 24b7278743..4760a93a19 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -55,21 +55,29 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(fn=_modify_res) diff --git a/torchtext/datasets/cc100.py b/torchtext/datasets/cc100.py index dc4d3af338..56d31d0e4f 100644 --- a/torchtext/datasets/cc100.py +++ b/torchtext/datasets/cc100.py @@ -151,18 +151,25 @@ def CC100(root: str, language_code: str = "en"): if language_code not in VALID_CODES: raise ValueError(f"Invalid language code {language_code}") + def _filepath_fn(_=None): + return os.path.join(root, os.path.basename(url)) + + def _decompressed_filepath_fn(x): + return os.path.join(root, os.path.basename(x).rstrip(".xz")) + + def _modify_res(x): + return language_code, x + url = URL % language_code url_dp = IterableWrapper([url]) - cache_compressed_dp = url_dp.on_disk_cache(filepath_fn=lambda x: os.path.join(root, os.path.basename(url))) + cache_compressed_dp = url_dp.on_disk_cache(filepath_fn=_filepath_fn) cache_compressed_dp = HttpReader(cache_compressed_dp) cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x).rstrip(".xz")) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_decompressed_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_xz() cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb") data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8").readlines(return_path=False) - return data_dp.map(lambda x: (language_code, x)) + return data_dp.map(_modify_res) diff --git a/torchtext/datasets/conll2000chunking.py b/torchtext/datasets/conll2000chunking.py index c763f7fb86..ce4a8737fc 100644 --- a/torchtext/datasets/conll2000chunking.py +++ b/torchtext/datasets/conll2000chunking.py @@ -55,20 +55,24 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, os.path.basename(URL[split])) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + url_dp = IterableWrapper([URL[split]]) # Cache and check HTTP response cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(URL[split])), - hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5[split]}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) # Cache and check the gzip extraction for relevant split - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").extract(file_type="gzip") cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/dbpedia.py b/torchtext/datasets/dbpedia.py index 6ea4b64953..1265badd4d 100644 --- a/torchtext/datasets/dbpedia.py +++ b/torchtext/datasets/dbpedia.py @@ -54,21 +54,29 @@ def DBpedia(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(fn=_modify_res) diff --git a/torchtext/datasets/enwik9.py b/torchtext/datasets/enwik9.py index 0940b77760..58b8357676 100644 --- a/torchtext/datasets/enwik9.py +++ b/torchtext/datasets/enwik9.py @@ -37,17 +37,21 @@ def EnWik9(root: str): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, os.path.splitext(_PATH)[0]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.splitext(_PATH)[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip() cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/imdb.py b/torchtext/datasets/imdb.py index 80de7134ed..2f0cc64484 100644 --- a/torchtext/datasets/imdb.py +++ b/torchtext/datasets/imdb.py @@ -47,20 +47,39 @@ def IMDB(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + + def _decompressed_filepath_fn(_=None): + return [os.path.join(root, decompressed_folder, split, label) for label in labels] + + def _filter_fn(t): + return filter_imdb_data(split, t[0]) + + def _path_map_fn(t): + return Path(t[0]).parts[-2], t[1] + + def _encode_map_fn(x): + return x[0], x[1].encode() + + def _cache_filepath_fn(x): + return os.path.join(root, decompressed_folder, split, x) + + def _modify_res(t): + return Path(t[0]).parts[-1], t[1] + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) labels = {"neg", "pos"} decompressed_folder = "aclImdb_v1" - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: [os.path.join(root, decompressed_folder, split, label) for label in labels] - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_decompressed_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") cache_decompressed_dp = cache_decompressed_dp.load_from_tar() @@ -69,17 +88,15 @@ def filter_imdb_data(key, fname): *_, split, label, file = Path(fname).parts return key == split and label in labels - cache_decompressed_dp = cache_decompressed_dp.filter(lambda t: filter_imdb_data(split, t[0])) + cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn) # eg. "aclImdb/train/neg/12416_3.txt" -> "neg" - cache_decompressed_dp = cache_decompressed_dp.map(lambda t: (Path(t[0]).parts[-2], t[1])) + cache_decompressed_dp = cache_decompressed_dp.map(_path_map_fn) cache_decompressed_dp = cache_decompressed_dp.readlines(decode=True) cache_decompressed_dp = cache_decompressed_dp.lines_to_paragraphs() # group by label in cache file - cache_decompressed_dp = cache_decompressed_dp.map(lambda x: (x[0], x[1].encode())) - cache_decompressed_dp = cache_decompressed_dp.end_caching( - mode="wb", filepath_fn=lambda x: os.path.join(root, decompressed_folder, split, x), skip_read=True - ) + cache_decompressed_dp = cache_decompressed_dp.map(_encode_map_fn) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", filepath_fn=_cache_filepath_fn, skip_read=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") # get label from cache file, eg. "aclImdb_v1/train/neg" -> "neg" - return data_dp.readlines().map(lambda t: (Path(t[0]).parts[-1], t[1])) + return data_dp.readlines().map(_modify_res) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index 9d3443a5bb..63b5e4a6db 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -124,12 +124,19 @@ # TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to # avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): - cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath) + def _return_full_filepath(_=None): + return full_filepath + + def _filter_fn(x): + return os.path.basename(uncleaned_filename) in x[0] + + def _clean_files_wrapper(x): + return _clean_files(full_filepath, x[0], x[1]) + + cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=_return_full_filepath) cache_inner_decompressed_dp = cache_inner_decompressed_dp.open_files(mode="b").load_from_tar() - cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter( - lambda x: os.path.basename(uncleaned_filename) in x[0] - ) - cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(lambda x: _clean_files(full_filepath, x[0], x[1])) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(_filter_fn) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(_clean_files_wrapper) cache_inner_decompressed_dp = cache_inner_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) return cache_inner_decompressed_dp @@ -234,10 +241,13 @@ def IWSLT2016( SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp) @@ -260,9 +270,15 @@ def IWSLT2016( + ".tgz" ) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar) + def _inner_iwslt_tar_filepath_fn(_=None): + return inner_iwslt_tar + + def _filter_fn(x): + return os.path.basename(inner_iwslt_tar) in x[0] + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_inner_iwslt_tar_filepath_fn) cache_decompressed_dp = cache_decompressed_dp.open_files(mode="b").load_from_tar() - cache_decompressed_dp = cache_decompressed_dp.filter(lambda x: os.path.basename(inner_iwslt_tar) in x[0]) + cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) cache_decompressed_dp_1, cache_decompressed_dp_2 = cache_decompressed_dp.fork(num_instances=2) diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index 0fb865d4e0..a585a5c604 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -103,12 +103,19 @@ # TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to # avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): - cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath) + def _return_full_filepath(_=None): + return full_filepath + + def _filter_fn(x): + return os.path.basename(uncleaned_filename) in x[0] + + def _clean_files_wrapper(x): + return _clean_files(full_filepath, x[0], x[1]) + + cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=_return_full_filepath) cache_inner_decompressed_dp = cache_inner_decompressed_dp.open_files(mode="b").load_from_tar() - cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter( - lambda x: os.path.basename(uncleaned_filename) in x[0] - ) - cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(lambda x: _clean_files(full_filepath, x[0], x[1])) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(_filter_fn) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(_clean_files_wrapper) cache_inner_decompressed_dp = cache_inner_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) return cache_inner_decompressed_dp @@ -188,10 +195,13 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp) @@ -207,7 +217,10 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo.tgz", ) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar) + def _inner_iwslt_tar_filepath_fn(_=None): + return inner_iwslt_tar + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_inner_iwslt_tar_filepath_fn) cache_decompressed_dp = cache_decompressed_dp.open_files(mode="b").load_from_tar() cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) cache_decompressed_dp_1, cache_decompressed_dp_2 = cache_decompressed_dp.fork(num_instances=2) diff --git a/torchtext/datasets/multi30k.py b/torchtext/datasets/multi30k.py index 26390379ba..6095316412 100644 --- a/torchtext/datasets/multi30k.py +++ b/torchtext/datasets/multi30k.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -70,34 +71,35 @@ def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str] "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, os.path.basename(URL[split])) + url_dp = IterableWrapper([URL[split]]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(URL[split])), - hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5[split]}, hash_type="sha256", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) cache_compressed_dp_1, cache_compressed_dp_2 = cache_compressed_dp.fork(num_instances=2) - src_cache_decompressed_dp = cache_compressed_dp_1.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, f"{_PREFIX[split]}.{language_pair[0]}") - ) + def _decompressed_filepath_fn(i, _): + return os.path.join(root, f"{_PREFIX[split]}.{language_pair[i]}") + + def _filter_fn(i, x): + return f"{_PREFIX[split]}.{language_pair[i]}" in x[0] + + src_cache_decompressed_dp = cache_compressed_dp_1.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, 0)) src_cache_decompressed_dp = ( - FileOpener(src_cache_decompressed_dp, mode="b") - .load_from_tar() - .filter(lambda x: f"{_PREFIX[split]}.{language_pair[0]}" in x[0]) + FileOpener(src_cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, 0)) ) src_cache_decompressed_dp = src_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - tgt_cache_decompressed_dp = cache_compressed_dp_2.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, f"{_PREFIX[split]}.{language_pair[1]}") - ) + tgt_cache_decompressed_dp = cache_compressed_dp_2.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, 1)) tgt_cache_decompressed_dp = ( - FileOpener(tgt_cache_decompressed_dp, mode="b") - .load_from_tar() - .filter(lambda x: f"{_PREFIX[split]}.{language_pair[1]}" in x[0]) + FileOpener(tgt_cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, 1)) ) tgt_cache_decompressed_dp = tgt_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/penntreebank.py b/torchtext/datasets/penntreebank.py index a9d7099792..2ba26bfc01 100644 --- a/torchtext/datasets/penntreebank.py +++ b/torchtext/datasets/penntreebank.py @@ -56,14 +56,20 @@ def PennTreebank(root, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, os.path.basename(URL[split])) + + def _modify_res(t): + return t.strip() + url_dp = IterableWrapper([URL[split]]) cache_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_dp, encoding="utf-8") # remove single leading and trailing space from the dataset - return data_dp.readlines(return_path=False).map(lambda t: t.strip()) + return data_dp.readlines(return_path=False).map(_modify_res) diff --git a/torchtext/datasets/sogounews.py b/torchtext/datasets/sogounews.py index 8f023971ec..a93bec0f1e 100644 --- a/torchtext/datasets/sogounews.py +++ b/torchtext/datasets/sogounews.py @@ -58,21 +58,29 @@ def SogouNews(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(fn=_modify_res) diff --git a/torchtext/datasets/squad1.py b/torchtext/datasets/squad1.py index 96f8f5626c..5393355002 100644 --- a/torchtext/datasets/squad1.py +++ b/torchtext/datasets/squad1.py @@ -53,11 +53,14 @@ def SQuAD1(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, os.path.basename(URL[split])) + url_dp = IterableWrapper([URL[split]]) # cache data on-disk with sanity check cache_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/squad2.py b/torchtext/datasets/squad2.py index 25247feac7..7be3d064bd 100644 --- a/torchtext/datasets/squad2.py +++ b/torchtext/datasets/squad2.py @@ -54,11 +54,14 @@ def SQuAD2(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, os.path.basename(URL[split])) + url_dp = IterableWrapper([URL[split]]) # cache data on-disk with sanity check cache_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/sst2.py b/torchtext/datasets/sst2.py index 23409f0a20..1d357ea3d6 100644 --- a/torchtext/datasets/sst2.py +++ b/torchtext/datasets/sst2.py @@ -61,26 +61,37 @@ def SST2(root, split): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, os.path.basename(URL)) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_test_res(t): + return (t[1].strip(),) + + def _modify_res(t): + return t[0].strip(), int(t[1]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") # test split for SST2 doesn't have labels if split == "test": - parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(lambda t: (t[1].strip(),)) + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_test_res) else: - parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(lambda t: (t[0].strip(), int(t[1]))) + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_res) return parsed_data diff --git a/torchtext/datasets/udpos.py b/torchtext/datasets/udpos.py index b9f0328690..2ec95bcece 100644 --- a/torchtext/datasets/udpos.py +++ b/torchtext/datasets/udpos.py @@ -49,20 +49,25 @@ def UDPOS(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, os.path.basename(URL)) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(URL)), - hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/wikitext103.py b/torchtext/datasets/wikitext103.py index 6f437fa7c6..d791572ec9 100644 --- a/torchtext/datasets/wikitext103.py +++ b/torchtext/datasets/wikitext103.py @@ -53,21 +53,27 @@ def WikiText103(root: str, split: Union[Tuple[str], str]): raise ModuleNotFoundError( "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + + def _filepath_fn(_=None): + return os.path.join(root, os.path.basename(URL)) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + url_dp = IterableWrapper([URL]) # cache data on-disk cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) # Extract zip and filter the appropriate split file - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") return data_dp.readlines(strip_newline=False, return_path=False) diff --git a/torchtext/datasets/wikitext2.py b/torchtext/datasets/wikitext2.py index edbd8faac2..ccd200e3c9 100644 --- a/torchtext/datasets/wikitext2.py +++ b/torchtext/datasets/wikitext2.py @@ -53,21 +53,27 @@ def WikiText2(root: str, split: Union[Tuple[str], str]): raise ModuleNotFoundError( "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + + def _filepath_fn(_=None): + return os.path.join(root, os.path.basename(URL)) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + url_dp = IterableWrapper([URL]) # cache data on-disk cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) # Extract zip and filter the appropriate split file - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") return data_dp.readlines(strip_newline=False, return_path=False) diff --git a/torchtext/datasets/yahooanswers.py b/torchtext/datasets/yahooanswers.py index dae6a86f2f..16dd47353b 100644 --- a/torchtext/datasets/yahooanswers.py +++ b/torchtext/datasets/yahooanswers.py @@ -54,23 +54,33 @@ def YahooAnswers(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") cache_decompressed_dp = cache_decompressed_dp.load_from_tar() - cache_decompressed_dp = cache_decompressed_dp.filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(_modify_res) diff --git a/torchtext/datasets/yelpreviewfull.py b/torchtext/datasets/yelpreviewfull.py index 54706e6222..1f56c75c0e 100644 --- a/torchtext/datasets/yelpreviewfull.py +++ b/torchtext/datasets/yelpreviewfull.py @@ -53,22 +53,32 @@ def YelpReviewFull(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp) cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") - cache_decompressed_dp = cache_decompressed_dp.load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + cache_decompressed_dp = cache_decompressed_dp.load_from_tar().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(_modify_res) diff --git a/torchtext/datasets/yelpreviewpolarity.py b/torchtext/datasets/yelpreviewpolarity.py index 9efc1084b2..40a1508c8a 100644 --- a/torchtext/datasets/yelpreviewpolarity.py +++ b/torchtext/datasets/yelpreviewpolarity.py @@ -53,23 +53,33 @@ def YelpReviewPolarity(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") cache_decompressed_dp = cache_decompressed_dp.load_from_tar() - cache_decompressed_dp = cache_decompressed_dp.filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(_modify_res)