diff --git a/docs/source/experimental_datasets.rst b/docs/source/experimental_datasets.rst index 044b731496..20cde70f2e 100644 --- a/docs/source/experimental_datasets.rst +++ b/docs/source/experimental_datasets.rst @@ -34,9 +34,9 @@ IMDb ~~~~ .. autoclass:: IMDB - :members: __init__ - - + :members: __init__ + + Text Classification ^^^^^^^^^^^^^^^^^^^ @@ -109,8 +109,8 @@ AmazonReviewFull dataset is subclass of ``TextClassificationDataset`` class. .. autoclass:: AmazonReviewFull :members: __init__ - - + + Language Modeling ^^^^^^^^^^^^^^^^^ @@ -124,28 +124,28 @@ WikiText-2 ~~~~~~~~~~ .. autoclass:: WikiText2 - :members: __init__ + :members: __init__ WikiText103 ~~~~~~~~~~~ .. autoclass:: WikiText103 - :members: __init__ + :members: __init__ PennTreebank ~~~~~~~~~~~~ .. autoclass:: PennTreebank - :members: __init__ + :members: __init__ WMTNewsCrawl ~~~~~~~~~~~~ -.. autoclass:: WMTNewsCrawl - :members: __init__ +.. autoclass:: WMTNewsCrawl + :members: __init__ Machine Translation @@ -177,12 +177,33 @@ WMT14 .. autoclass:: WMT14 :members: __init__ + +Sequence Tagging +^^^^^^^^^^^^^^^^ + +Language modeling datasets are subclasses of ``SequenceTaggingDataset`` class. + +.. autoclass:: SequenceTaggingDataset + :members: __init__ + +UDPOS +~~~~~ + +.. autoclass:: UDPOS + :members: __init__ + +CoNLL2000Chunking +~~~~~ + +.. autoclass:: CoNLL2000Chunking + :members: __init__ + Question Answer ^^^^^^^^^^^^^^^ Question answer datasets are subclasses of ``QuestionAnswerDataset`` class. -.. autoclass:: QuestionAnswerDataset +.. autoclass:: QuestionAnswerDataset :members: __init__ @@ -190,11 +211,11 @@ SQuAD 1.0 ~~~~~~~~~ .. autoclass:: SQuAD1 - :members: __init__ + :members: __init__ SQuAD 2.0 ~~~~~~~~~ .. autoclass:: SQuAD2 - :members: __init__ + :members: __init__ diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index aa0fff33a5..1c120da16c 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -156,6 +156,108 @@ def test_multi30k(self): "multi30k_task*.tar.gz") conditional_remove(datafile) + def test_udpos_sequence_tagging(self): + from torchtext.experimental.datasets import UDPOS + + # smoke test to ensure imdb works properly + train_dataset, valid_dataset, test_dataset = UDPOS() + self.assertEqual(len(train_dataset), 12543) + self.assertEqual(len(valid_dataset), 2002) + self.assertEqual(len(test_dataset), 2077) + self.assertEqual(train_dataset[0][0][:10], + torch.tensor([262, 16, 5728, 45, 289, 701, 1160, 4436, 10660, 585]).long()) + self.assertEqual(train_dataset[0][1][:10], + torch.tensor([8, 3, 8, 3, 9, 2, 4, 8, 8, 8]).long()) + self.assertEqual(train_dataset[0][2][:10], + torch.tensor([5, 34, 5, 27, 7, 11, 14, 5, 5, 5]).long()) + self.assertEqual(train_dataset[-1][0][:10], + torch.tensor([9, 32, 169, 436, 59, 192, 30, 6, 117, 17]).long()) + self.assertEqual(train_dataset[-1][1][:10], + torch.tensor([5, 10, 11, 4, 11, 11, 3, 12, 11, 4]).long()) + self.assertEqual(train_dataset[-1][2][:10], + torch.tensor([6, 20, 8, 10, 8, 8, 24, 13, 8, 15]).long()) + + self.assertEqual(valid_dataset[0][0][:10], + torch.tensor([746, 3, 10633, 656, 25, 1334, 45]).long()) + self.assertEqual(valid_dataset[0][1][:10], + torch.tensor([6, 7, 8, 4, 7, 2, 3]).long()) + self.assertEqual(valid_dataset[0][2][:10], + torch.tensor([3, 4, 5, 16, 4, 2, 27]).long()) + self.assertEqual(valid_dataset[-1][0][:10], + torch.tensor([354, 4, 31, 17, 141, 421, 148, 6, 7, 78]).long()) + self.assertEqual(valid_dataset[-1][1][:10], + torch.tensor([11, 3, 5, 4, 9, 2, 2, 12, 7, 11]).long()) + self.assertEqual(valid_dataset[-1][2][:10], + torch.tensor([8, 12, 6, 15, 7, 2, 2, 13, 4, 8]).long()) + + self.assertEqual(test_dataset[0][0][:10], + torch.tensor([210, 54, 3115, 0, 12229, 0, 33]).long()) + self.assertEqual(test_dataset[0][1][:10], + torch.tensor([5, 15, 8, 4, 6, 8, 3]).long()) + self.assertEqual(test_dataset[0][2][:10], + torch.tensor([30, 3, 5, 14, 3, 5, 9]).long()) + self.assertEqual(test_dataset[-1][0][:10], + torch.tensor([116, 0, 6, 11, 412, 10, 0, 4, 0, 6]).long()) + self.assertEqual(test_dataset[-1][1][:10], + torch.tensor([5, 4, 12, 10, 9, 15, 4, 3, 4, 12]).long()) + self.assertEqual(test_dataset[-1][2][:10], + torch.tensor([6, 16, 13, 16, 7, 3, 19, 12, 19, 13]).long()) + + # Assert vocabs + self.assertEqual(len(train_dataset.get_vocabs()), 3) + self.assertEqual(len(train_dataset.get_vocabs()[0]), 19674) + self.assertEqual(len(train_dataset.get_vocabs()[1]), 19) + self.assertEqual(len(train_dataset.get_vocabs()[2]), 52) + + # Assert token ids + word_vocab = train_dataset.get_vocabs()[0] + tokens_ids = [word_vocab[token] for token in 'Two of them were being run'.split()] + self.assertEqual(tokens_ids, [1206, 8, 69, 60, 157, 452]) + + def test_conll_sequence_tagging(self): + from torchtext.experimental.datasets import CoNLL2000Chunking + + # smoke test to ensure imdb works properly + train_dataset, test_dataset = CoNLL2000Chunking() + self.assertEqual(len(train_dataset), 8936) + self.assertEqual(len(test_dataset), 2012) + self.assertEqual(train_dataset[0][0][:10], + torch.tensor([11556, 9, 3, 1775, 17, 1164, 177, 6, 212, 317]).long()) + self.assertEqual(train_dataset[0][1][:10], + torch.tensor([2, 3, 5, 2, 17, 12, 16, 15, 13, 5]).long()) + self.assertEqual(train_dataset[0][2][:10], + torch.tensor([3, 6, 3, 2, 5, 7, 7, 7, 7, 3]).long()) + self.assertEqual(train_dataset[-1][0][:10], + torch.tensor([85, 17, 59, 6473, 288, 115, 72, 5, 2294, 2502]).long()) + self.assertEqual(train_dataset[-1][1][:10], + torch.tensor([18, 17, 12, 19, 10, 6, 3, 3, 4, 4]).long()) + self.assertEqual(train_dataset[-1][2][:10], + torch.tensor([3, 5, 7, 7, 3, 2, 6, 6, 3, 2]).long()) + + self.assertEqual(test_dataset[0][0][:10], + torch.tensor([0, 294, 73, 10, 13582, 194, 18, 24, 2414, 7]).long()) + self.assertEqual(test_dataset[0][1][:10], + torch.tensor([4, 4, 4, 23, 4, 2, 11, 18, 11, 5]).long()) + self.assertEqual(test_dataset[0][2][:10], + torch.tensor([3, 2, 2, 3, 2, 2, 5, 3, 5, 3]).long()) + self.assertEqual(test_dataset[-1][0][:10], + torch.tensor([51, 456, 560, 2, 11, 465, 2, 1413, 36, 60]).long()) + self.assertEqual(test_dataset[-1][1][:10], + torch.tensor([3, 4, 4, 8, 3, 2, 8, 4, 17, 16]).long()) + self.assertEqual(test_dataset[-1][2][:10], + torch.tensor([6, 3, 2, 4, 6, 3, 4, 3, 5, 7]).long()) + + # Assert vocabs + self.assertEqual(len(train_dataset.get_vocabs()), 3) + self.assertEqual(len(train_dataset.get_vocabs()[0]), 19124) + self.assertEqual(len(train_dataset.get_vocabs()[1]), 46) + self.assertEqual(len(train_dataset.get_vocabs()[2]), 24) + + # Assert token ids + word_vocab = train_dataset.get_vocabs()[0] + tokens_ids = [word_vocab[token] for token in 'Two of them were being run'.split()] + self.assertEqual(tokens_ids, [970, 5, 135, 43, 214, 690]) + def test_squad1(self): from torchtext.experimental.datasets import SQuAD1 from torchtext.vocab import Vocab diff --git a/torchtext/experimental/datasets/__init__.py b/torchtext/experimental/datasets/__init__.py index 7929f5f733..b86e9f4756 100644 --- a/torchtext/experimental/datasets/__init__.py +++ b/torchtext/experimental/datasets/__init__.py @@ -2,6 +2,7 @@ from .text_classification import AG_NEWS, SogouNews, DBpedia, YelpReviewPolarity, \ YelpReviewFull, YahooAnswers, \ AmazonReviewPolarity, AmazonReviewFull, IMDB +from .sequence_tagging import UDPOS, CoNLL2000Chunking from .translation import Multi30k, IWSLT, WMT14 from .question_answer import SQuAD1, SQuAD2 @@ -19,6 +20,8 @@ 'YahooAnswers', 'AmazonReviewPolarity', 'AmazonReviewFull', + 'UDPOS', + 'CoNLL2000Chunking', 'Multi30k', 'IWSLT', 'WMT14', diff --git a/torchtext/experimental/datasets/raw/__init__.py b/torchtext/experimental/datasets/raw/__init__.py index ad100429b0..ec3371de26 100644 --- a/torchtext/experimental/datasets/raw/__init__.py +++ b/torchtext/experimental/datasets/raw/__init__.py @@ -1,6 +1,7 @@ from .text_classification import AG_NEWS, SogouNews, DBpedia, YelpReviewPolarity, \ YelpReviewFull, YahooAnswers, \ AmazonReviewPolarity, AmazonReviewFull, IMDB +from .sequence_tagging import UDPOS, CoNLL2000Chunking from .translation import Multi30k, IWSLT, WMT14 from .language_modeling import WikiText2, WikiText103, PennTreebank, WMTNewsCrawl from .question_answer import SQuAD1, SQuAD2 @@ -14,6 +15,8 @@ 'YahooAnswers', 'AmazonReviewPolarity', 'AmazonReviewFull', + 'UDPOS', + 'CoNLL2000Chunking', 'Multi30k', 'IWSLT', 'WMT14', diff --git a/torchtext/experimental/datasets/raw/sequence_tagging.py b/torchtext/experimental/datasets/raw/sequence_tagging.py new file mode 100644 index 0000000000..b4576e7f81 --- /dev/null +++ b/torchtext/experimental/datasets/raw/sequence_tagging.py @@ -0,0 +1,136 @@ +import torch + +from torchtext.utils import download_from_url, extract_archive + +URLS = { + "UDPOS": + 'https://bitbucket.org/sivareddyg/public/downloads/en-ud-v2.zip', + "CoNLL2000Chunking": [ + 'https://www.clips.uantwerpen.be/conll2000/chunking/train.txt.gz', + 'https://www.clips.uantwerpen.be/conll2000/chunking/test.txt.gz' + ] +} + + +def _create_data_from_iob(data_path, separator="\t"): + with open(data_path, encoding="utf-8") as input_file: + columns = [] + for line in input_file: + line = line.strip() + if line == "": + if columns: + yield columns + columns = [] + else: + for i, column in enumerate(line.split(separator)): + if len(columns) < i + 1: + columns.append([]) + columns[i].append(column) + if len(columns) > 0: + yield columns + + +def _construct_filepath(paths, file_suffix): + if file_suffix: + path = None + for p in paths: + path = p if p.endswith(file_suffix) else path + return path + return None + + +def _setup_datasets(dataset_name, separator, root=".data"): + + extracted_files = [] + if isinstance(URLS[dataset_name], list): + for f in URLS[dataset_name]: + dataset_tar = download_from_url(f, root=root) + extracted_files.extend(extract_archive(dataset_tar)) + elif isinstance(URLS[dataset_name], str): + dataset_tar = download_from_url(URLS[dataset_name], root=root) + extracted_files.extend(extract_archive(dataset_tar)) + else: + raise ValueError( + "URLS for {} has to be in a form or list or string".format( + dataset_name)) + + data_filenames = { + "train": _construct_filepath(extracted_files, "train.txt"), + "valid": _construct_filepath(extracted_files, "dev.txt"), + "test": _construct_filepath(extracted_files, "test.txt") + } + + datasets = [] + for key in data_filenames.keys(): + if data_filenames[key] is not None: + datasets.append( + RawSequenceTaggingIterableDataset( + _create_data_from_iob(data_filenames[key], separator))) + else: + datasets.append(None) + + return datasets + + +class RawSequenceTaggingIterableDataset(torch.utils.data.IterableDataset): + """Defines an abstraction for raw text sequence tagging iterable datasets. + """ + def __init__(self, iterator): + super(RawSequenceTaggingIterableDataset).__init__() + + self._iterator = iterator + self.has_setup = False + self.start = 0 + self.num_lines = None + + def setup_iter(self, start=0, num_lines=None): + self.start = start + self.num_lines = num_lines + self.has_setup = True + + def __iter__(self): + if not self.has_setup: + self.setup_iter() + + for i, item in enumerate(self._iterator): + if i >= self.start: + yield item + if (self.num_lines is not None) and (i == (self.start + + self.num_lines)): + break + + def get_iterator(self): + return self._iterator + + +def UDPOS(*args, **kwargs): + """ Universal Dependencies English Web Treebank + + Separately returns the training and test dataset + + Arguments: + root: Directory where the datasets are saved. Default: ".data" + + Examples: + >>> from torchtext.datasets.raw import UDPOS + >>> train_dataset, valid_dataset, test_dataset = UDPOS() + """ + return _setup_datasets(*(("UDPOS", "\t") + args), **kwargs) + + +def CoNLL2000Chunking(*args, **kwargs): + """ CoNLL 2000 Chunking Dataset + + Separately returns the training and test dataset + + Arguments: + root: Directory where the datasets are saved. Default: ".data" + + Examples: + >>> from torchtext.datasets.raw import CoNLL2000Chunking + >>> train_dataset, valid_dataset, test_dataset = CoNLL2000Chunking() + """ + return _setup_datasets(*(("CoNLL2000Chunking", " ") + args), **kwargs) + + +DATASETS = {"UDPOS": UDPOS, "CoNLL2000Chunking": CoNLL2000Chunking} diff --git a/torchtext/experimental/datasets/sequence_tagging.py b/torchtext/experimental/datasets/sequence_tagging.py new file mode 100644 index 0000000000..8a6a97a55e --- /dev/null +++ b/torchtext/experimental/datasets/sequence_tagging.py @@ -0,0 +1,169 @@ +import torch + +from torchtext.experimental.datasets import raw +from torchtext.vocab import build_vocab_from_iterator +from torchtext.experimental.functional import ( + vocab_func, + totensor, + sequential_transforms, +) + + +def _build_vocab(data): + total_columns = len(data[0]) + data_list = [[] for _ in range(total_columns)] + vocabs = [] + + for line in data: + for idx, col in enumerate(line): + data_list[idx].append(col) + + for it in data_list: + vocabs.append(build_vocab_from_iterator(it)) + + return vocabs + + +def _setup_datasets(dataset_name, + root=".data", + vocabs=None, + data_select=("train", "valid", "test")): + if isinstance(data_select, str): + data_select = [data_select] + if not set(data_select).issubset(set(("train", "valid", "test"))): + raise TypeError("Given data selection {} is not supported!".format(data_select)) + + train, val, test = DATASETS[dataset_name](root=root) + raw_data = { + "train": [line for line in train] if train else None, + "valid": [line for line in val] if val else None, + "test": [line for line in test] if test else None + } + + if vocabs is None: + if "train" not in data_select: + raise TypeError("Must pass a vocab if train is not selected.") + vocabs = _build_vocab(raw_data["train"]) + else: + if not isinstance(vocabs, list): + raise TypeError("vocabs must be an instance of list") + + # Find data that's not None + notnone_data = None + for key in raw_data.keys(): + if raw_data[key] is not None: + notnone_data = raw_data[key] + break + if len(vocabs) != len(notnone_data[0]): + raise ValueError( + "Number of vocabs must match the number of columns " + "in the data") + + transformers = [ + sequential_transforms(vocab_func(vocabs[idx]), + totensor(dtype=torch.long)) + for idx in range(len(vocabs)) + ] + + datasets = [] + for item in data_select: + if raw_data[item] is not None: + datasets.append( + SequenceTaggingDataset(raw_data[item], vocabs, transformers)) + + return datasets + + +class SequenceTaggingDataset(torch.utils.data.Dataset): + """Defines an abstraction for raw text sequence tagging iterable datasets. + Currently, we only support the following datasets: + - UDPOS + - CoNLL2000Chunking + """ + def __init__(self, data, vocabs, transforms): + """Initiate sequence tagging dataset. + Arguments: + data: a list of word and its respective tags. Example: + [[word, POS, dep_parsing label, ...]] + vocabs: a list of vocabularies for its respective tags. + The number of vocabs must be the same as the number of columns + found in the data. + transforms: a list of string transforms for words and tags. + The number of transforms must be the same as the number of columns + found in the data. + """ + + super(SequenceTaggingDataset, self).__init__() + self.data = data + self.vocabs = vocabs + self.transforms = transforms + + if len(self.data[0]) != len(self.vocabs): + raise ValueError("vocabs must have the same number of columns " + "as the data") + + def __getitem__(self, i): + curr_data = self.data[i] + if len(curr_data) != len(self.transforms): + raise ValueError("data must have the same number of columns " + "with transforms function") + return [self.transforms[idx](curr_data[idx]) for idx in range(len(self.transforms))] + + def __len__(self): + return len(self.data) + + def get_vocabs(self): + return self.vocabs + + +def UDPOS(*args, **kwargs): + """ Universal Dependencies English Web Treebank + + Separately returns the training, validation, and test dataset + + Arguments: + root: Directory where the datasets are saved. Default: ".data" + vocabs: A list of voabularies for each columns in the dataset. Must be in an + instance of List + Default: None + data_select: a string or tuple for the returned datasets + (Default: ('train', 'valid', 'test')) + By default, all the three datasets (train, test, valid) are generated. Users + could also choose any one or two of them, for example ('train', 'test') or + just a string 'train'. If 'train' is not in the tuple or string, a vocab + object should be provided which will be used to process valid and/or test + data. + + Examples: + >>> from torchtext.datasets.raw import UDPOS + >>> train_dataset, valid_dataset, test_dataset = UDPOS() + """ + return _setup_datasets(*(("UDPOS", ) + args), **kwargs) + + +def CoNLL2000Chunking(*args, **kwargs): + """ CoNLL 2000 Chunking Dataset + + Separately returns the training and test dataset + + Arguments: + root: Directory where the datasets are saved. Default: ".data" + vocabs: A list of voabularies for each columns in the dataset. Must be in an + instance of List + Default: None + data_select: a string or tuple for the returned datasets + (Default: ('train', 'valid', 'test')) + By default, all the three datasets (train, test, valid) are generated. Users + could also choose any one or two of them, for example ('train', 'test') or + just a string 'train'. If 'train' is not in the tuple or string, a vocab + object should be provided which will be used to process valid and/or test + data. + + Examples: + >>> from torchtext.datasets.raw import CoNLL2000Chunking + >>> train_dataset, valid_dataset, test_dataset = CoNLL2000Chunking() + """ + return _setup_datasets(*(("CoNLL2000Chunking", ) + args), **kwargs) + + +DATASETS = {"UDPOS": raw.UDPOS, "CoNLL2000Chunking": raw.CoNLL2000Chunking}