diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 077aa27055..5a1cbf8167 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -13,7 +13,7 @@ General use cases are as follows: :: def tokenize(label, line): return line.split() - + tokens = [] for label, line in train_iter: tokens += tokenize(label, line) @@ -73,6 +73,11 @@ IMDb .. autofunction:: IMDB +SST2 +~~~~ + +.. autofunction:: SST2 + Language Modeling ^^^^^^^^^^^^^^^^^ @@ -152,4 +157,3 @@ EnWik9 ~~~~~~ .. autofunction:: EnWik9 - diff --git a/test/experimental/test_datasets.py b/test/experimental/test_datasets.py deleted file mode 100644 index 1868f6ee9d..0000000000 --- a/test/experimental/test_datasets.py +++ /dev/null @@ -1,42 +0,0 @@ -import hashlib -import json - -from torchtext.experimental.datasets import sst2 - -from ..common.assets import _ASSET_DIR -from ..common.case_utils import skipIfNoModule -from ..common.torchtext_test_case import TorchtextTestCase - - -class TestDataset(TorchtextTestCase): - @skipIfNoModule("torchdata") - def test_sst2_dataset(self): - - split = ("train", "dev", "test") - train_dataset, dev_dataset, test_dataset = sst2.SST2( - split=split, root=_ASSET_DIR, validate_hash=False - ) - - # verify datasets objects are instances of SST2Dataset - for dataset in (train_dataset, dev_dataset, test_dataset): - self.assertTrue(isinstance(dataset, sst2.SST2Dataset)) - - # verify hashes of first line in dataset - self.assertEqual( - hashlib.md5( - json.dumps(next(iter(train_dataset)), sort_keys=True).encode("utf-8") - ).hexdigest(), - sst2._FIRST_LINE_MD5["train"], - ) - self.assertEqual( - hashlib.md5( - json.dumps(next(iter(dev_dataset)), sort_keys=True).encode("utf-8") - ).hexdigest(), - sst2._FIRST_LINE_MD5["dev"], - ) - self.assertEqual( - hashlib.md5( - json.dumps(next(iter(test_dataset)), sort_keys=True).encode("utf-8") - ).hexdigest(), - sst2._FIRST_LINE_MD5["test"], - ) diff --git a/torchtext/datasets/__init__.py b/torchtext/datasets/__init__.py index 995fc96a89..5fda4a8451 100644 --- a/torchtext/datasets/__init__.py +++ b/torchtext/datasets/__init__.py @@ -1,4 +1,5 @@ import importlib + from .ag_news import AG_NEWS from .amazonreviewfull import AmazonReviewFull from .amazonreviewpolarity import AmazonReviewPolarity @@ -8,39 +9,41 @@ from .imdb import IMDB from .iwslt2016 import IWSLT2016 from .iwslt2017 import IWSLT2017 +from .multi30k import Multi30k from .penntreebank import PennTreebank from .sogounews import SogouNews from .squad1 import SQuAD1 from .squad2 import SQuAD2 +from .sst2 import SST2 from .udpos import UDPOS from .wikitext103 import WikiText103 from .wikitext2 import WikiText2 from .yahooanswers import YahooAnswers from .yelpreviewfull import YelpReviewFull from .yelpreviewpolarity import YelpReviewPolarity -from .multi30k import Multi30k DATASETS = { - 'AG_NEWS': AG_NEWS, - 'AmazonReviewFull': AmazonReviewFull, - 'AmazonReviewPolarity': AmazonReviewPolarity, - 'CoNLL2000Chunking': CoNLL2000Chunking, - 'DBpedia': DBpedia, - 'EnWik9': EnWik9, - 'IMDB': IMDB, - 'IWSLT2016': IWSLT2016, - 'IWSLT2017': IWSLT2017, - 'PennTreebank': PennTreebank, - 'SQuAD1': SQuAD1, - 'SQuAD2': SQuAD2, - 'SogouNews': SogouNews, - 'UDPOS': UDPOS, - 'WikiText103': WikiText103, - 'WikiText2': WikiText2, - 'YahooAnswers': YahooAnswers, - 'YelpReviewFull': YelpReviewFull, - 'YelpReviewPolarity': YelpReviewPolarity, - 'Multi30k': Multi30k + "AG_NEWS": AG_NEWS, + "AmazonReviewFull": AmazonReviewFull, + "AmazonReviewPolarity": AmazonReviewPolarity, + "CoNLL2000Chunking": CoNLL2000Chunking, + "DBpedia": DBpedia, + "EnWik9": EnWik9, + "IMDB": IMDB, + "IWSLT2016": IWSLT2016, + "IWSLT2017": IWSLT2017, + "Multi30k": Multi30k, + "PennTreebank": PennTreebank, + "SQuAD1": SQuAD1, + "SQuAD2": SQuAD2, + "SogouNews": SogouNews, + "SST2": SST2, + "UDPOS": UDPOS, + "WikiText103": WikiText103, + "WikiText2": WikiText2, + "YahooAnswers": YahooAnswers, + "YelpReviewFull": YelpReviewFull, + "YelpReviewPolarity": YelpReviewPolarity, } URLS = {} diff --git a/torchtext/datasets/sst2.py b/torchtext/datasets/sst2.py new file mode 100644 index 0000000000..45cae58c9c --- /dev/null +++ b/torchtext/datasets/sst2.py @@ -0,0 +1,82 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import ( + _add_docstring_header, + _create_dataset_directory, + _wrap_split_argument, +) + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import IterableWrapper, FileOpener + + # we import HttpReader from _download_hooks so we can swap out public URLs + # with interal URLs when the dataset is used within Facebook + from torchtext._download_hooks import HttpReader + + +URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip" + +MD5 = "9f81648d4199384278b86e315dac217c" + +NUM_LINES = { + "train": 67349, + "dev": 872, + "test": 1821, +} + +_PATH = "SST-2.zip" + +DATASET_NAME = "SST2" + +_EXTRACTED_FILES = { + "train": os.path.join("SST-2", "train.tsv"), + "dev": os.path.join("SST-2", "dev.tsv"), + "test": os.path.join("SST-2", "test.tsv"), +} + + +@_add_docstring_header(num_lines=NUM_LINES, num_classes=2) +@_create_dataset_directory(dataset_name=DATASET_NAME) +@_wrap_split_argument(("train", "dev", "test")) +def SST2(root, split): + # TODO Remove this after removing conditional dependency + if not is_module_available("torchdata"): + raise ModuleNotFoundError( + "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" + ) + + url_dp = IterableWrapper([URL]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), + hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + hash_type="md5", + ) + cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching( + mode="wb", same_filepath_fn=True + ) + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) + ) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b") + .read_from_zip() + .filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + ) + cache_decompressed_dp = cache_decompressed_dp.end_caching( + mode="wb", same_filepath_fn=True + ) + + data_dp = FileOpener(cache_decompressed_dp, mode="b") + # test split for SST2 doesn't have labels + if split == "test": + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map( + lambda t: (t[1].strip(),) + ) + else: + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map( + lambda t: (t[0].strip(), int(t[1])) + ) + return parsed_data diff --git a/torchtext/experimental/datasets/__init__.py b/torchtext/experimental/datasets/__init__.py index 81bc90a801..08e78af7b9 100644 --- a/torchtext/experimental/datasets/__init__.py +++ b/torchtext/experimental/datasets/__init__.py @@ -1,4 +1,3 @@ from . import raw -from . import sst2 -__all__ = ["raw", "sst2"] +__all__ = ["raw"] diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py deleted file mode 100644 index 8ba1dd3ad8..0000000000 --- a/torchtext/experimental/datasets/sst2.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -import os - -from torch.utils.data import IterDataPipe -from torchtext._internal.module_utils import is_module_available -from torchtext.data.datasets_utils import ( - _add_docstring_header, - _create_dataset_directory, - _wrap_split_argument, -) - -if is_module_available("torchdata"): - from torchdata.datapipes.iter import IterableWrapper, FileOpener - - # we import HttpReader from _download_hooks so we can swap out public URLs - # with interal URLs when the dataset is used within Facebook - from torchtext._download_hooks import HttpReader - - -NUM_LINES = { - "train": 67349, - "dev": 872, - "test": 1821, -} - -MD5 = "9f81648d4199384278b86e315dac217c" -URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip" - -_PATH = "SST-2.zip" - -_EXTRACTED_FILES = { - "train": os.path.join(_PATH, "SST-2", "train.tsv"), - "dev": os.path.join(_PATH, "SST-2", "dev.tsv"), - "test": os.path.join(_PATH, "SST-2", "test.tsv"), -} - -_EXTRACTED_FILES_MD5 = { - "train": "da409a0a939379ed32a470bc0f7fe99a", - "dev": "268856b487b2a31a28c0a93daaff7288", - "test": "3230e4efec76488b87877a56ae49675a", -} - -_FIRST_LINE_MD5 = { - "train": "2552b8cecd57b2e022ef23411c688fa8", - "dev": "1b0ffd6aa5f2bf0fd9840a5f6f1a9f07", - "test": "3e7ff69ab3fc6d026e3c96cadd8b0b53", -} - -DATASET_NAME = "SST2" - - -@_add_docstring_header(num_lines=NUM_LINES, num_classes=2) -@_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(("train", "dev", "test")) -def SST2(root, split, validate_hash=True): - return SST2Dataset(root, split, validate_hash=validate_hash) - - -class SST2Dataset(IterDataPipe): - """The SST2 dataset uses torchdata datapipes end-2-end. - To avoid download at every epoch, we cache the data on-disk - We do sanity check on dowloaded and extracted data - """ - - def __init__(self, root, split, validate_hash=True): - if not is_module_available("torchdata"): - raise ModuleNotFoundError( - "Package `torchdata` is required to be installed to use this dataset." - "Please refer to https://github.com/pytorch/data for instructions on " - "how to install the package." - ) - - self._dp = self._get_datapipe(root, split, validate_hash) - - def __iter__(self): - for data in self._dp: - yield data - - def _get_datapipe(self, root, split, validate_hash): - # Validate integrity of dataset using md5 checksum - hash_dict = {os.path.join(root, "SST-2.zip"): MD5} if validate_hash else None - hash_type = "md5" if validate_hash else None - - # cache data on-disk - cache_dp = IterableWrapper([URL]).on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict=hash_dict, - hash_type=hash_type, - ) - cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) - - # Load from cached file - cache_dp = FileOpener(cache_dp, mode="rb") - # extract data from zip - extracted_files = cache_dp.read_from_zip().filter( - lambda x: f"{split}.tsv" in x[0] - ) - - # Parse CSV file and yield data samples - if split == "test": - parsed_data = extracted_files.parse_csv(skip_lines=1, delimiter="\t").map( - lambda x: (x[1],) - ) - else: - parsed_data = extracted_files.parse_csv(skip_lines=1, delimiter="\t").map( - lambda x: (x[0], x[1]) - ) - - return parsed_data