Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e6a9677

Browse files
parmeetfacebook-github-bot
authored andcommitted
Import torchtext #1406 1fb2aed
Summary: Import latest from github Reviewed By: Nayef211 Differential Revision: D31762288 fbshipit-source-id: f439e04f903d640027660cb969d6d9e00e7ed4a0
1 parent 579c519 commit e6a9677

File tree

15 files changed

+883
-2
lines changed

15 files changed

+883
-2
lines changed

test/asset/xlmr.base.output.pt

24.7 KB
Binary file not shown.

test/asset/xlmr.large.output.pt

32.7 KB
Binary file not shown.

test/models/test_models.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torchtext
2+
import torch
3+
4+
from ..common.torchtext_test_case import TorchtextTestCase
5+
from ..common.assets import get_asset_path
6+
7+
8+
class TestModels(TorchtextTestCase):
9+
def test_xlmr_base_output(self):
10+
asset_name = "xlmr.base.output.pt"
11+
asset_path = get_asset_path(asset_name)
12+
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
13+
model = xlmr_base.get_model()
14+
model = model.eval()
15+
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
16+
actual = model(model_input)
17+
expected = torch.load(asset_path)
18+
torch.testing.assert_close(actual, expected)
19+
20+
def test_xlmr_base_jit_output(self):
21+
asset_name = "xlmr.base.output.pt"
22+
asset_path = get_asset_path(asset_name)
23+
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
24+
model = xlmr_base.get_model()
25+
model = model.eval()
26+
model_jit = torch.jit.script(model)
27+
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
28+
actual = model_jit(model_input)
29+
expected = torch.load(asset_path)
30+
torch.testing.assert_close(actual, expected)
31+
32+
def test_xlmr_large_output(self):
33+
asset_name = "xlmr.large.output.pt"
34+
asset_path = get_asset_path(asset_name)
35+
xlmr_base = torchtext.models.XLMR_LARGE_ENCODER
36+
model = xlmr_base.get_model()
37+
model = model.eval()
38+
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
39+
actual = model(model_input)
40+
expected = torch.load(asset_path)
41+
torch.testing.assert_close(actual, expected)
42+
43+
def test_xlmr_large_jit_output(self):
44+
asset_name = "xlmr.large.output.pt"
45+
asset_path = get_asset_path(asset_name)
46+
xlmr_base = torchtext.models.XLMR_LARGE_ENCODER
47+
model = xlmr_base.get_model()
48+
model = model.eval()
49+
model_jit = torch.jit.script(model)
50+
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
51+
actual = model_jit(model_input)
52+
expected = torch.load(asset_path)
53+
torch.testing.assert_close(actual, expected)
54+
55+
def test_xlmr_transform(self):
56+
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
57+
transform = xlmr_base.transform()
58+
test_text = "XLMR base Model Comparison"
59+
actual = transform([test_text])
60+
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
61+
torch.testing.assert_close(actual, expected)
62+
63+
def test_xlmr_transform_jit(self):
64+
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
65+
transform = xlmr_base.transform()
66+
transform_jit = torch.jit.script(transform)
67+
test_text = "XLMR base Model Comparison"
68+
actual = transform_jit([test_text])
69+
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
70+
torch.testing.assert_close(actual, expected)

test/test_functional.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
from torchtext import functional
3+
from .common.torchtext_test_case import TorchtextTestCase
4+
5+
6+
class TestFunctional(TorchtextTestCase):
7+
def test_to_tensor(self):
8+
input = [[1, 2], [1, 2, 3]]
9+
padding_value = 0
10+
actual = functional.to_tensor(input, padding_value=padding_value)
11+
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long)
12+
torch.testing.assert_close(actual, expected)
13+
14+
def test_to_tensor_jit(self):
15+
input = [[1, 2], [1, 2, 3]]
16+
padding_value = 0
17+
to_tensor_jit = torch.jit.script(functional.to_tensor)
18+
actual = to_tensor_jit(input, padding_value=padding_value)
19+
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long)
20+
torch.testing.assert_close(actual, expected)
21+
22+
def test_truncate(self):
23+
input = [[1, 2], [1, 2, 3]]
24+
max_seq_len = 2
25+
actual = functional.truncate(input, max_seq_len=max_seq_len)
26+
expected = [[1, 2], [1, 2]]
27+
self.assertEqual(actual, expected)
28+
29+
def test_truncate_jit(self):
30+
input = [[1, 2], [1, 2, 3]]
31+
max_seq_len = 2
32+
truncate_jit = torch.jit.script(functional.truncate)
33+
actual = truncate_jit(input, max_seq_len=max_seq_len)
34+
expected = [[1, 2], [1, 2]]
35+
self.assertEqual(actual, expected)
36+
37+
def test_add_token(self):
38+
input = [[1, 2], [1, 2, 3]]
39+
token_id = 0
40+
actual = functional.add_token(input, token_id=token_id)
41+
expected = [[0, 1, 2], [0, 1, 2, 3]]
42+
self.assertEqual(actual, expected)
43+
44+
actual = functional.add_token(input, token_id=token_id, begin=False)
45+
expected = [[1, 2, 0], [1, 2, 3, 0]]
46+
self.assertEqual(actual, expected)
47+
48+
def test_add_token_jit(self):
49+
input = [[1, 2], [1, 2, 3]]
50+
token_id = 0
51+
add_token_jit = torch.jit.script(functional.add_token)
52+
actual = add_token_jit(input, token_id=token_id)
53+
expected = [[0, 1, 2], [0, 1, 2, 3]]
54+
self.assertEqual(actual, expected)
55+
56+
actual = add_token_jit(input, token_id=token_id, begin=False)
57+
expected = [[1, 2, 0], [1, 2, 3, 0]]
58+
self.assertEqual(actual, expected)

