diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 08c75292c4..45e0d096bb 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -78,7 +78,7 @@ def test_penntreebank_legacy(self): def test_penntreebank(self): from torchtext.experimental.datasets import PennTreebank - # smoke test to ensure wikitext2 works properly + # smoke test to ensure penn treebank works properly train_dataset, test_dataset, valid_dataset = PennTreebank() self.assertEqual(len(train_dataset), 924412) self.assertEqual(len(test_dataset), 82114) diff --git a/torchtext/experimental/datasets/__init__.py b/torchtext/experimental/datasets/__init__.py index ac2faa423b..d91f60ff55 100644 --- a/torchtext/experimental/datasets/__init__.py +++ b/torchtext/experimental/datasets/__init__.py @@ -1,4 +1,4 @@ -from .language_modeling import LanguageModelingDataset, WikiText2, WikiText103, PennTreebank # NOQA +from .language_modeling import LanguageModelingDataset, WikiText2, WikiText103, PennTreebank, WMTNewsCrawl # NOQA from .text_classification import AG_NEWS, SogouNews, DBpedia, YelpReviewPolarity, \ YelpReviewFull, YahooAnswers, \ AmazonReviewPolarity, AmazonReviewFull, IMDB @@ -7,6 +7,7 @@ 'WikiText2', 'WikiText103', 'PennTreebank', + 'WMTNewsCrawl', 'IMDB', 'AG_NEWS', 'SogouNews', diff --git a/torchtext/experimental/datasets/language_modeling.py b/torchtext/experimental/datasets/language_modeling.py index 6c0ec5799e..f44de4af88 100644 --- a/torchtext/experimental/datasets/language_modeling.py +++ b/torchtext/experimental/datasets/language_modeling.py @@ -1,22 +1,15 @@ import torch -import logging -import io -from torchtext.utils import download_from_url, extract_archive -from torchtext.vocab import build_vocab_from_iterator from torchtext.data.utils import get_tokenizer -from torchtext.vocab import Vocab -from torchtext.data.functional import numericalize_tokens_from_iterator - -URLS = { - 'WikiText2': - 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip', - 'WikiText103': - 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip', - 'PennTreebank': - ['https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt', - 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt', - 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt'] -} +from torchtext.vocab import build_vocab_from_iterator +from torchtext.experimental.datasets.raw import language_modeling as raw +from torchtext.experimental.functional import vocab_func, totensor, sequential_transforms + + +def build_vocab(data, transforms): + tok_list = [] + for txt in data: + tok_list.append(transforms(txt)) + return build_vocab_from_iterator(tok_list) class LanguageModelingDataset(torch.utils.data.Dataset): @@ -26,10 +19,11 @@ class LanguageModelingDataset(torch.utils.data.Dataset): - WikiText2 - WikiText103 - PennTreebank + - WMTNewsCrawl """ - def __init__(self, data, vocab): + def __init__(self, data, vocab, transforms, single_line): """Initiate language modeling dataset. Arguments: @@ -37,22 +31,24 @@ def __init__(self, data, vocab): numericalizing the string tokens. torch.tensor([token_id_1, token_id_2, token_id_3, token_id1]).long() vocab: Vocabulary object used for dataset. - - Examples: - >>> from torchtext.vocab import build_vocab_from_iterator - >>> data = torch.tensor([token_id_1, token_id_2, - token_id_3, token_id_1]).long() - >>> vocab = build_vocab_from_iterator([['language', 'modeling']]) - >>> dataset = LanguageModelingDataset(data, vocab) + transforms: Text string transforms. """ super(LanguageModelingDataset, self).__init__() - self.data = data self.vocab = vocab + self.transforms = transforms + self.single_line = single_line + if single_line: + self.data = torch.cat(tuple(transforms(row) for row in data), axis=0) + else: + self.data = data def __getitem__(self, i): - return self.data[i] + if self.single_line: + return self.data[i] + else: + return self.transforms(self.data[i]) def __len__(self): return len(self.data) @@ -65,63 +61,45 @@ def get_vocab(self): return self.vocab -def _get_datafile_path(key, extracted_files): - for fname in extracted_files: - if key in fname: - return fname - - -def _setup_datasets(dataset_name, tokenizer=get_tokenizer("basic_english"), - root='.data', vocab=None, removed_tokens=[], - data_select=('train', 'test', 'valid')): +def _setup_datasets(dataset_name, tokenizer=None, root='.data', vocab=None, + data_select=('train', 'test', 'valid'), single_line=True): + if tokenizer is None: + tokenizer = get_tokenizer('basic_english') + text_transform = sequential_transforms(tokenizer) if isinstance(data_select, str): data_select = [data_select] - if not set(data_select).issubset(set(('train', 'test', 'valid'))): - raise TypeError('data_select is not supported!') - - if dataset_name == 'PennTreebank': - extracted_files = [] - select_to_index = {'train': 0, 'test': 1, 'valid': 2} - extracted_files = [download_from_url(URLS['PennTreebank'][select_to_index[key]], - root=root) for key in data_select] + if not set(data_select).issubset(set(('train', 'valid', 'test'))): + raise TypeError('Given data selection {} is not supported!'.format(data_select)) + + if not single_line and dataset_name != 'WikiText103': + raise TypeError('single_line must be True except for WikiText103') + if dataset_name == 'WMTNewsCrawl': + train, = raw.DATASETS[dataset_name](root=root, data_select=('train',)) + if single_line: + raw_data = {'train': [" ".join([txt for txt in train]), ]} + else: + raw_data = {'train': [txt for txt in train]} else: - dataset_tar = download_from_url(URLS[dataset_name], root=root) - extracted_files = extract_archive(dataset_tar) - - _path = {} - for item in data_select: - _path[item] = _get_datafile_path(item, extracted_files) + train, test, valid = raw.DATASETS[dataset_name](root=root, data_select=('train', 'test', 'valid')) + # Cache raw text iterable dataset + if single_line: + raw_data = {'train': [" ".join([txt for txt in train]), ], + 'valid': [" ".join(txt for txt in valid), ], + 'test': [" ".join(txt for txt in test), ]} + else: + raw_data = {'train': [txt for txt in train], + 'valid': [txt for txt in valid], + 'test': [txt for txt in test]} if vocab is None: - if 'train' not in _path.keys(): + if 'train' not in data_select: raise TypeError("Must pass a vocab if train is not selected.") - logging.info('Building Vocab based on {}'.format(_path['train'])) - txt_iter = iter(tokenizer(row) for row in io.open(_path['train'], - encoding="utf8")) - vocab = build_vocab_from_iterator(txt_iter) - logging.info('Vocab has {} entries'.format(len(vocab))) - else: - if not isinstance(vocab, Vocab): - raise TypeError("Passed vocabulary is not of type Vocab") - - data = {} - for item in _path.keys(): - data[item] = [] - logging.info('Creating {} data'.format(item)) - txt_iter = iter(tokenizer(row) for row in io.open(_path[item], - encoding="utf8")) - _iter = numericalize_tokens_from_iterator( - vocab, txt_iter, removed_tokens) - for tokens in _iter: - data[item] += [token_id for token_id in tokens] - - for key in data_select: - if data[key] == []: - raise TypeError('Dataset {} is empty!'.format(key)) - - return tuple(LanguageModelingDataset(torch.tensor(data[d]).long(), vocab) - for d in data_select) + vocab = build_vocab(raw_data['train'], text_transform) + text_transform = sequential_transforms(text_transform, vocab_func(vocab), + totensor(dtype=torch.long)) + return tuple(LanguageModelingDataset(raw_data[item], vocab, text_transform, single_line) + for item in data_select) def WikiText2(*args, **kwargs): @@ -138,7 +116,6 @@ def WikiText2(*args, **kwargs): root: Directory where the datasets are saved. Default: ".data" vocab: Vocabulary used for dataset. If None, it will generate a new vocabulary based on the train data set. - removed_tokens: removed tokens from output dataset (Default: []) data_select: a string or tupel for the returned datasets (Default: ('train', 'test','valid')) By default, all the three datasets (train, test, valid) are generated. Users @@ -146,6 +123,10 @@ def WikiText2(*args, **kwargs): just a string 'train'. If 'train' is not in the tuple or string, a vocab object should be provided which will be used to process valid and/or test data. + single_line: whether to return all tokens in a single line. + (Default: True) + By default, all lines in raw text file are concatenated into a single line. + Use `single_line = False` if one wants to get data line by line. Examples: >>> from torchtext.experimental.datasets import WikiText2 @@ -175,12 +156,6 @@ def WikiText103(*args, **kwargs): root: Directory where the datasets are saved. Default: ".data" vocab: Vocabulary used for dataset. If None, it will generate a new vocabulary based on the train data set. - data_select: the returned datasets (Default: ('train', 'test','valid')) - By default, all the three datasets (train, test, valid) are generated. Users - could also choose any one or two of them, for example ('train', 'test'). - If 'train' is not in the tuple, an vocab object should be provided which will - be used to process valid and/or test data. - removed_tokens: removed tokens from output dataset (Default: []) data_select: a string or tupel for the returned datasets (Default: ('train', 'test','valid')) By default, all the three datasets (train, test, valid) are generated. Users @@ -188,6 +163,10 @@ def WikiText103(*args, **kwargs): just a string 'train'. If 'train' is not in the tuple or string, a vocab object should be provided which will be used to process valid and/or test data. + single_line: whether to return all tokens in a single line. + (Default: True) + By default, all lines in raw text file are concatenated into a single line. + Use `single_line = False` if one wants to get data line by line. Examples: >>> from torchtext.experimental.datasets import WikiText103 @@ -217,7 +196,6 @@ def PennTreebank(*args, **kwargs): root: Directory where the datasets are saved. Default: ".data" vocab: Vocabulary used for dataset. If None, it will generate a new vocabulary based on the train data set. - removed_tokens: removed tokens from output dataset (Default: []) data_select: a string or tupel for the returned datasets (Default: ('train', 'test','valid')) By default, all the three datasets (train, test, valid) are generated. Users @@ -225,6 +203,10 @@ def PennTreebank(*args, **kwargs): just a string 'train'. If 'train' is not in the tuple or string, a vocab object should be provided which will be used to process valid and/or test data. + single_line: whether to return all tokens in a single line. + (Default: True) + By default, all lines in raw text file are concatenated into a single line. + Use `single_line = False` if one wants to get data line by line. Examples: >>> from torchtext.experimental.datasets import PennTreebank @@ -238,3 +220,42 @@ def PennTreebank(*args, **kwargs): """ return _setup_datasets(*(("PennTreebank",) + args), **kwargs) + + +def WMTNewsCrawl(*args, **kwargs): + """ Defines WMTNewsCrawl datasets. + + Create language modeling dataset: WMTNewsCrawl + returns the train set + + Arguments: + tokenizer: the tokenizer used to preprocess raw text data. + The default one is basic_english tokenizer in fastText. spacy tokenizer + is supported as well (see example below). A custom tokenizer is callable + function with input of a string and output of a token list. + root: Directory where the datasets are saved. Default: ".data" + vocab: Vocabulary used for dataset. If None, it will generate a new + vocabulary based on the train data set. + data_select: a string or tupel for the returned datasets + (Default: ('train',)) + single_line: whether to return all tokens in a single line. + (Default: True) + By default, all lines in raw text file are concatenated into a single line. + Use `single_line = False` if one wants to get data line by line. + Examples: + >>> from torchtext.experimental.datasets import WMTNewsCrawl + >>> from torchtext.data.utils import get_tokenizer + >>> tokenizer = get_tokenizer("spacy") + >>> train_dataset, = WMTNewsCrawl(tokenizer=tokenizer, data_select='train') + + """ + + return _setup_datasets(*(("WMTNewsCrawl",) + args), **kwargs) + + +DATASETS = { + 'WikiText2': WikiText2, + 'WikiText103': WikiText103, + 'PennTreebank': PennTreebank, + 'WMTNewsCrawl': WMTNewsCrawl +} diff --git a/torchtext/experimental/datasets/raw/__init__.py b/torchtext/experimental/datasets/raw/__init__.py index 61accbe2a1..9c988ec05c 100644 --- a/torchtext/experimental/datasets/raw/__init__.py +++ b/torchtext/experimental/datasets/raw/__init__.py @@ -1,6 +1,7 @@ from .text_classification import AG_NEWS, SogouNews, DBpedia, YelpReviewPolarity, \ YelpReviewFull, YahooAnswers, \ AmazonReviewPolarity, AmazonReviewFull, IMDB +from .language_modeling import WikiText2, WikiText103, PennTreebank, WMTNewsCrawl __all__ = ['IMDB', 'AG_NEWS', @@ -10,4 +11,8 @@ 'YelpReviewFull', 'YahooAnswers', 'AmazonReviewPolarity', - 'AmazonReviewFull'] + 'AmazonReviewFull', + 'WikiText2', + 'WikiText103', + 'PennTreebank', + 'WMTNewsCrawl'] diff --git a/torchtext/experimental/datasets/raw/language_modeling.py b/torchtext/experimental/datasets/raw/language_modeling.py new file mode 100644 index 0000000000..a867108978 --- /dev/null +++ b/torchtext/experimental/datasets/raw/language_modeling.py @@ -0,0 +1,184 @@ +import torch +import logging +import io +from torchtext.utils import download_from_url, extract_archive + +URLS = { + 'WikiText2': + 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip', + 'WikiText103': + 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip', + 'PennTreebank': + ['https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt', + 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt', + 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt'], + 'WMTNewsCrawl': 'http://www.statmt.org/wmt11/training-monolingual-news-2010.tgz' +} + + +class RawTextIterableDataset(torch.utils.data.IterableDataset): + """Defines an abstraction for raw text iterable datasets. + """ + + def __init__(self, iterator, start=0, num_lines=None): + """Initiate language modeling dataset. + """ + super(RawTextIterableDataset, self).__init__() + self._iterator = iterator + self.has_setup = False + self.start = start + self.num_lines = num_lines + + def setup_iter(self, start=0, num_lines=None): + self.start = start + self.num_lines = num_lines + self.has_setup = True + + def __iter__(self): + if not self.has_setup: + self.setup_iter() + for i, item in enumerate(self._iterator): + if i >= self.start: + yield item + if (self.num_lines is not None) and (i == (self.start + self.num_lines)): + break + + def get_iterator(self): + return self._iterator + + +def _setup_datasets(dataset_name, root='.data', data_select=('train', 'test', 'valid'), **kwargs): + if isinstance(data_select, str): + data_select = [data_select] + if not set(data_select).issubset(set(('train', 'test', 'valid'))): + raise TypeError('data_select is not supported!') + + if dataset_name == 'PennTreebank': + extracted_files = [] + select_to_index = {'train': 0, 'test': 1, 'valid': 2} + extracted_files = [download_from_url(URLS['PennTreebank'][select_to_index[key]], + root=root) for key in data_select] + elif dataset_name == 'WMTNewsCrawl': + if not (data_select == ['train'] or set(data_select).issubset(set(('train',)))): + raise ValueError("WMTNewsCrawl only creates a training dataset. " + "data_select should be 'train' " + "or ('train',), got {}.".format(data_select)) + dataset_tar = download_from_url(URLS[dataset_name], root=root) + extracted_files = extract_archive(dataset_tar) + year = kwargs.get('year', 2010) + language = kwargs.get('language', 'en') + file_name = 'news.{}.{}.shuffled'.format(year, language) + extracted_files = [f for f in extracted_files if file_name in f] + else: + dataset_tar = download_from_url(URLS[dataset_name], root=root) + extracted_files = extract_archive(dataset_tar) + + _path = {} + for item in data_select: + for fname in extracted_files: + if item in fname: + _path[item] = fname + + data = {} + for item in _path.keys(): + logging.info('Creating {} data'.format(item)) + data[item] = iter(io.open(_path[item], encoding="utf8")) + + return tuple(RawTextIterableDataset(data[item]) for item in data_select) + + +def WikiText2(*args, **kwargs): + """ Defines WikiText2 datasets. + + Create language modeling dataset: WikiText2 + Separately returns the train/test/valid set + + Arguments: + root: Directory where the datasets are saved. Default: ".data" + data_select: a string or tupel for the returned datasets + (Default: ('train', 'test','valid')) + By default, all the three datasets (train, test, valid) are generated. Users + could also choose any one or two of them, for example ('train', 'test') or + just a string 'train'. If 'train' is not in the tuple or string, a vocab + object should be provided which will be used to process valid and/or test + data. + + Examples: + >>> from torchtext.experimental.raw.datasets import WikiText2 + >>> train_dataset, test_dataset, valid_dataset = WikiText2() + >>> valid_dataset, = WikiText2(data_select='valid') + + """ + + return _setup_datasets(*(("WikiText2",) + args), **kwargs) + + +def WikiText103(*args, **kwargs): + """ Defines WikiText103 datasets. + + Create language modeling dataset: WikiText103 + Separately returns the train/test/valid set + + Arguments: + root: Directory where the datasets are saved. Default: ".data" + data_select: the returned datasets (Default: ('train', 'test','valid')) + By default, all the three datasets (train, test, valid) are generated. Users + could also choose any one or two of them, for example ('train', 'test'). + If 'train' is not in the tuple, an vocab object should be provided which will + be used to process valid and/or test data. + + Examples: + >>> from torchtext.experimental.datasets.raw import WikiText103 + >>> train_dataset, test_dataset, valid_dataset = WikiText103() + >>> valid_dataset, = WikiText103(data_select='valid') + """ + + return _setup_datasets(*(("WikiText103",) + args), **kwargs) + + +def PennTreebank(*args, **kwargs): + """ Defines PennTreebank datasets. + + Create language modeling dataset: PennTreebank + Separately returns the train/test/valid set + + Arguments: + root: Directory where the datasets are saved. Default: ".data" + data_select: a string or tuple for the returned datasets + (Default: ('train', 'test','valid')) + By default, all the three datasets (train, test, valid) are generated. Users + could also choose any one or two of them, for example ('train', 'test') or + just a string 'train'. If 'train' is not in the tuple or string, a vocab + object should be provided which will be used to process valid and/or test + data. + + Examples: + >>> from torchtext.experimental.datasets.raw import PennTreebank + >>> train_dataset, test_dataset, valid_dataset = PennTreebank() + >>> valid_dataset, = PennTreebank(data_select='valid') + + """ + + return _setup_datasets(*(("PennTreebank",) + args), **kwargs) + + +def WMTNewsCrawl(*args, **kwargs): + """ Defines WMT News Crawl. + + Create language modeling dataset: WMTNewsCrawl + + Arguments: + root: Directory where the datasets are saved. Default: ".data" + data_select: a string or tuple for the returned datasets. + (Default: 'train') + """ + + return _setup_datasets(*(("WMTNewsCrawl",) + args), **kwargs) + + +DATASETS = { + 'WikiText2': WikiText2, + 'WikiText103': WikiText103, + 'PennTreebank': PennTreebank, + 'WMTNewsCrawl': WMTNewsCrawl +}