Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit c09eb57

Browse files
committed
fix caching
1 parent a5eb508 commit c09eb57

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

torchtext/datasets/amazonreviewpolarity.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
_PATH = 'amazon_review_polarity_csv.tar.gz'
2424

2525
_EXTRACTED_FILES = {
26-
'train': os.path.join(_PATH, 'amazon_review_polarity_csv', 'train.csv'),
27-
'test': os.path.join(_PATH, 'amazon_review_polarity_csv', 'test.csv'),
26+
'train': os.path.join('amazon_review_polarity_csv', 'train.csv'),
27+
'test': os.path.join('amazon_review_polarity_csv', 'test.csv'),
2828
}
2929

3030

@@ -44,9 +44,25 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
4444
filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
4545
)
4646
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)
47+
48+
def extracted_filepath_fn(x):
49+
file_path = os.path.join(root, _EXTRACTED_FILES[split])
50+
dir_path = os.path.dirname(file_path)
51+
if not os.path.exists(dir_path):
52+
os.makedirs(dir_path)
53+
return file_path
54+
4755
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
48-
filepath_fn=lambda x: os.path.join(root, os.path.dirname(_EXTRACTED_FILES[split]), os.path.basename(x)))
49-
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar()
50-
cache_compressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
51-
data_dp = FileOpener(cache_decompressed_dp.filter(lambda x: _EXTRACTED_FILES[split] in x[0]).map(lambda x: x[0]), mode='b')
56+
filepath_fn=extracted_filepath_fn)
57+
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").\
58+
read_from_tar().\
59+
filter(lambda x: _EXTRACTED_FILES[split] in x[0]).\
60+
map(lambda x: (x[0].replace('_PATH' + '/', ''), x[1]))
61+
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
62+
data_dp = FileOpener(cache_decompressed_dp, mode='b')
63+
64+
# data_dp = FileOpener(cache_compressed_dp, mode='b')
65+
# data_dp = data_dp.read_from_tar()
66+
# data_dp = data_dp.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
67+
5268
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), ' '.join(t[1:])))

0 commit comments

Comments
 (0)