Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions test/data/test_dataset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from ..common.torchtext_test_case import TorchtextTestCase

from torchtext.data.datasets_utils import _ParseIOBData
from torch.utils.data.datapipes.iter import IterableWrapper

from parameterized import parameterized


class TestDatasetUtils(TorchtextTestCase):
@parameterized.expand([
[lambda it: list(_ParseIOBData(IterableWrapper(it), sep=" "))],
[lambda it: list(IterableWrapper(it).read_iob(sep=" "))]
])
def test_iob_datapipe(self, pipe_fn):
iob = [
"Alex I-PER",
"is O",
"going O",
"to O",
"Los I-LOC",
"Angeles I-LOC",
"in O",
"California I-LOC"
]
iterable = [("ignored.txt", e) for e in iob]
iob_dp = pipe_fn(iterable)
# There's only one example in this dataset
self.assertEqual(len(iob_dp), 1)
# The length of the list of surface forms is the number of lines in the example
self.assertEqual(len(iob_dp[0][0]), len(iob))
# The length of the list labels is the number of lines in the example
self.assertEqual(len(iob_dp[0][1]), len(iob))
iob = [
"Alex I-PER",
"is O",
"going O",
"to O",
"Los I-LOC",
"Angeles I-LOC",
"in O",
"California I-LOC",
"",
"Alex I-PER",
"is O",
"going O",
"to O",
"Los I-LOC",
"Angeles I-LOC",
"in O",
"California I-LOC",
]
iterable = [("ignored.txt", e) for e in iob]
iob_dp = pipe_fn(iterable)
# There are two examples in this dataset
self.assertEqual(len(iob_dp), 2)
# The length of the first list of surface forms is the length of everything before the empty line.
# The length of the first labels is the length of everything before the empty line.
self.assertEqual(len(iob_dp[0][0]), iob.index(""))
self.assertEqual(len(iob_dp[0][1]), iob.index(""))
# The length of the second list of surface forms is the length of everything after the empty line.
# The length of the second labels is the length of everything after the empty line.
self.assertEqual(len(iob_dp[1][0]), len(iob) - iob.index("") - 1)
self.assertEqual(len(iob_dp[1][1]), len(iob) - iob.index("") - 1)
28 changes: 27 additions & 1 deletion torchtext/data/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
extract_archive,
unicode_csv_reader,
)
from torch.utils.data import IterDataPipe, functional_datapipe
from torch.utils.data import functional_datapipe, IterDataPipe
import codecs
try:
import defusedxml.ElementTree as ET
Expand Down Expand Up @@ -342,3 +342,29 @@ def __iter__(self):
_answers = [""]
_answer_start = [-1]
yield _context, _question, _answers, _answer_start


@functional_datapipe("read_iob")
class _ParseIOBData(IterDataPipe):
"""A datapipe responsible for reading sep-delimited IOB data from a stream.

Used for CONLL 2000 and UDPOS."""
def __init__(self, dp, sep: str = "\t") -> None:
self.dp = dp
self.sep = sep

def __iter__(self):
columns = []
for filename, line in self.dp:
line = line.strip()
if line == "":
if columns:
yield columns
columns = []
else:
for i, column in enumerate(line.split(self.sep)):
if len(columns) < i + 1:
columns.append([])
columns[i].append(column)
if len(columns) > 0:
yield columns
52 changes: 32 additions & 20 deletions torchtext/datasets/conll2000chunking.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from torchtext._internal.module_utils import is_module_available
from typing import Union, Tuple

if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper

from torchtext.data.datasets_utils import (
_RawTextIterableDataset,
_wrap_split_argument,
_add_docstring_header,
_download_extract_validate,
_create_dataset_directory,
_create_data_from_iob,
)

import os
import logging

URL = {
'train': "https://www.clips.uantwerpen.be/conll2000/chunking/train.txt.gz",
Expand All @@ -29,24 +32,33 @@
'test': 'test.txt'
}

_EXTRACTED_FILES_MD5 = {
'train': "2e2f24e90e20fcb910ab2251b5ed8cd0",
'test': "56944df34be553b72a2a634e539a0951"
}


DATASET_NAME = "CoNLL2000Chunking"


@_add_docstring_header(num_lines=NUM_LINES)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'test'))
def CoNLL2000Chunking(root, split):
# Create a dataset specific subfolder to deal with generic download filenames
root = os.path.join(root, 'conll2000chunking')
path = os.path.join(root, split + ".txt.gz")
data_filename = _download_extract_validate(root, URL[split], MD5[split], path, os.path.join(root, _EXTRACTED_FILES[split]),
_EXTRACTED_FILES_MD5[split], hash_type="md5")
logging.info('Creating {} data'.format(split))
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
_create_data_from_iob(data_filename, " "))
@_wrap_split_argument(("train", "test"))
def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]):
if not is_module_available("torchdata"):
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")

url_dp = IterableWrapper([URL[split]])

# Cache and check HTTP response
cache_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, "conll2000chunking", os.path.basename(URL[split])),
hash_dict={os.path.join(root, "conll2000chunking", os.path.basename(URL[split])): MD5[split]},
hash_type="md5"
)
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
cache_dp = FileOpener(cache_dp, mode="b")

# Cache and check the gzip extraction for relevant split
cache_dp = cache_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, "conll2000chunking", _EXTRACTED_FILES[split])
)
cache_dp = cache_dp.extract(file_type="gzip").filter(lambda x: _EXTRACTED_FILES[split] in x[0])
cache_dp = cache_dp.end_caching(mode="wb")

cache_dp = FileOpener(cache_dp, mode="b")
return cache_dp.readlines(decode=True).read_iob(sep=" ")