diff --git a/test/asset/t5_tokenizer_base.model b/test/asset/t5_tokenizer_base.model new file mode 100644 index 0000000000..4e28ff6ebd Binary files /dev/null and b/test/asset/t5_tokenizer_base.model differ diff --git a/test/prototype/models/test_transforms.py b/test/prototype/models/test_transforms.py new file mode 100644 index 0000000000..a74e516fef --- /dev/null +++ b/test/prototype/models/test_transforms.py @@ -0,0 +1,45 @@ +import torch +from test.common.assets import get_asset_path +from test.common.torchtext_test_case import TorchtextTestCase +from torchtext.prototype.models import T5Transform + + +class TestTransforms(TorchtextTestCase): + def _t5tokenizer(self, test_scripting): + asset_name = "t5_tokenizer_base.model" + asset_path = get_asset_path(asset_name) + transform = T5Transform(asset_path, max_seq_len=512, eos_idx=1, padding_idx=0) + if test_scripting: + transform = torch.jit.script(transform) + + # test encode; input is a single string + encode_seq = "Hello World!, how are you?" + actual = transform(encode_seq) + expected = torch.tensor([8774, 1150, 55, 6, 149, 33, 25, 58, 1]) + self.assertEqual(actual, expected) + + # test encode; input is a batched string + encode_seq = ["Hello World!, how are you?"] + actual = transform(encode_seq) + expected = torch.tensor([[8774, 1150, 55, 6, 149, 33, 25, 58, 1]]) + self.assertEqual(actual, expected) + + # test decode; input is a list of token ids + decode_seq = [8774, 1150, 55, 6, 149, 33, 25, 58, 1] + actual = transform.decode(decode_seq) + expected = "Hello World!, how are you?" + self.assertEqual(actual, expected) + + # test decode; input is a batched list of token ids + decode_seq = [[8774, 1150, 55, 6, 149, 33, 25, 58, 1]] + actual = transform.decode(decode_seq) + expected = ["Hello World!, how are you?"] + self.assertEqual(actual, expected) + + def test_t5tokenizer(self): + """test tokenization on string input (encode) and translation from token ids to strings (decode)""" + self._t5tokenizer(test_scripting=False) + + def test_t5tokenizer_jit(self): + """test tokenization on string input (encode) and translation from token ids to strings (decode) with scripting""" + self._t5tokenizer(test_scripting=True) diff --git a/torchtext/prototype/models/t5/__init__.py b/torchtext/prototype/models/t5/__init__.py index f69829494d..45c0de5e04 100644 --- a/torchtext/prototype/models/t5/__init__.py +++ b/torchtext/prototype/models/t5/__init__.py @@ -4,6 +4,7 @@ T5Bundle, ) from .model import T5Conf, T5Model +from .t5_transform import T5Transform __all__ = [ "T5Conf", @@ -11,4 +12,5 @@ "T5Bundle", "T5_BASE_ENCODER", "T5_BASE", + "T5Transform", ] diff --git a/torchtext/prototype/models/t5/t5_transform.py b/torchtext/prototype/models/t5/t5_transform.py new file mode 100644 index 0000000000..7154e31d55 --- /dev/null +++ b/torchtext/prototype/models/t5/t5_transform.py @@ -0,0 +1,85 @@ +from typing import List, Union + +import torch +import torch.nn as nn +import torchtext.transforms as T +from torchtext.data.functional import load_sp_model +from torchtext.functional import to_tensor +from torchtext.utils import get_asset_local_path + + +class T5Transform(nn.Module): + """ + This transform makes use of a pre-trained sentencepiece model to tokenize text input. The resulting output is fed to the T5 model. + + Additional details: https://github.com/google/sentencepiece + + :param sp_model_path: Path to pre-trained sentencepiece model + :type sp_model_path: str + :param max_seq_len: Maximum sequence length accepted for inputs to T5 model + :type max_seq_len: int + :param eos_idx: End-of-sequence token id + :type eos_idx: int + :param padding_idx: Padding token id + :type padding_idx: int + + Example + >>> from torchtext.prototype.models import T5Transform + >>> transform = T5Transform("spm_model", max_seq_len = 10, eos_idx = 1, padding_idx = 0) + >>> transform(["hello world", "attention is all you need!"]) + """ + + def __init__(self, sp_model_path: str, max_seq_len: int, eos_idx: int, padding_idx: int): + super().__init__() + self.sp_model = load_sp_model(get_asset_local_path(sp_model_path)) + self.max_seq_len = max_seq_len + self.eos_idx = eos_idx + self.padding_idx = padding_idx + self.pipeline = T.Sequential(T.Truncate(self.max_seq_len), T.AddToken(token=self.eos_idx, begin=False)) + + def forward(self, input: Union[str, List[str]]) -> torch.Tensor: + """ + :param input: Input sentence or list of sentences to tokenize. + :type input: Union[str, List[str]] + :return: Tokenized text that has been truncated, appended with end-of-sequence token, and padded + :rtype: torch.Tensor + """ + tokens = self.encode(input) + out = to_tensor(self.pipeline(tokens), padding_value=self.padding_idx) + return out + + @torch.jit.export + def encode(self, input: Union[str, List[str]]) -> Union[List[int], List[List[int]]]: + """ + :param input: Input sentence or list of sentences to tokenize. + :type input: Union[str, List[str]] + :return: Tokenized text that has been translated to token ids + :rtype: Union[List[int], List[List[int]]] + """ + if torch.jit.isinstance(input, List[str]): + tokens: List[List[int]] = [] + for text in input: + tokens.append(self.sp_model.EncodeAsIds(text)) + return tokens + elif torch.jit.isinstance(input, str): + return self.sp_model.EncodeAsIds(input) + else: + raise TypeError("Input type not supported") + + @torch.jit.export + def decode(self, input: Union[List[int], List[List[int]]]) -> Union[str, List[str]]: + """ + :param input: List of token ids or list of lists of token ids (i.e. batched). + :type input: Union[List[int], List[List[int]]] + :return: Sentence or list of sentencess that were translated from the input token ids + :rtype: Union[str, List[str]] + """ + if torch.jit.isinstance(input, List[List[int]]): + tokens: List[str] = [] + for ids in input: + tokens.append(self.sp_model.DecodeIds(ids)) + return tokens + elif torch.jit.isinstance(input, List[int]): + return self.sp_model.DecodeIds(input) + else: + raise TypeError("Input type not supported")