Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
45d53de
Move PennTreebank, WikiText103, WikiText2 to torchtext.legacy
Oct 23, 2019
1f95483
Some initial work.
Oct 25, 2019
2d3ebe2
Merge branch 'master' into legacy_language_modeling
Oct 25, 2019
97af9d0
Re-write three datasets.
Oct 29, 2019
544b069
Merge branch 'master' into legacy_language_modeling
Oct 29, 2019
cc127de
Update tests.
Oct 29, 2019
97cfd05
Move legacy docs for language modeling dataset.
Oct 29, 2019
0ac3e18
Update docs.
Oct 29, 2019
56046fa
Minor debug
Oct 31, 2019
9962732
Update test.
Oct 31, 2019
ad7938e
Minor change in tests.
Oct 31, 2019
3ff1cce
Flake8
Oct 31, 2019
361f688
Merge branch 'master' into legacy_language_modeling
Nov 1, 2019
cc1ae4d
Move two funct to data/functional.py.
Nov 5, 2019
f4018cc
Fix <'unk'> compability issue.
Nov 5, 2019
ff329f9
Minor changes.
Nov 5, 2019
65c470c
Update unit tests.
Nov 5, 2019
96cd268
Merge branch 'master' into legacy_language_modeling
Nov 11, 2019
25336b9
Minor change
Nov 11, 2019
4819f18
Add flags for train/valid/test/
Nov 18, 2019
48cb0a8
Update docs.
Nov 19, 2019
7d70298
Add returned_dataset flag to determin subset data.
Nov 20, 2019
0588f1d
A small bug.
Nov 20, 2019
f01037d
Remove some printout.
Nov 21, 2019
f2ea3f1
Remove unk token.
Nov 21, 2019
a32712d
Use data_select.
Nov 21, 2019
d217294
Support a string in data_select.
Nov 21, 2019
cb902d4
Use torch.tensor instead of torch.Tensor
Nov 21, 2019
3a05197
remove duplicate code.
Nov 21, 2019
ac99329
Minor change in doc.
Nov 21, 2019
3a342c0
Change the extracted_files.
Nov 21, 2019
149cbc4
Docs.
Nov 21, 2019
6cfe9c9
get_data_path
Nov 21, 2019
297d1cc
Remove <unk> token.
Nov 22, 2019
d548bf6
Replace _data with data.
Nov 22, 2019
e77758e
Change create_data_from_iterator to double iter.
Nov 22, 2019
6d49f40
Add select_to_index.
Nov 22, 2019
1f60293
check subset.
Nov 22, 2019
8bb1cb2
Error if dataset is empty.
Nov 22, 2019
6a50f2a
filter output is iterable.
Nov 25, 2019
a29f4bd
flake8
Nov 25, 2019
9206e63
Add a claimer in README.rst
Nov 25, 2019
e2ba8bf
revise create_data_from_iterator
Nov 25, 2019
0993540
Remove a printout.
Nov 25, 2019
81055a0
Remove version num in legacy.
Nov 25, 2019
9dc4752
remove read_text_iterator func
Nov 26, 2019
367a340
Update README.
Nov 26, 2019
b54b883
Update the test case after not using read_text_iterator
Nov 26, 2019
1478d13
rename to numericalize_tokens_from_iterator
Nov 26, 2019
cf7c188
flake8
Nov 26, 2019
03dfc27
minor
Nov 26, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ Others are planned or a work in progress:

See the ``test`` directory for examples of dataset usage.

Legacy Code
===========

We have currently retired several datasets and moved them under ```torchtext.legacy```:

* Sentiment analysis: IMDb
* Language modeling: abstract class + WikiText-2, WikiText103, PennTreebank

These datasets are re-written with a new pattern that is introduced in `Release v0.5.0 <https://github.com/pytorch/text/releases>`_.

Disclaimer on Datasets
======================

Expand Down
69 changes: 69 additions & 0 deletions docs/source/legacy/datasets.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
torchtext.legacy.datasets
====================

.. currentmodule:: torchtext.legacy.datasets

TorchText legacy datasets.

