From 34a4aeee9bb3f715cf973f7b75309cb0242af895 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Wed, 11 May 2022 14:39:18 -0400 Subject: [PATCH 1/4] Replace lambda functions with regular functions in all datasets --- torchtext/datasets/ag_news.py | 12 +++++-- torchtext/datasets/amazonreviewfull.py | 26 +++++++++----- torchtext/datasets/amazonreviewpolarity.py | 26 +++++++++----- torchtext/datasets/cc100.py | 17 ++++++--- torchtext/datasets/conll2000chunking.py | 14 +++++--- torchtext/datasets/dbpedia.py | 26 +++++++++----- torchtext/datasets/enwik9.py | 14 +++++--- torchtext/datasets/imdb.py | 41 +++++++++++++++------- torchtext/datasets/iwslt2016.py | 34 +++++++++++++----- torchtext/datasets/iwslt2017.py | 29 ++++++++++----- torchtext/datasets/multi30k.py | 30 ++++++++-------- torchtext/datasets/penntreebank.py | 12 +++++-- torchtext/datasets/sogounews.py | 24 ++++++++----- torchtext/datasets/squad1.py | 7 ++-- torchtext/datasets/squad2.py | 7 ++-- torchtext/datasets/sst2.py | 31 ++++++++++------ torchtext/datasets/udpos.py | 21 ++++++----- torchtext/datasets/wikitext103.py | 22 +++++++----- torchtext/datasets/wikitext2.py | 22 +++++++----- torchtext/datasets/yahooanswers.py | 24 +++++++++---- torchtext/datasets/yelpreviewfull.py | 24 +++++++++---- torchtext/datasets/yelpreviewpolarity.py | 24 +++++++++---- 22 files changed, 329 insertions(+), 158 deletions(-) diff --git a/torchtext/datasets/ag_news.py b/torchtext/datasets/ag_news.py index 17e941681c..63fb47ef3f 100644 --- a/torchtext/datasets/ag_news.py +++ b/torchtext/datasets/ag_news.py @@ -52,14 +52,20 @@ def AG_NEWS(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, split + ".csv") + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL[split]]) cache_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, split + ".csv"), - hash_dict={os.path.join(root, split + ".csv"): MD5[split]}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp) cache_dp = cache_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_dp, encoding="utf-8") - return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(fn=_modify_res) diff --git a/torchtext/datasets/amazonreviewfull.py b/torchtext/datasets/amazonreviewfull.py index d490ba7463..e546ae0bb1 100644 --- a/torchtext/datasets/amazonreviewfull.py +++ b/torchtext/datasets/amazonreviewfull.py @@ -58,21 +58,29 @@ def AmazonReviewFull(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(fn=_modify_res) diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index 24b7278743..02a8d76393 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -55,21 +55,29 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(fn=_modify_res) diff --git a/torchtext/datasets/cc100.py b/torchtext/datasets/cc100.py index dc4d3af338..7414cc0b5d 100644 --- a/torchtext/datasets/cc100.py +++ b/torchtext/datasets/cc100.py @@ -151,18 +151,25 @@ def CC100(root: str, language_code: str = "en"): if language_code not in VALID_CODES: raise ValueError(f"Invalid language code {language_code}") + def _filepath_fn(): + return os.path.join(root, os.path.basename(url)) + + def _decompressed_filepath_fn(x): + return os.path.join(root, os.path.basename(x).rstrip(".xz")) + + def _modify_res(x): + return language_code, x + url = URL % language_code url_dp = IterableWrapper([url]) - cache_compressed_dp = url_dp.on_disk_cache(filepath_fn=lambda x: os.path.join(root, os.path.basename(url))) + cache_compressed_dp = url_dp.on_disk_cache(filepath_fn=_filepath_fn) cache_compressed_dp = HttpReader(cache_compressed_dp) cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x).rstrip(".xz")) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_decompressed_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_xz() cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb") data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8").readlines(return_path=False) - return data_dp.map(lambda x: (language_code, x)) + return data_dp.map(_modify_res) diff --git a/torchtext/datasets/conll2000chunking.py b/torchtext/datasets/conll2000chunking.py index c763f7fb86..e0f9400039 100644 --- a/torchtext/datasets/conll2000chunking.py +++ b/torchtext/datasets/conll2000chunking.py @@ -55,20 +55,24 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, os.path.basename(URL[split])) + + def _extracted_filepath_fn(): + return os.path.join(root, _EXTRACTED_FILES[split]) + url_dp = IterableWrapper([URL[split]]) # Cache and check HTTP response cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(URL[split])), - hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5[split]}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) # Cache and check the gzip extraction for relevant split - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").extract(file_type="gzip") cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/dbpedia.py b/torchtext/datasets/dbpedia.py index 6ea4b64953..cf88628b75 100644 --- a/torchtext/datasets/dbpedia.py +++ b/torchtext/datasets/dbpedia.py @@ -54,21 +54,29 @@ def DBpedia(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(fn=_modify_res) diff --git a/torchtext/datasets/enwik9.py b/torchtext/datasets/enwik9.py index 0940b77760..b908e95ea5 100644 --- a/torchtext/datasets/enwik9.py +++ b/torchtext/datasets/enwik9.py @@ -37,17 +37,21 @@ def EnWik9(root: str): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(): + return os.path.join(root, os.path.splitext(_PATH)[0]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.splitext(_PATH)[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip() cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/imdb.py b/torchtext/datasets/imdb.py index 80de7134ed..0d37bf9826 100644 --- a/torchtext/datasets/imdb.py +++ b/torchtext/datasets/imdb.py @@ -47,20 +47,39 @@ def IMDB(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, _PATH) + + def _decompressed_filepath_fn(): + return [os.path.join(root, decompressed_folder, split, label) for label in labels] + + def _filter_fn(t): + return filter_imdb_data(split, t[0]) + + def _path_map_fn(t): + return Path(t[0]).parts[-2], t[1] + + def _encode_map_fn(x): + return x[0], x[1].encode() + + def _cache_filepath_fn(x): + return os.path.join(root, decompressed_folder, split, x) + + def _modify_res(t): + return Path(t[0]).parts[-1], t[1] + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) labels = {"neg", "pos"} decompressed_folder = "aclImdb_v1" - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: [os.path.join(root, decompressed_folder, split, label) for label in labels] - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_decompressed_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") cache_decompressed_dp = cache_decompressed_dp.load_from_tar() @@ -69,17 +88,15 @@ def filter_imdb_data(key, fname): *_, split, label, file = Path(fname).parts return key == split and label in labels - cache_decompressed_dp = cache_decompressed_dp.filter(lambda t: filter_imdb_data(split, t[0])) + cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn) # eg. "aclImdb/train/neg/12416_3.txt" -> "neg" - cache_decompressed_dp = cache_decompressed_dp.map(lambda t: (Path(t[0]).parts[-2], t[1])) + cache_decompressed_dp = cache_decompressed_dp.map(_path_map_fn) cache_decompressed_dp = cache_decompressed_dp.readlines(decode=True) cache_decompressed_dp = cache_decompressed_dp.lines_to_paragraphs() # group by label in cache file - cache_decompressed_dp = cache_decompressed_dp.map(lambda x: (x[0], x[1].encode())) - cache_decompressed_dp = cache_decompressed_dp.end_caching( - mode="wb", filepath_fn=lambda x: os.path.join(root, decompressed_folder, split, x), skip_read=True - ) + cache_decompressed_dp = cache_decompressed_dp.map(_encode_map_fn) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", filepath_fn=_cache_filepath_fn, skip_read=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") # get label from cache file, eg. "aclImdb_v1/train/neg" -> "neg" - return data_dp.readlines().map(lambda t: (Path(t[0]).parts[-1], t[1])) + return data_dp.readlines().map(_modify_res) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index 9d3443a5bb..3de0019bd8 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -124,12 +124,19 @@ # TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to # avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): - cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath) + def _return_full_filepath(): + return full_filepath + + def _filter_fn(x): + return os.path.basename(uncleaned_filename) in x[0] + + def _clean_files_wrapper(x): + return _clean_files(full_filepath, x[0], x[1]) + + cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=_return_full_filepath) cache_inner_decompressed_dp = cache_inner_decompressed_dp.open_files(mode="b").load_from_tar() - cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter( - lambda x: os.path.basename(uncleaned_filename) in x[0] - ) - cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(lambda x: _clean_files(full_filepath, x[0], x[1])) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(_filter_fn) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(_clean_files_wrapper) cache_inner_decompressed_dp = cache_inner_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) return cache_inner_decompressed_dp @@ -234,10 +241,13 @@ def IWSLT2016( SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set ) + def _filepath_fn(): + return os.path.join(root, _PATH) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp) @@ -260,9 +270,15 @@ def IWSLT2016( + ".tgz" ) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar) + def _inner_iwslt_tar_filepath_fn(): + return inner_iwslt_tar + + def _filter_fn(x): + return os.path.basename(inner_iwslt_tar) in x[0] + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_inner_iwslt_tar_filepath_fn) cache_decompressed_dp = cache_decompressed_dp.open_files(mode="b").load_from_tar() - cache_decompressed_dp = cache_decompressed_dp.filter(lambda x: os.path.basename(inner_iwslt_tar) in x[0]) + cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) cache_decompressed_dp_1, cache_decompressed_dp_2 = cache_decompressed_dp.fork(num_instances=2) diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index 0fb865d4e0..3b6ad4595d 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -103,12 +103,19 @@ # TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to # avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): - cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath) + def _return_full_filepath(): + return full_filepath + + def _filter_fn(x): + return os.path.basename(uncleaned_filename) in x[0] + + def _clean_files_wrapper(x): + return _clean_files(full_filepath, x[0], x[1]) + + cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=_return_full_filepath) cache_inner_decompressed_dp = cache_inner_decompressed_dp.open_files(mode="b").load_from_tar() - cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter( - lambda x: os.path.basename(uncleaned_filename) in x[0] - ) - cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(lambda x: _clean_files(full_filepath, x[0], x[1])) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(_filter_fn) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(_clean_files_wrapper) cache_inner_decompressed_dp = cache_inner_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) return cache_inner_decompressed_dp @@ -188,10 +195,13 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set ) + def _filepath_fn(): + return os.path.join(root, _PATH) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp) @@ -207,7 +217,10 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo.tgz", ) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar) + def _inner_iwslt_tar_filepath_fn(): + return inner_iwslt_tar + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_inner_iwslt_tar_filepath_fn) cache_decompressed_dp = cache_decompressed_dp.open_files(mode="b").load_from_tar() cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) cache_decompressed_dp_1, cache_decompressed_dp_2 = cache_decompressed_dp.fork(num_instances=2) diff --git a/torchtext/datasets/multi30k.py b/torchtext/datasets/multi30k.py index 26390379ba..b6d35978f2 100644 --- a/torchtext/datasets/multi30k.py +++ b/torchtext/datasets/multi30k.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -70,34 +71,35 @@ def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str] "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, os.path.basename(URL[split])) + url_dp = IterableWrapper([URL[split]]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(URL[split])), - hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5[split]}, hash_type="sha256", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) cache_compressed_dp_1, cache_compressed_dp_2 = cache_compressed_dp.fork(num_instances=2) - src_cache_decompressed_dp = cache_compressed_dp_1.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, f"{_PREFIX[split]}.{language_pair[0]}") - ) + def _decompressed_filepath_fn(i, _): + return os.path.join(root, f"{_PREFIX[split]}.{language_pair[i]}") + + def _filter_fn(i, x): + return f"{_PREFIX[split]}.{language_pair[i]}" in x[0] + + src_cache_decompressed_dp = cache_compressed_dp_1.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, i=0)) src_cache_decompressed_dp = ( - FileOpener(src_cache_decompressed_dp, mode="b") - .load_from_tar() - .filter(lambda x: f"{_PREFIX[split]}.{language_pair[0]}" in x[0]) + FileOpener(src_cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, i=0)) ) src_cache_decompressed_dp = src_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - tgt_cache_decompressed_dp = cache_compressed_dp_2.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, f"{_PREFIX[split]}.{language_pair[1]}") - ) + tgt_cache_decompressed_dp = cache_compressed_dp_2.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, i=1)) tgt_cache_decompressed_dp = ( - FileOpener(tgt_cache_decompressed_dp, mode="b") - .load_from_tar() - .filter(lambda x: f"{_PREFIX[split]}.{language_pair[1]}" in x[0]) + FileOpener(tgt_cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, i=1)) ) tgt_cache_decompressed_dp = tgt_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/penntreebank.py b/torchtext/datasets/penntreebank.py index a9d7099792..0086e7bd71 100644 --- a/torchtext/datasets/penntreebank.py +++ b/torchtext/datasets/penntreebank.py @@ -56,14 +56,20 @@ def PennTreebank(root, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, os.path.basename(URL[split])) + + def _modify_res(t): + return t.strip() + url_dp = IterableWrapper([URL[split]]) cache_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_dp, encoding="utf-8") # remove single leading and trailing space from the dataset - return data_dp.readlines(return_path=False).map(lambda t: t.strip()) + return data_dp.readlines(return_path=False).map(_modify_res) diff --git a/torchtext/datasets/sogounews.py b/torchtext/datasets/sogounews.py index 8f023971ec..c61cc48a13 100644 --- a/torchtext/datasets/sogounews.py +++ b/torchtext/datasets/sogounews.py @@ -58,21 +58,29 @@ def SogouNews(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), + filepath_fn=_filepath_fn, hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(fn=_modify_res) diff --git a/torchtext/datasets/squad1.py b/torchtext/datasets/squad1.py index 96f8f5626c..d491e3192a 100644 --- a/torchtext/datasets/squad1.py +++ b/torchtext/datasets/squad1.py @@ -53,11 +53,14 @@ def SQuAD1(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, os.path.basename(URL[split])) + url_dp = IterableWrapper([URL[split]]) # cache data on-disk with sanity check cache_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/squad2.py b/torchtext/datasets/squad2.py index 25247feac7..7bfc91a1b9 100644 --- a/torchtext/datasets/squad2.py +++ b/torchtext/datasets/squad2.py @@ -54,11 +54,14 @@ def SQuAD2(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, os.path.basename(URL[split])) + url_dp = IterableWrapper([URL[split]]) # cache data on-disk with sanity check cache_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/sst2.py b/torchtext/datasets/sst2.py index 23409f0a20..678136c970 100644 --- a/torchtext/datasets/sst2.py +++ b/torchtext/datasets/sst2.py @@ -61,26 +61,37 @@ def SST2(root, split): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, os.path.basename(URL)) + + def _extracted_filepath_fn(): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_test_res(t): + return (t[1].strip(),) + + def _modify_res(t): + return t[0].strip(), int(t[1]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") # test split for SST2 doesn't have labels if split == "test": - parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(lambda t: (t[1].strip(),)) + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_test_res) else: - parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(lambda t: (t[0].strip(), int(t[1]))) + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_res) return parsed_data diff --git a/torchtext/datasets/udpos.py b/torchtext/datasets/udpos.py index b9f0328690..4c8c183e41 100644 --- a/torchtext/datasets/udpos.py +++ b/torchtext/datasets/udpos.py @@ -49,20 +49,25 @@ def UDPOS(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, os.path.basename(URL)) + + def _extracted_filepath_fn(): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(URL)), - hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/wikitext103.py b/torchtext/datasets/wikitext103.py index 6f437fa7c6..2fcdd08b4c 100644 --- a/torchtext/datasets/wikitext103.py +++ b/torchtext/datasets/wikitext103.py @@ -53,21 +53,27 @@ def WikiText103(root: str, split: Union[Tuple[str], str]): raise ModuleNotFoundError( "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + + def _filepath_fn(): + return os.path.join(root, os.path.basename(URL)) + + def _extracted_filepath_fn(): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + url_dp = IterableWrapper([URL]) # cache data on-disk cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) # Extract zip and filter the appropriate split file - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") return data_dp.readlines(strip_newline=False, return_path=False) diff --git a/torchtext/datasets/wikitext2.py b/torchtext/datasets/wikitext2.py index edbd8faac2..eb39a0335b 100644 --- a/torchtext/datasets/wikitext2.py +++ b/torchtext/datasets/wikitext2.py @@ -53,21 +53,27 @@ def WikiText2(root: str, split: Union[Tuple[str], str]): raise ModuleNotFoundError( "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + + def _filepath_fn(): + return os.path.join(root, os.path.basename(URL)) + + def _extracted_filepath_fn(): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + url_dp = IterableWrapper([URL]) # cache data on-disk cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) # Extract zip and filter the appropriate split file - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - ) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") return data_dp.readlines(strip_newline=False, return_path=False) diff --git a/torchtext/datasets/yahooanswers.py b/torchtext/datasets/yahooanswers.py index dae6a86f2f..beac7a3fe5 100644 --- a/torchtext/datasets/yahooanswers.py +++ b/torchtext/datasets/yahooanswers.py @@ -54,23 +54,33 @@ def YahooAnswers(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") cache_decompressed_dp = cache_decompressed_dp.load_from_tar() - cache_decompressed_dp = cache_decompressed_dp.filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(_modify_res) diff --git a/torchtext/datasets/yelpreviewfull.py b/torchtext/datasets/yelpreviewfull.py index 54706e6222..99fbc2480e 100644 --- a/torchtext/datasets/yelpreviewfull.py +++ b/torchtext/datasets/yelpreviewfull.py @@ -53,22 +53,32 @@ def YelpReviewFull(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp) cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") - cache_decompressed_dp = cache_decompressed_dp.load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + cache_decompressed_dp = cache_decompressed_dp.load_from_tar().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(_modify_res) diff --git a/torchtext/datasets/yelpreviewpolarity.py b/torchtext/datasets/yelpreviewpolarity.py index 9efc1084b2..9d93059eae 100644 --- a/torchtext/datasets/yelpreviewpolarity.py +++ b/torchtext/datasets/yelpreviewpolarity.py @@ -53,23 +53,33 @@ def YelpReviewPolarity(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(): + return os.path.join(root, _PATH) + + def _extracted_filepath_fn(): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") cache_decompressed_dp = cache_decompressed_dp.load_from_tar() - cache_decompressed_dp = cache_decompressed_dp.filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return data_dp.parse_csv().map(lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(_modify_res) From 01085ab916ab9397649da3724f4eba0cdc15a480 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Wed, 11 May 2022 16:11:59 -0400 Subject: [PATCH 2/4] Fix errors in function call --- torchtext/datasets/ag_news.py | 2 +- torchtext/datasets/amazonreviewfull.py | 4 ++-- torchtext/datasets/amazonreviewpolarity.py | 4 ++-- torchtext/datasets/cc100.py | 2 +- torchtext/datasets/conll2000chunking.py | 4 ++-- torchtext/datasets/dbpedia.py | 4 ++-- torchtext/datasets/enwik9.py | 4 ++-- torchtext/datasets/imdb.py | 4 ++-- torchtext/datasets/iwslt2016.py | 6 +++--- torchtext/datasets/iwslt2017.py | 6 +++--- torchtext/datasets/multi30k.py | 2 +- torchtext/datasets/penntreebank.py | 2 +- torchtext/datasets/sogounews.py | 4 ++-- torchtext/datasets/squad1.py | 2 +- torchtext/datasets/squad2.py | 2 +- torchtext/datasets/sst2.py | 4 ++-- torchtext/datasets/udpos.py | 4 ++-- torchtext/datasets/wikitext103.py | 4 ++-- torchtext/datasets/wikitext2.py | 4 ++-- torchtext/datasets/yahooanswers.py | 4 ++-- torchtext/datasets/yelpreviewfull.py | 4 ++-- torchtext/datasets/yelpreviewpolarity.py | 4 ++-- 22 files changed, 40 insertions(+), 40 deletions(-) diff --git a/torchtext/datasets/ag_news.py b/torchtext/datasets/ag_news.py index 63fb47ef3f..4b3533fa08 100644 --- a/torchtext/datasets/ag_news.py +++ b/torchtext/datasets/ag_news.py @@ -52,7 +52,7 @@ def AG_NEWS(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, split + ".csv") def _modify_res(t): diff --git a/torchtext/datasets/amazonreviewfull.py b/torchtext/datasets/amazonreviewfull.py index e546ae0bb1..3a57db391a 100644 --- a/torchtext/datasets/amazonreviewfull.py +++ b/torchtext/datasets/amazonreviewfull.py @@ -58,10 +58,10 @@ def AmazonReviewFull(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, _PATH) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) def _filter_fn(x): diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index 02a8d76393..4760a93a19 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -55,10 +55,10 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, _PATH) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) def _filter_fn(x): diff --git a/torchtext/datasets/cc100.py b/torchtext/datasets/cc100.py index 7414cc0b5d..56d31d0e4f 100644 --- a/torchtext/datasets/cc100.py +++ b/torchtext/datasets/cc100.py @@ -151,7 +151,7 @@ def CC100(root: str, language_code: str = "en"): if language_code not in VALID_CODES: raise ValueError(f"Invalid language code {language_code}") - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, os.path.basename(url)) def _decompressed_filepath_fn(x): diff --git a/torchtext/datasets/conll2000chunking.py b/torchtext/datasets/conll2000chunking.py index e0f9400039..ce4a8737fc 100644 --- a/torchtext/datasets/conll2000chunking.py +++ b/torchtext/datasets/conll2000chunking.py @@ -55,10 +55,10 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, os.path.basename(URL[split])) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) url_dp = IterableWrapper([URL[split]]) diff --git a/torchtext/datasets/dbpedia.py b/torchtext/datasets/dbpedia.py index cf88628b75..1265badd4d 100644 --- a/torchtext/datasets/dbpedia.py +++ b/torchtext/datasets/dbpedia.py @@ -54,10 +54,10 @@ def DBpedia(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, _PATH) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) def _filter_fn(x): diff --git a/torchtext/datasets/enwik9.py b/torchtext/datasets/enwik9.py index b908e95ea5..58b8357676 100644 --- a/torchtext/datasets/enwik9.py +++ b/torchtext/datasets/enwik9.py @@ -37,10 +37,10 @@ def EnWik9(root: str): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, _PATH) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, os.path.splitext(_PATH)[0]) url_dp = IterableWrapper([URL]) diff --git a/torchtext/datasets/imdb.py b/torchtext/datasets/imdb.py index 0d37bf9826..2f0cc64484 100644 --- a/torchtext/datasets/imdb.py +++ b/torchtext/datasets/imdb.py @@ -47,10 +47,10 @@ def IMDB(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, _PATH) - def _decompressed_filepath_fn(): + def _decompressed_filepath_fn(_=None): return [os.path.join(root, decompressed_folder, split, label) for label in labels] def _filter_fn(t): diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index 3de0019bd8..63b5e4a6db 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -124,7 +124,7 @@ # TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to # avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): - def _return_full_filepath(): + def _return_full_filepath(_=None): return full_filepath def _filter_fn(x): @@ -241,7 +241,7 @@ def IWSLT2016( SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, _PATH) url_dp = IterableWrapper([URL]) @@ -270,7 +270,7 @@ def _filepath_fn(): + ".tgz" ) - def _inner_iwslt_tar_filepath_fn(): + def _inner_iwslt_tar_filepath_fn(_=None): return inner_iwslt_tar def _filter_fn(x): diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index 3b6ad4595d..a585a5c604 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -103,7 +103,7 @@ # TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to # avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): - def _return_full_filepath(): + def _return_full_filepath(_=None): return full_filepath def _filter_fn(x): @@ -195,7 +195,7 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, _PATH) url_dp = IterableWrapper([URL]) @@ -217,7 +217,7 @@ def _filepath_fn(): "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo.tgz", ) - def _inner_iwslt_tar_filepath_fn(): + def _inner_iwslt_tar_filepath_fn(_=None): return inner_iwslt_tar cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_inner_iwslt_tar_filepath_fn) diff --git a/torchtext/datasets/multi30k.py b/torchtext/datasets/multi30k.py index b6d35978f2..189b058d06 100644 --- a/torchtext/datasets/multi30k.py +++ b/torchtext/datasets/multi30k.py @@ -71,7 +71,7 @@ def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str] "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, os.path.basename(URL[split])) url_dp = IterableWrapper([URL[split]]) diff --git a/torchtext/datasets/penntreebank.py b/torchtext/datasets/penntreebank.py index 0086e7bd71..2ba26bfc01 100644 --- a/torchtext/datasets/penntreebank.py +++ b/torchtext/datasets/penntreebank.py @@ -56,7 +56,7 @@ def PennTreebank(root, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, os.path.basename(URL[split])) def _modify_res(t): diff --git a/torchtext/datasets/sogounews.py b/torchtext/datasets/sogounews.py index c61cc48a13..c4c9aca88d 100644 --- a/torchtext/datasets/sogounews.py +++ b/torchtext/datasets/sogounews.py @@ -58,10 +58,10 @@ def SogouNews(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, _PATH) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) def _filter_fn(x): diff --git a/torchtext/datasets/squad1.py b/torchtext/datasets/squad1.py index d491e3192a..5393355002 100644 --- a/torchtext/datasets/squad1.py +++ b/torchtext/datasets/squad1.py @@ -53,7 +53,7 @@ def SQuAD1(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, os.path.basename(URL[split])) url_dp = IterableWrapper([URL[split]]) diff --git a/torchtext/datasets/squad2.py b/torchtext/datasets/squad2.py index 7bfc91a1b9..7be3d064bd 100644 --- a/torchtext/datasets/squad2.py +++ b/torchtext/datasets/squad2.py @@ -54,7 +54,7 @@ def SQuAD2(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, os.path.basename(URL[split])) url_dp = IterableWrapper([URL[split]]) diff --git a/torchtext/datasets/sst2.py b/torchtext/datasets/sst2.py index 678136c970..1d357ea3d6 100644 --- a/torchtext/datasets/sst2.py +++ b/torchtext/datasets/sst2.py @@ -61,10 +61,10 @@ def SST2(root, split): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, os.path.basename(URL)) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) def _filter_fn(x): diff --git a/torchtext/datasets/udpos.py b/torchtext/datasets/udpos.py index 4c8c183e41..2ec95bcece 100644 --- a/torchtext/datasets/udpos.py +++ b/torchtext/datasets/udpos.py @@ -49,10 +49,10 @@ def UDPOS(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, os.path.basename(URL)) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) def _filter_fn(x): diff --git a/torchtext/datasets/wikitext103.py b/torchtext/datasets/wikitext103.py index 2fcdd08b4c..d791572ec9 100644 --- a/torchtext/datasets/wikitext103.py +++ b/torchtext/datasets/wikitext103.py @@ -54,10 +54,10 @@ def WikiText103(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, os.path.basename(URL)) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) def _filter_fn(x): diff --git a/torchtext/datasets/wikitext2.py b/torchtext/datasets/wikitext2.py index eb39a0335b..ccd200e3c9 100644 --- a/torchtext/datasets/wikitext2.py +++ b/torchtext/datasets/wikitext2.py @@ -54,10 +54,10 @@ def WikiText2(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, os.path.basename(URL)) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) def _filter_fn(x): diff --git a/torchtext/datasets/yahooanswers.py b/torchtext/datasets/yahooanswers.py index beac7a3fe5..16dd47353b 100644 --- a/torchtext/datasets/yahooanswers.py +++ b/torchtext/datasets/yahooanswers.py @@ -54,10 +54,10 @@ def YahooAnswers(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, _PATH) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) def _filter_fn(x): diff --git a/torchtext/datasets/yelpreviewfull.py b/torchtext/datasets/yelpreviewfull.py index 99fbc2480e..1f56c75c0e 100644 --- a/torchtext/datasets/yelpreviewfull.py +++ b/torchtext/datasets/yelpreviewfull.py @@ -53,10 +53,10 @@ def YelpReviewFull(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, _PATH) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) def _filter_fn(x): diff --git a/torchtext/datasets/yelpreviewpolarity.py b/torchtext/datasets/yelpreviewpolarity.py index 9d93059eae..40a1508c8a 100644 --- a/torchtext/datasets/yelpreviewpolarity.py +++ b/torchtext/datasets/yelpreviewpolarity.py @@ -53,10 +53,10 @@ def YelpReviewPolarity(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(): + def _filepath_fn(_=None): return os.path.join(root, _PATH) - def _extracted_filepath_fn(): + def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) def _filter_fn(x): From 2bffdc54ecf30c6d32e2a266bf0910dbbdc6135e Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Wed, 11 May 2022 16:19:09 -0400 Subject: [PATCH 3/4] Fix errors in function call --- torchtext/datasets/multi30k.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtext/datasets/multi30k.py b/torchtext/datasets/multi30k.py index 189b058d06..6095316412 100644 --- a/torchtext/datasets/multi30k.py +++ b/torchtext/datasets/multi30k.py @@ -91,15 +91,15 @@ def _decompressed_filepath_fn(i, _): def _filter_fn(i, x): return f"{_PREFIX[split]}.{language_pair[i]}" in x[0] - src_cache_decompressed_dp = cache_compressed_dp_1.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, i=0)) + src_cache_decompressed_dp = cache_compressed_dp_1.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, 0)) src_cache_decompressed_dp = ( - FileOpener(src_cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, i=0)) + FileOpener(src_cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, 0)) ) src_cache_decompressed_dp = src_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - tgt_cache_decompressed_dp = cache_compressed_dp_2.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, i=1)) + tgt_cache_decompressed_dp = cache_compressed_dp_2.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, 1)) tgt_cache_decompressed_dp = ( - FileOpener(tgt_cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, i=1)) + FileOpener(tgt_cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, 1)) ) tgt_cache_decompressed_dp = tgt_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) From 1ad3fd38040691a0e59375400847b5028ed4835e Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Wed, 11 May 2022 16:38:55 -0400 Subject: [PATCH 4/4] Fix errors in function call --- torchtext/datasets/sogounews.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/datasets/sogounews.py b/torchtext/datasets/sogounews.py index c4c9aca88d..a93bec0f1e 100644 --- a/torchtext/datasets/sogounews.py +++ b/torchtext/datasets/sogounews.py @@ -73,7 +73,7 @@ def _modify_res(t): url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( filepath_fn=_filepath_fn, - hash_dict={os.path.join(root, _PATH): MD5}, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)