Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions test/datasets/test_cnndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _get_mock_dataset(root_dir):
stories.append((txt_file, dataset_line))
seed += 2

# append stories to correct dataset split, must be in legixographic order of filenames per dataset
# append stories to correct dataset split, must be in lexicographic order of filenames per dataset
stories.sort(key=lambda x: x[0])
mocked_data[split] += [t[1] for t in stories]

Expand Down Expand Up @@ -70,15 +70,14 @@ def tearDownClass(cls):
cls.patcher.stop()
super().tearDownClass()

def _mock_split_list(split):
def _mock_split_list(source, split):
story_fnames = []
for source in ["cnn", "dailymail"]:
for i in range(5):
url = "_".join([source, split, str(i)])
h = hashlib.sha1()
h.update(url.encode())
filename = h.hexdigest() + ".story"
story_fnames.append(filename)
for i in range(5):
url = "_".join([source, split, str(i)])
h = hashlib.sha1()
h.update(url.encode())
filename = h.hexdigest() + ".story"
story_fnames.append(filename)

return story_fnames

Expand All @@ -92,6 +91,7 @@ def test_cnndm(self, split):
self.assertEqual(sample, expected_sample)

@parameterized.expand(["train", "val", "test"])
@patch("torchtext.datasets.cnndm._get_split_list", _mock_split_list)
def test_cnndm_split_argument(self, split):
dataset1 = CNNDM(root=self.root_dir, split=split)
(dataset2,) = CNNDM(root=self.root_dir, split=(split,))
Expand Down
63 changes: 41 additions & 22 deletions torchtext/datasets/cnndm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import os
from collections import defaultdict
from functools import partial
from typing import Union, Tuple

Expand All @@ -20,9 +21,12 @@
DATASET_NAME = "CNNDM"

URL_LIST = {
"train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt",
"val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt",
"test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt",
"cnn_train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_training_urls.txt",
"cnn_val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_validation_urls.txt",
"cnn_test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_test_urls.txt",
"dailymail_train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_training_urls.txt",
"dailymail_val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_validation_urls.txt",
"dailymail_test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_test_urls.txt",
}

STORIES_LIST = {
Expand All @@ -39,24 +43,34 @@

_EXTRACTED_FOLDERS = {
"cnn": os.path.join("cnn", "stories"),
"daily_mail": os.path.join("dailymail", "stories"),
"dailymail": os.path.join("dailymail", "stories"),
}

story_fnames = defaultdict(set)


def _filepath_fn(root: str, source: str, _=None):
return os.path.join(root, PATH_LIST[source])


# this function will be used to cache the contents of the tar file
def _extracted_filepath_fn(root: str, source: str):
return os.path.join(root, _EXTRACTED_FOLDERS[source])
# called once per tar file, therefore no duplicate processing
def _extracted_folder_fn(root: str, source: str, split: str, _=None):
global story_fnames
key = source + "_" + split
story_fnames[key] = set(_get_split_list(source, split))
filepaths = [os.path.join(root, _EXTRACTED_FOLDERS[source], story) for story in story_fnames[key]]
return filepaths


def _extracted_filepath_fn(root: str, source: str, x: str):
return os.path.join(root, _EXTRACTED_FOLDERS[source], os.path.basename(x))

def _filter_fn(story_fnames, x):
return os.path.basename(x[0]) in story_fnames

def _filter_fn(source: str, split: str, x: tuple):
return os.path.basename(x[0]) in story_fnames[source + "_" + split]

def _hash_urls(s):

def _hash_urls(s: tuple):
"""
Returns story filename as a heximal formated SHA1 hash of the input url string.
Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py
Expand All @@ -69,23 +83,32 @@ def _hash_urls(s):
return story_fname


def _get_split_list(split: str):
url_dp = IterableWrapper([URL_LIST[split]])
def _get_split_list(source: str, split: str):
url_dp = IterableWrapper([URL_LIST[source + "_" + split]])
online_dp = OnlineReader(url_dp)
return online_dp.readlines().map(fn=_hash_urls)


def _load_stories(root: str, source: str):
def _load_stories(root: str, source: str, split: str):
story_dp = IterableWrapper([STORIES_LIST[source]])
cache_compressed_dp = story_dp.on_disk_cache(
filepath_fn=partial(_filepath_fn, root, source),
hash_dict={_filepath_fn(root, source): STORIES_MD5[source]},
hash_type="md5",
)
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)
# TODO: cache the contents of the extracted tar file
cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar()
return cache_decompressed_dp

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=partial(_extracted_folder_fn, root, source, split)
)
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, source, split))
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(
mode="wb", filepath_fn=partial(_extracted_filepath_fn, root, source)
)
data_dp = FileOpener(cache_decompressed_dp, mode="b")
return data_dp


@_create_dataset_directory(dataset_name=DATASET_NAME)
Expand Down Expand Up @@ -119,11 +142,7 @@ def CNNDM(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)

cnn_dp = _load_stories(root, "cnn")
dailymail_dp = _load_stories(root, "dailymail")
cnn_dp = _load_stories(root, "cnn", split)
dailymail_dp = _load_stories(root, "dailymail", split)
data_dp = cnn_dp.concat(dailymail_dp)
# TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn
# of the on_disk_cache_dp which caches the files extracted from the tar
story_fnames = set(_get_split_list(split))
data_dp = data_dp.filter(partial(_filter_fn, story_fnames))
return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter()