Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3ab5c1e
Merge pull request #1 from pytorch/master
akurniawan Feb 13, 2019
b068312
add raw for sequence tagging
Jun 3, 2020
41b975f
WIP sequence tagging dataset
Jun 3, 2020
49dec4b
add specialized function to handle None case
Jun 3, 2020
3f988b1
expose raw datasets for sequence tagging
Jun 3, 2020
27bbeb7
finalized sequence tagging dataset
Jun 3, 2020
a99723f
add documentation
Jun 4, 2020
b7ba5b0
expose sequence tagging data
Jun 4, 2020
36d4652
add unit test for sequence tagging
Jun 4, 2020
b79839c
fix linting
Jun 4, 2020
a03eec4
remove filename arguments
Jun 5, 2020
481ea37
[WIP] adding conll test
Jun 5, 2020
8eeeff9
Merge branch 'master' of https://github.com/pytorch/text into new_seq…
Jun 5, 2020
fb3f8f5
move the test order with translation dataset and finalize conll testing
Jun 5, 2020
9a84a8a
add doc string for sequence tagging dataset
Jun 5, 2020
9d11bc1
remove spaces at the end of the file
Jun 5, 2020
b1c5ec4
reformat docstring
Jun 5, 2020
c38a0b8
remove tokenizer
Jun 5, 2020
51f7fbf
Merge branch 'master' of https://github.com/pytorch/text into new_seq…
Jun 12, 2020
12d4482
fix linting
Jun 12, 2020
247a14c
Merge branch 'master' of https://github.com/pytorch/text into new_seq…
Jun 15, 2020
0aad1c4
add cases where we don't have blank by the end of the file
Jun 15, 2020
b662081
- add validation for data_select
Jun 15, 2020
899f872
Merge branch 'master' of github.com:akurniawan/text into new_sequence…
Jun 15, 2020
a0ec2e7
Merge branch 'new_sequence_tagging' of github.com:akurniawan/text int…
Jun 15, 2020
13864ea
modify method name
Jun 15, 2020
e3d4256
add "valid" to data_select option validation
Jun 15, 2020
351ad46
add todo for assert_allclose
Jun 17, 2020
73ce74a
remove duplicate validation for transforms function
Jun 17, 2020
0358f2c
Merge branch 'master' of https://github.com/pytorch/text into new_seq…
Jun 19, 2020
e4ba11c
replace assert_allclose with self.assertEqual
Jun 19, 2020
2d01b15
Merge branch 'master' of https://github.com/pytorch/text into new_seq…
Jun 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions docs/source/experimental_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ IMDb
~~~~

.. autoclass:: IMDB
:members: __init__
:members: __init__


Text Classification
^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -109,8 +109,8 @@ AmazonReviewFull dataset is subclass of ``TextClassificationDataset`` class.

.. autoclass:: AmazonReviewFull
:members: __init__


Language Modeling
^^^^^^^^^^^^^^^^^

Expand All @@ -124,28 +124,28 @@ WikiText-2
~~~~~~~~~~

.. autoclass:: WikiText2
:members: __init__
:members: __init__


WikiText103
~~~~~~~~~~~

.. autoclass:: WikiText103
:members: __init__
:members: __init__


PennTreebank
~~~~~~~~~~~~

.. autoclass:: PennTreebank
:members: __init__
:members: __init__


WMTNewsCrawl
~~~~~~~~~~~~

.. autoclass:: WMTNewsCrawl
:members: __init__
.. autoclass:: WMTNewsCrawl
:members: __init__


Machine Translation
Expand Down Expand Up @@ -177,24 +177,45 @@ WMT14
.. autoclass:: WMT14
:members: __init__


Sequence Tagging
^^^^^^^^^^^^^^^^

Language modeling datasets are subclasses of ``SequenceTaggingDataset`` class.

.. autoclass:: SequenceTaggingDataset
:members: __init__

UDPOS
~~~~~

.. autoclass:: UDPOS
:members: __init__

CoNLL2000Chunking
~~~~~

.. autoclass:: CoNLL2000Chunking
:members: __init__

Question Answer
^^^^^^^^^^^^^^^

Question answer datasets are subclasses of ``QuestionAnswerDataset`` class.

.. autoclass:: QuestionAnswerDataset
.. autoclass:: QuestionAnswerDataset
:members: __init__


SQuAD 1.0
~~~~~~~~~

.. autoclass:: SQuAD1
:members: __init__
:members: __init__


SQuAD 2.0
~~~~~~~~~

.. autoclass:: SQuAD2
:members: __init__
:members: __init__
102 changes: 102 additions & 0 deletions test/data/test_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,108 @@ def test_multi30k(self):
"multi30k_task*.tar.gz")
conditional_remove(datafile)

def test_udpos_sequence_tagging(self):
from torchtext.experimental.datasets import UDPOS

# smoke test to ensure imdb works properly
train_dataset, valid_dataset, test_dataset = UDPOS()
self.assertEqual(len(train_dataset), 12543)
self.assertEqual(len(valid_dataset), 2002)
self.assertEqual(len(test_dataset), 2077)
self.assertEqual(train_dataset[0][0][:10],
torch.tensor([262, 16, 5728, 45, 289, 701, 1160, 4436, 10660, 585]).long())
self.assertEqual(train_dataset[0][1][:10],
torch.tensor([8, 3, 8, 3, 9, 2, 4, 8, 8, 8]).long())
self.assertEqual(train_dataset[0][2][:10],
torch.tensor([5, 34, 5, 27, 7, 11, 14, 5, 5, 5]).long())
self.assertEqual(train_dataset[-1][0][:10],
torch.tensor([9, 32, 169, 436, 59, 192, 30, 6, 117, 17]).long())
self.assertEqual(train_dataset[-1][1][:10],
torch.tensor([5, 10, 11, 4, 11, 11, 3, 12, 11, 4]).long())
self.assertEqual(train_dataset[-1][2][:10],
torch.tensor([6, 20, 8, 10, 8, 8, 24, 13, 8, 15]).long())

self.assertEqual(valid_dataset[0][0][:10],
torch.tensor([746, 3, 10633, 656, 25, 1334, 45]).long())
self.assertEqual(valid_dataset[0][1][:10],
torch.tensor([6, 7, 8, 4, 7, 2, 3]).long())
self.assertEqual(valid_dataset[0][2][:10],
torch.tensor([3, 4, 5, 16, 4, 2, 27]).long())
self.assertEqual(valid_dataset[-1][0][:10],
torch.tensor([354, 4, 31, 17, 141, 421, 148, 6, 7, 78]).long())
self.assertEqual(valid_dataset[-1][1][:10],
torch.tensor([11, 3, 5, 4, 9, 2, 2, 12, 7, 11]).long())
self.assertEqual(valid_dataset[-1][2][:10],
torch.tensor([8, 12, 6, 15, 7, 2, 2, 13, 4, 8]).long())

self.assertEqual(test_dataset[0][0][:10],
torch.tensor([210, 54, 3115, 0, 12229, 0, 33]).long())
self.assertEqual(test_dataset[0][1][:10],
torch.tensor([5, 15, 8, 4, 6, 8, 3]).long())
self.assertEqual(test_dataset[0][2][:10],
torch.tensor([30, 3, 5, 14, 3, 5, 9]).long())
self.assertEqual(test_dataset[-1][0][:10],
torch.tensor([116, 0, 6, 11, 412, 10, 0, 4, 0, 6]).long())
self.assertEqual(test_dataset[-1][1][:10],
torch.tensor([5, 4, 12, 10, 9, 15, 4, 3, 4, 12]).long())
self.assertEqual(test_dataset[-1][2][:10],
torch.tensor([6, 16, 13, 16, 7, 3, 19, 12, 19, 13]).long())

# Assert vocabs
self.assertEqual(len(train_dataset.get_vocabs()), 3)
self.assertEqual(len(train_dataset.get_vocabs()[0]), 19674)
self.assertEqual(len(train_dataset.get_vocabs()[1]), 19)
self.assertEqual(len(train_dataset.get_vocabs()[2]), 52)

# Assert token ids
word_vocab = train_dataset.get_vocabs()[0]
tokens_ids = [word_vocab[token] for token in 'Two of them were being run'.split()]
self.assertEqual(tokens_ids, [1206, 8, 69, 60, 157, 452])

def test_conll_sequence_tagging(self):
from torchtext.experimental.datasets import CoNLL2000Chunking

# smoke test to ensure imdb works properly
train_dataset, test_dataset = CoNLL2000Chunking()
self.assertEqual(len(train_dataset), 8936)
self.assertEqual(len(test_dataset), 2012)
self.assertEqual(train_dataset[0][0][:10],
torch.tensor([11556, 9, 3, 1775, 17, 1164, 177, 6, 212, 317]).long())
self.assertEqual(train_dataset[0][1][:10],
torch.tensor([2, 3, 5, 2, 17, 12, 16, 15, 13, 5]).long())
self.assertEqual(train_dataset[0][2][:10],
torch.tensor([3, 6, 3, 2, 5, 7, 7, 7, 7, 3]).long())
self.assertEqual(train_dataset[-1][0][:10],
torch.tensor([85, 17, 59, 6473, 288, 115, 72, 5, 2294, 2502]).long())
self.assertEqual(train_dataset[-1][1][:10],
torch.tensor([18, 17, 12, 19, 10, 6, 3, 3, 4, 4]).long())
self.assertEqual(train_dataset[-1][2][:10],
torch.tensor([3, 5, 7, 7, 3, 2, 6, 6, 3, 2]).long())

self.assertEqual(test_dataset[0][0][:10],
torch.tensor([0, 294, 73, 10, 13582, 194, 18, 24, 2414, 7]).long())
self.assertEqual(test_dataset[0][1][:10],
torch.tensor([4, 4, 4, 23, 4, 2, 11, 18, 11, 5]).long())
self.assertEqual(test_dataset[0][2][:10],
torch.tensor([3, 2, 2, 3, 2, 2, 5, 3, 5, 3]).long())
self.assertEqual(test_dataset[-1][0][:10],
torch.tensor([51, 456, 560, 2, 11, 465, 2, 1413, 36, 60]).long())
self.assertEqual(test_dataset[-1][1][:10],
torch.tensor([3, 4, 4, 8, 3, 2, 8, 4, 17, 16]).long())
self.assertEqual(test_dataset[-1][2][:10],
torch.tensor([6, 3, 2, 4, 6, 3, 4, 3, 5, 7]).long())

# Assert vocabs
self.assertEqual(len(train_dataset.get_vocabs()), 3)
self.assertEqual(len(train_dataset.get_vocabs()[0]), 19124)
self.assertEqual(len(train_dataset.get_vocabs()[1]), 46)
self.assertEqual(len(train_dataset.get_vocabs()[2]), 24)

# Assert token ids
word_vocab = train_dataset.get_vocabs()[0]
tokens_ids = [word_vocab[token] for token in 'Two of them were being run'.split()]
self.assertEqual(tokens_ids, [970, 5, 135, 43, 214, 690])

def test_squad1(self):
from torchtext.experimental.datasets import SQuAD1
from torchtext.vocab import Vocab
Expand Down
3 changes: 3 additions & 0 deletions torchtext/experimental/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .text_classification import AG_NEWS, SogouNews, DBpedia, YelpReviewPolarity, \
YelpReviewFull, YahooAnswers, \
AmazonReviewPolarity, AmazonReviewFull, IMDB
from .sequence_tagging import UDPOS, CoNLL2000Chunking
from .translation import Multi30k, IWSLT, WMT14
from .question_answer import SQuAD1, SQuAD2

Expand All @@ -19,6 +20,8 @@
'YahooAnswers',
'AmazonReviewPolarity',
'AmazonReviewFull',
'UDPOS',
'CoNLL2000Chunking',
'Multi30k',
'IWSLT',
'WMT14',
Expand Down
3 changes: 3 additions & 0 deletions torchtext/experimental/datasets/raw/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .text_classification import AG_NEWS, SogouNews, DBpedia, YelpReviewPolarity, \
YelpReviewFull, YahooAnswers, \
AmazonReviewPolarity, AmazonReviewFull, IMDB
from .sequence_tagging import UDPOS, CoNLL2000Chunking
from .translation import Multi30k, IWSLT, WMT14
from .language_modeling import WikiText2, WikiText103, PennTreebank, WMTNewsCrawl
from .question_answer import SQuAD1, SQuAD2
Expand All @@ -14,6 +15,8 @@
'YahooAnswers',
'AmazonReviewPolarity',
'AmazonReviewFull',
'UDPOS',
'CoNLL2000Chunking',
'Multi30k',
'IWSLT',
'WMT14',
Expand Down
136 changes: 136 additions & 0 deletions torchtext/experimental/datasets/raw/sequence_tagging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch

from torchtext.utils import download_from_url, extract_archive

URLS = {
"UDPOS":
'https://bitbucket.org/sivareddyg/public/downloads/en-ud-v2.zip',
"CoNLL2000Chunking": [
'https://www.clips.uantwerpen.be/conll2000/chunking/train.txt.gz',
'https://www.clips.uantwerpen.be/conll2000/chunking/test.txt.gz'
]
}


def _create_data_from_iob(data_path, separator="\t"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we return or yield something from this func?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, please take a look at line 22. due to the nature of the data, each sentences is separated with an empty line, therefore we will return one sentence if we found one. and I just add new commit to return leftovers

with open(data_path, encoding="utf-8") as input_file:
columns = []
for line in input_file:
line = line.strip()
if line == "":
if columns:
yield columns
columns = []
else:
for i, column in enumerate(line.split(separator)):
if len(columns) < i + 1:
columns.append([])
columns[i].append(column)
if len(columns) > 0:
yield columns


def _construct_filepath(paths, file_suffix):
if file_suffix:
path = None
for p in paths:
path = p if p.endswith(file_suffix) else path
return path
return None


def _setup_datasets(dataset_name, separator, root=".data"):

extracted_files = []
if isinstance(URLS[dataset_name], list):
for f in URLS[dataset_name]:
dataset_tar = download_from_url(f, root=root)
extracted_files.extend(extract_archive(dataset_tar))
elif isinstance(URLS[dataset_name], str):
dataset_tar = download_from_url(URLS[dataset_name], root=root)
extracted_files.extend(extract_archive(dataset_tar))
else:
raise ValueError(
"URLS for {} has to be in a form or list or string".format(
dataset_name))

data_filenames = {
"train": _construct_filepath(extracted_files, "train.txt"),
"valid": _construct_filepath(extracted_files, "dev.txt"),
"test": _construct_filepath(extracted_files, "test.txt")
}

datasets = []
for key in data_filenames.keys():
if data_filenames[key] is not None:
datasets.append(
RawSequenceTaggingIterableDataset(
_create_data_from_iob(data_filenames[key], separator)))
else:
datasets.append(None)

return datasets


class RawSequenceTaggingIterableDataset(torch.utils.data.IterableDataset):
"""Defines an abstraction for raw text sequence tagging iterable datasets.
"""
def __init__(self, iterator):
super(RawSequenceTaggingIterableDataset).__init__()

self._iterator = iterator
self.has_setup = False
self.start = 0
self.num_lines = None

def setup_iter(self, start=0, num_lines=None):
self.start = start
self.num_lines = num_lines
self.has_setup = True

def __iter__(self):
if not self.has_setup:
self.setup_iter()

for i, item in enumerate(self._iterator):
if i >= self.start:
yield item
if (self.num_lines is not None) and (i == (self.start +
self.num_lines)):
break

def get_iterator(self):
return self._iterator


def UDPOS(*args, **kwargs):
""" Universal Dependencies English Web Treebank

Separately returns the training and test dataset

Arguments:
root: Directory where the datasets are saved. Default: ".data"

Examples:
>>> from torchtext.datasets.raw import UDPOS
>>> train_dataset, valid_dataset, test_dataset = UDPOS()
"""
return _setup_datasets(*(("UDPOS", "\t") + args), **kwargs)


def CoNLL2000Chunking(*args, **kwargs):
""" CoNLL 2000 Chunking Dataset

Separately returns the training and test dataset

Arguments:
root: Directory where the datasets are saved. Default: ".data"

Examples:
>>> from torchtext.datasets.raw import CoNLL2000Chunking
>>> train_dataset, valid_dataset, test_dataset = CoNLL2000Chunking()
"""
return _setup_datasets(*(("CoNLL2000Chunking", " ") + args), **kwargs)


DATASETS = {"UDPOS": UDPOS, "CoNLL2000Chunking": CoNLL2000Chunking}
Loading