From ecf8012d5b69816d682ee8653cd7f3f3d8fb7bad Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 4 May 2022 23:02:17 -0400 Subject: [PATCH 1/6] Add QQP dataset + unit test --- test/datasets/test_qqp.py | 60 ++++++++++++++++++++++++++++++++++ torchtext/datasets/__init__.py | 2 ++ torchtext/datasets/qqp.py | 48 +++++++++++++++++++++++++++ 3 files changed, 110 insertions(+) create mode 100644 test/datasets/test_qqp.py create mode 100644 torchtext/datasets/qqp.py diff --git a/test/datasets/test_qqp.py b/test/datasets/test_qqp.py new file mode 100644 index 0000000000..9a99254bd6 --- /dev/null +++ b/test/datasets/test_qqp.py @@ -0,0 +1,60 @@ +import os +from unittest.mock import patch + +from torchtext.datasets.qqp import QQP + +from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode +from ..common.torchtext_test_case import TorchtextTestCase + + +def _get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + base_dir = os.path.join(root_dir, "QQP") + os.makedirs(base_dir, exist_ok=True) + + seed = 1 + file_name = "quora_duplicate_questions.tsv" + txt_file = os.path.join(base_dir, file_name) + mocked_data = [] + print(txt_file) + with open(txt_file, "w", encoding="utf-8") as f: + f.write(f"id\tqid1\tqid2\tquestion1\tquestion2\tis_duplicate\n") + for i in range(5): + label = seed % 2 + rand_string = get_random_unicode(seed) + dataset_line = (i, label, rand_string, rand_string) + # append line to correct dataset split + mocked_data.append(dataset_line) + f.write(f'{i}\t{i}\t{i}\t{rand_string}\t{rand_string}\t{label}\n') + seed += 1 + + return mocked_data + + +class TestQQP(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + print(cls.root_dir) + cls.samples = _get_mock_dataset(cls.root_dir) + cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() + + def test_qqp(self): + dataset = QQP(root=self.root_dir) + + samples = list(dataset) + expected_samples = self.samples + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) \ No newline at end of file diff --git a/torchtext/datasets/__init__.py b/torchtext/datasets/__init__.py index d7d33298ad..a87ca30559 100644 --- a/torchtext/datasets/__init__.py +++ b/torchtext/datasets/__init__.py @@ -12,6 +12,7 @@ from .iwslt2017 import IWSLT2017 from .multi30k import Multi30k from .penntreebank import PennTreebank +from .qqp import QQP from .sogounews import SogouNews from .squad1 import SQuAD1 from .squad2 import SQuAD2 @@ -36,6 +37,7 @@ "IWSLT2017": IWSLT2017, "Multi30k": Multi30k, "PennTreebank": PennTreebank, + "QQP": QQP, "SQuAD1": SQuAD1, "SQuAD2": SQuAD2, "SogouNews": SogouNews, diff --git a/torchtext/datasets/qqp.py b/torchtext/datasets/qqp.py new file mode 100644 index 0000000000..8058026cd0 --- /dev/null +++ b/torchtext/datasets/qqp.py @@ -0,0 +1,48 @@ +import os + +from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import _create_dataset_directory, _wrap_split_argument +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, IterableWrapper + from torchtext._download_hooks import HttpReader + +URL = "http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv" + +MD5 = "b6d5672bd9dc1e66ab2bb020ebeafb8d" + +_PATH = "quora_duplicate_questions.tsv" + +NUM_LINES = {"train": 404290} + +DATASET_NAME = "QQP" + + +@_create_dataset_directory(dataset_name=DATASET_NAME) +def QQP(root: str): + """QQP dataset + For additional details refer to https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs + + Args: + root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') + + :returns: DataPipe that yields rows from CoLA dataset (idx (int), label (int), question1 (str), question2 (str)) + :rtype: str + """ + if not is_module_available("torchdata"): + raise ModuleNotFoundError( + "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" + ) + + url_dp = IterableWrapper([URL]) + cache_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _PATH), + hash_dict={os.path.join(root, _PATH): MD5}, + hash_type="md5", + ) + cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_dp = FileOpener(cache_dp, encoding="utf-8") + # some context stored at top of the file needs to be removed + parsed_data = cache_dp.parse_csv(skip_lines=1, delimiter="\t").map(lambda x: (int(x[0]), int(x[-1]), x[3], x[4])) + return parsed_data \ No newline at end of file From c0303734cfd86a2daa56ad5e26dc33951e250422 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 11 May 2022 12:22:58 -0400 Subject: [PATCH 2/6] Adjust output + add different strings for tests --- test/datasets/test_qqp.py | 7 ++++--- torchtext/datasets/qqp.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/test/datasets/test_qqp.py b/test/datasets/test_qqp.py index 9a99254bd6..0541693848 100644 --- a/test/datasets/test_qqp.py +++ b/test/datasets/test_qqp.py @@ -23,11 +23,12 @@ def _get_mock_dataset(root_dir): f.write(f"id\tqid1\tqid2\tquestion1\tquestion2\tis_duplicate\n") for i in range(5): label = seed % 2 - rand_string = get_random_unicode(seed) - dataset_line = (i, label, rand_string, rand_string) + rand_string_1 = get_random_unicode(seed) + rand_string_2 = get_random_unicode(seed+1) + dataset_line = (label, rand_string_1, rand_string_2) # append line to correct dataset split mocked_data.append(dataset_line) - f.write(f'{i}\t{i}\t{i}\t{rand_string}\t{rand_string}\t{label}\n') + f.write(f'{i}\t{i}\t{i}\t{rand_string_1}\t{rand_string_2}\t{label}\n') seed += 1 return mocked_data diff --git a/torchtext/datasets/qqp.py b/torchtext/datasets/qqp.py index 8058026cd0..6d41051060 100644 --- a/torchtext/datasets/qqp.py +++ b/torchtext/datasets/qqp.py @@ -27,8 +27,8 @@ def QQP(root: str): Args: root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') - :returns: DataPipe that yields rows from CoLA dataset (idx (int), label (int), question1 (str), question2 (str)) - :rtype: str + :returns: DataPipe that yields rows from CoLA dataset (label (int), question1 (str), question2 (str)) + :rtype: (int, str, str) """ if not is_module_available("torchdata"): raise ModuleNotFoundError( @@ -44,5 +44,5 @@ def QQP(root: str): cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) cache_dp = FileOpener(cache_dp, encoding="utf-8") # some context stored at top of the file needs to be removed - parsed_data = cache_dp.parse_csv(skip_lines=1, delimiter="\t").map(lambda x: (int(x[0]), int(x[-1]), x[3], x[4])) + parsed_data = cache_dp.parse_csv(skip_lines=1, delimiter="\t").map(lambda x: (int(x[-1]), x[3], x[4])) return parsed_data \ No newline at end of file From d35d2f41414352015cf5748fe9adf67ebc5616bc Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Thu, 12 May 2022 11:03:27 -0400 Subject: [PATCH 3/6] Remove lambda functions + correct docstring --- torchtext/datasets/qqp.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/torchtext/datasets/qqp.py b/torchtext/datasets/qqp.py index 6d41051060..4f9dc3e922 100644 --- a/torchtext/datasets/qqp.py +++ b/torchtext/datasets/qqp.py @@ -1,8 +1,7 @@ import os from torchtext._internal.module_utils import is_module_available -from torchtext.data.datasets_utils import _create_dataset_directory, _wrap_split_argument -from typing import Union, Tuple +from torchtext.data.datasets_utils import _create_dataset_directory if is_module_available("torchdata"): from torchdata.datapipes.iter import FileOpener, IterableWrapper @@ -27,7 +26,7 @@ def QQP(root: str): Args: root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') - :returns: DataPipe that yields rows from CoLA dataset (label (int), question1 (str), question2 (str)) + :returns: DataPipe that yields rows from QQP dataset (label (int), question1 (str), question2 (str)) :rtype: (int, str, str) """ if not is_module_available("torchdata"): @@ -35,14 +34,20 @@ def QQP(root: str): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(_=None): + return os.path.join(root, _PATH) + + def _modify_res(x): + return (int(x[-1]), x[3], x[4]) + url_dp = IterableWrapper([URL]) cache_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(): MD5}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) cache_dp = FileOpener(cache_dp, encoding="utf-8") # some context stored at top of the file needs to be removed - parsed_data = cache_dp.parse_csv(skip_lines=1, delimiter="\t").map(lambda x: (int(x[-1]), x[3], x[4])) + parsed_data = cache_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_res) return parsed_data \ No newline at end of file From 1950d57e164e3f39abc104a52ae0e615eec69c61 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Mon, 16 May 2022 09:58:32 -0400 Subject: [PATCH 4/6] Fix lint --- test/datasets/test_qqp.py | 8 ++++---- torchtext/datasets/qqp.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/datasets/test_qqp.py b/test/datasets/test_qqp.py index 0541693848..4c782040ea 100644 --- a/test/datasets/test_qqp.py +++ b/test/datasets/test_qqp.py @@ -20,15 +20,15 @@ def _get_mock_dataset(root_dir): mocked_data = [] print(txt_file) with open(txt_file, "w", encoding="utf-8") as f: - f.write(f"id\tqid1\tqid2\tquestion1\tquestion2\tis_duplicate\n") + f.write("id\tqid1\tqid2\tquestion1\tquestion2\tis_duplicate\n") for i in range(5): label = seed % 2 rand_string_1 = get_random_unicode(seed) - rand_string_2 = get_random_unicode(seed+1) + rand_string_2 = get_random_unicode(seed + 1) dataset_line = (label, rand_string_1, rand_string_2) # append line to correct dataset split mocked_data.append(dataset_line) - f.write(f'{i}\t{i}\t{i}\t{rand_string_1}\t{rand_string_2}\t{label}\n') + f.write(f"{i}\t{i}\t{i}\t{rand_string_1}\t{rand_string_2}\t{label}\n") seed += 1 return mocked_data @@ -58,4 +58,4 @@ def test_qqp(self): samples = list(dataset) expected_samples = self.samples for sample, expected_sample in zip_equal(samples, expected_samples): - self.assertEqual(sample, expected_sample) \ No newline at end of file + self.assertEqual(sample, expected_sample) diff --git a/torchtext/datasets/qqp.py b/torchtext/datasets/qqp.py index 4f9dc3e922..2c940e7afb 100644 --- a/torchtext/datasets/qqp.py +++ b/torchtext/datasets/qqp.py @@ -50,4 +50,4 @@ def _modify_res(x): cache_dp = FileOpener(cache_dp, encoding="utf-8") # some context stored at top of the file needs to be removed parsed_data = cache_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_res) - return parsed_data \ No newline at end of file + return parsed_data From b53de4a91765f24bd4e7672c8aeba9c509d14b33 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Tue, 17 May 2022 11:55:22 -0400 Subject: [PATCH 5/6] Add dataset documentation --- docs/source/datasets.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 33eb44b21d..6002c79ff6 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -52,6 +52,11 @@ IMDb .. autofunction:: IMDB +QQP +~~~~ + +.. autofunction:: QQP + SogouNews ~~~~~~~~~ From 2f0fc5297323225ec52ee52e9f9e085c6d794e0a Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 18 May 2022 15:46:07 -0400 Subject: [PATCH 6/6] Add shuffle and sharding --- torchtext/datasets/qqp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/datasets/qqp.py b/torchtext/datasets/qqp.py index 2c940e7afb..387cbffaa5 100644 --- a/torchtext/datasets/qqp.py +++ b/torchtext/datasets/qqp.py @@ -50,4 +50,4 @@ def _modify_res(x): cache_dp = FileOpener(cache_dp, encoding="utf-8") # some context stored at top of the file needs to be removed parsed_data = cache_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_res) - return parsed_data + return parsed_data.shuffle().set_shuffle(False).sharding_filter()