Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 44 additions & 45 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,87 +4,86 @@


class TestFunctional(TorchtextTestCase):
def test_to_tensor(self):
def _to_tensor(self, test_scripting):
input = [[1, 2], [1, 2, 3]]
padding_value = 0

actual = functional.to_tensor(input, padding_value=padding_value)
func = functional.to_tensor
if test_scripting:
func = torch.jit.script(func)
actual = func(input, padding_value=padding_value)
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long)
torch.testing.assert_close(actual, expected)

input = [1, 2]
actual = functional.to_tensor(input, padding_value=padding_value)
actual = func(input, padding_value=padding_value)
expected = torch.tensor([1, 2], dtype=torch.long)
torch.testing.assert_close(actual, expected)

def test_to_tensor_jit(self):
input = [[1, 2], [1, 2, 3]]
padding_value = 0
to_tensor_jit = torch.jit.script(functional.to_tensor)
actual = to_tensor_jit(input, padding_value=padding_value)
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long)
torch.testing.assert_close(actual, expected)
def test_to_tensor(self):
"""test tensorization on both single sequence and batch of sequence"""
self._to_tensor(test_scripting=False)

input = [1, 2]
actual = to_tensor_jit(input, padding_value=padding_value)
expected = torch.tensor([1, 2], dtype=torch.long)
torch.testing.assert_close(actual, expected)
def test_to_tensor_jit(self):
"""test tensorization with scripting on both single sequence and batch of sequence"""
self._to_tensor(test_scripting=True)

def test_truncate(self):
input = [[1, 2], [1, 2, 3]]
def _truncate(self, test_scripting):
max_seq_len = 2
func = functional.truncate
if test_scripting:
func = torch.jit.script(func)

actual = functional.truncate(input, max_seq_len=max_seq_len)
input = [[1, 2], [1, 2, 3]]
actual = func(input, max_seq_len=max_seq_len)
expected = [[1, 2], [1, 2]]
self.assertEqual(actual, expected)

input = [1, 2, 3]
actual = functional.truncate(input, max_seq_len=max_seq_len)
actual = func(input, max_seq_len=max_seq_len)
expected = [1, 2]
self.assertEqual(actual, expected)

def test_truncate_jit(self):
input = [[1, 2], [1, 2, 3]]
max_seq_len = 2
truncate_jit = torch.jit.script(functional.truncate)
actual = truncate_jit(input, max_seq_len=max_seq_len)
expected = [[1, 2], [1, 2]]
input = [["a", "b"], ["a", "b", "c"]]
actual = func(input, max_seq_len=max_seq_len)
expected = [["a", "b"], ["a", "b"]]
self.assertEqual(actual, expected)

input = [1, 2, 3]
actual = truncate_jit(input, max_seq_len=max_seq_len)
expected = [1, 2]
input = ["a", "b", "c"]
actual = func(input, max_seq_len=max_seq_len)
expected = ["a", "b"]
self.assertEqual(actual, expected)

def test_add_token(self):
input = [[1, 2], [1, 2, 3]]
token_id = 0
actual = functional.add_token(input, token_id=token_id)
expected = [[0, 1, 2], [0, 1, 2, 3]]
self.assertEqual(actual, expected)
def test_truncate(self):
"""test truncation on both sequence and batch of sequence with both str and int types"""
self._truncate(test_scripting=False)

actual = functional.add_token(input, token_id=token_id, begin=False)
expected = [[1, 2, 0], [1, 2, 3, 0]]
self.assertEqual(actual, expected)
def test_truncate_jit(self):
"""test truncation with scripting on both sequence and batch of sequence with both str and int types"""
self._truncate(test_scripting=True)

input = [1, 2]
actual = functional.add_token(input, token_id=token_id, begin=False)
expected = [1, 2, 0]
self.assertEqual(actual, expected)
def _add_token(self, test_scripting):

def test_add_token_jit(self):
func = functional.add_token
if test_scripting:
func = torch.jit.script(func)
input = [[1, 2], [1, 2, 3]]
token_id = 0
add_token_jit = torch.jit.script(functional.add_token)
actual = add_token_jit(input, token_id=token_id)
actual = func(input, token_id=token_id)
expected = [[0, 1, 2], [0, 1, 2, 3]]
self.assertEqual(actual, expected)

actual = add_token_jit(input, token_id=token_id, begin=False)
actual = func(input, token_id=token_id, begin=False)
expected = [[1, 2, 0], [1, 2, 3, 0]]
self.assertEqual(actual, expected)

input = [1, 2]
actual = add_token_jit(input, token_id=token_id, begin=False)
actual = func(input, token_id=token_id, begin=False)
expected = [1, 2, 0]
self.assertEqual(actual, expected)

def test_add_token(self):
self._add_token(test_scripting=False)

