diff --git a/test/data/test_dataset_utils.py b/test/data/test_dataset_utils.py new file mode 100644 index 0000000000..74562b7602 --- /dev/null +++ b/test/data/test_dataset_utils.py @@ -0,0 +1,63 @@ +from ..common.torchtext_test_case import TorchtextTestCase + +from torchtext.data.datasets_utils import _ParseIOBData +from torch.utils.data.datapipes.iter import IterableWrapper + +from parameterized import parameterized + + +class TestDatasetUtils(TorchtextTestCase): + @parameterized.expand([ + [lambda it: list(_ParseIOBData(IterableWrapper(it), sep=" "))], + [lambda it: list(IterableWrapper(it).read_iob(sep=" "))] + ]) + def test_iob_datapipe(self, pipe_fn): + iob = [ + "Alex I-PER", + "is O", + "going O", + "to O", + "Los I-LOC", + "Angeles I-LOC", + "in O", + "California I-LOC" + ] + iterable = [("ignored.txt", e) for e in iob] + iob_dp = pipe_fn(iterable) + # There's only one example in this dataset + self.assertEqual(len(iob_dp), 1) + # The length of the list of surface forms is the number of lines in the example + self.assertEqual(len(iob_dp[0][0]), len(iob)) + # The length of the list labels is the number of lines in the example + self.assertEqual(len(iob_dp[0][1]), len(iob)) + iob = [ + "Alex I-PER", + "is O", + "going O", + "to O", + "Los I-LOC", + "Angeles I-LOC", + "in O", + "California I-LOC", + "", + "Alex I-PER", + "is O", + "going O", + "to O", + "Los I-LOC", + "Angeles I-LOC", + "in O", + "California I-LOC", + ] + iterable = [("ignored.txt", e) for e in iob] + iob_dp = pipe_fn(iterable) + # There are two examples in this dataset + self.assertEqual(len(iob_dp), 2) + # The length of the first list of surface forms is the length of everything before the empty line. + # The length of the first labels is the length of everything before the empty line. + self.assertEqual(len(iob_dp[0][0]), iob.index("")) + self.assertEqual(len(iob_dp[0][1]), iob.index("")) + # The length of the second list of surface forms is the length of everything after the empty line. + # The length of the second labels is the length of everything after the empty line. + self.assertEqual(len(iob_dp[1][0]), len(iob) - iob.index("") - 1) + self.assertEqual(len(iob_dp[1][1]), len(iob) - iob.index("") - 1) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 8f238177ed..d6daf3a38f 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -10,7 +10,7 @@ extract_archive, unicode_csv_reader, ) -from torch.utils.data import IterDataPipe, functional_datapipe +from torch.utils.data import functional_datapipe, IterDataPipe import codecs try: import defusedxml.ElementTree as ET @@ -342,3 +342,29 @@ def __iter__(self): _answers = [""] _answer_start = [-1] yield _context, _question, _answers, _answer_start + + +@functional_datapipe("read_iob") +class _ParseIOBData(IterDataPipe): + """A datapipe responsible for reading sep-delimited IOB data from a stream. + + Used for CONLL 2000 and UDPOS.""" + def __init__(self, dp, sep: str = "\t") -> None: + self.dp = dp + self.sep = sep + + def __iter__(self): + columns = [] + for filename, line in self.dp: + line = line.strip() + if line == "": + if columns: + yield columns + columns = [] + else: + for i, column in enumerate(line.split(self.sep)): + if len(columns) < i + 1: + columns.append([]) + columns[i].append(column) + if len(columns) > 0: + yield columns diff --git a/torchtext/datasets/conll2000chunking.py b/torchtext/datasets/conll2000chunking.py index 3816c1ddb3..513132233c 100644 --- a/torchtext/datasets/conll2000chunking.py +++ b/torchtext/datasets/conll2000chunking.py @@ -1,13 +1,16 @@ +from torchtext._internal.module_utils import is_module_available +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, - _download_extract_validate, _create_dataset_directory, - _create_data_from_iob, ) + import os -import logging URL = { 'train': "https://www.clips.uantwerpen.be/conll2000/chunking/train.txt.gz", @@ -29,24 +32,33 @@ 'test': 'test.txt' } -_EXTRACTED_FILES_MD5 = { - 'train': "2e2f24e90e20fcb910ab2251b5ed8cd0", - 'test': "56944df34be553b72a2a634e539a0951" -} - - DATASET_NAME = "CoNLL2000Chunking" @_add_docstring_header(num_lines=NUM_LINES) @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'test')) -def CoNLL2000Chunking(root, split): - # Create a dataset specific subfolder to deal with generic download filenames - root = os.path.join(root, 'conll2000chunking') - path = os.path.join(root, split + ".txt.gz") - data_filename = _download_extract_validate(root, URL[split], MD5[split], path, os.path.join(root, _EXTRACTED_FILES[split]), - _EXTRACTED_FILES_MD5[split], hash_type="md5") - logging.info('Creating {} data'.format(split)) - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], - _create_data_from_iob(data_filename, " ")) +@_wrap_split_argument(("train", "test")) +def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + + url_dp = IterableWrapper([URL[split]]) + + # Cache and check HTTP response + cache_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, "conll2000chunking", os.path.basename(URL[split])), + hash_dict={os.path.join(root, "conll2000chunking", os.path.basename(URL[split])): MD5[split]}, + hash_type="md5" + ) + cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_dp = FileOpener(cache_dp, mode="b") + + # Cache and check the gzip extraction for relevant split + cache_dp = cache_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, "conll2000chunking", _EXTRACTED_FILES[split]) + ) + cache_dp = cache_dp.extract(file_type="gzip").filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + cache_dp = cache_dp.end_caching(mode="wb") + + cache_dp = FileOpener(cache_dp, mode="b") + return cache_dp.readlines(decode=True).read_iob(sep=" ")