Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions test/datasets/test_cnndm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import hashlib
import os
import tarfile
from collections import defaultdict
from unittest.mock import patch

from parameterized import parameterized
from torchtext.datasets.cnndm import CNNDM

from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode
from ..common.torchtext_test_case import TorchtextTestCase


def _get_mock_dataset(root_dir):
"""
root_dir: directory to the mocked dataset
"""

base_dir = os.path.join(root_dir, "CNNDM")
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir")
os.makedirs(temp_dataset_dir, exist_ok=True)
seed = 1
mocked_data = defaultdict(list)

for source in ["cnn", "dailymail"]:
source_dir = os.path.join(temp_dataset_dir, source, "stories")
os.makedirs(source_dir, exist_ok=True)
for split in ["train", "val", "test"]:
stories = []
for i in range(5):
url = "_".join([source, split, str(i)])
h = hashlib.sha1()
h.update(url.encode())
filename = h.hexdigest() + ".story"
txt_file = os.path.join(source_dir, filename)
with open(txt_file, "w", encoding=("utf-8")) as f:
article = get_random_unicode(seed) + "."
abstract = get_random_unicode(seed + 1) + "."
dataset_line = (article, abstract)
f.writelines([article, "\n@highlight\n", abstract])
stories.append((txt_file, dataset_line))
seed += 2

# append stories to correct dataset split, must be in legixographic order of filenames per dataset
stories.sort(key=lambda x: x[0])
mocked_data[split] += [t[1] for t in stories]

compressed_dataset_path = os.path.join(base_dir, f"{source}_stories.tgz")
# create zip file from dataset folder
with tarfile.open(compressed_dataset_path, "w:gz") as tar:
tar.add(os.path.join(temp_dataset_dir, source), arcname=source)

return mocked_data


class TestCNNDM(TempDirMixin, TorchtextTestCase):
root_dir = None
samples = []

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.root_dir = cls.get_base_temp_dir()
cls.samples = _get_mock_dataset(os.path.join(cls.root_dir, "datasets"))
cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True)
cls.patcher.start()

@classmethod
def tearDownClass(cls):
cls.patcher.stop()
super().tearDownClass()

def _mock_split_list(split):
story_fnames = []
for source in ["cnn", "dailymail"]:
for i in range(5):
url = "_".join([source, split, str(i)])
h = hashlib.sha1()
h.update(url.encode())
filename = h.hexdigest() + ".story"
story_fnames.append(filename)

return story_fnames

@parameterized.expand(["train", "val", "test"])
@patch("torchtext.datasets.cnndm._get_split_list", _mock_split_list)
def test_cnndm(self, split):
dataset = CNNDM(root=self.root_dir, split=split)
samples = list(dataset)
expected_samples = self.samples[split]
for sample, expected_sample in zip_equal(samples, expected_samples):
self.assertEqual(sample, expected_sample)

@parameterized.expand(["train", "val", "test"])
def test_cnndm_split_argument(self, split):
dataset1 = CNNDM(root=self.root_dir, split=split)
(dataset2,) = CNNDM(root=self.root_dir, split=(split,))

for d1, d2 in zip_equal(dataset1, dataset2):
self.assertEqual(d1, d2)
52 changes: 52 additions & 0 deletions torchtext/data/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,55 @@ def __iter__(self):
columns[i].append(column)
if len(columns) > 0:
yield columns


@functional_datapipe("parse_cnndm_data")
class _ParseCNNDMData(IterDataPipe):
"""Iterable DataPipe to parse the article and abstract from a CNNDM data stream.
Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py"""

dm_single_close_quote = "\u2019" # unicode
dm_double_close_quote = "\u201d"
# acceptable ways to end a sentence
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")", "\n"]

def __init__(self, source_datapipe) -> None:
self.source_datapipe = source_datapipe

def _fix_missing_period(self, line):
"""Adds a period to a line that is missing a period"""
if "@highlight" in line:
return line
if line == "":
return line
if line[-1] in self.END_TOKENS:
return line
return line + " ."

