Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Binary file added test/asset/t5.base.generation.output.pt
Binary file not shown.
8 changes: 7 additions & 1 deletion test/prototype/integration_tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torchtext.prototype.models import (
T5_BASE_ENCODER,
T5_BASE,
T5_BASE_GENERATION,
)


Expand All @@ -24,7 +25,7 @@ def _t5_model(self, t5_model, expected_asset_name, test_text):
actual = model(model_input)["decoder_output"]

expected = torch.load(expected_asset_path)
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06)

def test_t5_base_encoder_model(self):
expected_asset_name = "t5.base.encoder.output.pt"
Expand All @@ -35,3 +36,8 @@ def test_t5_base_model(self):
expected_asset_name = "t5.base.output.pt"
test_text = ["Hello world", "Attention rocks!"]
self._t5_model(t5_model=T5_BASE, expected_asset_name=expected_asset_name, test_text=test_text)

def test_t5_base_generation_model(self):
expected_asset_name = "t5.base.generation.output.pt"
test_text = ["Hello world", "Attention rocks!"]
self._t5_model(t5_model=T5_BASE_GENERATION, expected_asset_name=expected_asset_name, test_text=test_text)
100 changes: 83 additions & 17 deletions test/prototype/models/test_models.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,77 @@
import copy
from unittest.mock import patch

import torch
from test.common.torchtext_test_case import TorchtextTestCase
from torch.nn import functional as F


class TestModels(TorchtextTestCase):
def test_t5_bundler_build_model(self):
from torchtext.prototype.models import T5Conf, T5Model, T5Bundle

# case: user provide encoder checkpoint state dict
# case: user provides encoder checkpoint state dict
dummy_encoder_conf = T5Conf(
encoder_only=True,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
num_decoder_layers=2,
)
dummy_t5_encoder = T5Model(dummy_encoder_conf)
t5_encoder_model = T5Bundle.build_model(config=dummy_encoder_conf, checkpoint=dummy_t5_encoder.state_dict())
self.assertEqual(t5_encoder_model.state_dict(), dummy_t5_encoder.state_dict())

# case: user provide encoder-decoder checkpoint state dict
# case: user provides encoder-decoder checkpoint state dict
dummy_t5_conf = T5Conf(
encoder_only=False,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
num_decoder_layers=2,
)
dummy_t5 = T5Model(dummy_t5_conf)
t5_model = T5Bundle.build_model(config=dummy_t5_conf, checkpoint=dummy_t5.state_dict())
self.assertEqual(t5_model.state_dict(), dummy_t5.state_dict())

@patch("logging.Logger.warning")
def test_t5_bundler_get_model(self, mock):
from torchtext.prototype.models import T5Conf, T5Bundle

# encoder-only
dummy_encoder_conf = T5Conf(
encoder_only=True,
# case: user provides checkpoint state dict for encoder-decoder with generation
dummy_t5_generation_conf = T5Conf(
encoder_only=False,
linear_head=True,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
num_decoder_layers=2,
)
encoder_bundle = T5Bundle(dummy_encoder_conf)
encoder_bundle.get_model(load_weights=False, freeze_model=True)
mock.assert_called_with(
"The model is not loaded with pre-trained weights. Setting freeze_model to True will hinder model from learning appropriate weights."
dummy_t5_generation = T5Model(dummy_t5_generation_conf)
t5_generation_model = T5Bundle.build_model(
config=dummy_t5_generation_conf, checkpoint=dummy_t5_generation.state_dict()
)
self.assertEqual(t5_generation_model.state_dict(), dummy_t5_generation.state_dict())

# encoder-decoder
dummy_t5_conf = T5Conf(
@patch("logging.Logger.warning")
def test_t5_bundler_get_model(self, mock):
from torchtext.prototype.models import T5Conf, T5Bundle

# encoder-decoder with generation
dummy_t5_generation_conf = T5Conf(
encoder_only=False,
linear_head=True,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
num_decoder_layers=2,
)
t5_bundle = T5Bundle(dummy_t5_conf)
t5_bundle.get_model(load_weights=False, freeze_model=True)
t5_generation_bundle = T5Bundle(dummy_t5_generation_conf)
t5_generation_bundle.get_model(load_weights=False, freeze_model=True)
mock.assert_called_with(
"The model is not loaded with pre-trained weights. Setting freeze_model to True will hinder model from learning appropriate weights."
)
Expand All @@ -79,6 +88,7 @@ def test_t5_bundler_raise_checkpoint(self):
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
num_decoder_layers=2,
)
T5Bundle.build_model(
config=dummy_encoder_conf,
Expand All @@ -95,13 +105,32 @@ def test_t5_bundler_raise_checkpoint(self):
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
num_decoder_layers=2,
)
T5Bundle.build_model(
config=dummy_t5_conf,
freeze_model=True,
checkpoint=1,
)

# encoder-decoder with generation
with self.assertRaises(TypeError):
dummy_t5_generation_conf = T5Conf(
encoder_only=False,
linear_head=True,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
num_decoder_layers=2,
)
T5Bundle.build_model(
config=dummy_t5_generation_conf,
freeze_model=True,
checkpoint=1,
)

def test_t5_bundler_conf_property(self):
from torchtext.prototype.models import T5Conf, T5Bundle

Expand All @@ -112,6 +141,43 @@ def test_t5_bundler_conf_property(self):
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
num_decoder_layers=2,
)
t5_bundle = T5Bundle(dummy_t5_conf)
self.assertTrue(isinstance(t5_bundle.config, T5Conf))

def test_t5_bundler_train(self):
from torch.optim import SGD
from torchtext.prototype.models import T5Conf, T5Model, T5Bundle

def _train(model):
optim = SGD(model.parameters(), lr=1)
model_input = torch.tensor([[1, 2, 3, 4, 5]])
target = torch.tensor([1])
output = model(model_input)["decoder_output"]
logits = F.log_softmax(output[:, -1], dim=-1)
loss = F.cross_entropy(logits, target)
loss.backward()
optim.step()

dummy_conf = T5Conf(
encoder_only=False,
linear_head=True,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
num_decoder_layers=2,
training=True,
)
dummy_model = T5Model(dummy_conf)
model = T5Bundle.build_model(
config=dummy_conf,
freeze_model=False,
checkpoint=dummy_model.state_dict(),
)
current_state_dict = copy.deepcopy(model.state_dict())

_train(model)
self.assertNotEqual(model.state_dict(), current_state_dict)
2 changes: 2 additions & 0 deletions torchtext/prototype/models/t5/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .bundler import (
T5_BASE_ENCODER,
T5_BASE,
T5_BASE_GENERATION,
T5Bundle,
)
from .model import T5Conf, T5Model
Expand All @@ -12,5 +13,6 @@
"T5Bundle",
"T5_BASE_ENCODER",
"T5_BASE",
"T5_BASE_GENERATION",
"T5Transform",
]
50 changes: 48 additions & 2 deletions torchtext/prototype/models/t5/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ class T5Bundle:
>>> output.shape
torch.Size([2, 1, 768])

Example - Pretrained base t5 model for generation
>>> import torch, torchtext
>>> import torch.nn.functional as F
>>> t5_base_generation = torchtext.prototype.models.T5_BASE_GENERATION
>>> transform = t5_base_generation.transform()
>>> input_seq = ["Hello world", "Attention rocks!"]
>>> model = t5_base_generation.get_model()
>>> model_input = transform(input_seq)
>>> output = model(model_input)['decoder_output']
>>> logits = F.log_softmax(output[:,-1], dim=-1)
>>> logits.shape
torch.Size([2, 1, 32128])

Example - User-specified configuration and checkpoint
>>> from torchtext.prototype.models import T5Conf, T5Bundle
>>> model_weights_path = "https://download.pytorch.org/models/text/t5.base.encoder.pt"
Expand Down Expand Up @@ -137,7 +150,8 @@ def config(self) -> T5Conf:
)

T5_BASE_ENCODER.__doc__ = """
T5 Encoder with Base configuration
T5_BASE_ENCODER is an encoder-only model from a pre-trained T5 model with the base configuration..
It returns the normalized output from the final layer of the encoder.

The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
<http://jmlr.org/papers/v21/20-074.html>`. It introduces a unified framework that converts text-based
Expand Down Expand Up @@ -167,7 +181,39 @@ def config(self) -> T5Conf:
)

T5_BASE.__doc__ = """
T5 Encoder-Decoder with Base configuration
T5_BASE is an encoder-decoder model from a pre-trained T5 model with the base configuration.
It returns the normalized output from the final layer of the decoder.

The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
<http://jmlr.org/papers/v21/20-074.html>`. It introduces a unified framework that converts text-based
language problems, such as translation, question-answering, and summarization, into a text-to-text format. The
Colossal Clean Crawled Corpus (C4) dataset is used to pre-train the model on a masked language modeling task,
and various datasets are used to fine-tune the model on each downstream task. The model's architecture is a modified version
of the canonical Transformer architecture.

Originally published by the authors of T5 under Apache License, Version 2.0
and redistributed with the same license.
[`License <https://github.com/google-research/text-to-text-transfer-transformer/blob/main/LICENSE>`__,
`Source <https://github.com/google-research/text-to-text-transfer-transformer#released-model-checkpoints>`__]

Please refer to :func:`torchtext.prototype.models.T5Bundle` for the usage.
"""

T5_BASE_GENERATION = T5Bundle(
_path=urljoin(_TEXT_BUCKET, "t5.base.generation.pt"),
_config=T5Conf(encoder_only=False, linear_head=True),
transform=lambda: T5Transform(
urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"),
max_seq_len=512,
eos_idx=1,
padding_idx=0,
),
)

T5_BASE_GENERATION.__doc__ = """
T5_BASE_GENERATION is an encoder-decoder model from a pre-trained T5 model with the base configuration.
It returns the output of the final layer of the decoder after passing through a linear layer to project the hidden states to
the model vocabulary. This output can then be used for language generation.

The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
<http://jmlr.org/papers/v21/20-074.html>`. It introduces a unified framework that converts text-based
Expand Down
25 changes: 20 additions & 5 deletions torchtext/prototype/models/t5/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
@dataclass
class T5Conf:
encoder_only: bool = False
linear_head: bool = False
embedding_dim: int = 768
num_attention_heads: int = 12
num_encoder_layers: int = 12
Expand All @@ -35,7 +36,8 @@ class T5Model(nn.Module):
Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research.
Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html
Args:
config.encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (required)
config.encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (default=False).
config.linear_head: Whether or not a linear layer should be used to project the output of the decoder's last layer to the vocab (default=False).
config.embedding_dim: Number of expected features in the encoder/decoder inputs (default=768).
config.num_attention_heads: Number of heads in the multiheadattention models (default=12).
config.num_encoder_layers: Number of encoder layers in the encoder (default=12).
Expand All @@ -55,11 +57,12 @@ class T5Model(nn.Module):
freeze: Indicates whether or not to freeze the model weights. (default: False)
Examples:
>>> from torchtext.prototype.models import T5Conf, T5Model
>>> t5_config = T5Conf(encoder_only=False)
>>> t5_config = T5Conf(encoder_only=False, linear_head=True)
>>> t5_model = T5Model(t5_config)
>>> encoder_input = torch.rand((32, 10, 512))
>>> decoder_input = torch.rand((32, 20, 512))
>>> out = t5_model(encoder_input, decoder_input)
>>> encoder_input = torch.randint(0, t5_config.vocab_size, (32, 512))
>>> out = t5_model(encoder_input)['decoder_output']
>>> out.shape
torch.Size([32, 1, 32128])
"""

def __init__(
Expand All @@ -73,7 +76,9 @@ def __init__(

assert isinstance(config, T5Conf)

self.config = config
self.encoder_only = config.encoder_only
self.linear_head = config.linear_head
self.padding_idx = config.padding_idx
self.training = config.training
self.dropout = config.dropout if config.training else 0.0
Expand Down Expand Up @@ -118,6 +123,9 @@ def __init__(
self.dropout3 = nn.Dropout(self.dropout)
self.dropout4 = nn.Dropout(self.dropout)

if config.linear_head:
self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)

if freeze:
for p in self.parameters():
p.requires_grad = False
Expand Down Expand Up @@ -191,6 +199,13 @@ def forward(
decoder_output = self.dropout4(decoder_output)
decoder_hidden_states = decoder_hidden_states + (decoder_output,)

if self.linear_head:
# Rescale output before projecting on vocab. This happens when the encoder and decoder share the
# same word embeddings, which is always the case in our t5 implementation.
# See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661
decoder_output = decoder_output * (self.config.embedding_dim ** -0.5)
decoder_output = self.lm_head(decoder_output)

t5_output = {
"encoder_output": encoder_output,
"encoder_hidden_states": encoder_hidden_states,
Expand Down