test/test_transforms.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
from torchtext import transforms
3+
from torchtext.vocab import vocab
4+
from collections import OrderedDict
5+
6+
from .common.torchtext_test_case import TorchtextTestCase
7+
from .common.assets import get_asset_path
8+
9+
10+
class TestTransforms(TorchtextTestCase):
11+
def test_spmtokenizer_transform(self):
12+
asset_name = "spm_example.model"
13+
asset_path = get_asset_path(asset_name)
14+
transform = transforms.SpmTokenizerTransform(asset_path)
15+
actual = transform(["Hello World!, how are you?"])
16+
expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']]
17+
self.assertEqual(actual, expected)
18+
19+
def test_spmtokenizer_transform_jit(self):
20+
asset_name = "spm_example.model"
21+
asset_path = get_asset_path(asset_name)
22+
transform = transforms.SpmTokenizerTransform(asset_path)
23+
transform_jit = torch.jit.script(transform)
24+
actual = transform_jit(["Hello World!, how are you?"])
25+
expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']]
26+
self.assertEqual(actual, expected)
27+
28+
def test_vocab_transform(self):
29+
vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)]))
30+
transform = transforms.VocabTransform(vocab_obj)
31+
actual = transform([['a', 'b', 'c']])
32+
expected = [[0, 1, 2]]
33+
self.assertEqual(actual, expected)

torchtext/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
import os
2+
_TEXT_BUCKET = 'https://download.pytorch.org/models/text'
3+
_CACHE_DIR = os.path.expanduser('~/.torchtext/cache')
4+
15
from . import data
26
from . import nn
37
from . import datasets
48
from . import utils
59
from . import vocab
10+
from . import transforms
11+
from . import functional
12+
from . import models
613
from . import experimental
714
from . import legacy
815
from ._extension import _init_extension
@@ -18,6 +25,9 @@
1825
'datasets',
1926
'utils',
2027
'vocab',
28+
'transforms',
29+
'functional',
30+
'models',
2131
'experimental',
2232
'legacy']
2333

torchtext/data/datasets_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import defusedxml.ElementTree as ET
1616
except ImportError:
1717
import xml.etree.ElementTree as ET
18+
19+
from torchtext import _CACHE_DIR
1820
"""
1921
These functions and classes are meant solely for use in torchtext.datasets and not
2022
for public consumption yet.
@@ -213,7 +215,7 @@ def _wrap_split_argument_with_fn(fn, splits):
213215
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
214216

215217
@functools.wraps(fn)
216-
def new_fn(root=os.path.expanduser('~/.torchtext/cache'), split=splits, **kwargs):
218+
def new_fn(root=_CACHE_DIR, split=splits, **kwargs):
217219
result = []
218220
for item in _check_default_set(split, splits, fn.__name__):
219221
result.append(fn(root, item, **kwargs))
@@ -250,7 +252,7 @@ def decorator(func):
250252
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
251253

252254
@functools.wraps(func)
253-
def wrapper(root=os.path.expanduser('~/.torchtext/cache'), *args, **kwargs):
255+
def wrapper(root=_CACHE_DIR, *args, **kwargs):
254256
new_root = os.path.join(root, dataset_name)
255257
if not os.path.exists(new_root):
256258
os.makedirs(new_root)

torchtext/functional.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch
2+
from torch import Tensor
3+
from torch.nn.utils.rnn import pad_sequence
4+
from typing import List, Optional
5+
6+
__all__ = [
7+
'to_tensor',
8+
'truncate',
9+
'add_token',
10+
]
11+
12+
13+
def to_tensor(input: List[List[int]], padding_value: Optional[int] = None) -> Tensor:
14+
if padding_value is None:
15+
output = torch.tensor(input, dtype=torch.long)
16+
return output
17+
else:
18+
output = pad_sequence(
19+
[torch.tensor(ids, dtype=torch.long) for ids in input],
20+
batch_first=True,
21+
padding_value=float(padding_value)
22+
)
23+
return output
24+
25+
26+
def truncate(input: List[List[int]], max_seq_len: int) -> List[List[int]]:
27+
output: List[List[int]] = []
28+
29+
for ids in input:
30+
output.append(ids[:max_seq_len])
31+
32+
return output
33+
34+
35+
def add_token(input: List[List[int]], token_id: int, begin: bool = True) -> List[List[int]]:
36+
output: List[List[int]] = []
37+
38+
if begin:
39+
for ids in input:
40+
output.append([token_id] + ids)
41+
else:
42+
for ids in input:
43+
output.append(ids + [token_id])
44+
45+
return output

torchtext/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .roberta import * # noqa: F401, F403
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from .model import (
2+
RobertaEncoderParams,
3+
RobertaClassificationHead,
4+
)
5+
6+
from .bundler import (
7+
RobertaModelBundle,
8+
XLMR_BASE_ENCODER,
9+
XLMR_LARGE_ENCODER,
10+
)
11+
12+
__all__ = [
13+
"RobertaEncoderParams",
14+
"RobertaClassificationHead",
15+
"RobertaModelBundle",
16+
"XLMR_BASE_ENCODER",
17+
"XLMR_LARGE_ENCODER",
18+
]

0 commit comments

Comments
 (0)