def __iter__(self):
for _, stream in self.source_datapipe:
lines = stream.readlines()
lines = [line.decode("utf-8").strip() for line in lines]

# put periods on the ends of lines that are missing them
# this is a problem in the dataset because many image captions don't end in periods
# consequently they end up in the body of the article as run-on sentences
lines = [self._fix_missing_period(line) for line in lines]

# Separate out article and abstract sentences
article_lines = []
highlights = []
next_is_highlight = False
for idx, line in enumerate(lines):
if line == "":
continue # empty line
elif line.startswith("@highlight"):
next_is_highlight = True
elif next_is_highlight:
highlights.append(line)
else:
article_lines.append(line)

article = " ".join(article_lines)
abstract = " ".join(highlights)
yield article, abstract
129 changes: 129 additions & 0 deletions torchtext/datasets/cnndm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import hashlib
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
_create_dataset_directory,
)

if is_module_available("torchdata"):
from torchdata.datapipes.iter import (
FileOpener,
IterableWrapper,
OnlineReader,
GDriveReader,
)

DATASET_NAME = "CNNDM"

URL_LIST = {
"train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt",
"val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt",
"test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt",
}

STORIES_LIST = {
"cnn": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ",
"dailymail": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs",
}

PATH_LIST = {
"cnn": "cnn_stories.tgz",
"dailymail": "dailymail_stories.tgz",
}

STORIES_MD5 = {"cnn": "85ac23a1926a831e8f46a6b8eaf57263", "dailymail": "f9c5f565e8abe86c38bfa4ae8f96fd72"}

_EXTRACTED_FOLDERS = {
"cnn": os.path.join("cnn", "stories"),
"daily_mail": os.path.join("dailymail", "stories"),
}


def _filepath_fn(root: str, source: str, _=None):
return os.path.join(root, PATH_LIST[source])


# this function will be used to cache the contents of the tar file
def _extracted_filepath_fn(root: str, source: str):
return os.path.join(root, _EXTRACTED_FOLDERS[source])


def _filter_fn(story_fnames, x):
return os.path.basename(x[0]) in story_fnames


def _hash_urls(s):
"""
Returns story filename as a heximal formated SHA1 hash of the input url string.
Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py
"""
url = s[1]
h = hashlib.sha1()
h.update(url)
url_hash = h.hexdigest()
story_fname = url_hash + ".story"
return story_fname


def _get_split_list(split: str):
url_dp = IterableWrapper([URL_LIST[split]])
online_dp = OnlineReader(url_dp)
return online_dp.readlines().map(fn=_hash_urls)


def _load_stories(root: str, source: str):
story_dp = IterableWrapper([STORIES_LIST[source]])
cache_compressed_dp = story_dp.on_disk_cache(
filepath_fn=partial(_filepath_fn, root, source),
hash_dict={_filepath_fn(root, source): STORIES_MD5[source]},
hash_type="md5",
)
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)
# TODO: cache the contents of the extracted tar file
cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar()
return cache_decompressed_dp


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "val", "test"))
def CNNDM(root: str, split: Union[Tuple[str], str]):
"""CNNDM Dataset

.. warning::

Using datapipes is still currently subject to a few caveats. If you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://arxiv.org/pdf/1704.04368.pdf

Number of lines per split:
- train: 287,227
- val: 13,368
- test: 11,490

Args:
root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache')
split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `val`, `test`)

:returns: DataPipe that yields a tuple of texts containing an article and its abstract (i.e. (article, abstract))
:rtype: (str, str)
"""
if not is_module_available("torchdata"):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data"
)

cnn_dp = _load_stories(root, "cnn")
dailymail_dp = _load_stories(root, "dailymail")
data_dp = cnn_dp.concat(dailymail_dp)
# TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn
# of the on_disk_cache_dp which caches the files extracted from the tar
story_fnames = set(_get_split_list(split))
data_dp = data_dp.filter(partial(_filter_fn, story_fnames))
return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter()