This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 814
T5Transform text pre-processing for t5 model #1852
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
74b6fd3
text pre-processing for t5 model
pmabbo13 2656071
save tokenizer model in asset, to be used during testing
pmabbo13 eb05721
moving t5transform and tests under prototype/models
pmabbo13 f9a59fa
instantiate pipeline when initializing transform
pmabbo13 1bb5302
add testing for decode method
pmabbo13 94b50eb
adding docstrings
pmabbo13 adf166f
script encode method
pmabbo13 75a48d2
coalesce encode and decode tests
pmabbo13 a136eea
updating docstrings
pmabbo13 1f29d6b
type annotations
pmabbo13 270052e
Merge branch 'main' into feature/t5-transform
pmabbo13 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| import torch | ||
| from test.common.assets import get_asset_path | ||
| from test.common.torchtext_test_case import TorchtextTestCase | ||
| from torchtext.prototype.models import T5Transform | ||
|
|
||
|
|
||
| class TestTransforms(TorchtextTestCase): | ||
| def _t5tokenizer(self, test_scripting): | ||
| asset_name = "t5_tokenizer_base.model" | ||
| asset_path = get_asset_path(asset_name) | ||
| transform = T5Transform(asset_path, max_seq_len=512, eos_idx=1, padding_idx=0) | ||
| if test_scripting: | ||
| transform = torch.jit.script(transform) | ||
|
|
||
| # test encode; input is a single string | ||
| encode_seq = "Hello World!, how are you?" | ||
| actual = transform(encode_seq) | ||
| expected = torch.tensor([8774, 1150, 55, 6, 149, 33, 25, 58, 1]) | ||
| self.assertEqual(actual, expected) | ||
|
|
||
| # test encode; input is a batched string | ||
| encode_seq = ["Hello World!, how are you?"] | ||
| actual = transform(encode_seq) | ||
| expected = torch.tensor([[8774, 1150, 55, 6, 149, 33, 25, 58, 1]]) | ||
| self.assertEqual(actual, expected) | ||
|
|
||
| # test decode; input is a list of token ids | ||
| decode_seq = [8774, 1150, 55, 6, 149, 33, 25, 58, 1] | ||
| actual = transform.decode(decode_seq) | ||
| expected = "Hello World!, how are you?" | ||
| self.assertEqual(actual, expected) | ||
|
|
||
| # test decode; input is a batched list of token ids | ||
| decode_seq = [[8774, 1150, 55, 6, 149, 33, 25, 58, 1]] | ||
| actual = transform.decode(decode_seq) | ||
| expected = ["Hello World!, how are you?"] | ||
| self.assertEqual(actual, expected) | ||
|
|
||
| def test_t5tokenizer(self): | ||
| """test tokenization on string input (encode) and translation from token ids to strings (decode)""" | ||
| self._t5tokenizer(test_scripting=False) | ||
|
|
||
| def test_t5tokenizer_jit(self): | ||
| """test tokenization on string input (encode) and translation from token ids to strings (decode) with scripting""" | ||
| self._t5tokenizer(test_scripting=True) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| from typing import List, Union | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torchtext.transforms as T | ||
| from torchtext.data.functional import load_sp_model | ||
| from torchtext.functional import to_tensor | ||
| from torchtext.utils import get_asset_local_path | ||
|
|
||
|
|
||
| class T5Transform(nn.Module): | ||
| """ | ||
| This transform makes use of a pre-trained sentencepiece model to tokenize text input. The resulting output is fed to the T5 model. | ||
|
|
||
| Additional details: https://github.com/google/sentencepiece | ||
|
|
||
| :param sp_model_path: Path to pre-trained sentencepiece model | ||
| :type sp_model_path: str | ||
| :param max_seq_len: Maximum sequence length accepted for inputs to T5 model | ||
| :type max_seq_len: int | ||
| :param eos_idx: End-of-sequence token id | ||
| :type eos_idx: int | ||
| :param padding_idx: Padding token id | ||
| :type padding_idx: int | ||
|
|
||
| Example | ||
| >>> from torchtext.prototype.models import T5Transform | ||
| >>> transform = T5Transform("spm_model", max_seq_len = 10, eos_idx = 1, padding_idx = 0) | ||
| >>> transform(["hello world", "attention is all you need!"]) | ||
| """ | ||
|
|
||
| def __init__(self, sp_model_path: str, max_seq_len: int, eos_idx: int, padding_idx: int): | ||
| super().__init__() | ||
| self.sp_model = load_sp_model(get_asset_local_path(sp_model_path)) | ||
| self.max_seq_len = max_seq_len | ||
| self.eos_idx = eos_idx | ||
| self.padding_idx = padding_idx | ||
| self.pipeline = T.Sequential(T.Truncate(self.max_seq_len), T.AddToken(token=self.eos_idx, begin=False)) | ||
pmabbo13 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def forward(self, input: Union[str, List[str]]) -> torch.Tensor: | ||
| """ | ||
| :param input: Input sentence or list of sentences to tokenize. | ||
| :type input: Union[str, List[str]] | ||
| :return: Tokenized text that has been truncated, appended with end-of-sequence token, and padded | ||
| :rtype: torch.Tensor | ||
| """ | ||
| tokens = self.encode(input) | ||
| out = to_tensor(self.pipeline(tokens), padding_value=self.padding_idx) | ||
| return out | ||
|
|
||
| @torch.jit.export | ||
| def encode(self, input: Union[str, List[str]]) -> Union[List[int], List[List[int]]]: | ||
| """ | ||
| :param input: Input sentence or list of sentences to tokenize. | ||
| :type input: Union[str, List[str]] | ||
| :return: Tokenized text that has been translated to token ids | ||
| :rtype: Union[List[int], List[List[int]]] | ||
| """ | ||
| if torch.jit.isinstance(input, List[str]): | ||
| tokens: List[List[int]] = [] | ||
| for text in input: | ||
| tokens.append(self.sp_model.EncodeAsIds(text)) | ||
| return tokens | ||
| elif torch.jit.isinstance(input, str): | ||
| return self.sp_model.EncodeAsIds(input) | ||
| else: | ||
| raise TypeError("Input type not supported") | ||
|
|
||
| @torch.jit.export | ||
| def decode(self, input: Union[List[int], List[List[int]]]) -> Union[str, List[str]]: | ||
| """ | ||
| :param input: List of token ids or list of lists of token ids (i.e. batched). | ||
| :type input: Union[List[int], List[List[int]]] | ||
| :return: Sentence or list of sentencess that were translated from the input token ids | ||
| :rtype: Union[str, List[str]] | ||
| """ | ||
| if torch.jit.isinstance(input, List[List[int]]): | ||
| tokens: List[str] = [] | ||
| for ids in input: | ||
| tokens.append(self.sp_model.DecodeIds(ids)) | ||
| return tokens | ||
| elif torch.jit.isinstance(input, List[int]): | ||
| return self.sp_model.DecodeIds(input) | ||
| else: | ||
| raise TypeError("Input type not supported") | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding new asset file, we should probably work with existing assets if available. In this case, shall we try working with
spm_example.model?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Nayef211 and I were actually debating how best to approach this, because if we used
spm_example.modelthen we'd essentially be testing for functional correctness. But since T5Transform is so similar to SentencePieceTokenizer except that it includes additional transformations specific to T5, we thought it made more sense to tailor the test towards t5 specifically as opposed to a general spm model.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@parmeet if we use the existing
spm_example.model, these tests do not add as much value as we already have specific tests for the SentencePiece tokenizer. As @pmabbo13 mentioned, if we want to test that the output of the T5Transform is equal to that of the T5Transform in HF, then it would make sense to make use of the spm model specific to T5. Also the asset is around 700 KB which is less than some of the existing assets we've checked in. Lmk what you think!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the overall sentiment here and it's a good argument for adding the actual asset file. But then this make me wonder if we are really unit-testing the functional correctness of the transform implementation or actually testing the asset file :).
That said, I think we would also be needing this for integration testing, since we need a real output in there instead of dummy output from any SPM model file. So I think I agree with you both, adding the actually asset file would make sense!