diff --git a/test/datasets/test_cnndm.py b/test/datasets/test_cnndm.py index 224329a376..951cc5447e 100644 --- a/test/datasets/test_cnndm.py +++ b/test/datasets/test_cnndm.py @@ -41,7 +41,7 @@ def _get_mock_dataset(root_dir): stories.append((txt_file, dataset_line)) seed += 2 - # append stories to correct dataset split, must be in legixographic order of filenames per dataset + # append stories to correct dataset split, must be in lexicographic order of filenames per dataset stories.sort(key=lambda x: x[0]) mocked_data[split] += [t[1] for t in stories] @@ -70,15 +70,14 @@ def tearDownClass(cls): cls.patcher.stop() super().tearDownClass() - def _mock_split_list(split): + def _mock_split_list(source, split): story_fnames = [] - for source in ["cnn", "dailymail"]: - for i in range(5): - url = "_".join([source, split, str(i)]) - h = hashlib.sha1() - h.update(url.encode()) - filename = h.hexdigest() + ".story" - story_fnames.append(filename) + for i in range(5): + url = "_".join([source, split, str(i)]) + h = hashlib.sha1() + h.update(url.encode()) + filename = h.hexdigest() + ".story" + story_fnames.append(filename) return story_fnames @@ -92,6 +91,7 @@ def test_cnndm(self, split): self.assertEqual(sample, expected_sample) @parameterized.expand(["train", "val", "test"]) + @patch("torchtext.datasets.cnndm._get_split_list", _mock_split_list) def test_cnndm_split_argument(self, split): dataset1 = CNNDM(root=self.root_dir, split=split) (dataset2,) = CNNDM(root=self.root_dir, split=(split,)) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 3950557f5d..cb638e51a5 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -1,5 +1,6 @@ import hashlib import os +from collections import defaultdict from functools import partial from typing import Union, Tuple @@ -20,9 +21,12 @@ DATASET_NAME = "CNNDM" URL_LIST = { - "train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt", - "val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt", - "test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt", + "cnn_train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_training_urls.txt", + "cnn_val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_validation_urls.txt", + "cnn_test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_test_urls.txt", + "dailymail_train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_training_urls.txt", + "dailymail_val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_validation_urls.txt", + "dailymail_test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_test_urls.txt", } STORIES_LIST = { @@ -39,24 +43,34 @@ _EXTRACTED_FOLDERS = { "cnn": os.path.join("cnn", "stories"), - "daily_mail": os.path.join("dailymail", "stories"), + "dailymail": os.path.join("dailymail", "stories"), } +story_fnames = defaultdict(set) + def _filepath_fn(root: str, source: str, _=None): return os.path.join(root, PATH_LIST[source]) -# this function will be used to cache the contents of the tar file -def _extracted_filepath_fn(root: str, source: str): - return os.path.join(root, _EXTRACTED_FOLDERS[source]) +# called once per tar file, therefore no duplicate processing +def _extracted_folder_fn(root: str, source: str, split: str, _=None): + global story_fnames + key = source + "_" + split + story_fnames[key] = set(_get_split_list(source, split)) + filepaths = [os.path.join(root, _EXTRACTED_FOLDERS[source], story) for story in story_fnames[key]] + return filepaths + +def _extracted_filepath_fn(root: str, source: str, x: str): + return os.path.join(root, _EXTRACTED_FOLDERS[source], os.path.basename(x)) -def _filter_fn(story_fnames, x): - return os.path.basename(x[0]) in story_fnames +def _filter_fn(source: str, split: str, x: tuple): + return os.path.basename(x[0]) in story_fnames[source + "_" + split] -def _hash_urls(s): + +def _hash_urls(s: tuple): """ Returns story filename as a heximal formated SHA1 hash of the input url string. Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py @@ -69,13 +83,13 @@ def _hash_urls(s): return story_fname -def _get_split_list(split: str): - url_dp = IterableWrapper([URL_LIST[split]]) +def _get_split_list(source: str, split: str): + url_dp = IterableWrapper([URL_LIST[source + "_" + split]]) online_dp = OnlineReader(url_dp) return online_dp.readlines().map(fn=_hash_urls) -def _load_stories(root: str, source: str): +def _load_stories(root: str, source: str, split: str): story_dp = IterableWrapper([STORIES_LIST[source]]) cache_compressed_dp = story_dp.on_disk_cache( filepath_fn=partial(_filepath_fn, root, source), @@ -83,9 +97,18 @@ def _load_stories(root: str, source: str): hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - # TODO: cache the contents of the extracted tar file - cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar() - return cache_decompressed_dp + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=partial(_extracted_folder_fn, root, source, split) + ) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, source, split)) + ) + cache_decompressed_dp = cache_decompressed_dp.end_caching( + mode="wb", filepath_fn=partial(_extracted_filepath_fn, root, source) + ) + data_dp = FileOpener(cache_decompressed_dp, mode="b") + return data_dp @_create_dataset_directory(dataset_name=DATASET_NAME) @@ -119,11 +142,7 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data" ) - cnn_dp = _load_stories(root, "cnn") - dailymail_dp = _load_stories(root, "dailymail") + cnn_dp = _load_stories(root, "cnn", split) + dailymail_dp = _load_stories(root, "dailymail", split) data_dp = cnn_dp.concat(dailymail_dp) - # TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn - # of the on_disk_cache_dp which caches the files extracted from the tar - story_fnames = set(_get_split_list(split)) - data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter()