diff --git a/test/asset/t5.base.encoder.output.pt b/test/asset/t5.base.encoder.output.pt index 9d56557c42..e77270b114 100644 Binary files a/test/asset/t5.base.encoder.output.pt and b/test/asset/t5.base.encoder.output.pt differ diff --git a/test/asset/t5.base.output.pt b/test/asset/t5.base.output.pt index 384b59074d..5789499927 100644 Binary files a/test/asset/t5.base.output.pt and b/test/asset/t5.base.output.pt differ diff --git a/test/prototype/integration_tests/test_models.py b/test/prototype/integration_tests/test_models.py index 01d728b179..27fc313acc 100644 --- a/test/prototype/integration_tests/test_models.py +++ b/test/prototype/integration_tests/test_models.py @@ -8,14 +8,16 @@ class TestT5(TorchtextTestCase): - def _t5_model(self, t5_model, expected_asset_name, model_input): + def _t5_model(self, t5_model, expected_asset_name, test_text): """Verify that pre-trained T5 models in torchtext produce the same output as the HuggingFace reference implementation. """ expected_asset_path = get_asset_path(expected_asset_name) + transform = t5_model.transform() model = t5_model.get_model() model = model.eval() + model_input = transform(test_text) if model.encoder_only: actual = model(model_input)["encoder_output"] else: @@ -26,10 +28,10 @@ def _t5_model(self, t5_model, expected_asset_name, model_input): def test_t5_base_encoder_model(self): expected_asset_name = "t5.base.encoder.output.pt" - model_input = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) - self._t5_model(t5_model=T5_BASE_ENCODER, expected_asset_name=expected_asset_name, model_input=model_input) + test_text = ["Hello world", "Attention rocks!"] + self._t5_model(t5_model=T5_BASE_ENCODER, expected_asset_name=expected_asset_name, test_text=test_text) def test_t5_base_model(self): expected_asset_name = "t5.base.output.pt" - model_input = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]]) - self._t5_model(t5_model=T5_BASE, expected_asset_name=expected_asset_name, model_input=model_input) + test_text = ["Hello world", "Attention rocks!"] + self._t5_model(t5_model=T5_BASE, expected_asset_name=expected_asset_name, test_text=test_text) diff --git a/torchtext/prototype/models/t5/bundler.py b/torchtext/prototype/models/t5/bundler.py index ed2262db82..0a52007aae 100644 --- a/torchtext/prototype/models/t5/bundler.py +++ b/torchtext/prototype/models/t5/bundler.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union from urllib.parse import urljoin import torch @@ -8,28 +8,33 @@ from torchtext._download_hooks import load_state_dict_from_url from .model import T5Conf, T5Model +from .t5_transform import T5Transform logger = logging.getLogger(__name__) @dataclass class T5Bundle: - """T5Bundle(_config: torchtext.prototype.models.T5Conf, _path: Optional[str] = None) + """T5Bundle(_config: torchtext.prototype.models.T5Conf, _path: Optional[str] = None, transform: Optional[Callable] = None) Example - Pretrained base t5 encoder >>> import torch, torchtext >>> t5_encoder_base = torchtext.prototype.models.T5_BASE_ENCODER + >>> transform = t5_encoder_base.transform() + >>> input_seq = ["Hello world", "Attention rocks!"] >>> model = t5_encoder_base.get_model() - >>> model_input = torch.tensor([[1,2,3,4,5,6],[7,8,9,0,0,0]]) + >>> model_input = transform(input_seq) >>> output = model(model_input)['encoder_output'] >>> output.shape - torch.Size([2, 6, 768]) + torch.Size([2, 4, 768]) Example - Pretrained base t5 model >>> import torch, torchtext >>> t5_base = torchtext.prototype.models.T5_BASE + >>> transform = t5_base.transform() + >>> input_seq = ["Hello world", "Attention rocks!"] >>> model = t5_base.get_model() - >>> model_input = torch.tensor([[1,2,3,4,5,6],[7,8,9,0,0,0]]) + >>> model_input = transform(input_seq) >>> output = model(model_input)['decoder_output'] >>> output.shape torch.Size([2, 1, 768]) @@ -43,6 +48,7 @@ class T5Bundle: _config: T5Conf _path: Optional[str] = None + transform: Optional[Callable] = None def get_model( self, @@ -122,6 +128,12 @@ def config(self) -> T5Conf: T5_BASE_ENCODER = T5Bundle( _path=urljoin(_TEXT_BUCKET, "t5.base.encoder.pt"), _config=T5Conf(encoder_only=True), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), ) T5_BASE_ENCODER.__doc__ = """ @@ -146,6 +158,12 @@ def config(self) -> T5Conf: T5_BASE = T5Bundle( _path=urljoin(_TEXT_BUCKET, "t5.base.pt"), _config=T5Conf(encoder_only=False), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), ) T5_BASE.__doc__ = """