diff --git a/test/torchtext_unittest/datasets/test_cnndm.py b/test/torchtext_unittest/datasets/test_cnndm.py index a2de23cc3c..c3b5561e36 100644 --- a/test/torchtext_unittest/datasets/test_cnndm.py +++ b/test/torchtext_unittest/datasets/test_cnndm.py @@ -87,8 +87,7 @@ def test_cnndm(self, split): dataset = CNNDM(root=self.root_dir, split=split) samples = list(dataset) expected_samples = self.samples[split] - for sample, expected_sample in zip_equal(samples, expected_samples): - self.assertEqual(sample, expected_sample) + self.assertEqual(expected_samples, samples) @parameterized.expand(["train", "val", "test"]) @patch("torchtext.datasets.cnndm._get_split_list", _mock_split_list) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 3b5f6aa42b..db65680d17 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -1,8 +1,7 @@ import hashlib import os -from collections import defaultdict from functools import partial -from typing import Union, Tuple +from typing import Union, Set, Tuple from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -52,8 +51,6 @@ "test": 11490, } -story_fnames = defaultdict(set) - def _filepath_fn(root: str, source: str, _=None): return os.path.join(root, PATH_LIST[source]) @@ -61,19 +58,17 @@ def _filepath_fn(root: str, source: str, _=None): # called once per tar file, therefore no duplicate processing def _extracted_folder_fn(root: str, source: str, split: str, _=None): - global story_fnames key = source + "_" + split - story_fnames[key] = set(_get_split_list(source, split)) - filepaths = [os.path.join(root, _EXTRACTED_FOLDERS[source], story) for story in story_fnames[key]] - return filepaths + filepath = os.path.join(root, key) + return filepath def _extracted_filepath_fn(root: str, source: str, x: str): return os.path.join(root, _EXTRACTED_FOLDERS[source], os.path.basename(x)) -def _filter_fn(source: str, split: str, x: tuple): - return os.path.basename(x[0]) in story_fnames[source + "_" + split] +def _filter_fn(split_list: Set[str], x: tuple): + return os.path.basename(x[0]) in split_list def _hash_urls(s: tuple): @@ -96,6 +91,7 @@ def _get_split_list(source: str, split: str): def _load_stories(root: str, source: str, split: str): + split_list = set(_get_split_list(source, split)) story_dp = IterableWrapper([URL[source]]) cache_compressed_dp = story_dp.on_disk_cache( filepath_fn=partial(_filepath_fn, root, source), @@ -108,7 +104,7 @@ def _load_stories(root: str, source: str, split: str): filepath_fn=partial(_extracted_folder_fn, root, source, split) ) cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, source, split)) + FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split_list)) ) cache_decompressed_dp = cache_decompressed_dp.end_caching( mode="wb", filepath_fn=partial(_extracted_filepath_fn, root, source) diff --git a/torchtext/datasets/imdb.py b/torchtext/datasets/imdb.py index debb5c06f3..d9962342b4 100644 --- a/torchtext/datasets/imdb.py +++ b/torchtext/datasets/imdb.py @@ -32,7 +32,7 @@ def _filepath_fn(root, _=None): def _decompressed_filepath_fn(root, decompressed_folder, split, labels, _=None): - return [os.path.join(root, decompressed_folder, split, label) for label in labels] + return os.path.join(root, decompressed_folder, split) def _filter_fn(filter_imdb_data, split, t): diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index 5d51c75a62..1691e0c89c 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -240,7 +240,24 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de filepath_fn=partial(_inner_iwslt_tar_filepath_fn, inner_iwslt_tar) ) cache_decompressed_dp = cache_decompressed_dp.open_files(mode="b").load_from_tar() - cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + # As we had filenames duplicated, any trash files in archive can become tgz + + def extracted_file_name(inner_iwslt_tar, inner_tar_name): + name = os.path.basename(inner_tar_name) + path = os.path.dirname(inner_iwslt_tar) + return os.path.join(path, name) + + cache_decompressed_dp = cache_decompressed_dp.end_caching( + mode="wb", filepath_fn=partial(extracted_file_name, inner_iwslt_tar) + ) + # As we corrected path, we need to leave tgz files only now and no dot files + + def leave_only_tgz(file_name): + name = os.path.basename(file_name) + _, file_extension = os.path.splitext(file_name) + return file_extension == ".tgz" and name[0] != "." + + cache_decompressed_dp = cache_decompressed_dp.filter(leave_only_tgz) cache_decompressed_dp_1, cache_decompressed_dp_2 = cache_decompressed_dp.fork(num_instances=2) src_filename = file_path_by_lang_and_split[src_language][split]