diff --git a/README.rst b/README.rst index 6169e26a04..df6dc9a29c 100644 --- a/README.rst +++ b/README.rst @@ -129,6 +129,16 @@ Others are planned or a work in progress: See the ``test`` directory for examples of dataset usage. +Legacy Code +=========== + +We have currently retired several datasets and moved them under ```torchtext.legacy```: + +* Sentiment analysis: IMDb +* Language modeling: abstract class + WikiText-2, WikiText103, PennTreebank + +These datasets are re-written with a new pattern that is introduced in `Release v0.5.0 `_. + Disclaimer on Datasets ====================== diff --git a/docs/source/legacy/datasets.rst b/docs/source/legacy/datasets.rst new file mode 100644 index 0000000000..859b5df5da --- /dev/null +++ b/docs/source/legacy/datasets.rst @@ -0,0 +1,69 @@ +torchtext.legacy.datasets +==================== + +.. currentmodule:: torchtext.legacy.datasets + +TorchText legacy datasets. + +All datasets are subclasses of :class:`torchtext.data.Dataset`, which +inherits from :class:`torch.utils.data.Dataset` i.e, they have ``split`` and +``iters`` methods implemented. + +General use cases are as follows: + +Approach 1, ``splits``: :: + + # set up fields + TEXT = data.Field(lower=True, include_lengths=True, batch_first=True) + LABEL = data.Field(sequential=False) + + # make splits for data + train, test = datasets.IMDB.splits(TEXT, LABEL) + + # build the vocabulary + TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300)) + LABEL.build_vocab(train) + + # make iterator for splits + train_iter, test_iter = data.BucketIterator.splits( + (train, test), batch_size=3, device=0) + +Approach 2, ``iters``: :: + + # use default configurations + train_iter, test_iter = datasets.IMDB.iters(batch_size=4) + +The following datasets are available: + +.. contents:: Datasets + :local: + + +Language Modeling +^^^^^^^^^^^^^^^^^ + +Language modeling datasets are subclasses of ``LanguageModelingDataset`` class. + +.. autoclass:: LanguageModelingDataset + :members: __init__ + + +WikiText-2 +~~~~~~~~~~ + +.. autoclass:: WikiText2 + :members: splits, iters + + +WikiText103 +~~~~~~~~~~~ + +.. autoclass:: WikiText103 + :members: splits, iters + + +PennTreebank +~~~~~~~~~~~~ + +.. autoclass:: PennTreebank + :members: splits, iters diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index c16777c214..3e1102ef44 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -1,6 +1,6 @@ import os +import shutil import torchtext.data as data -from torchtext.datasets import WikiText2, PennTreebank from torchtext.datasets import AG_NEWS from ..common.test_markers import slow @@ -10,11 +10,14 @@ def conditional_remove(f): if os.path.isfile(f): os.remove(f) + elif os.path.isdir(f): + shutil.rmtree(f) class TestDataset(TorchtextTestCase): @slow - def test_wikitext2(self): + def test_wikitext2_legacy(self): + from torchtext.legacy.datasets import WikiText2 # smoke test to ensure wikitext2 works properly ds = WikiText2 TEXT = data.Field(lower=True, batch_first=True) @@ -27,12 +30,30 @@ def test_wikitext2(self): bptt_len=30) # Delete the dataset after we're done to save disk space on CI - if os.environ.get("TRAVIS") == "true": - datafile = os.path.join(self.project_root, ".data", "wikitext-2") - conditional_remove(datafile) + datafile = os.path.join(self.project_root, ".data", "wikitext-2") + conditional_remove(datafile) + + def test_wikitext2(self): + from torchtext.datasets import WikiText2 + # smoke test to ensure wikitext2 works properly + train_dataset, test_dataset, valid_dataset = WikiText2() + self.assertEqual(len(train_dataset), 2049990) + self.assertEqual(len(test_dataset), 241859) + self.assertEqual(len(valid_dataset), 214417) + + vocab = train_dataset.get_vocab() + tokens_ids = [vocab[token] for token in 'the player characters rest'.split()] + self.assertEqual(tokens_ids, [2, 286, 503, 700]) + + # Delete the dataset after we're done to save disk space on CI + datafile = os.path.join(self.project_root, ".data", "wikitext-2") + conditional_remove(datafile) + datafile = os.path.join(self.project_root, ".data", "wikitext-2-v1.zip") + conditional_remove(datafile) @slow - def test_penntreebank(self): + def test_penntreebank_legacy(self): + from torchtext.legacy.datasets import PennTreebank # smoke test to ensure penn treebank works properly TEXT = data.Field(lower=True, batch_first=True) ds = PennTreebank @@ -45,9 +66,28 @@ def test_penntreebank(self): bptt_len=30) # Delete the dataset after we're done to save disk space on CI - if os.environ.get("TRAVIS") == "true": - datafile = os.path.join(self.project_root, ".data", "penn-treebank") - conditional_remove(datafile) + datafile = os.path.join(self.project_root, ".data", "penn-treebank") + conditional_remove(datafile) + + def test_penntreebank(self): + from torchtext.datasets import PennTreebank + # smoke test to ensure wikitext2 works properly + train_dataset, test_dataset, valid_dataset = PennTreebank() + self.assertEqual(len(train_dataset), 924412) + self.assertEqual(len(test_dataset), 82114) + self.assertEqual(len(valid_dataset), 73339) + + vocab = train_dataset.get_vocab() + tokens_ids = [vocab[token] for token in 'the player characters rest'.split()] + self.assertEqual(tokens_ids, [2, 2550, 3344, 1125]) + + # Delete the dataset after we're done to save disk space on CI + datafile = os.path.join(self.project_root, ".data", 'ptb.train.txt') + conditional_remove(datafile) + datafile = os.path.join(self.project_root, ".data", 'ptb.test.txt') + conditional_remove(datafile) + datafile = os.path.join(self.project_root, ".data", 'ptb.valid.txt') + conditional_remove(datafile) def test_text_classification(self): # smoke test to ensure ag_news dataset works properly @@ -60,6 +100,7 @@ def test_text_classification(self): self.assertEqual(len(ag_news_test), 7600) # Delete the dataset after we're done to save disk space on CI - if os.environ.get("TRAVIS") == "true": - datafile = os.path.join(self.project_root, ".data", "AG_NEWS") - conditional_remove(datafile) + datafile = os.path.join(self.project_root, ".data", "ag_news_csv") + conditional_remove(datafile) + datafile = os.path.join(self.project_root, ".data", "ag_news_csv.tar.gz") + conditional_remove(datafile) diff --git a/torchtext/__init__.py b/torchtext/__init__.py index e799058a88..2701580acb 100644 --- a/torchtext/__init__.py +++ b/torchtext/__init__.py @@ -2,10 +2,12 @@ from . import datasets from . import utils from . import vocab +from . import legacy __version__ = '0.4.0' __all__ = ['data', 'datasets', 'utils', - 'vocab'] + 'vocab', + 'legacy'] diff --git a/torchtext/data/__init__.py b/torchtext/data/__init__.py index ea0096c6c2..20ca10a7a3 100644 --- a/torchtext/data/__init__.py +++ b/torchtext/data/__init__.py @@ -10,7 +10,8 @@ from .functional import generate_sp_model, \ load_sp_model, \ sentencepiece_numericalizer, \ - sentencepiece_tokenizer, custom_replace, simple_space_split + sentencepiece_tokenizer, custom_replace, simple_space_split, \ + numericalize_tokens_from_iterator __all__ = ["Batch", "Dataset", "TabularDataset", @@ -24,4 +25,5 @@ "get_tokenizer", "interleave_keys", "generate_sp_model", "load_sp_model", "sentencepiece_numericalizer", "sentencepiece_tokenizer", - "custom_replace", "simple_space_split"] + "custom_replace", "simple_space_split", + "numericalize_tokens_from_iterator"] diff --git a/torchtext/data/functional.py b/torchtext/data/functional.py index 07c16e5a2d..3aad8aa7fc 100644 --- a/torchtext/data/functional.py +++ b/torchtext/data/functional.py @@ -1,7 +1,6 @@ import sentencepiece as spm import re - __all__ = [ "generate_sp_model", "load_sp_model", "sentencepiece_numericalizer", "sentencepiece_tokenizer" @@ -151,3 +150,31 @@ def simple_space_split(iterator): for line in iterator: yield line.split() + + +def numericalize_tokens_from_iterator(vocab, iterator, removed_tokens=None): + r"""Yield a list of ids from an token iterator with a vocab. + + Arguments: + vocab: the vocabulary convert token into id. + iterator: the iterator yield a list of tokens. + removed_tokens: removed tokens from output dataset (Default: None) + + Examples: + >>> from torchtext.data.functional import simple_space_split + >>> from torchtext.data.functional import numericalize_tokens_from_iterator + >>> vocab = {'Sentencepiece' : 0, 'encode' : 1, 'as' : 2, 'pieces' : 3} + >>> ids_iter = numericalize_tokens_from_iterator(vocab, + >>> simple_space_split(["Sentencepiece as pieces", + >>> "as pieces"])) + >>> for ids in ids_iter: + >>> print([num for num in ids]) + >>> [0, 2, 3] + >>> [2, 3] + """ + for tokens in iterator: + if removed_tokens is None: + yield iter(vocab[token] for token in tokens) + else: + yield iter(map(lambda x: vocab[x], + filter(lambda x: x not in removed_tokens, tokens))) diff --git a/torchtext/datasets/language_modeling.py b/torchtext/datasets/language_modeling.py index 775fe51d9e..04c7f5c00d 100644 --- a/torchtext/datasets/language_modeling.py +++ b/torchtext/datasets/language_modeling.py @@ -1,217 +1,241 @@ -from .. import data +import torch +import logging +import os import io +from torchtext.utils import download_from_url, extract_archive +from torchtext.vocab import build_vocab_from_iterator +from torchtext.data.utils import get_tokenizer +from torchtext.vocab import Vocab +from torchtext.data.functional import numericalize_tokens_from_iterator + +URLS = { + 'WikiText2': + 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip', + 'WikiText103': + 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip', + 'PennTreebank': + ['https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt', + 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt', + 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt'] +} + + +class LanguageModelingDataset(torch.utils.data.Dataset): + """Defines a dataset for language modeling. + Currently, we only support the following datasets: + + - WikiText2 + - WikiText103 + - PennTreebank + """ -class LanguageModelingDataset(data.Dataset): - """Defines a dataset for language modeling.""" - - def __init__(self, path, text_field, newline_eos=True, - encoding='utf-8', **kwargs): - """Create a LanguageModelingDataset given a path and a field. - - Arguments: - path: Path to the data file. - text_field: The field that will be used for text data. - newline_eos: Whether to add an token for every newline in the - data file. Default: True. - Remaining keyword arguments: Passed to the constructor of - data.Dataset. - """ - fields = [('text', text_field)] - text = [] - with io.open(path, encoding=encoding) as f: - for line in f: - text += text_field.preprocess(line) - if newline_eos: - text.append(u'') - - examples = [data.Example.fromlist([text], fields)] - super(LanguageModelingDataset, self).__init__( - examples, fields, **kwargs) - - -class WikiText2(LanguageModelingDataset): - - urls = ['https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'] - name = 'wikitext-2' - dirname = 'wikitext-2' - - @classmethod - def splits(cls, text_field, root='.data', train='wiki.train.tokens', - validation='wiki.valid.tokens', test='wiki.test.tokens', - **kwargs): - """Create dataset objects for splits of the WikiText-2 dataset. - - This is the most flexible way to use the dataset. - - Arguments: - text_field: The field that will be used for text data. - root: The root directory that the dataset's zip archive will be - expanded into; therefore the directory in whose wikitext-2 - subdirectory the data files will be stored. - train: The filename of the train data. Default: 'wiki.train.tokens'. - validation: The filename of the validation data, or None to not - load the validation set. Default: 'wiki.valid.tokens'. - test: The filename of the test data, or None to not load the test - set. Default: 'wiki.test.tokens'. - """ - return super(WikiText2, cls).splits( - root=root, train=train, validation=validation, test=test, - text_field=text_field, **kwargs) - - @classmethod - def iters(cls, batch_size=32, bptt_len=35, device=0, root='.data', - vectors=None, **kwargs): - """Create iterator objects for splits of the WikiText-2 dataset. - - This is the simplest way to use the dataset, and assumes common - defaults for field, vocabulary, and iterator parameters. - - Arguments: - batch_size: Batch size. - bptt_len: Length of sequences for backpropagation through time. - device: Device to create batches on. Use -1 for CPU and None for - the currently active GPU device. - root: The root directory that the dataset's zip archive will be - expanded into; therefore the directory in whose wikitext-2 - subdirectory the data files will be stored. - wv_dir, wv_type, wv_dim: Passed to the Vocab constructor for the - text field. The word vectors are accessible as - train.dataset.fields['text'].vocab.vectors. - Remaining keyword arguments: Passed to the splits method. - """ - TEXT = data.Field() - - train, val, test = cls.splits(TEXT, root=root, **kwargs) - - TEXT.build_vocab(train, vectors=vectors) - - return data.BPTTIterator.splits( - (train, val, test), batch_size=batch_size, bptt_len=bptt_len, - device=device) - - -class WikiText103(LanguageModelingDataset): - - urls = ['https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip'] - name = 'wikitext-103' - dirname = 'wikitext-103' - - @classmethod - def splits(cls, text_field, root='.data', train='wiki.train.tokens', - validation='wiki.valid.tokens', test='wiki.test.tokens', - **kwargs): - """Create dataset objects for splits of the WikiText-103 dataset. - - This is the most flexible way to use the dataset. + def __init__(self, data, vocab): + """Initiate language modeling dataset. Arguments: - text_field: The field that will be used for text data. - root: The root directory that the dataset's zip archive will be - expanded into; therefore the directory in whose wikitext-103 - subdirectory the data files will be stored. - train: The filename of the train data. Default: 'wiki.train.tokens'. - validation: The filename of the validation data, or None to not - load the validation set. Default: 'wiki.valid.tokens'. - test: The filename of the test data, or None to not load the test - set. Default: 'wiki.test.tokens'. - """ - return super(WikiText103, cls).splits( - root=root, train=train, validation=validation, test=test, - text_field=text_field, **kwargs) - - @classmethod - def iters(cls, batch_size=32, bptt_len=35, device=0, root='.data', - vectors=None, **kwargs): - """Create iterator objects for splits of the WikiText-103 dataset. + data: a tensor of tokens. tokens are ids after + numericalizing the string tokens. + torch.tensor([token_id_1, token_id_2, token_id_3, token_id1]).long() + vocab: Vocabulary object used for dataset. + + Examples: + >>> from torchtext.vocab import build_vocab_from_iterator + >>> data = torch.tensor([token_id_1, token_id_2, + token_id_3, token_id_1]).long() + >>> vocab = build_vocab_from_iterator([['language', 'modeling']]) + >>> dataset = LanguageModelingDataset(data, vocab) - This is the simplest way to use the dataset, and assumes common - defaults for field, vocabulary, and iterator parameters. - - Arguments: - batch_size: Batch size. - bptt_len: Length of sequences for backpropagation through time. - device: Device to create batches on. Use -1 for CPU and None for - the currently active GPU device. - root: The root directory that the dataset's zip archive will be - expanded into; therefore the directory in whose wikitext-2 - subdirectory the data files will be stored. - wv_dir, wv_type, wv_dim: Passed to the Vocab constructor for the - text field. The word vectors are accessible as - train.dataset.fields['text'].vocab.vectors. - Remaining keyword arguments: Passed to the splits method. """ - TEXT = data.Field() - - train, val, test = cls.splits(TEXT, root=root, **kwargs) - - TEXT.build_vocab(train, vectors=vectors) - - return data.BPTTIterator.splits( - (train, val, test), batch_size=batch_size, bptt_len=bptt_len, - device=device) + super(LanguageModelingDataset, self).__init__() + self.data = data + self.vocab = vocab + + def __getitem__(self, i): + return self.data[i] + + def __len__(self): + return len(self.data) + + def __iter__(self): + for x in self.data: + yield x + + def get_vocab(self): + return self.vocab + + +def _get_datafile_path(key, extracted_files): + for fname in extracted_files: + if key in fname: + return fname + + +def _setup_datasets(dataset_name, tokenizer=get_tokenizer("basic_english"), + root='.data', vocab=None, removed_tokens=[], + data_select=('train', 'test', 'valid')): + + if isinstance(data_select, str): + data_select = [data_select] + if not set(data_select).issubset(set(('train', 'test', 'valid'))): + raise TypeError('data_select is not supported!') + + if dataset_name == 'PennTreebank': + extracted_files = [] + select_to_index = {'train': 0, 'test': 1, 'valid': 2} + extracted_files = [download_from_url(URLS['PennTreebank'][select_to_index[key]], + root=root) for key in data_select] + else: + dataset_tar = download_from_url(URLS[dataset_name], root=root) + extracted_files = [os.path.join(root, d) for d in extract_archive(dataset_tar)] + + _path = {} + for item in data_select: + _path[item] = _get_datafile_path(item, extracted_files) + + if vocab is None: + if 'train' not in _path.keys(): + raise TypeError("Must pass a vocab if train is not selected.") + logging.info('Building Vocab based on {}'.format(_path['train'])) + txt_iter = iter(tokenizer(row) for row in io.open(_path['train'], + encoding="utf8")) + vocab = build_vocab_from_iterator(txt_iter) + logging.info('Vocab has {} entries'.format(len(vocab))) + else: + if not isinstance(vocab, Vocab): + raise TypeError("Passed vocabulary is not of type Vocab") + + data = {} + for item in _path.keys(): + data[item] = [] + logging.info('Creating {} data'.format(item)) + txt_iter = iter(tokenizer(row) for row in io.open(_path[item], + encoding="utf8")) + _iter = numericalize_tokens_from_iterator( + vocab, txt_iter, removed_tokens) + for tokens in _iter: + data[item] += [token_id for token_id in tokens] + + for key in data_select: + if data[key] == []: + raise TypeError('Dataset {} is empty!'.format(key)) + + return tuple(LanguageModelingDataset(torch.tensor(data[d]).long(), vocab) + for d in data_select) + + +def WikiText2(*args, **kwargs): + """ Defines WikiText2 datasets. + + Create language modeling dataset: WikiText2 + Separately returns the train/test/valid set + + Arguments: + tokenizer: the tokenizer used to preprocess raw text data. + The default one is basic_english tokenizer in fastText. spacy tokenizer + is supported as well (see example below). A custom tokenizer is callable + function with input of a string and output of a token list. + root: Directory where the datasets are saved. Default: ".data" + vocab: Vocabulary used for dataset. If None, it will generate a new + vocabulary based on the train data set. + removed_tokens: removed tokens from output dataset (Default: []) + data_select: a string or tupel for the returned datasets + (Default: ('train', 'test','valid')) + By default, all the three datasets (train, test, valid) are generated. Users + could also choose any one or two of them, for example ('train', 'test') or + just a string 'train'. If 'train' is not in the tuple or string, a vocab + object should be provided which will be used to process valid and/or test + data. + + Examples: + >>> from torchtext.datasets import WikiText2 + >>> from torchtext.data.utils import get_tokenizer + >>> tokenizer = get_tokenizer("spacy") + >>> train_dataset, test_dataset, valid_dataset = WikiText2(tokenizer=tokenizer) + >>> vocab = train_dataset.get_vocab() + >>> valid_dataset, = WikiText2(tokenizer=tokenizer, vocab=vocab, + data_select='valid') -class PennTreebank(LanguageModelingDataset): - """The Penn Treebank dataset. - A relatively small dataset originally created for POS tagging. - - References - ---------- - Marcus, Mitchell P., Marcinkiewicz, Mary Ann & Santorini, Beatrice (1993). - Building a Large Annotated Corpus of English: The Penn Treebank """ - urls = ['https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt', - 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt', - 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt'] - name = 'penn-treebank' - dirname = '' - - @classmethod - def splits(cls, text_field, root='.data', train='ptb.train.txt', - validation='ptb.valid.txt', test='ptb.test.txt', - **kwargs): - """Create dataset objects for splits of the Penn Treebank dataset. - - Arguments: - text_field: The field that will be used for text data. - root: The root directory where the data files will be stored. - train: The filename of the train data. Default: 'ptb.train.txt'. - validation: The filename of the validation data, or None to not - load the validation set. Default: 'ptb.valid.txt'. - test: The filename of the test data, or None to not load the test - set. Default: 'ptb.test.txt'. - """ - return super(PennTreebank, cls).splits( - root=root, train=train, validation=validation, test=test, - text_field=text_field, **kwargs) - - @classmethod - def iters(cls, batch_size=32, bptt_len=35, device=0, root='.data', - vectors=None, **kwargs): - """Create iterator objects for splits of the Penn Treebank dataset. + return _setup_datasets(*(("WikiText2",) + args), **kwargs) + + +def WikiText103(*args, **kwargs): + """ Defines WikiText103 datasets. + + Create language modeling dataset: WikiText103 + Separately returns the train/test/valid set + + Arguments: + tokenizer: the tokenizer used to preprocess raw text data. + The default one is basic_english tokenizer in fastText. spacy tokenizer + is supported as well (see example below). A custom tokenizer is callable + function with input of a string and output of a token list. + root: Directory where the datasets are saved. Default: ".data" + vocab: Vocabulary used for dataset. If None, it will generate a new + vocabulary based on the train data set. + data_select: the returned datasets (Default: ('train', 'test','valid')) + By default, all the three datasets (train, test, valid) are generated. Users + could also choose any one or two of them, for example ('train', 'test'). + If 'train' is not in the tuple, an vocab object should be provided which will + be used to process valid and/or test data. + removed_tokens: removed tokens from output dataset (Default: []) + data_select: a string or tupel for the returned datasets + (Default: ('train', 'test','valid')) + By default, all the three datasets (train, test, valid) are generated. Users + could also choose any one or two of them, for example ('train', 'test') or + just a string 'train'. If 'train' is not in the tuple or string, a vocab + object should be provided which will be used to process valid and/or test + data. + + Examples: + >>> from torchtext.datasets import WikiText103 + >>> from torchtext.data.utils import get_tokenizer + >>> tokenizer = get_tokenizer("spacy") + >>> train_dataset, test_dataset, valid_dataset = WikiText103(tokenizer=tokenizer) + >>> vocab = train_dataset.get_vocab() + >>> valid_dataset, = WikiText103(tokenizer=tokenizer, vocab=vocab, + data_select='valid') - This is the simplest way to use the dataset, and assumes common - defaults for field, vocabulary, and iterator parameters. - - Arguments: - batch_size: Batch size. - bptt_len: Length of sequences for backpropagation through time. - device: Device to create batches on. Use -1 for CPU and None for - the currently active GPU device. - root: The root directory where the data files will be stored. - wv_dir, wv_type, wv_dim: Passed to the Vocab constructor for the - text field. The word vectors are accessible as - train.dataset.fields['text'].vocab.vectors. - Remaining keyword arguments: Passed to the splits method. - """ - TEXT = data.Field() + """ - train, val, test = cls.splits(TEXT, root=root, **kwargs) + return _setup_datasets(*(("WikiText103",) + args), **kwargs) + + +def PennTreebank(*args, **kwargs): + """ Defines PennTreebank datasets. + + Create language modeling dataset: PennTreebank + Separately returns the train/test/valid set + + Arguments: + tokenizer: the tokenizer used to preprocess raw text data. + The default one is basic_english tokenizer in fastText. spacy tokenizer + is supported as well (see example below). A custom tokenizer is callable + function with input of a string and output of a token list. + root: Directory where the datasets are saved. Default: ".data" + vocab: Vocabulary used for dataset. If None, it will generate a new + vocabulary based on the train data set. + removed_tokens: removed tokens from output dataset (Default: []) + data_select: a string or tupel for the returned datasets + (Default: ('train', 'test','valid')) + By default, all the three datasets (train, test, valid) are generated. Users + could also choose any one or two of them, for example ('train', 'test') or + just a string 'train'. If 'train' is not in the tuple or string, a vocab + object should be provided which will be used to process valid and/or test + data. + + Examples: + >>> from torchtext.datasets import PennTreebank + >>> from torchtext.data.utils import get_tokenizer + >>> tokenizer = get_tokenizer("spacy") + >>> train_dataset, test_dataset, valid_dataset = PennTreebank(tokenizer=tokenizer) + >>> vocab = train_dataset.get_vocab() + >>> valid_dataset, = PennTreebank(tokenizer=tokenizer, vocab=vocab, + data_select='valid') - TEXT.build_vocab(train, vectors=vectors) + """ - return data.BPTTIterator.splits( - (train, val, test), batch_size=batch_size, bptt_len=bptt_len, - device=device) + return _setup_datasets(*(("PennTreebank",) + args), **kwargs) diff --git a/torchtext/legacy/__init__.py b/torchtext/legacy/__init__.py new file mode 100644 index 0000000000..d7fa116bab --- /dev/null +++ b/torchtext/legacy/__init__.py @@ -0,0 +1,3 @@ +from . import datasets + +__all__ = ['datasets'] diff --git a/torchtext/legacy/datasets/__init__.py b/torchtext/legacy/datasets/__init__.py new file mode 100644 index 0000000000..0320fc3d56 --- /dev/null +++ b/torchtext/legacy/datasets/__init__.py @@ -0,0 +1,4 @@ +from .language_modeling import LanguageModelingDataset, WikiText2, WikiText103, PennTreebank # NOQA + + +__all__ = ['LanguageModelingDataset'] diff --git a/torchtext/legacy/datasets/language_modeling.py b/torchtext/legacy/datasets/language_modeling.py new file mode 100644 index 0000000000..ed7b912efc --- /dev/null +++ b/torchtext/legacy/datasets/language_modeling.py @@ -0,0 +1,220 @@ +from torchtext import data +import io +import warnings + + +class LanguageModelingDataset(data.Dataset): + """Defines a dataset for language modeling.""" + + def __init__(self, path, text_field, newline_eos=True, + encoding='utf-8', **kwargs): + """Create a LanguageModelingDataset given a path and a field. + + Arguments: + path: Path to the data file. + text_field: The field that will be used for text data. + newline_eos: Whether to add an token for every newline in the + data file. Default: True. + Remaining keyword arguments: Passed to the constructor of + data.Dataset. + """ + warnings.warn("You are using a legacy code, which is not being covered " + "by the PyTorch team now !") + fields = [('text', text_field)] + text = [] + with io.open(path, encoding=encoding) as f: + for line in f: + text += text_field.preprocess(line) + if newline_eos: + text.append(u'') + + examples = [data.Example.fromlist([text], fields)] + super(LanguageModelingDataset, self).__init__( + examples, fields, **kwargs) + + +class WikiText2(LanguageModelingDataset): + + urls = ['https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'] + name = 'wikitext-2' + dirname = 'wikitext-2' + + @classmethod + def splits(cls, text_field, root='.data', train='wiki.train.tokens', + validation='wiki.valid.tokens', test='wiki.test.tokens', + **kwargs): + """Create dataset objects for splits of the WikiText-2 dataset. + + This is the most flexible way to use the dataset. + + Arguments: + text_field: The field that will be used for text data. + root: The root directory that the dataset's zip archive will be + expanded into; therefore the directory in whose wikitext-2 + subdirectory the data files will be stored. + train: The filename of the train data. Default: 'wiki.train.tokens'. + validation: The filename of the validation data, or None to not + load the validation set. Default: 'wiki.valid.tokens'. + test: The filename of the test data, or None to not load the test + set. Default: 'wiki.test.tokens'. + """ + return super(WikiText2, cls).splits( + root=root, train=train, validation=validation, test=test, + text_field=text_field, **kwargs) + + @classmethod + def iters(cls, batch_size=32, bptt_len=35, device=0, root='.data', + vectors=None, **kwargs): + """Create iterator objects for splits of the WikiText-2 dataset. + + This is the simplest way to use the dataset, and assumes common + defaults for field, vocabulary, and iterator parameters. + + Arguments: + batch_size: Batch size. + bptt_len: Length of sequences for backpropagation through time. + device: Device to create batches on. Use -1 for CPU and None for + the currently active GPU device. + root: The root directory that the dataset's zip archive will be + expanded into; therefore the directory in whose wikitext-2 + subdirectory the data files will be stored. + wv_dir, wv_type, wv_dim: Passed to the Vocab constructor for the + text field. The word vectors are accessible as + train.dataset.fields['text'].vocab.vectors. + Remaining keyword arguments: Passed to the splits method. + """ + TEXT = data.Field() + + train, val, test = cls.splits(TEXT, root=root, **kwargs) + + TEXT.build_vocab(train, vectors=vectors) + + return data.BPTTIterator.splits( + (train, val, test), batch_size=batch_size, bptt_len=bptt_len, + device=device) + + +class WikiText103(LanguageModelingDataset): + + urls = ['https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip'] + name = 'wikitext-103' + dirname = 'wikitext-103' + + @classmethod + def splits(cls, text_field, root='.data', train='wiki.train.tokens', + validation='wiki.valid.tokens', test='wiki.test.tokens', + **kwargs): + """Create dataset objects for splits of the WikiText-103 dataset. + + This is the most flexible way to use the dataset. + + Arguments: + text_field: The field that will be used for text data. + root: The root directory that the dataset's zip archive will be + expanded into; therefore the directory in whose wikitext-103 + subdirectory the data files will be stored. + train: The filename of the train data. Default: 'wiki.train.tokens'. + validation: The filename of the validation data, or None to not + load the validation set. Default: 'wiki.valid.tokens'. + test: The filename of the test data, or None to not load the test + set. Default: 'wiki.test.tokens'. + """ + return super(WikiText103, cls).splits( + root=root, train=train, validation=validation, test=test, + text_field=text_field, **kwargs) + + @classmethod + def iters(cls, batch_size=32, bptt_len=35, device=0, root='.data', + vectors=None, **kwargs): + """Create iterator objects for splits of the WikiText-103 dataset. + + This is the simplest way to use the dataset, and assumes common + defaults for field, vocabulary, and iterator parameters. + + Arguments: + batch_size: Batch size. + bptt_len: Length of sequences for backpropagation through time. + device: Device to create batches on. Use -1 for CPU and None for + the currently active GPU device. + root: The root directory that the dataset's zip archive will be + expanded into; therefore the directory in whose wikitext-2 + subdirectory the data files will be stored. + wv_dir, wv_type, wv_dim: Passed to the Vocab constructor for the + text field. The word vectors are accessible as + train.dataset.fields['text'].vocab.vectors. + Remaining keyword arguments: Passed to the splits method. + """ + TEXT = data.Field() + + train, val, test = cls.splits(TEXT, root=root, **kwargs) + + TEXT.build_vocab(train, vectors=vectors) + + return data.BPTTIterator.splits( + (train, val, test), batch_size=batch_size, bptt_len=bptt_len, + device=device) + + +class PennTreebank(LanguageModelingDataset): + """The Penn Treebank dataset. + A relatively small dataset originally created for POS tagging. + + References + ---------- + Marcus, Mitchell P., Marcinkiewicz, Mary Ann & Santorini, Beatrice (1993). + Building a Large Annotated Corpus of English: The Penn Treebank + """ + + urls = ['https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt', + 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt', + 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt'] + name = 'penn-treebank' + dirname = '' + + @classmethod + def splits(cls, text_field, root='.data', train='ptb.train.txt', + validation='ptb.valid.txt', test='ptb.test.txt', + **kwargs): + """Create dataset objects for splits of the Penn Treebank dataset. + + Arguments: + text_field: The field that will be used for text data. + root: The root directory where the data files will be stored. + train: The filename of the train data. Default: 'ptb.train.txt'. + validation: The filename of the validation data, or None to not + load the validation set. Default: 'ptb.valid.txt'. + test: The filename of the test data, or None to not load the test + set. Default: 'ptb.test.txt'. + """ + return super(PennTreebank, cls).splits( + root=root, train=train, validation=validation, test=test, + text_field=text_field, **kwargs) + + @classmethod + def iters(cls, batch_size=32, bptt_len=35, device=0, root='.data', + vectors=None, **kwargs): + """Create iterator objects for splits of the Penn Treebank dataset. + + This is the simplest way to use the dataset, and assumes common + defaults for field, vocabulary, and iterator parameters. + + Arguments: + batch_size: Batch size. + bptt_len: Length of sequences for backpropagation through time. + device: Device to create batches on. Use -1 for CPU and None for + the currently active GPU device. + root: The root directory where the data files will be stored. + wv_dir, wv_type, wv_dim: Passed to the Vocab constructor for the + text field. The word vectors are accessible as + train.dataset.fields['text'].vocab.vectors. + Remaining keyword arguments: Passed to the splits method. + """ + TEXT = data.Field() + + train, val, test = cls.splits(TEXT, root=root, **kwargs) + + TEXT.build_vocab(train, vectors=vectors) + + return data.BPTTIterator.splits( + (train, val, test), batch_size=batch_size, bptt_len=bptt_len, + device=device)