This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 814
Add CNN-DM dataset to torchtext #1789
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
d6a985d
Add CNN-DM dataset
pmabbo13 56f7ac4
Load url_list and stories
pmabbo13 9f53689
Convert cnndm output to datapipes
Nayef211 9ed7e24
Merge branch 'feature/add-cnndm' of https://github.com/pmabbo13/text …
pmabbo13 b9c816e
run pre-commit
pmabbo13 b3deaa9
Merge branch 'feature/add-cnndm' into feature/cnndm
pmabbo13 0fcb0cd
Merge pull request #2 from Nayef211/feature/cnndm
pmabbo13 a46b28c
Load and filter tar files
pmabbo13 5588c16
pre-commit
pmabbo13 be5a244
Create datapipe to parse out article and abstract
pmabbo13 2fa9018
pre-commit run
pmabbo13 e04fe45
Turn url list helper functions into datapipes
pmabbo13 c9e7ffc
pre-commit
pmabbo13 b495353
Add dataset documentation and nit corrections
pmabbo13 0af3243
Implement testing
pmabbo13 4342a4e
Remove empty lines
pmabbo13 e9381f9
Test split argument
pmabbo13 4415bb4
pre-commit
pmabbo13 9687dae
Strip newlines to see if passes windows unittest
pmabbo13 3eeb81c
Remove print statements
pmabbo13 8b1fd91
Add more mock examples
pmabbo13 da45057
Turn parse_cnndm_split datapipe into hiddent fxn applied using map
pmabbo13 a5a6a9c
pre-commit
pmabbo13 75e2760
Merge branch 'main' into feature/add-cnndm
pmabbo13 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| import hashlib | ||
| import os | ||
| import tarfile | ||
| from collections import defaultdict | ||
| from unittest.mock import patch | ||
|
|
||
| from parameterized import parameterized | ||
| from torchtext.datasets.cnndm import CNNDM | ||
|
|
||
| from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode | ||
| from ..common.torchtext_test_case import TorchtextTestCase | ||
|
|
||
|
|
||
| def _get_mock_dataset(root_dir): | ||
| """ | ||
| root_dir: directory to the mocked dataset | ||
| """ | ||
|
|
||
| base_dir = os.path.join(root_dir, "CNNDM") | ||
| temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") | ||
| os.makedirs(temp_dataset_dir, exist_ok=True) | ||
| seed = 1 | ||
| mocked_data = defaultdict(list) | ||
|
|
||
| for source in ["cnn", "dailymail"]: | ||
| source_dir = os.path.join(temp_dataset_dir, source, "stories") | ||
| os.makedirs(source_dir, exist_ok=True) | ||
| for split in ["train", "val", "test"]: | ||
| stories = [] | ||
| for i in range(5): | ||
| url = "_".join([source, split, str(i)]) | ||
| h = hashlib.sha1() | ||
| h.update(url.encode()) | ||
| filename = h.hexdigest() + ".story" | ||
| txt_file = os.path.join(source_dir, filename) | ||
| with open(txt_file, "w", encoding=("utf-8")) as f: | ||
| article = get_random_unicode(seed) + "." | ||
| abstract = get_random_unicode(seed + 1) + "." | ||
| dataset_line = (article, abstract) | ||
| f.writelines([article, "\n@highlight\n", abstract]) | ||
| stories.append((txt_file, dataset_line)) | ||
| seed += 2 | ||
|
|
||
| # append stories to correct dataset split, must be in legixographic order of filenames per dataset | ||
| stories.sort(key=lambda x: x[0]) | ||
| mocked_data[split] += [t[1] for t in stories] | ||
|
|
||
| compressed_dataset_path = os.path.join(base_dir, f"{source}_stories.tgz") | ||
| # create zip file from dataset folder | ||
| with tarfile.open(compressed_dataset_path, "w:gz") as tar: | ||
| tar.add(os.path.join(temp_dataset_dir, source), arcname=source) | ||
|
|
||
| return mocked_data | ||
|
|
||
|
|
||
| class TestCNNDM(TempDirMixin, TorchtextTestCase): | ||
| root_dir = None | ||
| samples = [] | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| super().setUpClass() | ||
| cls.root_dir = cls.get_base_temp_dir() | ||
| cls.samples = _get_mock_dataset(os.path.join(cls.root_dir, "datasets")) | ||
| cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) | ||
| cls.patcher.start() | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| cls.patcher.stop() | ||
| super().tearDownClass() | ||
|
|
||
| def _mock_split_list(split): | ||
| story_fnames = [] | ||
| for source in ["cnn", "dailymail"]: | ||
| for i in range(5): | ||
| url = "_".join([source, split, str(i)]) | ||
| h = hashlib.sha1() | ||
| h.update(url.encode()) | ||
| filename = h.hexdigest() + ".story" | ||
| story_fnames.append(filename) | ||
|
|
||
| return story_fnames | ||
|
|
||
| @parameterized.expand(["train", "val", "test"]) | ||
| @patch("torchtext.datasets.cnndm._get_split_list", _mock_split_list) | ||
| def test_cnndm(self, split): | ||
| dataset = CNNDM(root=self.root_dir, split=split) | ||
| samples = list(dataset) | ||
| expected_samples = self.samples[split] | ||
| for sample, expected_sample in zip_equal(samples, expected_samples): | ||
| self.assertEqual(sample, expected_sample) | ||
|
|
||
| @parameterized.expand(["train", "val", "test"]) | ||
| def test_cnndm_split_argument(self, split): | ||
| dataset1 = CNNDM(root=self.root_dir, split=split) | ||
| (dataset2,) = CNNDM(root=self.root_dir, split=(split,)) | ||
|
|
||
| for d1, d2 in zip_equal(dataset1, dataset2): | ||
| self.assertEqual(d1, d2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| import hashlib | ||
| import os | ||
| from functools import partial | ||
| from typing import Union, Tuple | ||
|
|
||
| from torchtext._internal.module_utils import is_module_available | ||
| from torchtext.data.datasets_utils import ( | ||
| _wrap_split_argument, | ||
| _create_dataset_directory, | ||
| ) | ||
|
|
||
| if is_module_available("torchdata"): | ||
| from torchdata.datapipes.iter import ( | ||
| FileOpener, | ||
| IterableWrapper, | ||
| OnlineReader, | ||
| GDriveReader, | ||
| ) | ||
|
|
||
| DATASET_NAME = "CNNDM" | ||
|
|
||
| URL_LIST = { | ||
| "train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt", | ||
| "val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt", | ||
| "test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt", | ||
| } | ||
|
|
||
| STORIES_LIST = { | ||
| "cnn": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ", | ||
| "dailymail": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs", | ||
| } | ||
|
|
||
| PATH_LIST = { | ||
| "cnn": "cnn_stories.tgz", | ||
| "dailymail": "dailymail_stories.tgz", | ||
| } | ||
|
|
||
| STORIES_MD5 = {"cnn": "85ac23a1926a831e8f46a6b8eaf57263", "dailymail": "f9c5f565e8abe86c38bfa4ae8f96fd72"} | ||
|
|
||
| _EXTRACTED_FOLDERS = { | ||
| "cnn": os.path.join("cnn", "stories"), | ||
| "daily_mail": os.path.join("dailymail", "stories"), | ||
| } | ||
|
|
||
|
|
||
| def _filepath_fn(root: str, source: str, _=None): | ||
| return os.path.join(root, PATH_LIST[source]) | ||
|
|
||
|
|
||
| # this function will be used to cache the contents of the tar file | ||
| def _extracted_filepath_fn(root: str, source: str): | ||
| return os.path.join(root, _EXTRACTED_FOLDERS[source]) | ||
|
|
||
|
|
||
| def _filter_fn(story_fnames, x): | ||
| return os.path.basename(x[0]) in story_fnames | ||
|
|
||
|
|
||
| def _hash_urls(s): | ||
| """ | ||
| Returns story filename as a heximal formated SHA1 hash of the input url string. | ||
| Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py | ||
| """ | ||
| url = s[1] | ||
| h = hashlib.sha1() | ||
| h.update(url) | ||
| url_hash = h.hexdigest() | ||
| story_fname = url_hash + ".story" | ||
| return story_fname | ||
|
|
||
pmabbo13 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _get_split_list(split: str): | ||
| url_dp = IterableWrapper([URL_LIST[split]]) | ||
| online_dp = OnlineReader(url_dp) | ||
| return online_dp.readlines().map(fn=_hash_urls) | ||
|
|
||
|
|
||
| def _load_stories(root: str, source: str): | ||
| story_dp = IterableWrapper([STORIES_LIST[source]]) | ||
| cache_compressed_dp = story_dp.on_disk_cache( | ||
| filepath_fn=partial(_filepath_fn, root, source), | ||
| hash_dict={_filepath_fn(root, source): STORIES_MD5[source]}, | ||
| hash_type="md5", | ||
| ) | ||
| cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) | ||
| # TODO: cache the contents of the extracted tar file | ||
| cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar() | ||
| return cache_decompressed_dp | ||
|
|
||
|
|
||
| @_create_dataset_directory(dataset_name=DATASET_NAME) | ||
| @_wrap_split_argument(("train", "val", "test")) | ||
| def CNNDM(root: str, split: Union[Tuple[str], str]): | ||
| """CNNDM Dataset | ||
|
|
||
| .. warning:: | ||
|
|
||
| Using datapipes is still currently subject to a few caveats. If you wish | ||
| to use this dataset with shuffling, multi-processing, or distributed | ||
| learning, please see :ref:`this note <datapipes_warnings>` for further | ||
| instructions. | ||
|
|
||
| For additional details refer to https://arxiv.org/pdf/1704.04368.pdf | ||
|
|
||
| Number of lines per split: | ||
| - train: 287,227 | ||
| - val: 13,368 | ||
| - test: 11,490 | ||
|
|
||
| Args: | ||
| root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') | ||
| split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `val`, `test`) | ||
|
|
||
| :returns: DataPipe that yields a tuple of texts containing an article and its abstract (i.e. (article, abstract)) | ||
| :rtype: (str, str) | ||
| """ | ||
| if not is_module_available("torchdata"): | ||
| raise ModuleNotFoundError( | ||
| "Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data" | ||
| ) | ||
|
|
||
| cnn_dp = _load_stories(root, "cnn") | ||
| dailymail_dp = _load_stories(root, "dailymail") | ||
| data_dp = cnn_dp.concat(dailymail_dp) | ||
| # TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn | ||
| # of the on_disk_cache_dp which caches the files extracted from the tar | ||
| story_fnames = set(_get_split_list(split)) | ||
| data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) | ||
| return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.