All datasets are subclasses of :class:`torchtext.data.Dataset`, which
inherits from :class:`torch.utils.data.Dataset` i.e, they have ``split`` and
``iters`` methods implemented.

General use cases are as follows:

Approach 1, ``splits``: ::

# set up fields
TEXT = data.Field(lower=True, include_lengths=True, batch_first=True)
LABEL = data.Field(sequential=False)

# make splits for data
train, test = datasets.IMDB.splits(TEXT, LABEL)

# build the vocabulary
TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300))
LABEL.build_vocab(train)

# make iterator for splits
train_iter, test_iter = data.BucketIterator.splits(
(train, test), batch_size=3, device=0)

Approach 2, ``iters``: ::

# use default configurations
train_iter, test_iter = datasets.IMDB.iters(batch_size=4)

The following datasets are available:

.. contents:: Datasets
:local:


Language Modeling
^^^^^^^^^^^^^^^^^

Language modeling datasets are subclasses of ``LanguageModelingDataset`` class.

.. autoclass:: LanguageModelingDataset
:members: __init__


WikiText-2
~~~~~~~~~~

.. autoclass:: WikiText2
:members: splits, iters


WikiText103
~~~~~~~~~~~

.. autoclass:: WikiText103
:members: splits, iters


PennTreebank
~~~~~~~~~~~~

.. autoclass:: PennTreebank
:members: splits, iters
65 changes: 53 additions & 12 deletions test/data/test_builtin_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import shutil
import torchtext.data as data
from torchtext.datasets import WikiText2, PennTreebank
from torchtext.datasets import AG_NEWS

from ..common.test_markers import slow
Expand All @@ -10,11 +10,14 @@
def conditional_remove(f):
if os.path.isfile(f):
os.remove(f)
elif os.path.isdir(f):
shutil.rmtree(f)


class TestDataset(TorchtextTestCase):
@slow
def test_wikitext2(self):
def test_wikitext2_legacy(self):
from torchtext.legacy.datasets import WikiText2
# smoke test to ensure wikitext2 works properly
ds = WikiText2
TEXT = data.Field(lower=True, batch_first=True)
Expand All @@ -27,12 +30,30 @@ def test_wikitext2(self):
bptt_len=30)

# Delete the dataset after we're done to save disk space on CI
if os.environ.get("TRAVIS") == "true":
datafile = os.path.join(self.project_root, ".data", "wikitext-2")
conditional_remove(datafile)
datafile = os.path.join(self.project_root, ".data", "wikitext-2")
conditional_remove(datafile)

def test_wikitext2(self):
from torchtext.datasets import WikiText2
# smoke test to ensure wikitext2 works properly
train_dataset, test_dataset, valid_dataset = WikiText2()
self.assertEqual(len(train_dataset), 2049990)
self.assertEqual(len(test_dataset), 241859)
self.assertEqual(len(valid_dataset), 214417)

vocab = train_dataset.get_vocab()
tokens_ids = [vocab[token] for token in 'the player characters rest'.split()]
self.assertEqual(tokens_ids, [2, 286, 503, 700])

# Delete the dataset after we're done to save disk space on CI
datafile = os.path.join(self.project_root, ".data", "wikitext-2")
conditional_remove(datafile)
datafile = os.path.join(self.project_root, ".data", "wikitext-2-v1.zip")
conditional_remove(datafile)

@slow
def test_penntreebank(self):
def test_penntreebank_legacy(self):
from torchtext.legacy.datasets import PennTreebank
# smoke test to ensure penn treebank works properly
TEXT = data.Field(lower=True, batch_first=True)
ds = PennTreebank
Expand All @@ -45,9 +66,28 @@ def test_penntreebank(self):
bptt_len=30)

# Delete the dataset after we're done to save disk space on CI
if os.environ.get("TRAVIS") == "true":
datafile = os.path.join(self.project_root, ".data", "penn-treebank")
conditional_remove(datafile)
datafile = os.path.join(self.project_root, ".data", "penn-treebank")
conditional_remove(datafile)