def test_add_token_jit(self):
self._add_token(test_scripting=True)
91 changes: 41 additions & 50 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@


class TestTransforms(TorchtextTestCase):
def test_spmtokenizer(self):
def _spmtokenizer(self, test_scripting):
asset_name = "spm_example.model"
asset_path = get_asset_path(asset_name)
transform = transforms.SentencePieceTokenizer(asset_path)
if test_scripting:
transform = torch.jit.script(transform)

actual = transform(["Hello World!, how are you?"])
expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']]
Expand All @@ -21,24 +23,19 @@ def test_spmtokenizer(self):
expected = ['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']
self.assertEqual(actual, expected)

def test_spmtokenizer_jit(self):
asset_name = "spm_example.model"
asset_path = get_asset_path(asset_name)
transform = transforms.SentencePieceTokenizer(asset_path)
transform_jit = torch.jit.script(transform)

actual = transform_jit(["Hello World!, how are you?"])
expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']]
self.assertEqual(actual, expected)
def test_spmtokenizer(self):
"""test tokenization on single sentence input as well as batch on sentences"""
self._spmtokenizer(test_scripting=False)

actual = transform_jit("Hello World!, how are you?")
expected = ['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']
self.assertEqual(actual, expected)
def test_spmtokenizer_jit(self):
"""test tokenization with scripting on single sentence input as well as batch on sentences"""
self._spmtokenizer(test_scripting=True)

def test_vocab_transform(self):
def _vocab_transform(self, test_scripting):
vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)]))
transform = transforms.VocabTransform(vocab_obj)

if test_scripting:
transform = torch.jit.script(transform)
actual = transform([['a', 'b', 'c']])
expected = [[0, 1, 2]]
self.assertEqual(actual, expected)
Expand All @@ -47,22 +44,20 @@ def test_vocab_transform(self):
expected = [0, 1, 2]
self.assertEqual(actual, expected)

def test_vocab_transform_jit(self):
vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)]))
transform_jit = torch.jit.script(transforms.VocabTransform(vocab_obj))

actual = transform_jit([['a', 'b', 'c']])
expected = [[0, 1, 2]]
self.assertEqual(actual, expected)
def test_vocab_transform(self):
"""test token to indices on both sequence of input tokens as well as batch of sequence"""
self._vocab_transform(test_scripting=False)

actual = transform_jit(['a', 'b', 'c'])
expected = [0, 1, 2]
self.assertEqual(actual, expected)
def test_vocab_transform_jit(self):
"""test token to indices with scripting on both sequence of input tokens as well as batch of sequence"""
self._vocab_transform(test_scripting=True)

def test_totensor(self):
input = [[1, 2], [1, 2, 3]]
def _totensor(self, test_scripting):
padding_value = 0
transform = transforms.ToTensor(padding_value=padding_value)
if test_scripting:
transform = torch.jit.script(transform)
input = [[1, 2], [1, 2, 3]]

actual = transform(input)
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long)
Expand All @@ -73,29 +68,26 @@ def test_totensor(self):
expected = torch.tensor([1, 2], dtype=torch.long)
torch.testing.assert_close(actual, expected)

def test_totensor_jit(self):
input = [[1, 2], [1, 2, 3]]
padding_value = 0
transform = transforms.ToTensor(padding_value=padding_value)
transform_jit = torch.jit.script(transform)

actual = transform_jit(input)
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long)
torch.testing.assert_close(actual, expected)
def test_totensor(self):
"""test tensorization on both single sequence and batch of sequence"""
self._totensor(test_scripting=False)

input = [1, 2]
actual = transform_jit(input)
expected = torch.tensor([1, 2], dtype=torch.long)
torch.testing.assert_close(actual, expected)
def test_totensor_jit(self):
"""test tensorization with scripting on both single sequence and batch of sequence"""
self._totensor(test_scripting=True)

def test_labeltoindex(self):
def _labeltoindex(self, test_scripting):
label_names = ['test', 'label', 'indices']
transform = transforms.LabelToIndex(label_names=label_names)
if test_scripting:
transform = torch.jit.script(transform)
actual = transform(label_names)
expected = [0, 1, 2]
self.assertEqual(actual, expected)

transform = transforms.LabelToIndex(label_names=label_names, sort_names=True)
if test_scripting:
transform = torch.jit.script(transform)
actual = transform(label_names)
expected = [2, 1, 0]
self.assertEqual(actual, expected)
Expand All @@ -107,17 +99,16 @@ def test_labeltoindex(self):
asset_name = "label_names.txt"
asset_path = get_asset_path(asset_name)
transform = transforms.LabelToIndex(label_path=asset_path)
if test_scripting:
transform = torch.jit.script(transform)
actual = transform(label_names)
expected = [0, 1, 2]
self.assertEqual(actual, expected)

