Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit f450271

Browse files
authored
Prepare T5 Model for Language Generation (#1862)
* modify T5 model to include linear head for language generation * create t5_base_generation bundler * test t5_base_configuration * test model training * nit corrections * loosening tolerance for integration tests * removing redundant tests * change target for train testing since original target led to no learning * remove tie_word_embeddings since not nedded * add example using T5_BASE_GENERATION in bundler docstring
1 parent 4a5f11c commit f450271

File tree

6 files changed

+160
-25
lines changed

6 files changed

+160
-25
lines changed
252 KB
Binary file not shown.

test/prototype/integration_tests/test_models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torchtext.prototype.models import (
55
T5_BASE_ENCODER,
66
T5_BASE,
7+
T5_BASE_GENERATION,
78
)
89

910

@@ -24,7 +25,7 @@ def _t5_model(self, t5_model, expected_asset_name, test_text):
2425
actual = model(model_input)["decoder_output"]
2526

2627
expected = torch.load(expected_asset_path)
27-
torch.testing.assert_close(actual, expected)
28+
torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06)
2829

2930
def test_t5_base_encoder_model(self):
3031
expected_asset_name = "t5.base.encoder.output.pt"
@@ -35,3 +36,8 @@ def test_t5_base_model(self):
3536
expected_asset_name = "t5.base.output.pt"
3637
test_text = ["Hello world", "Attention rocks!"]
3738
self._t5_model(t5_model=T5_BASE, expected_asset_name=expected_asset_name, test_text=test_text)
39+
40+
def test_t5_base_generation_model(self):
41+
expected_asset_name = "t5.base.generation.output.pt"
42+
test_text = ["Hello world", "Attention rocks!"]
43+
self._t5_model(t5_model=T5_BASE_GENERATION, expected_asset_name=expected_asset_name, test_text=test_text)

test/prototype/models/test_models.py

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,77 @@
1+
import copy
12
from unittest.mock import patch
23

4+
import torch
35
from test.common.torchtext_test_case import TorchtextTestCase
6+
from torch.nn import functional as F
47

58

69
class TestModels(TorchtextTestCase):
710
def test_t5_bundler_build_model(self):
811
from torchtext.prototype.models import T5Conf, T5Model, T5Bundle
912

10-
# case: user provide encoder checkpoint state dict
13+
# case: user provides encoder checkpoint state dict
1114
dummy_encoder_conf = T5Conf(
1215
encoder_only=True,
1316
vocab_size=10,
1417
embedding_dim=16,
1518
ffn_dimension=64,
1619
num_attention_heads=2,
1720
num_encoder_layers=2,
21+
num_decoder_layers=2,
1822
)
1923
dummy_t5_encoder = T5Model(dummy_encoder_conf)
2024
t5_encoder_model = T5Bundle.build_model(config=dummy_encoder_conf, checkpoint=dummy_t5_encoder.state_dict())
2125
self.assertEqual(t5_encoder_model.state_dict(), dummy_t5_encoder.state_dict())
2226

23-
# case: user provide encoder-decoder checkpoint state dict
27+
# case: user provides encoder-decoder checkpoint state dict
2428
dummy_t5_conf = T5Conf(
2529
encoder_only=False,
2630
vocab_size=10,
2731
embedding_dim=16,
2832
ffn_dimension=64,
2933
num_attention_heads=2,
3034
num_encoder_layers=2,
35+
num_decoder_layers=2,
3136
)
3237
dummy_t5 = T5Model(dummy_t5_conf)
3338
t5_model = T5Bundle.build_model(config=dummy_t5_conf, checkpoint=dummy_t5.state_dict())
3439
self.assertEqual(t5_model.state_dict(), dummy_t5.state_dict())
3540

36-
@patch("logging.Logger.warning")
37-
def test_t5_bundler_get_model(self, mock):
38-
from torchtext.prototype.models import T5Conf, T5Bundle
39-
40-
# encoder-only
41-
dummy_encoder_conf = T5Conf(
42-
encoder_only=True,
41+
# case: user provides checkpoint state dict for encoder-decoder with generation
42+
dummy_t5_generation_conf = T5Conf(
43+
encoder_only=False,
44+
linear_head=True,
4345
vocab_size=10,
4446
embedding_dim=16,
4547
ffn_dimension=64,
4648
num_attention_heads=2,
4749
num_encoder_layers=2,
50+
num_decoder_layers=2,
4851
)
49-
encoder_bundle = T5Bundle(dummy_encoder_conf)
50-
encoder_bundle.get_model(load_weights=False, freeze_model=True)
51-
mock.assert_called_with(
52-
"The model is not loaded with pre-trained weights. Setting freeze_model to True will hinder model from learning appropriate weights."
52+
dummy_t5_generation = T5Model(dummy_t5_generation_conf)
53+
t5_generation_model = T5Bundle.build_model(
54+
config=dummy_t5_generation_conf, checkpoint=dummy_t5_generation.state_dict()
5355
)
56+
self.assertEqual(t5_generation_model.state_dict(), dummy_t5_generation.state_dict())
5457

55-
# encoder-decoder
56-
dummy_t5_conf = T5Conf(
58+
@patch("logging.Logger.warning")
59+
def test_t5_bundler_get_model(self, mock):
60+
from torchtext.prototype.models import T5Conf, T5Bundle
61+
62+
# encoder-decoder with generation
63+
dummy_t5_generation_conf = T5Conf(
5764
encoder_only=False,
65+
linear_head=True,
5866
vocab_size=10,
5967
embedding_dim=16,
6068
ffn_dimension=64,
6169
num_attention_heads=2,
6270
num_encoder_layers=2,
71+
num_decoder_layers=2,
6372
)
64-
t5_bundle = T5Bundle(dummy_t5_conf)
65-
t5_bundle.get_model(load_weights=False, freeze_model=True)
73+
t5_generation_bundle = T5Bundle(dummy_t5_generation_conf)
74+
t5_generation_bundle.get_model(load_weights=False, freeze_model=True)
6675
mock.assert_called_with(
6776
"The model is not loaded with pre-trained weights. Setting freeze_model to True will hinder model from learning appropriate weights."
6877
)
@@ -79,6 +88,7 @@ def test_t5_bundler_raise_checkpoint(self):
7988
ffn_dimension=64,
8089
num_attention_heads=2,
8190
num_encoder_layers=2,
91+
num_decoder_layers=2,
8292
)
8393
T5Bundle.build_model(
8494
config=dummy_encoder_conf,
@@ -95,13 +105,32 @@ def test_t5_bundler_raise_checkpoint(self):
95105
ffn_dimension=64,
96106
num_attention_heads=2,
97107
num_encoder_layers=2,
108+
num_decoder_layers=2,
98109
)
99110
T5Bundle.build_model(
100111
config=dummy_t5_conf,
101112
freeze_model=True,
102113
checkpoint=1,
103114
)
104115

116+
# encoder-decoder with generation
117+
with self.assertRaises(TypeError):
118+
dummy_t5_generation_conf = T5Conf(
119+
encoder_only=False,
120+
linear_head=True,
121+
vocab_size=10,
122+
embedding_dim=16,
123+
ffn_dimension=64,
124+
num_attention_heads=2,
125+
num_encoder_layers=2,
126+
num_decoder_layers=2,
127+
)
128+
T5Bundle.build_model(
129+
config=dummy_t5_generation_conf,
130+
freeze_model=True,
131+
checkpoint=1,
132+
)
133+
105134
def test_t5_bundler_conf_property(self):
106135
from torchtext.prototype.models import T5Conf, T5Bundle
107136

@@ -112,6 +141,43 @@ def test_t5_bundler_conf_property(self):
112141
ffn_dimension=64,
113142
num_attention_heads=2,
114143
num_encoder_layers=2,
144+
num_decoder_layers=2,
115145
)
116146
t5_bundle = T5Bundle(dummy_t5_conf)
117147
self.assertTrue(isinstance(t5_bundle.config, T5Conf))
148+
149+
def test_t5_bundler_train(self):
150+
from torch.optim import SGD
151+
from torchtext.prototype.models import T5Conf, T5Model, T5Bundle
152+
153+
def _train(model):
154+
optim = SGD(model.parameters(), lr=1)
155+
model_input = torch.tensor([[1, 2, 3, 4, 5]])
156+
target = torch.tensor([1])
157+
output = model(model_input)["decoder_output"]
158+
logits = F.log_softmax(output[:, -1], dim=-1)
159+
loss = F.cross_entropy(logits, target)
160+
loss.backward()
161+
optim.step()
162+
163+
dummy_conf = T5Conf(
164+
encoder_only=False,
165+
linear_head=True,
166+
vocab_size=10,
167+
embedding_dim=16,
168+
ffn_dimension=64,
169+
num_attention_heads=2,
170+
num_encoder_layers=2,
171+
num_decoder_layers=2,
172+
training=True,
173+
)
174+
dummy_model = T5Model(dummy_conf)
175+
model = T5Bundle.build_model(
176+
config=dummy_conf,
177+
freeze_model=False,
178+
checkpoint=dummy_model.state_dict(),
179+
)
180+
current_state_dict = copy.deepcopy(model.state_dict())
181+
182+
_train(model)
183+
self.assertNotEqual(model.state_dict(), current_state_dict)

torchtext/prototype/models/t5/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .bundler import (
22
T5_BASE_ENCODER,
33
T5_BASE,
4+
T5_BASE_GENERATION,
45
T5Bundle,
56
)
67
from .model import T5Conf, T5Model
@@ -12,5 +13,6 @@
1213
"T5Bundle",
1314
"T5_BASE_ENCODER",
1415
"T5_BASE",
16+
"T5_BASE_GENERATION",
1517
"T5Transform",
1618
]

torchtext/prototype/models/t5/bundler.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ class T5Bundle:
3939
>>> output.shape
4040
torch.Size([2, 1, 768])
4141
42+
Example - Pretrained base t5 model for generation
43+
>>> import torch, torchtext
44+
>>> import torch.nn.functional as F
45+
>>> t5_base_generation = torchtext.prototype.models.T5_BASE_GENERATION
46+
>>> transform = t5_base_generation.transform()
47+
>>> input_seq = ["Hello world", "Attention rocks!"]
48+
>>> model = t5_base_generation.get_model()
49+
>>> model_input = transform(input_seq)
50+
>>> output = model(model_input)['decoder_output']
51+
>>> logits = F.log_softmax(output[:,-1], dim=-1)
52+
>>> logits.shape
53+
torch.Size([2, 1, 32128])
54+
4255
Example - User-specified configuration and checkpoint
4356
>>> from torchtext.prototype.models import T5Conf, T5Bundle
4457
>>> model_weights_path = "https://download.pytorch.org/models/text/t5.base.encoder.pt"
@@ -137,7 +150,8 @@ def config(self) -> T5Conf:
137150
)
138151

139152
T5_BASE_ENCODER.__doc__ = """
140-
T5 Encoder with Base configuration
153+
T5_BASE_ENCODER is an encoder-only model from a pre-trained T5 model with the base configuration..
154+
It returns the normalized output from the final layer of the encoder.
141155
142156
The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
143157
<http://jmlr.org/papers/v21/20-074.html>`. It introduces a unified framework that converts text-based
@@ -167,7 +181,39 @@ def config(self) -> T5Conf:
167181
)
168182

169183
T5_BASE.__doc__ = """
170-
T5 Encoder-Decoder with Base configuration
184+
T5_BASE is an encoder-decoder model from a pre-trained T5 model with the base configuration.
185+
It returns the normalized output from the final layer of the decoder.
186+
187+
The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
188+
<http://jmlr.org/papers/v21/20-074.html>`. It introduces a unified framework that converts text-based
189+
language problems, such as translation, question-answering, and summarization, into a text-to-text format. The
190+
Colossal Clean Crawled Corpus (C4) dataset is used to pre-train the model on a masked language modeling task,
191+
and various datasets are used to fine-tune the model on each downstream task. The model's architecture is a modified version
192+
of the canonical Transformer architecture.
193+
194+
Originally published by the authors of T5 under Apache License, Version 2.0
195+
and redistributed with the same license.
196+
[`License <https://github.com/google-research/text-to-text-transfer-transformer/blob/main/LICENSE>`__,
197+
`Source <https://github.com/google-research/text-to-text-transfer-transformer#released-model-checkpoints>`__]
198+
199+
Please refer to :func:`torchtext.prototype.models.T5Bundle` for the usage.
200+
"""
201+
202+
T5_BASE_GENERATION = T5Bundle(
203+
_path=urljoin(_TEXT_BUCKET, "t5.base.generation.pt"),
204+
_config=T5Conf(encoder_only=False, linear_head=True),
205+
transform=lambda: T5Transform(
206+
urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"),
207+
max_seq_len=512,
208+
eos_idx=1,
209+
padding_idx=0,
210+
),
211+
)
212+
213+
T5_BASE_GENERATION.__doc__ = """
214+
T5_BASE_GENERATION is an encoder-decoder model from a pre-trained T5 model with the base configuration.
215+
It returns the output of the final layer of the decoder after passing through a linear layer to project the hidden states to
216+
the model vocabulary. This output can then be used for language generation.
171217
172218
The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
173219
<http://jmlr.org/papers/v21/20-074.html>`. It introduces a unified framework that converts text-based

torchtext/prototype/models/t5/model.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
@dataclass
1212
class T5Conf:
1313
encoder_only: bool = False
14+
linear_head: bool = False
1415
embedding_dim: int = 768
1516
num_attention_heads: int = 12
1617
num_encoder_layers: int = 12
@@ -35,7 +36,8 @@ class T5Model(nn.Module):
3536
Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research.
3637
Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html
3738
Args:
38-
config.encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (required)
39+
config.encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (default=False).
40+
config.linear_head: Whether or not a linear layer should be used to project the output of the decoder's last layer to the vocab (default=False).
3941
config.embedding_dim: Number of expected features in the encoder/decoder inputs (default=768).
4042
config.num_attention_heads: Number of heads in the multiheadattention models (default=12).
4143
config.num_encoder_layers: Number of encoder layers in the encoder (default=12).
@@ -55,11 +57,12 @@ class T5Model(nn.Module):
5557
freeze: Indicates whether or not to freeze the model weights. (default: False)
5658
Examples:
5759
>>> from torchtext.prototype.models import T5Conf, T5Model
58-
>>> t5_config = T5Conf(encoder_only=False)
60+
>>> t5_config = T5Conf(encoder_only=False, linear_head=True)
5961
>>> t5_model = T5Model(t5_config)
60-
>>> encoder_input = torch.rand((32, 10, 512))
61-
>>> decoder_input = torch.rand((32, 20, 512))
62-
>>> out = t5_model(encoder_input, decoder_input)
62+
>>> encoder_input = torch.randint(0, t5_config.vocab_size, (32, 512))
63+
>>> out = t5_model(encoder_input)['decoder_output']
64+
>>> out.shape
65+
torch.Size([32, 1, 32128])
6366
"""
6467

6568
def __init__(
@@ -73,7 +76,9 @@ def __init__(
7376

7477
assert isinstance(config, T5Conf)
7578

79+
self.config = config
7680
self.encoder_only = config.encoder_only
81+
self.linear_head = config.linear_head
7782
self.padding_idx = config.padding_idx
7883
self.training = config.training
7984
self.dropout = config.dropout if config.training else 0.0
@@ -118,6 +123,9 @@ def __init__(
118123
self.dropout3 = nn.Dropout(self.dropout)
119124
self.dropout4 = nn.Dropout(self.dropout)
120125

126+
if config.linear_head:
127+
self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
128+
121129
if freeze:
122130
for p in self.parameters():
123131
p.requires_grad = False
@@ -191,6 +199,13 @@ def forward(
191199
decoder_output = self.dropout4(decoder_output)
192200
decoder_hidden_states = decoder_hidden_states + (decoder_output,)
193201

202+
if self.linear_head:
203+
# Rescale output before projecting on vocab. This happens when the encoder and decoder share the
204+
# same word embeddings, which is always the case in our t5 implementation.
205+
# See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661
206+
decoder_output = decoder_output * (self.config.embedding_dim ** -0.5)
207+
decoder_output = self.lm_head(decoder_output)
208+
194209
t5_output = {
195210
"encoder_output": encoder_output,
196211
"encoder_hidden_states": encoder_hidden_states,

0 commit comments

Comments
 (0)