From 672227e3ad6031626f9ee4f7cb74e5e580083f1d Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Thu, 5 May 2022 11:48:15 -0400 Subject: [PATCH 1/8] Add support for STS-B dataset _ unit test --- test/datasets/test_stsb.py | 75 +++++++++++++++++++++++++++++++ torchtext/datasets/__init__.py | 2 + torchtext/datasets/stsb.py | 81 ++++++++++++++++++++++++++++++++++ 3 files changed, 158 insertions(+) create mode 100644 test/datasets/test_stsb.py create mode 100644 torchtext/datasets/stsb.py diff --git a/test/datasets/test_stsb.py b/test/datasets/test_stsb.py new file mode 100644 index 0000000000..3c8c0a51af --- /dev/null +++ b/test/datasets/test_stsb.py @@ -0,0 +1,75 @@ +import os +import tarfile +from collections import defaultdict +from unittest.mock import patch + +from parameterized import parameterized +from torchtext.datasets.stsb import STSB + +from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode +from ..common.torchtext_test_case import TorchtextTestCase + + +def _get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + base_dir = os.path.join(root_dir, "STSB") + temp_dataset_dir = os.path.join(base_dir, "stsbenchmark") + os.makedirs(temp_dataset_dir, exist_ok=True) + + seed = 1 + mocked_data = defaultdict(list) + for file_name, name in zip(["sts-train.csv", "sts-dev.csv" "sts-test.csv"], ["train", "dev", "test"]): + txt_file = os.path.join(temp_dataset_dir, file_name) + with open(txt_file, "w", encoding="utf-8") as f: + for i in range(5): + label = seed % 5 + rand_string = get_random_unicode(seed) + dataset_line = (label, label, rand_string, rand_string) + # append line to correct dataset split + mocked_data[name].append(dataset_line) + f.write(f'{rand_string}\t{rand_string}\t{rand_string}\t{label}\t{label}\t{rand_string}\t{rand_string}\n') + seed += 1 + + compressed_dataset_path = os.path.join(base_dir, "Stsbenchmark.tar.gz") + # create tar file from dataset folder + with tarfile.open(compressed_dataset_path, "w:gz") as tar: + tar.add(temp_dataset_dir, arcname="stsbenchmark") + + return mocked_data + + +class TestSTSB(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + cls.samples = _get_mock_dataset(cls.root_dir) + cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() + + @parameterized.expand(["train", "dev", "test"]) + def test_stsb(self, split): + dataset = STSB(root=self.root_dir, split=split) + + samples = list(dataset) + expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) + + @parameterized.expand(["train", "dev", "test"]) + def test_stsb_split_argument(self, split): + dataset1 = STSB(root=self.root_dir, split=split) + (dataset2,) = STSB(root=self.root_dir, split=(split,)) + + for d1, d2 in zip_equal(dataset1, dataset2): + self.assertEqual(d1, d2) diff --git a/torchtext/datasets/__init__.py b/torchtext/datasets/__init__.py index d7d33298ad..81b5f6b2b7 100644 --- a/torchtext/datasets/__init__.py +++ b/torchtext/datasets/__init__.py @@ -16,6 +16,7 @@ from .squad1 import SQuAD1 from .squad2 import SQuAD2 from .sst2 import SST2 +from .stsb import STSB from .udpos import UDPOS from .wikitext103 import WikiText103 from .wikitext2 import WikiText2 @@ -40,6 +41,7 @@ "SQuAD2": SQuAD2, "SogouNews": SogouNews, "SST2": SST2, + "STSB": STSB, "UDPOS": UDPOS, "WikiText103": WikiText103, "WikiText2": WikiText2, diff --git a/torchtext/datasets/stsb.py b/torchtext/datasets/stsb.py new file mode 100644 index 0000000000..c20284e1b9 --- /dev/null +++ b/torchtext/datasets/stsb.py @@ -0,0 +1,81 @@ +import os + +from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import ( + _create_dataset_directory, + _wrap_split_argument, +) + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, IterableWrapper + + # we import HttpReader from _download_hooks so we can swap out public URLs + # with interal URLs when the dataset is used within Facebook + from torchtext._download_hooks import HttpReader + + +URL = "http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz" + +MD5 = "4eb0065aba063ef77873d3a9c8088811" + +NUM_LINES = { + "train": 5749, + "dev": 1500, + "test": 1379, +} + +_PATH = "Stsbenchmark.tar.gz" + +DATASET_NAME = "STSB" + +_EXTRACTED_FILES = { + "train": os.path.join("stsbenchmark", "sts-train.csv"), + "dev": os.path.join("stsbenchmark", "sts-dev.csv"), + "test": os.path.join("stsbenchmark", "sts-test.csv"), +} + + +@_create_dataset_directory(dataset_name=DATASET_NAME) +@_wrap_split_argument(("train", "dev", "test")) +def STSB(root, split): + """STSB Dataset + + For additional details refer to https://ixa2.si.ehu.eus/stswiki/index.php/STSbenchmark + + Number of lines per split: + - train: 5749 + - dev: 1500 + - test: 1379 + + Args: + root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') + split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`) + + :returns: DataPipe that yields tuple of (index (int), label (float), sentence1 (str), sentence2 (str)) + :rtype: Union[(int, str), (str,)] + """ + # TODO Remove this after removing conditional dependency + if not is_module_available("torchdata"): + raise ModuleNotFoundError( + "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" + ) + + url_dp = IterableWrapper([URL]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), + hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + hash_type="md5", + ) + cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) + ) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + ) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") + parsed_data = data_dp.parse_csv(delimiter='\t').filter(lambda x: len(x) >= 7).map(lambda x: (int(x[3]), float(x[4]), x[5], x[6])) + return parsed_data From e81d8ec75f737632fbc9c6ea90741b72adafe463 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Mon, 9 May 2022 15:50:48 -0400 Subject: [PATCH 2/8] Fix quote issue --- torchtext/datasets/stsb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtext/datasets/stsb.py b/torchtext/datasets/stsb.py index c20284e1b9..45cabf2ac1 100644 --- a/torchtext/datasets/stsb.py +++ b/torchtext/datasets/stsb.py @@ -1,4 +1,5 @@ import os +import csv from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -77,5 +78,5 @@ def STSB(root, split): cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - parsed_data = data_dp.parse_csv(delimiter='\t').filter(lambda x: len(x) >= 7).map(lambda x: (int(x[3]), float(x[4]), x[5], x[6])) + parsed_data = data_dp.parse_csv(delimiter='\t', quoting=csv.QUOTE_NONE).map(lambda x: (int(x[3]), float(x[4]), x[5], x[6])) return parsed_data From d5afa01e25aaf834ac236a9f15fb687dda36425f Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 11 May 2022 12:30:35 -0400 Subject: [PATCH 3/8] Modify tests + docstring --- test/datasets/test_stsb.py | 10 +++++++--- torchtext/datasets/stsb.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/datasets/test_stsb.py b/test/datasets/test_stsb.py index 3c8c0a51af..57468ce0e1 100644 --- a/test/datasets/test_stsb.py +++ b/test/datasets/test_stsb.py @@ -25,11 +25,15 @@ def _get_mock_dataset(root_dir): with open(txt_file, "w", encoding="utf-8") as f: for i in range(5): label = seed % 5 - rand_string = get_random_unicode(seed) - dataset_line = (label, label, rand_string, rand_string) + rand_string_1 = get_random_unicode(seed) + rand_string_2 = get_random_unicode(seed+1) + rand_string_3 = get_random_unicode(seed+2) + rand_string_4 = get_random_unicode(seed+3) + rand_string_5 = get_random_unicode(seed+4) + dataset_line = (i, label, rand_string_4, rand_string_5) # append line to correct dataset split mocked_data[name].append(dataset_line) - f.write(f'{rand_string}\t{rand_string}\t{rand_string}\t{label}\t{label}\t{rand_string}\t{rand_string}\n') + f.write(f'{rand_string_1}\t{rand_string_2}\t{rand_string_3}\t{i}\t{label}\t{rand_string_4}\t{rand_string_5}\n') seed += 1 compressed_dataset_path = os.path.join(base_dir, "Stsbenchmark.tar.gz") diff --git a/torchtext/datasets/stsb.py b/torchtext/datasets/stsb.py index 45cabf2ac1..50ac36e1b8 100644 --- a/torchtext/datasets/stsb.py +++ b/torchtext/datasets/stsb.py @@ -53,7 +53,7 @@ def STSB(root, split): split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`) :returns: DataPipe that yields tuple of (index (int), label (float), sentence1 (str), sentence2 (str)) - :rtype: Union[(int, str), (str,)] + :rtype: (int, float, str, str) """ # TODO Remove this after removing conditional dependency if not is_module_available("torchdata"): From 9cb201e13151df56a86f78ae25839bf60d41ce67 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Thu, 12 May 2022 11:10:47 -0400 Subject: [PATCH 4/8] Remove lambda functions --- torchtext/datasets/stsb.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/torchtext/datasets/stsb.py b/torchtext/datasets/stsb.py index 50ac36e1b8..72d6c4ba44 100644 --- a/torchtext/datasets/stsb.py +++ b/torchtext/datasets/stsb.py @@ -61,22 +61,34 @@ def STSB(root, split): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(x=_PATH): + return os.path.join(root, os.path.basename(x)) + + def _extracted_filepath_fn(_=None): + return _filepath_fn(_EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(x): + return (int(x[3]), float(x[4]), x[5], x[6]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(URL): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) + filepath_fn=_extracted_filepath_fn ) cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(_filter_fn) ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - parsed_data = data_dp.parse_csv(delimiter='\t', quoting=csv.QUOTE_NONE).map(lambda x: (int(x[3]), float(x[4]), x[5], x[6])) + parsed_data = data_dp.parse_csv(delimiter='\t', quoting=csv.QUOTE_NONE).map(_modify_res) return parsed_data From 271a044031b9ad304fda609a46302d70bd698f5e Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Mon, 16 May 2022 15:38:58 -0400 Subject: [PATCH 5/8] Lint, adjust test float & quote issues in parsing --- test/datasets/test_stsb.py | 23 +++++++++++++++++------ torchtext/datasets/stsb.py | 12 ++++-------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/test/datasets/test_stsb.py b/test/datasets/test_stsb.py index 57468ce0e1..1795f6efb1 100644 --- a/test/datasets/test_stsb.py +++ b/test/datasets/test_stsb.py @@ -1,4 +1,5 @@ import os +import random import tarfile from collections import defaultdict from unittest.mock import patch @@ -24,17 +25,27 @@ def _get_mock_dataset(root_dir): txt_file = os.path.join(temp_dataset_dir, file_name) with open(txt_file, "w", encoding="utf-8") as f: for i in range(5): - label = seed % 5 + label = random.uniform(0, 5) rand_string_1 = get_random_unicode(seed) - rand_string_2 = get_random_unicode(seed+1) - rand_string_3 = get_random_unicode(seed+2) - rand_string_4 = get_random_unicode(seed+3) - rand_string_5 = get_random_unicode(seed+4) + rand_string_2 = get_random_unicode(seed + 1) + rand_string_3 = get_random_unicode(seed + 2) + rand_string_4 = get_random_unicode(seed + 3) + rand_string_5 = get_random_unicode(seed + 4) dataset_line = (i, label, rand_string_4, rand_string_5) # append line to correct dataset split mocked_data[name].append(dataset_line) - f.write(f'{rand_string_1}\t{rand_string_2}\t{rand_string_3}\t{i}\t{label}\t{rand_string_4}\t{rand_string_5}\n') + f.write( + f"{rand_string_1}\t{rand_string_2}\t{rand_string_3}\t{i}\t{label}\t{rand_string_4}\t{rand_string_5}\n" + ) seed += 1 + # case with quotes to test arg `quoting=csv.QUOTE_NONE` + dataset_line = (i, label, rand_string_4, rand_string_5) + # append line to correct dataset split + mocked_data[name].append(dataset_line) + f.write( + f'{rand_string_1}"\t"{rand_string_2}\t{rand_string_3}\t{i}\t{label}\t{rand_string_4}\t{rand_string_5}\n' + ) + compressed_dataset_path = os.path.join(base_dir, "Stsbenchmark.tar.gz") # create tar file from dataset folder diff --git a/torchtext/datasets/stsb.py b/torchtext/datasets/stsb.py index 72d6c4ba44..679436901e 100644 --- a/torchtext/datasets/stsb.py +++ b/torchtext/datasets/stsb.py @@ -1,5 +1,5 @@ -import os import csv +import os from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -81,14 +81,10 @@ def _modify_res(x): ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=_extracted_filepath_fn - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(_filter_fn) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - parsed_data = data_dp.parse_csv(delimiter='\t', quoting=csv.QUOTE_NONE).map(_modify_res) + parsed_data = data_dp.parse_csv(delimiter="\t", quoting=csv.QUOTE_NONE).map(_modify_res) return parsed_data From f6e18c9ba523591a0fb7e1b669b6ecceabbd90e1 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Tue, 17 May 2022 11:56:22 -0400 Subject: [PATCH 6/8] Add dataset documentation --- docs/source/datasets.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 33eb44b21d..4100753230 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -62,6 +62,11 @@ SST2 .. autofunction:: SST2 +STSB +~~~~ + +.. autofunction:: STSB + YahooAnswers ~~~~~~~~~~~~ From 4350c4b84879be625d5ec576ec4573a2dcf6472d Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 18 May 2022 15:48:54 -0400 Subject: [PATCH 7/8] Add shuffle and sharding --- torchtext/datasets/stsb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/datasets/stsb.py b/torchtext/datasets/stsb.py index 679436901e..c239de4465 100644 --- a/torchtext/datasets/stsb.py +++ b/torchtext/datasets/stsb.py @@ -87,4 +87,4 @@ def _modify_res(x): data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") parsed_data = data_dp.parse_csv(delimiter="\t", quoting=csv.QUOTE_NONE).map(_modify_res) - return parsed_data + return parsed_data.shuffle().set_shuffle(False).sharding_filter() From 0dc7a653501fb5b4be547793e43a724a3c9f3a2b Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 18 May 2022 16:03:34 -0400 Subject: [PATCH 8/8] Lint --- test/datasets/test_stsb.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/datasets/test_stsb.py b/test/datasets/test_stsb.py index 1795f6efb1..7e95d4f5ca 100644 --- a/test/datasets/test_stsb.py +++ b/test/datasets/test_stsb.py @@ -46,7 +46,6 @@ def _get_mock_dataset(root_dir): f'{rand_string_1}"\t"{rand_string_2}\t{rand_string_3}\t{i}\t{label}\t{rand_string_4}\t{rand_string_5}\n' ) - compressed_dataset_path = os.path.join(base_dir, "Stsbenchmark.tar.gz") # create tar file from dataset folder with tarfile.open(compressed_dataset_path, "w:gz") as tar: