diff --git a/.circleci/config.yml b/.circleci/config.yml index 21bcfb9954..bc3f09d043 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -497,7 +497,7 @@ jobs: - v1-windows-dataset-vector-{{ checksum ".cachekey" }} - v1-windows-dataset-{{ checksum ".cachekey" }} - + - run: name: Run tests # Downloading embedding vector takes long time. diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 971f4eb971..911295217b 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -497,7 +497,7 @@ jobs: - v1-windows-dataset-vector-{{ checksum ".cachekey" }} - v1-windows-dataset-{{ checksum ".cachekey" }} {% endraw %} - + - run: name: Run tests # Downloading embedding vector takes long time. diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index e9201b266b..a3ecba2770 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -13,6 +13,9 @@ conda activate ./env printf "* Installing PyTorch\n" conda install -y -c "pytorch-${UPLOAD_CHANNEL}" ${CONDA_CHANNEL_FLAGS} pytorch cpuonly +printf "Installing torchdata from source\n" +pip install git+https://github.com/pytorch/data.git + printf "* Installing torchtext\n" git submodule update --init --recursive python setup.py develop diff --git a/.circleci/unittest/windows/scripts/install.sh b/.circleci/unittest/windows/scripts/install.sh index 622ebc1cd1..1922b9a78f 100644 --- a/.circleci/unittest/windows/scripts/install.sh +++ b/.circleci/unittest/windows/scripts/install.sh @@ -18,6 +18,9 @@ conda activate ./env printf "* Installing PyTorch\n" conda install -y -c "pytorch-${UPLOAD_CHANNEL}" ${CONDA_CHANNEL_FLAGS} pytorch cpuonly +printf "Installing torchdata from source\n" +pip install git+https://github.com/pytorch/data.git + printf "* Installing torchtext\n" git submodule update --init --recursive "$root_dir/packaging/vc_env_helper.bat" python setup.py develop diff --git a/test/common/case_utils.py b/test/common/case_utils.py new file mode 100644 index 0000000000..03eec2627f --- /dev/null +++ b/test/common/case_utils.py @@ -0,0 +1,7 @@ +import unittest +from torchtext._internal.module_utils import is_module_available + + +def skipIfNoModule(module, display_name=None): + display_name = display_name or module + return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available') diff --git a/test/experimental/test_datasets.py b/test/experimental/test_datasets.py new file mode 100644 index 0000000000..2a9ff700ff --- /dev/null +++ b/test/experimental/test_datasets.py @@ -0,0 +1,34 @@ +import hashlib +import json + +from torchtext.experimental.datasets import sst2 + +from ..common.case_utils import skipIfNoModule +from ..common.torchtext_test_case import TorchtextTestCase + + +class TestDataset(TorchtextTestCase): + @skipIfNoModule("torchdata") + def test_sst2_dataset(self): + split = ("train", "dev", "test") + train_dp, dev_dp, test_dp = sst2.SST2(split=split) + + # verify hashes of first line in dataset + self.assertEqual( + hashlib.md5( + json.dumps(next(iter(train_dp)), sort_keys=True).encode("utf-8") + ).hexdigest(), + sst2._FIRST_LINE_MD5["train"], + ) + self.assertEqual( + hashlib.md5( + json.dumps(next(iter(dev_dp)), sort_keys=True).encode("utf-8") + ).hexdigest(), + sst2._FIRST_LINE_MD5["dev"], + ) + self.assertEqual( + hashlib.md5( + json.dumps(next(iter(test_dp)), sort_keys=True).encode("utf-8") + ).hexdigest(), + sst2._FIRST_LINE_MD5["test"], + ) diff --git a/torchtext/_internal/__init__.py b/torchtext/_internal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchtext/_internal/module_utils.py b/torchtext/_internal/module_utils.py new file mode 100644 index 0000000000..33ac388bc4 --- /dev/null +++ b/torchtext/_internal/module_utils.py @@ -0,0 +1,11 @@ +import importlib.util + + +def is_module_available(*modules: str) -> bool: + r"""Returns if a top-level module with :attr:`name` exists *without** + importing it. This is generally safer than try-catch block around a + `import X`. It avoids third party libraries breaking assumptions of some of + our tests, e.g., setting multiprocessing start method when imported + (see librosa/#747, torchvision/#544). + """ + return all(importlib.util.find_spec(m) is not None for m in modules) diff --git a/torchtext/experimental/datasets/__init__.py b/torchtext/experimental/datasets/__init__.py index bf2cbaa924..81bc90a801 100644 --- a/torchtext/experimental/datasets/__init__.py +++ b/torchtext/experimental/datasets/__init__.py @@ -1,3 +1,4 @@ from . import raw +from . import sst2 -__all__ = ['raw'] +__all__ = ["raw", "sst2"] diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py new file mode 100644 index 0000000000..85b892eb69 --- /dev/null +++ b/torchtext/experimental/datasets/sst2.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import os + +from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import ( + _add_docstring_header, + _create_dataset_directory, + _wrap_split_argument, +) + +logger = logging.getLogger(__name__) + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import ( + HttpReader, + IterableWrapper, + ) +else: + logger.warning( + "Package `torchdata` is required to be installed to use this dataset." + "Please refer to https://github.com/pytorch/data for instructions on " + "how to install the package." + ) + + +NUM_LINES = { + "train": 67349, + "dev": 872, + "test": 1821, +} + +MD5 = "9f81648d4199384278b86e315dac217c" +URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip" + +_EXTRACTED_FILES = { + "train": f"{os.sep}".join(["SST-2", "train.tsv"]), + "dev": f"{os.sep}".join(["SST-2", "dev.tsv"]), + "test": f"{os.sep}".join(["SST-2", "test.tsv"]), +} + +_EXTRACTED_FILES_MD5 = { + "train": "da409a0a939379ed32a470bc0f7fe99a", + "dev": "268856b487b2a31a28c0a93daaff7288", + "test": "3230e4efec76488b87877a56ae49675a", +} + +_FIRST_LINE_MD5 = { + "train": "2552b8cecd57b2e022ef23411c688fa8", + "dev": "1b0ffd6aa5f2bf0fd9840a5f6f1a9f07", + "test": "f838c81fe40bfcd7e42e9ffc4dd004f7", +} + +DATASET_NAME = "SST2" + + +@_add_docstring_header(num_lines=NUM_LINES, num_classes=2) +@_create_dataset_directory(dataset_name=DATASET_NAME) +@_wrap_split_argument(("train", "dev", "test")) +def SST2(root, split): + return SST2Dataset(root, split).get_datapipe() + + +class SST2Dataset: + """The SST2 dataset uses torchdata datapipes end-2-end. + To avoid download at every epoch, we cache the data on-disk + We do sanity check on dowloaded and extracted data + """ + + def __init__(self, root, split): + self.root = root + self.split = split + + def get_datapipe(self): + # cache data on-disk + cache_dp = IterableWrapper([URL]).on_disk_cache( + HttpReader, + op_map=lambda x: (x[0], x[1].read()), + filepath_fn=lambda x: os.path.join(self.root, os.path.basename(x)), + ) + + # extract data from zip + extracted_files = cache_dp.read_from_zip() + + # Parse CSV file and yield data samples + return ( + extracted_files.filter(lambda x: self.split in x[0]) + .parse_csv(skip_lines=1, delimiter="\t") + .map(lambda x: (x[0], x[1])) + )