Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Binary file added test/asset/t5_tokenizer_base.model
Binary file not shown.
45 changes: 45 additions & 0 deletions test/prototype/models/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from test.common.assets import get_asset_path
from test.common.torchtext_test_case import TorchtextTestCase
from torchtext.prototype.models import T5Transform


class TestTransforms(TorchtextTestCase):
def _t5tokenizer(self, test_scripting):
asset_name = "t5_tokenizer_base.model"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding new asset file, we should probably work with existing assets if available. In this case, shall we try working with spm_example.model ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Nayef211 and I were actually debating how best to approach this, because if we used spm_example.model then we'd essentially be testing for functional correctness. But since T5Transform is so similar to SentencePieceTokenizer except that it includes additional transformations specific to T5, we thought it made more sense to tailor the test towards t5 specifically as opposed to a general spm model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parmeet if we use the existing spm_example.model, these tests do not add as much value as we already have specific tests for the SentencePiece tokenizer. As @pmabbo13 mentioned, if we want to test that the output of the T5Transform is equal to that of the T5Transform in HF, then it would make sense to make use of the spm model specific to T5. Also the asset is around 700 KB which is less than some of the existing assets we've checked in. Lmk what you think!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the overall sentiment here and it's a good argument for adding the actual asset file. But then this make me wonder if we are really unit-testing the functional correctness of the transform implementation or actually testing the asset file :).

That said, I think we would also be needing this for integration testing, since we need a real output in there instead of dummy output from any SPM model file. So I think I agree with you both, adding the actually asset file would make sense!

asset_path = get_asset_path(asset_name)
transform = T5Transform(asset_path, max_seq_len=512, eos_idx=1, padding_idx=0)
if test_scripting:
transform = torch.jit.script(transform)

# test encode; input is a single string
encode_seq = "Hello World!, how are you?"
actual = transform(encode_seq)
expected = torch.tensor([8774, 1150, 55, 6, 149, 33, 25, 58, 1])
self.assertEqual(actual, expected)

# test encode; input is a batched string
encode_seq = ["Hello World!, how are you?"]
actual = transform(encode_seq)
expected = torch.tensor([[8774, 1150, 55, 6, 149, 33, 25, 58, 1]])
self.assertEqual(actual, expected)

# test decode; input is a list of token ids
decode_seq = [8774, 1150, 55, 6, 149, 33, 25, 58, 1]
actual = transform.decode(decode_seq)
expected = "Hello World!, how are you?"
self.assertEqual(actual, expected)

# test decode; input is a batched list of token ids
decode_seq = [[8774, 1150, 55, 6, 149, 33, 25, 58, 1]]
actual = transform.decode(decode_seq)
expected = ["Hello World!, how are you?"]
self.assertEqual(actual, expected)

def test_t5tokenizer(self):
"""test tokenization on string input (encode) and translation from token ids to strings (decode)"""
self._t5tokenizer(test_scripting=False)

def test_t5tokenizer_jit(self):
"""test tokenization on string input (encode) and translation from token ids to strings (decode) with scripting"""
self._t5tokenizer(test_scripting=True)
2 changes: 2 additions & 0 deletions torchtext/prototype/models/t5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
T5Bundle,
)
from .model import T5Conf, T5Model
from .t5_transform import T5Transform

__all__ = [
"T5Conf",
"T5Model",
"T5Bundle",
"T5_BASE_ENCODER",
"T5_BASE",
"T5Transform",
]
85 changes: 85 additions & 0 deletions torchtext/prototype/models/t5/t5_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import List, Union

import torch
import torch.nn as nn
import torchtext.transforms as T
from torchtext.data.functional import load_sp_model
from torchtext.functional import to_tensor
from torchtext.utils import get_asset_local_path


class T5Transform(nn.Module):
"""
This transform makes use of a pre-trained sentencepiece model to tokenize text input. The resulting output is fed to the T5 model.

Additional details: https://github.com/google/sentencepiece

:param sp_model_path: Path to pre-trained sentencepiece model
:type sp_model_path: str
:param max_seq_len: Maximum sequence length accepted for inputs to T5 model
:type max_seq_len: int
:param eos_idx: End-of-sequence token id
:type eos_idx: int
:param padding_idx: Padding token id
:type padding_idx: int

Example
>>> from torchtext.prototype.models import T5Transform
>>> transform = T5Transform("spm_model", max_seq_len = 10, eos_idx = 1, padding_idx = 0)
>>> transform(["hello world", "attention is all you need!"])
"""

def __init__(self, sp_model_path: str, max_seq_len: int, eos_idx: int, padding_idx: int):
super().__init__()
self.sp_model = load_sp_model(get_asset_local_path(sp_model_path))
self.max_seq_len = max_seq_len
self.eos_idx = eos_idx
self.padding_idx = padding_idx
self.pipeline = T.Sequential(T.Truncate(self.max_seq_len), T.AddToken(token=self.eos_idx, begin=False))

def forward(self, input: Union[str, List[str]]) -> torch.Tensor:
"""
:param input: Input sentence or list of sentences to tokenize.
:type input: Union[str, List[str]]
:return: Tokenized text that has been truncated, appended with end-of-sequence token, and padded
:rtype: torch.Tensor
"""
tokens = self.encode(input)
out = to_tensor(self.pipeline(tokens), padding_value=self.padding_idx)
return out

@torch.jit.export
def encode(self, input: Union[str, List[str]]) -> Union[List[int], List[List[int]]]:
"""
:param input: Input sentence or list of sentences to tokenize.
:type input: Union[str, List[str]]
:return: Tokenized text that has been translated to token ids
:rtype: Union[List[int], List[List[int]]]
"""
if torch.jit.isinstance(input, List[str]):
tokens: List[List[int]] = []
for text in input:
tokens.append(self.sp_model.EncodeAsIds(text))
return tokens
elif torch.jit.isinstance(input, str):
return self.sp_model.EncodeAsIds(input)
else:
raise TypeError("Input type not supported")

@torch.jit.export
def decode(self, input: Union[List[int], List[List[int]]]) -> Union[str, List[str]]:
"""
:param input: List of token ids or list of lists of token ids (i.e. batched).
:type input: Union[List[int], List[List[int]]]
:return: Sentence or list of sentencess that were translated from the input token ids
:rtype: Union[str, List[str]]
"""
if torch.jit.isinstance(input, List[List[int]]):
tokens: List[str] = []
for ids in input:
tokens.append(self.sp_model.DecodeIds(ids))
return tokens
elif torch.jit.isinstance(input, List[int]):
return self.sp_model.DecodeIds(input)
else:
raise TypeError("Input type not supported")