Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit d19a77e

Browse files
authored
add double caching for yelp full to speed up extracted reading. (#1529)
1 parent 437eea8 commit d19a77e

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

torchtext/datasets/yelpreviewfull.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,24 @@
3333

3434
@_add_docstring_header(num_lines=NUM_LINES, num_classes=5)
3535
@_create_dataset_directory(dataset_name=DATASET_NAME)
36-
@_wrap_split_argument(('train', 'test'))
36+
@_wrap_split_argument(("train", "test"))
3737
def YelpReviewFull(root: str, split: Union[Tuple[str], str]):
3838
if not is_module_available("torchdata"):
3939
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")
4040

4141
url_dp = IterableWrapper([URL])
4242

43-
cache_dp = url_dp.on_disk_cache(
43+
cache_compressed_dp = url_dp.on_disk_cache(
4444
filepath_fn=lambda x: os.path.join(root, _PATH),
4545
hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
4646
)
47-
cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
48-
cache_dp = FileOpener(cache_dp, mode="b")
47+
cache_compressed_dp = GDriveReader(cache_compressed_dp)
48+
cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True)
4949

50-
extracted_files = cache_dp.read_from_tar()
50+
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]))
51+
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b")
52+
cache_decompressed_dp = cache_decompressed_dp.read_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0])
53+
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
5154

52-
filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
53-
54-
return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))
55+
data_dp = FileOpener(cache_decompressed_dp, mode="b")
56+
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))

0 commit comments

Comments
 (0)