diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index bb873db7d8..91df97cfeb 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -67,6 +67,11 @@ MRPC .. autofunction:: MRPC +QNLI +~~~~ + +.. autofunction:: QNLI + QQP ~~~~ diff --git a/test/datasets/test_qnli.py b/test/datasets/test_qnli.py new file mode 100644 index 0000000000..c74cbfd4d0 --- /dev/null +++ b/test/datasets/test_qnli.py @@ -0,0 +1,81 @@ +import os +import zipfile +from collections import defaultdict +from unittest.mock import patch + +from parameterized import parameterized +from torchtext.datasets.qnli import QNLI + +from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode +from ..common.torchtext_test_case import TorchtextTestCase + + +def _get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + base_dir = os.path.join(root_dir, "QNLI") + temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") + os.makedirs(temp_dataset_dir, exist_ok=True) + + seed = 1 + mocked_data = defaultdict(list) + for file_name in ("train.tsv", "dev.tsv", "test.tsv"): + txt_file = os.path.join(temp_dataset_dir, file_name) + with open(txt_file, "w", encoding="utf-8") as f: + f.write("index\tquestion\tsentence\tlabel\n") + for i in range(5): + label = seed % 2 + rand_string_1 = get_random_unicode(seed) + rand_string_2 = get_random_unicode(seed + 1) + dataset_line = (label, rand_string_1, rand_string_2) + label_str = "entailment" if label == 1 else "not_entailment" + f.write(f"{i}\t{rand_string_1}\t{rand_string_2}\t{label_str}\n") + + # append line to correct dataset split + mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) + seed += 1 + + compressed_dataset_path = os.path.join(base_dir, "QNLIv2.zip") + # create zip file from dataset folder + with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: + for file_name in ("train.tsv", "dev.tsv", "test.tsv"): + txt_file = os.path.join(temp_dataset_dir, file_name) + zip_file.write(txt_file, arcname=os.path.join("QNLI", file_name)) + + return mocked_data + + +class TestQNLI(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + cls.samples = _get_mock_dataset(cls.root_dir) + cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() + + @parameterized.expand(["train", "test", "dev"]) + def test_qnli(self, split): + dataset = QNLI(root=self.root_dir, split=split) + + samples = list(dataset) + expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) + + @parameterized.expand(["train", "test", "dev"]) + def test_qnli_split_argument(self, split): + dataset1 = QNLI(root=self.root_dir, split=split) + (dataset2,) = QNLI(root=self.root_dir, split=(split,)) + + for d1, d2 in zip_equal(dataset1, dataset2): + self.assertEqual(d1, d2) diff --git a/torchtext/datasets/__init__.py b/torchtext/datasets/__init__.py index 5efab8816a..c25eda736b 100644 --- a/torchtext/datasets/__init__.py +++ b/torchtext/datasets/__init__.py @@ -15,6 +15,7 @@ from .mrpc import MRPC from .multi30k import Multi30k from .penntreebank import PennTreebank +from .qnli import QNLI from .qqp import QQP from .sogounews import SogouNews from .squad1 import SQuAD1 @@ -44,6 +45,7 @@ "MRPC": MRPC, "Multi30k": Multi30k, "PennTreebank": PennTreebank, + "QNLI": QNLI, "QQP": QQP, "SQuAD1": SQuAD1, "SQuAD2": SQuAD2, diff --git a/torchtext/datasets/qnli.py b/torchtext/datasets/qnli.py new file mode 100644 index 0000000000..096c746208 --- /dev/null +++ b/torchtext/datasets/qnli.py @@ -0,0 +1,98 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import csv +import os +from functools import partial + +from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import ( + _create_dataset_directory, + _wrap_split_argument, +) + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, IterableWrapper + + # we import HttpReader from _download_hooks so we can swap out public URLs + # with interal URLs when the dataset is used within Facebook + from torchtext._download_hooks import HttpReader + + +URL = "https://dl.fbaipublicfiles.com/glue/data/QNLIv2.zip" + +MD5 = "b4efd6554440de1712e9b54e14760e82" + +NUM_LINES = { + "train": 104743, + "dev": 5463, + "test": 5463, +} + +_PATH = "QNLIv2.zip" + +DATASET_NAME = "QNLI" + +_EXTRACTED_FILES = { + "train": os.path.join("QNLI", "train.tsv"), + "dev": os.path.join("QNLI", "dev.tsv"), + "test": os.path.join("QNLI", "test.tsv"), +} + + +def _filepath_fn(root, x=None): + return os.path.join(root, os.path.basename(x)) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _modify_res(x): + return (int(x[3] == "entailment"), x[1], x[2]) + + +@_create_dataset_directory(dataset_name=DATASET_NAME) +@_wrap_split_argument(("train", "dev", "test")) +def QNLI(root, split): + """QNLI Dataset + + For additional details refer to https://arxiv.org/pdf/1804.07461.pdf (from GLUE paper) + + Number of lines per split: + - train: 104743 + - dev: 5463 + - test: 5463 + + Args: + root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') + split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`) + + :returns: DataPipe that yields tuple of text and label (0 and 1). + :rtype: (int, str, str) + """ + # TODO Remove this after removing conditional dependency + if not is_module_available("torchdata"): + raise ModuleNotFoundError( + "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" + ) + + url_dp = IterableWrapper([URL]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root, URL): MD5}, + hash_type="md5", + ) + cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split)) + ) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).map(_modify_res) + return parsed_data.shuffle().set_shuffle(False).sharding_filter()