-
Notifications
You must be signed in to change notification settings - Fork 814
T5Transform text pre-processing for t5 model #1852
Conversation
|
|
||
| class TestTransforms(TorchtextTestCase): | ||
| def _t5tokenizer(self, test_scripting): | ||
| asset_name = "t5_tokenizer_base.model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding new asset file, we should probably work with existing assets if available. In this case, shall we try working with spm_example.model ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Nayef211 and I were actually debating how best to approach this, because if we used spm_example.model then we'd essentially be testing for functional correctness. But since T5Transform is so similar to SentencePieceTokenizer except that it includes additional transformations specific to T5, we thought it made more sense to tailor the test towards t5 specifically as opposed to a general spm model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@parmeet if we use the existing spm_example.model, these tests do not add as much value as we already have specific tests for the SentencePiece tokenizer. As @pmabbo13 mentioned, if we want to test that the output of the T5Transform is equal to that of the T5Transform in HF, then it would make sense to make use of the spm model specific to T5. Also the asset is around 700 KB which is less than some of the existing assets we've checked in. Lmk what you think!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the overall sentiment here and it's a good argument for adding the actual asset file. But then this make me wonder if we are really unit-testing the functional correctness of the transform implementation or actually testing the asset file :).
That said, I think we would also be needing this for integration testing, since we need a real output in there instead of dummy output from any SPM model file. So I think I agree with you both, adding the actually asset file would make sense!
| self.padding_idx = padding_idx | ||
| self.pipeline = T.Sequential(T.Truncate(self.max_seq_len), T.AddToken(token=self.eos_idx, begin=False)) | ||
|
|
||
| def forward(self, input: Any) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we specify the input as Any instead of Union[str, List[str]]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I initially had them typed, but noticed that SentencePieceTokenizer had them as Any so deferred to that. I will revert it back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we specify the input as
Anyinstead ofUnion[str, List[str]]?
I guess @pmabbo13 might have followed what we did in our transform implementation where we always use Any as type for input. The reason for this is to ensure when transforms are combined in SequentialTransform, the overall transform is still scriptable. More details about this issue can be found here. As for T5Transform, if we do not expect it to be used in SequentialTransform and treat it as a standalone one, I agree that we could just use the right annotation types as suggested above.
parmeet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @pmabbo13 for adding the transform class.
Description
Add a transformation class that takes string inputs and prepares them to be passed into a T5 encoder. The transformation class should also have a decode method that translates token ids back into strings. This will be used to translate the sequences generated by the T5 decoder.
Process
T5Transform is instantiated by providing a path to a pre-trained SentencePiece model, the maximum sequence length (used for truncation), the padding index, and the end-of-sequence index.
Its forward method accepts a single string, or a batch of strings, and uses the pre-trained SentencePiece model to tokenize the string(s) and translate the tokens to their corresponding ids. Then the resulting sequences are truncated and an end-of-sequence token is added to each. Finally, the sequences are padded to the length of the longest sequence in the batch.
Its decode method accepts a single list of token ids, or a batch of these lists to represent multiple sequences. The pre-trained SentencePiece model is then used to translate them back into tokens and merge the tokens into a single string per sequence.
Test
We test that the forward method correctly translates an input string (batched and un-batched) into the appropriate token ids, with the special tokens added where necessary. We also test that the decode method correctly translates token ids (batched and un-batched) into the correct strings. The torch-scripted versions of these transforms are also tested.
pytest test/prototype/models/test_transforms.py