Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit ccf021f

Browse files
authored
Adding secondary caching to datasets (#1594)
1 parent 31434ff commit ccf021f

File tree

2 files changed

+41
-16
lines changed

2 files changed

+41
-16
lines changed

torchtext/datasets/amazonreviewfull.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,26 @@ def AmazonReviewFull(root: str, split: Union[Tuple[str], str]):
4444
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")
4545

4646
url_dp = IterableWrapper([URL])
47-
48-
cache_dp = url_dp.on_disk_cache(
47+
cache_compressed_dp = url_dp.on_disk_cache(
4948
filepath_fn=lambda x: os.path.join(root, _PATH),
50-
hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
49+
hash_dict={os.path.join(root, _PATH): MD5},
50+
hash_type="md5",
51+
)
52+
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(
53+
mode="wb", same_filepath_fn=True
5154
)
52-
cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
53-
cache_dp = FileOpener(cache_dp, mode="b")
54-
55-
extracted_files = cache_dp.read_from_tar()
5655

57-
filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
56+
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
57+
filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])
58+
)
59+
cache_decompressed_dp = (
60+
FileOpener(cache_decompressed_dp, mode="b")
61+
.read_from_tar()
62+
.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
63+
)
64+
cache_decompressed_dp = cache_decompressed_dp.end_caching(
65+
mode="wb", same_filepath_fn=True
66+
)
5867

59-
return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))
68+
data_dp = FileOpener(cache_decompressed_dp, mode="b")
69+
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))

torchtext/datasets/sogounews.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,26 @@ def SogouNews(root: str, split: Union[Tuple[str], str]):
4444
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")
4545

4646
url_dp = IterableWrapper([URL])
47-
cache_dp = url_dp.on_disk_cache(
48-
filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
47+
cache_compressed_dp = url_dp.on_disk_cache(
48+
filepath_fn=lambda x: os.path.join(root, _PATH),
49+
hash_dict={os.path.join(root, _PATH): MD5},
50+
hash_type="md5",
4951
)
50-
cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
51-
cache_dp = FileOpener(cache_dp, mode="b")
52-
extracted_files = cache_dp.read_from_tar()
53-
filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
54-
return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), ' '.join(t[1:])))
52+
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(
53+
mode="wb", same_filepath_fn=True
54+
)
55+
56+
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
57+
filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])
58+
)
59+
cache_decompressed_dp = (
60+
FileOpener(cache_decompressed_dp, mode="b")
61+
.read_from_tar()
62+
.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
63+
)
64+
cache_decompressed_dp = cache_decompressed_dp.end_caching(
65+
mode="wb", same_filepath_fn=True
66+
)
67+
68+
data_dp = FileOpener(cache_decompressed_dp, mode="b")
69+
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))

0 commit comments

Comments
 (0)