From 433d50e7e58a4a772ba2c7b518beb747af0263d3 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sun, 16 Jan 2022 08:37:25 -0500 Subject: [PATCH 1/5] add initial pass at migrating CONLL 2000 to datapipes. --- test/data/test_dataset_utils.py | 110 ++++++++++++++++++++++++ torchtext/data/datasets_utils.py | 28 +++++- torchtext/datasets/conll2000chunking.py | 34 +++++--- 3 files changed, 158 insertions(+), 14 deletions(-) create mode 100644 test/data/test_dataset_utils.py diff --git a/test/data/test_dataset_utils.py b/test/data/test_dataset_utils.py new file mode 100644 index 0000000000..2afca0e55f --- /dev/null +++ b/test/data/test_dataset_utils.py @@ -0,0 +1,110 @@ +from ..common.torchtext_test_case import TorchtextTestCase + +from torchtext.data.datasets_utils import _ParseIOBData +from torch.utils.data.datapipes.iter import IterableWrapper + + +class TestDatasetUtils(TorchtextTestCase): + def test_iob_datapipe_basic(self): + iob = [ + "Alex I-PER", + "is O", + "going O", + "to O", + "Los I-LOC", + "Angeles I-LOC", + "in O", + "California I-LOC" + ] + iterable = [("ignored.txt", e) for e in iob] + iterable = IterableWrapper(iterable) + iob_dp = list(_ParseIOBData(iterable, sep=" ")) + # There's only one example in this dataset + self.assertEqual(len(iob_dp), 1) + # The length of the list of surface forms is the number of lines in the example + self.assertEqual(len(iob_dp[0][0]), len(iob)) + # The length of the list labels is the number of lines in the example + self.assertEqual(len(iob_dp[0][1]), len(iob)) + iob = [ + "Alex I-PER", + "is O", + "going O", + "to O", + "Los I-LOC", + "Angeles I-LOC", + "in O", + "California I-LOC", + "", + "Alex I-PER", + "is O", + "going O", + "to O", + "Los I-LOC", + "Angeles I-LOC", + "in O", + "California I-LOC", + ] + iterable = [("ignored.txt", e) for e in iob] + iterable = IterableWrapper(iterable) + iob_dp = list(_ParseIOBData(iterable, sep=" ")) + # There's only one example in this dataset + self.assertEqual(len(iob_dp), 2) + # The length of the first list of surface forms is the length of everything before the empty line. + # The length of the first labels is the length of everything before the empty line. + self.assertEqual(len(iob_dp[0][0]), iob.index("")) + self.assertEqual(len(iob_dp[0][1]), iob.index("")) + # The length of the second list of surface forms is the length of everything after the empty line. + # The length of the second labels is the length of everything after the empty line. + self.assertEqual(len(iob_dp[1][0]), len(iob) - iob.index("") - 1) + self.assertEqual(len(iob_dp[1][1]), len(iob) - iob.index("") - 1) + + def test_iob_datapipe_functional(self): + iob = [ + "Alex I-PER", + "is O", + "going O", + "to O", + "Los I-LOC", + "Angeles I-LOC", + "in O", + "California I-LOC" + ] + iterable = [("ignored.txt", e) for e in iob] + iob_dp = list(IterableWrapper(iterable).read_iob(sep=" ")) + # There's only one example in this dataset + self.assertEqual(len(iob_dp), 1) + # The length of the list of surface forms is the number of lines in the example + self.assertEqual(len(iob_dp[0][0]), len(iob)) + # The length of the list labels is the number of lines in the example + self.assertEqual(len(iob_dp[0][1]), len(iob)) + iob = [ + "Alex I-PER", + "is O", + "going O", + "to O", + "Los I-LOC", + "Angeles I-LOC", + "in O", + "California I-LOC", + "", + "Alex I-PER", + "is O", + "going O", + "to O", + "Los I-LOC", + "Angeles I-LOC", + "in O", + "California I-LOC", + ] + iterable = [("ignored.txt", e) for e in iob] + iob_dp = list(IterableWrapper(iterable).read_iob(sep=" ")) + # There's only one example in this dataset + self.assertEqual(len(iob_dp), 2) + # The length of the first list of surface forms is the length of everything before the empty line. + # The length of the first labels is the length of everything before the empty line. + self.assertEqual(len(iob_dp[0][0]), iob.index("")) + self.assertEqual(len(iob_dp[0][1]), iob.index("")) + # The length of the second list of surface forms is the length of everything after the empty line. + # The length of the second labels is the length of everything after the empty line. + self.assertEqual(len(iob_dp[1][0]), len(iob) - iob.index("") - 1) + self.assertEqual(len(iob_dp[1][1]), len(iob) - iob.index("") - 1) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 8f238177ed..d6daf3a38f 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -10,7 +10,7 @@ extract_archive, unicode_csv_reader, ) -from torch.utils.data import IterDataPipe, functional_datapipe +from torch.utils.data import functional_datapipe, IterDataPipe import codecs try: import defusedxml.ElementTree as ET @@ -342,3 +342,29 @@ def __iter__(self): _answers = [""] _answer_start = [-1] yield _context, _question, _answers, _answer_start + + +@functional_datapipe("read_iob") +class _ParseIOBData(IterDataPipe): + """A datapipe responsible for reading sep-delimited IOB data from a stream. + + Used for CONLL 2000 and UDPOS.""" + def __init__(self, dp, sep: str = "\t") -> None: + self.dp = dp + self.sep = sep + + def __iter__(self): + columns = [] + for filename, line in self.dp: + line = line.strip() + if line == "": + if columns: + yield columns + columns = [] + else: + for i, column in enumerate(line.split(self.sep)): + if len(columns) < i + 1: + columns.append([]) + columns[i].append(column) + if len(columns) > 0: + yield columns diff --git a/torchtext/datasets/conll2000chunking.py b/torchtext/datasets/conll2000chunking.py index 3816c1ddb3..3c9dea24a9 100644 --- a/torchtext/datasets/conll2000chunking.py +++ b/torchtext/datasets/conll2000chunking.py @@ -1,13 +1,16 @@ +from torchtext._internal.module_utils import is_module_available +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, - _download_extract_validate, _create_dataset_directory, - _create_data_from_iob, ) + import os -import logging URL = { 'train': "https://www.clips.uantwerpen.be/conll2000/chunking/train.txt.gz", @@ -41,12 +44,17 @@ @_add_docstring_header(num_lines=NUM_LINES) @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(('train', 'test')) -def CoNLL2000Chunking(root, split): - # Create a dataset specific subfolder to deal with generic download filenames - root = os.path.join(root, 'conll2000chunking') - path = os.path.join(root, split + ".txt.gz") - data_filename = _download_extract_validate(root, URL[split], MD5[split], path, os.path.join(root, _EXTRACTED_FILES[split]), - _EXTRACTED_FILES_MD5[split], hash_type="md5") - logging.info('Creating {} data'.format(split)) - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], - _create_data_from_iob(data_filename, " ")) +def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + + url_dp = IterableWrapper([URL[split]]) + cache_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, "conll2000chunking", os.path.basename(URL[split])), + hash_dict={os.path.join(root, "conll2000chunking", os.path.basename(URL[split])): MD5[split]}, + hash_type="md5" + ) + cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_dp = FileOpener(cache_dp, mode="b") + cache_dp = cache_dp.extract(file_type="gzip").readlines(decode=True) + return cache_dp.read_iob(sep=" ") From 8c55fba05a74d95d05af67ae1aff7813f3403d1e Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Tue, 18 Jan 2022 10:01:49 -0500 Subject: [PATCH 2/5] fix typo in test. --- test/data/test_dataset_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/data/test_dataset_utils.py b/test/data/test_dataset_utils.py index 2afca0e55f..f613df4ce4 100644 --- a/test/data/test_dataset_utils.py +++ b/test/data/test_dataset_utils.py @@ -47,7 +47,7 @@ def test_iob_datapipe_basic(self): iterable = [("ignored.txt", e) for e in iob] iterable = IterableWrapper(iterable) iob_dp = list(_ParseIOBData(iterable, sep=" ")) - # There's only one example in this dataset + # There are two examples in this dataset self.assertEqual(len(iob_dp), 2) # The length of the first list of surface forms is the length of everything before the empty line. # The length of the first labels is the length of everything before the empty line. From c7e128278470ed5978968db9ad4e0fbf9331c0f3 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Tue, 18 Jan 2022 11:09:50 -0500 Subject: [PATCH 3/5] add caching for the extraction from raw gzip too. --- torchtext/datasets/conll2000chunking.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/torchtext/datasets/conll2000chunking.py b/torchtext/datasets/conll2000chunking.py index 3c9dea24a9..0dd07ea30e 100644 --- a/torchtext/datasets/conll2000chunking.py +++ b/torchtext/datasets/conll2000chunking.py @@ -49,6 +49,8 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]): raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") url_dp = IterableWrapper([URL[split]]) + + # Cache and check HTTP response cache_dp = url_dp.on_disk_cache( filepath_fn=lambda x: os.path.join(root, "conll2000chunking", os.path.basename(URL[split])), hash_dict={os.path.join(root, "conll2000chunking", os.path.basename(URL[split])): MD5[split]}, @@ -56,5 +58,15 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]): ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) cache_dp = FileOpener(cache_dp, mode="b") - cache_dp = cache_dp.extract(file_type="gzip").readlines(decode=True) - return cache_dp.read_iob(sep=" ") + + # Cache and check the gzip extraction for relevant split + cache_dp = cache_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, "conll2000chunking", _EXTRACTED_FILES[split]), + hash_dict={os.path.join(root, "conll2000chunking", _EXTRACTED_FILES[split]): _EXTRACTED_FILES_MD5[split]}, + hash_type="md5" + ) + cache_dp = cache_dp.extract(file_type="gzip").filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + cache_dp = cache_dp.end_caching(mode="wb") + + cache_dp = FileOpener(cache_dp, mode="b") + return cache_dp.readlines(decode=True).read_iob(sep=" ") From 0977d456265adfc73cd6a26a2b3bef0002ea0412 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Tue, 18 Jan 2022 11:32:55 -0500 Subject: [PATCH 4/5] parameterize test for minimal duplication. --- test/data/test_dataset_utils.py | 65 +++++---------------------------- 1 file changed, 9 insertions(+), 56 deletions(-) diff --git a/test/data/test_dataset_utils.py b/test/data/test_dataset_utils.py index f613df4ce4..74562b7602 100644 --- a/test/data/test_dataset_utils.py +++ b/test/data/test_dataset_utils.py @@ -3,9 +3,15 @@ from torchtext.data.datasets_utils import _ParseIOBData from torch.utils.data.datapipes.iter import IterableWrapper +from parameterized import parameterized + class TestDatasetUtils(TorchtextTestCase): - def test_iob_datapipe_basic(self): + @parameterized.expand([ + [lambda it: list(_ParseIOBData(IterableWrapper(it), sep=" "))], + [lambda it: list(IterableWrapper(it).read_iob(sep=" "))] + ]) + def test_iob_datapipe(self, pipe_fn): iob = [ "Alex I-PER", "is O", @@ -17,8 +23,7 @@ def test_iob_datapipe_basic(self): "California I-LOC" ] iterable = [("ignored.txt", e) for e in iob] - iterable = IterableWrapper(iterable) - iob_dp = list(_ParseIOBData(iterable, sep=" ")) + iob_dp = pipe_fn(iterable) # There's only one example in this dataset self.assertEqual(len(iob_dp), 1) # The length of the list of surface forms is the number of lines in the example @@ -45,8 +50,7 @@ def test_iob_datapipe_basic(self): "California I-LOC", ] iterable = [("ignored.txt", e) for e in iob] - iterable = IterableWrapper(iterable) - iob_dp = list(_ParseIOBData(iterable, sep=" ")) + iob_dp = pipe_fn(iterable) # There are two examples in this dataset self.assertEqual(len(iob_dp), 2) # The length of the first list of surface forms is the length of everything before the empty line. @@ -57,54 +61,3 @@ def test_iob_datapipe_basic(self): # The length of the second labels is the length of everything after the empty line. self.assertEqual(len(iob_dp[1][0]), len(iob) - iob.index("") - 1) self.assertEqual(len(iob_dp[1][1]), len(iob) - iob.index("") - 1) - - def test_iob_datapipe_functional(self): - iob = [ - "Alex I-PER", - "is O", - "going O", - "to O", - "Los I-LOC", - "Angeles I-LOC", - "in O", - "California I-LOC" - ] - iterable = [("ignored.txt", e) for e in iob] - iob_dp = list(IterableWrapper(iterable).read_iob(sep=" ")) - # There's only one example in this dataset - self.assertEqual(len(iob_dp), 1) - # The length of the list of surface forms is the number of lines in the example - self.assertEqual(len(iob_dp[0][0]), len(iob)) - # The length of the list labels is the number of lines in the example - self.assertEqual(len(iob_dp[0][1]), len(iob)) - iob = [ - "Alex I-PER", - "is O", - "going O", - "to O", - "Los I-LOC", - "Angeles I-LOC", - "in O", - "California I-LOC", - "", - "Alex I-PER", - "is O", - "going O", - "to O", - "Los I-LOC", - "Angeles I-LOC", - "in O", - "California I-LOC", - ] - iterable = [("ignored.txt", e) for e in iob] - iob_dp = list(IterableWrapper(iterable).read_iob(sep=" ")) - # There's only one example in this dataset - self.assertEqual(len(iob_dp), 2) - # The length of the first list of surface forms is the length of everything before the empty line. - # The length of the first labels is the length of everything before the empty line. - self.assertEqual(len(iob_dp[0][0]), iob.index("")) - self.assertEqual(len(iob_dp[0][1]), iob.index("")) - # The length of the second list of surface forms is the length of everything after the empty line. - # The length of the second labels is the length of everything after the empty line. - self.assertEqual(len(iob_dp[1][0]), len(iob) - iob.index("") - 1) - self.assertEqual(len(iob_dp[1][1]), len(iob) - iob.index("") - 1) From 67f0509591e94e72d871efb3fb14e9e7dcc0703e Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Tue, 18 Jan 2022 12:46:58 -0500 Subject: [PATCH 5/5] remove hash-check from extracted files since the archived hash has already been checked. --- torchtext/datasets/conll2000chunking.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/torchtext/datasets/conll2000chunking.py b/torchtext/datasets/conll2000chunking.py index 0dd07ea30e..513132233c 100644 --- a/torchtext/datasets/conll2000chunking.py +++ b/torchtext/datasets/conll2000chunking.py @@ -32,18 +32,12 @@ 'test': 'test.txt' } -_EXTRACTED_FILES_MD5 = { - 'train': "2e2f24e90e20fcb910ab2251b5ed8cd0", - 'test': "56944df34be553b72a2a634e539a0951" -} - - DATASET_NAME = "CoNLL2000Chunking" @_add_docstring_header(num_lines=NUM_LINES) @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'test')) +@_wrap_split_argument(("train", "test")) def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]): if not is_module_available("torchdata"): raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") @@ -61,9 +55,7 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]): # Cache and check the gzip extraction for relevant split cache_dp = cache_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, "conll2000chunking", _EXTRACTED_FILES[split]), - hash_dict={os.path.join(root, "conll2000chunking", _EXTRACTED_FILES[split]): _EXTRACTED_FILES_MD5[split]}, - hash_type="md5" + filepath_fn=lambda x: os.path.join(root, "conll2000chunking", _EXTRACTED_FILES[split]) ) cache_dp = cache_dp.extract(file_type="gzip").filter(lambda x: _EXTRACTED_FILES[split] in x[0]) cache_dp = cache_dp.end_caching(mode="wb")