From 6ea51f3a309b0208b66ed73530a8242e0aff4e75 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Mon, 9 May 2022 15:42:54 -0400 Subject: [PATCH 1/6] Add support for MNLI + add tests --- test/datasets/test_mnli.py | 85 ++++++++++++++++++++++++++++++++ torchtext/datasets/__init__.py | 2 + torchtext/datasets/mnli.py | 88 ++++++++++++++++++++++++++++++++++ 3 files changed, 175 insertions(+) create mode 100644 test/datasets/test_mnli.py create mode 100644 torchtext/datasets/mnli.py diff --git a/test/datasets/test_mnli.py b/test/datasets/test_mnli.py new file mode 100644 index 0000000000..57dfe28086 --- /dev/null +++ b/test/datasets/test_mnli.py @@ -0,0 +1,85 @@ +import os +import zipfile +from collections import defaultdict +from unittest.mock import patch + +from parameterized import parameterized +from torchtext.datasets.mnli import MNLI + +from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode +from ..common.torchtext_test_case import TorchtextTestCase + + +def _get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + base_dir = os.path.join(root_dir, "MNLI") + temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") + os.makedirs(temp_dataset_dir, exist_ok=True) + + seed = 1 + mocked_data = defaultdict(list) + for file_name in [ + "multinli_1.0_train.txt", "multinli_1.0_dev_matched.txt", + "multinli_1.0_dev_mismatched.txt" + ]: + txt_file = os.path.join(temp_dataset_dir, file_name) + with open(txt_file, "w", encoding="utf-8") as f: + f.write( + "gold_label\tsentence1_binary_parse\tsentence2_binary_parse\tsentence1_parse\tsentence2_parse\tsentence1\tsentence2\tpromptID\tpairID\tgenre\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5" + ) + for i in range(5): + label = seed % 3 + rand_string = get_random_unicode(seed) + dataset_line = (label, rand_string, rand_string) + f.write(f"{label}\t{rand_string}\t{rand_string}\t{rand_string}\t{rand_string}\t{rand_string}\t{rand_string}\t{i}\t{i}\t{i}\t{i}\t{i}\t{i}\t{i}\t{i}\n") + + # append line to correct dataset split + mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) + seed += 1 + + compressed_dataset_path = os.path.join(base_dir, "multinli_1.0.zip") + # create zip file from dataset folder + with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: + for file_name in ("multinli_1.0_train.txt", "multinli_1.0_dev_matched.txt", + "multinli_1.0_dev_mismatched.txt"): + txt_file = os.path.join(temp_dataset_dir, file_name) + zip_file.write(txt_file, arcname=os.path.join("multinli_1.0", file_name)) + + return mocked_data + + +class TestMNLI(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + cls.samples = _get_mock_dataset(cls.root_dir) + cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() + + @parameterized.expand(["train", "dev_matched", "dev_mismatched"]) + def test_mnli(self, split): + dataset = MNLI(root=self.root_dir, split=split) + + samples = list(dataset) + expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) + + @parameterized.expand(["train", "dev_matched", "dev_mismatched"]) + def test_sst2_split_argument(self, split): + dataset1 = MNLI(root=self.root_dir, split=split) + (dataset2,) = MNLI(root=self.root_dir, split=(split,)) + + for d1, d2 in zip_equal(dataset1, dataset2): + self.assertEqual(d1, d2) diff --git a/torchtext/datasets/__init__.py b/torchtext/datasets/__init__.py index d7d33298ad..f957a46a7a 100644 --- a/torchtext/datasets/__init__.py +++ b/torchtext/datasets/__init__.py @@ -10,6 +10,7 @@ from .imdb import IMDB from .iwslt2016 import IWSLT2016 from .iwslt2017 import IWSLT2017 +from .mnli import MNLI from .multi30k import Multi30k from .penntreebank import PennTreebank from .sogounews import SogouNews @@ -34,6 +35,7 @@ "IMDB": IMDB, "IWSLT2016": IWSLT2016, "IWSLT2017": IWSLT2017, + "MNLI": MNLI, "Multi30k": Multi30k, "PennTreebank": PennTreebank, "SQuAD1": SQuAD1, diff --git a/torchtext/datasets/mnli.py b/torchtext/datasets/mnli.py new file mode 100644 index 0000000000..255856e46d --- /dev/null +++ b/torchtext/datasets/mnli.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os +import csv + +from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import ( + _create_dataset_directory, + _wrap_split_argument, +) + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, IterableWrapper + + # we import HttpReader from _download_hooks so we can swap out public URLs + # with interal URLs when the dataset is used within Facebook + from torchtext._download_hooks import HttpReader + + +URL = "https://cims.nyu.edu/~sbowman/multinli/multinli_1.0.zip" + +MD5 = "0f70aaf66293b3c088a864891db51353" + +NUM_LINES = { + "train": 392702, + "dev": 9714, + "dev_mismatched": 9832, +} + +_PATH = "multinli_1.0.zip" + +DATASET_NAME = "MNLI" + +_EXTRACTED_FILES = { + "train": "multinli_1.0_train.txt", + "dev_matched": "multinli_1.0_dev_matched.txt", + "dev_mismatched": "multinli_1.0_dev_mismatched.txt", +} + +LABEL_TO_INT = { + "entailment": 0, + "neutral": 1, + "contradiction": 2 +} + +@_create_dataset_directory(dataset_name=DATASET_NAME) +@_wrap_split_argument(("train", "dev_matched", "dev_mismatched")) +def MNLI(root, split): + """MNLI Dataset + + For additional details refer to https://cims.nyu.edu/~sbowman/multinli/ + + Number of lines per split: + - train: 392702 + - dev_matched: 9714 + - dev_mismatched: 9832 + + Args: + root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') + split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev_matched`, `dev_mismatched`) + + :returns: DataPipe that yields tuple of text and/or label (0 to 2). + :rtype: Tuple[int, str, str] + """ + # TODO Remove this after removing conditional dependency + if not is_module_available("torchdata"): + raise ModuleNotFoundError( + "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" + ) + + url_dp = IterableWrapper([URL]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), + hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + hash_type="md5", + ) + cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) + ) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + ) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(lambda x: x[0] in LABEL_TO_INT).map(lambda x: (LABEL_TO_INT[x[0]], x[5], x[6])) + return parsed_data From 662b42cf589bbe0cdf7ebc25d5de51171f7be986 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 11 May 2022 13:54:02 -0400 Subject: [PATCH 2/6] Adjust dataset size docstring --- torchtext/datasets/mnli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/datasets/mnli.py b/torchtext/datasets/mnli.py index 255856e46d..dfa6be3a36 100644 --- a/torchtext/datasets/mnli.py +++ b/torchtext/datasets/mnli.py @@ -22,7 +22,7 @@ NUM_LINES = { "train": 392702, - "dev": 9714, + "dev_matched": 9815, "dev_mismatched": 9832, } @@ -51,7 +51,7 @@ def MNLI(root, split): Number of lines per split: - train: 392702 - - dev_matched: 9714 + - dev_matched: 9815 - dev_mismatched: 9832 Args: From e1e9b9bd294c1c1293ad13257ed75fd94e305b07 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Thu, 12 May 2022 13:09:28 -0400 Subject: [PATCH 3/6] Remove lambda functions --- torchtext/datasets/mnli.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/torchtext/datasets/mnli.py b/torchtext/datasets/mnli.py index dfa6be3a36..5f6bc4e715 100644 --- a/torchtext/datasets/mnli.py +++ b/torchtext/datasets/mnli.py @@ -58,7 +58,7 @@ def MNLI(root, split): root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev_matched`, `dev_mismatched`) - :returns: DataPipe that yields tuple of text and/or label (0 to 2). + :returns: DataPipe that yields tuple of text and label (0 to 2). :rtype: Tuple[int, str, str] """ # TODO Remove this after removing conditional dependency @@ -67,22 +67,37 @@ def MNLI(root, split): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(x=None): + return os.path.join(root, os.path.basename(x)) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _filter_res(x): + return x[0] in LABEL_TO_INT + + def _modify_res(x): + return (LABEL_TO_INT[x[0]], x[5], x[6]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(URL): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) + filepath_fn=_extracted_filepath_fn ) cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(_filter_fn) ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(lambda x: x[0] in LABEL_TO_INT).map(lambda x: (LABEL_TO_INT[x[0]], x[5], x[6])) + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(_filter_res).map(_modify_res) return parsed_data From 155e46cb67f4f13c2d79a61cc4af14f91936aedd Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Tue, 17 May 2022 11:57:07 -0400 Subject: [PATCH 4/6] Add dataset documentation --- docs/source/datasets.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 33eb44b21d..2b252ce4cb 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -52,6 +52,11 @@ IMDb .. autofunction:: IMDB +MNLI +~~~~ + +.. autofunction:: MNLI + SogouNews ~~~~~~~~~ From a93279ddcc7cb3233fae7f48705c88e9a17ddb27 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 18 May 2022 15:49:33 -0400 Subject: [PATCH 5/6] Add shuffle and sharding --- torchtext/datasets/mnli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/datasets/mnli.py b/torchtext/datasets/mnli.py index 5f6bc4e715..e0b2749526 100644 --- a/torchtext/datasets/mnli.py +++ b/torchtext/datasets/mnli.py @@ -100,4 +100,4 @@ def _modify_res(x): data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(_filter_res).map(_modify_res) - return parsed_data + return parsed_data.shuffle().set_shuffle(False).sharding_filter() From 23e0281e3300685fb5e025e04393abf89d259065 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 18 May 2022 16:04:43 -0400 Subject: [PATCH 6/6] Lint --- test/datasets/test_mnli.py | 12 +++++------- torchtext/datasets/mnli.py | 21 ++++++++------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/test/datasets/test_mnli.py b/test/datasets/test_mnli.py index 57dfe28086..fed6fca241 100644 --- a/test/datasets/test_mnli.py +++ b/test/datasets/test_mnli.py @@ -20,10 +20,7 @@ def _get_mock_dataset(root_dir): seed = 1 mocked_data = defaultdict(list) - for file_name in [ - "multinli_1.0_train.txt", "multinli_1.0_dev_matched.txt", - "multinli_1.0_dev_mismatched.txt" - ]: + for file_name in ["multinli_1.0_train.txt", "multinli_1.0_dev_matched.txt", "multinli_1.0_dev_mismatched.txt"]: txt_file = os.path.join(temp_dataset_dir, file_name) with open(txt_file, "w", encoding="utf-8") as f: f.write( @@ -33,7 +30,9 @@ def _get_mock_dataset(root_dir): label = seed % 3 rand_string = get_random_unicode(seed) dataset_line = (label, rand_string, rand_string) - f.write(f"{label}\t{rand_string}\t{rand_string}\t{rand_string}\t{rand_string}\t{rand_string}\t{rand_string}\t{i}\t{i}\t{i}\t{i}\t{i}\t{i}\t{i}\t{i}\n") + f.write( + f"{label}\t{rand_string}\t{rand_string}\t{rand_string}\t{rand_string}\t{rand_string}\t{rand_string}\t{i}\t{i}\t{i}\t{i}\t{i}\t{i}\t{i}\t{i}\n" + ) # append line to correct dataset split mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) @@ -42,8 +41,7 @@ def _get_mock_dataset(root_dir): compressed_dataset_path = os.path.join(base_dir, "multinli_1.0.zip") # create zip file from dataset folder with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: - for file_name in ("multinli_1.0_train.txt", "multinli_1.0_dev_matched.txt", - "multinli_1.0_dev_mismatched.txt"): + for file_name in ("multinli_1.0_train.txt", "multinli_1.0_dev_matched.txt", "multinli_1.0_dev_mismatched.txt"): txt_file = os.path.join(temp_dataset_dir, file_name) zip_file.write(txt_file, arcname=os.path.join("multinli_1.0", file_name)) diff --git a/torchtext/datasets/mnli.py b/torchtext/datasets/mnli.py index e0b2749526..43bcfe7d9f 100644 --- a/torchtext/datasets/mnli.py +++ b/torchtext/datasets/mnli.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. -import os import csv +import os from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -36,11 +36,8 @@ "dev_mismatched": "multinli_1.0_dev_mismatched.txt", } -LABEL_TO_INT = { - "entailment": 0, - "neutral": 1, - "contradiction": 2 -} +LABEL_TO_INT = {"entailment": 0, "neutral": 1, "contradiction": 2} + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "dev_matched", "dev_mismatched")) @@ -90,14 +87,12 @@ def _modify_res(x): ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=_extracted_filepath_fn - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(_filter_fn) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(_filter_res).map(_modify_res) + parsed_data = ( + data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(_filter_res).map(_modify_res) + ) return parsed_data.shuffle().set_shuffle(False).sharding_filter()