Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit a38f37e

Browse files
Fixed on_disk_cache issues
ghstack-source-id: b560715 Pull Request resolved: #1942
1 parent 4d88d4e commit a38f37e

File tree

4 files changed

+29
-15
lines changed

4 files changed

+29
-15
lines changed

test/torchtext_unittest/datasets/test_cnndm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ def test_cnndm(self, split):
8787
dataset = CNNDM(root=self.root_dir, split=split)
8888
samples = list(dataset)
8989
expected_samples = self.samples[split]
90-
for sample, expected_sample in zip_equal(samples, expected_samples):
91-
self.assertEqual(sample, expected_sample)
90+
self.assertEqual(expected_samples, samples)
9291

9392
@parameterized.expand(["train", "val", "test"])
9493
@patch("torchtext.datasets.cnndm._get_split_list", _mock_split_list)

torchtext/datasets/cnndm.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import hashlib
22
import os
3-
from collections import defaultdict
43
from functools import partial
5-
from typing import Union, Tuple
4+
from typing import Union, Set, Tuple
65

76
from torchtext._internal.module_utils import is_module_available
87
from torchtext.data.datasets_utils import (
@@ -52,28 +51,24 @@
5251
"test": 11490,
5352
}
5453

55-
story_fnames = defaultdict(set)
56-
5754

5855
def _filepath_fn(root: str, source: str, _=None):
5956
return os.path.join(root, PATH_LIST[source])
6057

6158

6259
# called once per tar file, therefore no duplicate processing
6360
def _extracted_folder_fn(root: str, source: str, split: str, _=None):
64-
global story_fnames
6561
key = source + "_" + split
66-
story_fnames[key] = set(_get_split_list(source, split))
67-
filepaths = [os.path.join(root, _EXTRACTED_FOLDERS[source], story) for story in story_fnames[key]]
68-
return filepaths
62+
filepath = os.path.join(root, key)
63+
return filepath
6964

7065

7166
def _extracted_filepath_fn(root: str, source: str, x: str):
7267
return os.path.join(root, _EXTRACTED_FOLDERS[source], os.path.basename(x))
7368

7469

75-
def _filter_fn(source: str, split: str, x: tuple):
76-
return os.path.basename(x[0]) in story_fnames[source + "_" + split]
70+
def _filter_fn(split_list: Set[str], x: tuple):
71+
return os.path.basename(x[0]) in split_list
7772

7873

7974
def _hash_urls(s: tuple):
@@ -96,6 +91,9 @@ def _get_split_list(source: str, split: str):
9691

9792

9893
def _load_stories(root: str, source: str, split: str):
94+
95+
split_list = set(_get_split_list(source, split))
96+
9997
story_dp = IterableWrapper([URL[source]])
10098
cache_compressed_dp = story_dp.on_disk_cache(
10199
filepath_fn=partial(_filepath_fn, root, source),
@@ -108,7 +106,7 @@ def _load_stories(root: str, source: str, split: str):
108106
filepath_fn=partial(_extracted_folder_fn, root, source, split)
109107
)
110108
cache_decompressed_dp = (
111-
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, source, split))
109+
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split_list))
112110
)
113111
cache_decompressed_dp = cache_decompressed_dp.end_caching(
114112
mode="wb", filepath_fn=partial(_extracted_filepath_fn, root, source)

torchtext/datasets/imdb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _filepath_fn(root, _=None):
3232

3333

3434
def _decompressed_filepath_fn(root, decompressed_folder, split, labels, _=None):
35-
return [os.path.join(root, decompressed_folder, split, label) for label in labels]
35+
return os.path.join(root, decompressed_folder, split)
3636

3737

3838
def _filter_fn(filter_imdb_data, split, t):

torchtext/datasets/iwslt2017.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,24 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de
240240
filepath_fn=partial(_inner_iwslt_tar_filepath_fn, inner_iwslt_tar)
241241
)
242242
cache_decompressed_dp = cache_decompressed_dp.open_files(mode="b").load_from_tar()
243-
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
243+
# As we had filenames duplicated, any trash files in archive can become tgz
244+
245+
def extracted_file_name(inner_iwslt_tar, inner_tar_name):
246+
name = os.path.basename(inner_tar_name)
247+
path = os.path.dirname(inner_iwslt_tar)
248+
return os.path.join(path, name)
249+
250+
cache_decompressed_dp = cache_decompressed_dp.end_caching(
251+
mode="wb", filepath_fn=partial(extracted_file_name, inner_iwslt_tar)
252+
)
253+
# As we corrected path, we need to leave tgz files only now and no dot files
254+
255+
def leave_only_tgz(file_name):
256+
name = os.path.basename(file_name)
257+
_, file_extension = os.path.splitext(file_name)
258+
return file_extension == ".tgz" and name[0] != "."
259+
260+
cache_decompressed_dp = cache_decompressed_dp.filter(leave_only_tgz)
244261
cache_decompressed_dp_1, cache_decompressed_dp_2 = cache_decompressed_dp.fork(num_instances=2)
245262

246263
src_filename = file_path_by_lang_and_split[src_language][split]

0 commit comments

Comments
 (0)