def test_penntreebank(self):
from torchtext.datasets import PennTreebank
# smoke test to ensure wikitext2 works properly
train_dataset, test_dataset, valid_dataset = PennTreebank()
self.assertEqual(len(train_dataset), 924412)
self.assertEqual(len(test_dataset), 82114)
self.assertEqual(len(valid_dataset), 73339)

vocab = train_dataset.get_vocab()
tokens_ids = [vocab[token] for token in 'the player characters rest'.split()]
self.assertEqual(tokens_ids, [2, 2550, 3344, 1125])

# Delete the dataset after we're done to save disk space on CI
datafile = os.path.join(self.project_root, ".data", 'ptb.train.txt')
conditional_remove(datafile)
datafile = os.path.join(self.project_root, ".data", 'ptb.test.txt')
conditional_remove(datafile)
datafile = os.path.join(self.project_root, ".data", 'ptb.valid.txt')
conditional_remove(datafile)

def test_text_classification(self):
# smoke test to ensure ag_news dataset works properly
Expand All @@ -60,6 +100,7 @@ def test_text_classification(self):
self.assertEqual(len(ag_news_test), 7600)

# Delete the dataset after we're done to save disk space on CI
if os.environ.get("TRAVIS") == "true":
datafile = os.path.join(self.project_root, ".data", "AG_NEWS")
conditional_remove(datafile)
datafile = os.path.join(self.project_root, ".data", "ag_news_csv")
conditional_remove(datafile)
datafile = os.path.join(self.project_root, ".data", "ag_news_csv.tar.gz")
conditional_remove(datafile)
4 changes: 3 additions & 1 deletion torchtext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from . import datasets
from . import utils
from . import vocab
from . import legacy

__version__ = '0.4.0'

__all__ = ['data',
'datasets',
'utils',
'vocab']
'vocab',
'legacy']
6 changes: 4 additions & 2 deletions torchtext/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from .functional import generate_sp_model, \
load_sp_model, \
sentencepiece_numericalizer, \
sentencepiece_tokenizer, custom_replace, simple_space_split
sentencepiece_tokenizer, custom_replace, simple_space_split, \
numericalize_tokens_from_iterator

__all__ = ["Batch",
"Dataset", "TabularDataset",
Expand All @@ -24,4 +25,5 @@
"get_tokenizer", "interleave_keys",
"generate_sp_model", "load_sp_model",
"sentencepiece_numericalizer", "sentencepiece_tokenizer",
"custom_replace", "simple_space_split"]
"custom_replace", "simple_space_split",
"numericalize_tokens_from_iterator"]
29 changes: 28 additions & 1 deletion torchtext/data/functional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sentencepiece as spm
import re


__all__ = [
"generate_sp_model", "load_sp_model",
"sentencepiece_numericalizer", "sentencepiece_tokenizer"
Expand Down Expand Up @@ -151,3 +150,31 @@ def simple_space_split(iterator):

for line in iterator:
yield line.split()


def numericalize_tokens_from_iterator(vocab, iterator, removed_tokens=None):
r"""Yield a list of ids from an token iterator with a vocab.

Arguments:
vocab: the vocabulary convert token into id.
iterator: the iterator yield a list of tokens.
removed_tokens: removed tokens from output dataset (Default: None)

Examples:
>>> from torchtext.data.functional import simple_space_split
>>> from torchtext.data.functional import numericalize_tokens_from_iterator
>>> vocab = {'Sentencepiece' : 0, 'encode' : 1, 'as' : 2, 'pieces' : 3}
>>> ids_iter = numericalize_tokens_from_iterator(vocab,
>>> simple_space_split(["Sentencepiece as pieces",
>>> "as pieces"]))
>>> for ids in ids_iter:
>>> print([num for num in ids])
>>> [0, 2, 3]
>>> [2, 3]
"""
for tokens in iterator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One optimization here could be to yield an iterator instead of a list. This way we don't have to materialize the numbers per sentence which could be pretty large (and lists can be very slow).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. That's doable. Then, we materialize the token id outside the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

if removed_tokens is None:
yield iter(vocab[token] for token in tokens)
else:
yield iter(map(lambda x: vocab[x],
filter(lambda x: x not in removed_tokens, tokens)))
Loading