def test_labeltoindex_jit(self):
label_names = ['test', 'label', 'indices']
transform_jit = torch.jit.script(transforms.LabelToIndex(label_names=label_names))
actual = transform_jit(label_names)
expected = [0, 1, 2]
self.assertEqual(actual, expected)
def test_labeltoindex(self):
"""test labe to ids on single label input as well as batch of labels"""
self._labeltoindex(test_scripting=False)

actual = transform_jit("test")
expected = 0
self.assertEqual(actual, expected)
def test_labeltoindex_jit(self):
"""test labe to ids with scripting on single label input as well as batch of labels"""
self._labeltoindex(test_scripting=True)
76 changes: 66 additions & 10 deletions torchtext/functional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from typing import List, Optional, Union
from typing import List, Optional, Any

__all__ = [
'to_tensor',
Expand All @@ -10,10 +10,20 @@
]


def to_tensor(input: Union[List[int], List[List[int]]], padding_value: Optional[int] = None, dtype: Optional[torch.dtype] = torch.long) -> Tensor:
def to_tensor(input: Any, padding_value: Optional[int] = None, dtype: Optional[torch.dtype] = torch.long) -> Tensor:
r"""Convert input to torch tensor

:param padding_value: Pad value to make each input in the batch of length equal to the longest sequence in the batch.
:type padding_value: Optional[int]
:param dtype: :class:`torch.dtype` of output tensor
:type dtype: :class:`torch.dtype`
:param input: Sequence or batch of token ids
:type input: Union[List[int], List[List[int]]]
:rtype: Tensor
"""
if torch.jit.isinstance(input, List[int]):
return torch.tensor(input, dtype=torch.long)
else:
elif torch.jit.isinstance(input, List[List[int]]):
if padding_value is None:
output = torch.tensor(input, dtype=dtype)
return output
Expand All @@ -24,27 +34,61 @@ def to_tensor(input: Union[List[int], List[List[int]]], padding_value: Optional[
padding_value=float(padding_value)
)
return output
else:
raise TypeError("Input type not supported")


def truncate(input: Any, max_seq_len: int) -> Any:
""" Truncate input sequence or batch

def truncate(input: Union[List[int], List[List[int]]], max_seq_len: int) -> Union[List[int], List[List[int]]]:
:param input: Input sequence or batch to be truncated
:type input: Union[List[Union[str, int]], List[List[Union[str, int]]]]
:param max_seq_len: Maximum length beyond which input is discarded
:type max_seq_len: int
:return: Truncated sequence
:rtype: Union[List[Union[str, int]], List[List[Union[str, int]]]]
"""
if torch.jit.isinstance(input, List[int]):
return input[:max_seq_len]
else:
elif torch.jit.isinstance(input, List[str]):
return input[:max_seq_len]
elif torch.jit.isinstance(input, List[List[int]]):
output: List[List[int]] = []

for ids in input:
output.append(ids[:max_seq_len])

return output
elif torch.jit.isinstance(input, List[List[str]]):
output: List[List[str]] = []
for ids in input:
output.append(ids[:max_seq_len])
return output
else:
raise TypeError("Input type not supported")


def add_token(input: Union[List[int], List[List[int]]], token_id: int, begin: bool = True) -> Union[List[int], List[List[int]]]:
if torch.jit.isinstance(input, List[int]):
def add_token(input: Any, token_id: Any, begin: bool = True) -> Any:
"""Add token to start or end of sequence

:param input: Input sequence or batch
:type input: Union[List[Union[str, int]], List[List[Union[str, int]]]]
:param token_id: token to be added
:type token_id: Union[str, int]
:param begin: Whether to insert token at start or end or sequence, defaults to True
:type begin: bool, optional
:return: sequence or batch with token_id added to begin or end or input
:rtype: Union[List[Union[str, int]], List[List[Union[str, int]]]]
"""
if torch.jit.isinstance(input, List[int]) and torch.jit.isinstance(token_id, int):
if begin:
return [token_id] + input
else:
return input + [token_id]
else:
elif torch.jit.isinstance(input, List[str]) and torch.jit.isinstance(token_id, str):
if begin:
return [token_id] + input
else:
return input + [token_id]
elif torch.jit.isinstance(input, List[List[int]]) and torch.jit.isinstance(token_id, int):
output: List[List[int]] = []

if begin:
Expand All @@ -55,3 +99,15 @@ def add_token(input: Union[List[int], List[List[int]]], token_id: int, begin: bo
output.append(ids + [token_id])

return output
elif torch.jit.isinstance(input, List[List[str]]) and torch.jit.isinstance(token_id, str):
output: List[List[str]] = []
if begin:
for ids in input:
output.append([token_id] + ids)
else:
for ids in input:
output.append(ids + [token_id])

return output
else:
raise TypeError("Input type not supported")
Loading