Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e7bcf3c

Browse files
authored
Make T5 model torchscriptable (#1876)
* type annotate device * refactor relative_attention_bias * breaking out encoder and decoder layer and stacks * updating doc strings * correcting type annotations * update integration tests to test scripted version of models
1 parent e1b6984 commit e7bcf3c

File tree

3 files changed

+333
-134
lines changed

3 files changed

+333
-134
lines changed
Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from parameterized import parameterized
23
from test.common.assets import get_asset_path
34
from test.common.torchtext_test_case import TorchtextTestCase
45
from torchtext.prototype.models import (
@@ -9,7 +10,7 @@
910

1011

1112
class TestT5(TorchtextTestCase):
12-
def _t5_model(self, t5_model, expected_asset_name, test_text):
13+
def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text):
1314
"""Verify that pre-trained T5 models in torchtext produce
1415
the same output as the HuggingFace reference implementation.
1516
"""
@@ -18,6 +19,10 @@ def _t5_model(self, t5_model, expected_asset_name, test_text):
1819
model = t5_model.get_model()
1920
model = model.eval()
2021

22+
if is_jit:
23+
transform = torch.jit.script(transform)
24+
model = torch.jit.script(model)
25+
2126
model_input = transform(test_text)
2227
if model.encoder_only:
2328
actual = model(model_input)["encoder_output"]
@@ -27,17 +32,24 @@ def _t5_model(self, t5_model, expected_asset_name, test_text):
2732
expected = torch.load(expected_asset_path)
2833
torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06)
2934

30-
def test_t5_base_encoder_model(self) -> None:
35+
@parameterized.expand([("jit", True), ("not_jit", False)])
36+
def test_t5_base_encoder_model(self, name, is_jit) -> None:
3137
expected_asset_name = "t5.base.encoder.output.pt"
3238
test_text = ["Hello world", "Attention rocks!"]
33-
self._t5_model(t5_model=T5_BASE_ENCODER, expected_asset_name=expected_asset_name, test_text=test_text)
39+
self._t5_model(
40+
is_jit=is_jit, t5_model=T5_BASE_ENCODER, expected_asset_name=expected_asset_name, test_text=test_text
41+
)
3442

35-
def test_t5_base_model(self) -> None:
43+
@parameterized.expand([("jit", True), ("not_jit", False)])
44+
def test_t5_base_model(self, name, is_jit) -> None:
3645
expected_asset_name = "t5.base.output.pt"
3746
test_text = ["Hello world", "Attention rocks!"]
38-
self._t5_model(t5_model=T5_BASE, expected_asset_name=expected_asset_name, test_text=test_text)
47+
self._t5_model(is_jit=is_jit, t5_model=T5_BASE, expected_asset_name=expected_asset_name, test_text=test_text)
3948

40-
def test_t5_base_generation_model(self) -> None:
49+
@parameterized.expand([("jit", True), ("not_jit", False)])
50+
def test_t5_base_generation_model(self, name, is_jit) -> None:
4151
expected_asset_name = "t5.base.generation.output.pt"
4252
test_text = ["Hello world", "Attention rocks!"]
43-
self._t5_model(t5_model=T5_BASE_GENERATION, expected_asset_name=expected_asset_name, test_text=test_text)
53+
self._t5_model(
54+
is_jit=is_jit, t5_model=T5_BASE_GENERATION, expected_asset_name=expected_asset_name, test_text=test_text
55+
)

torchtext/prototype/models/t5/model.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from dataclasses import dataclass
2-
from typing import Dict, Optional, Tuple, Union, Callable
2+
from typing import Dict, List, Optional, Union, Callable
33

44
import torch
55
import torch.nn as nn
66
from torch import Tensor
77

8-
from .modules import T5Stack, T5LayerNorm
8+
from .modules import T5Encoder, T5Decoder, T5LayerNorm
99

1010

1111
@dataclass
@@ -69,14 +69,15 @@ def __init__(
6969
self,
7070
config: T5Conf,
7171
freeze: bool = False,
72-
device=None,
72+
device: Optional[torch.device] = None,
7373
dtype=None,
7474
) -> None:
7575
super().__init__()
7676

7777
assert isinstance(config, T5Conf)
7878

7979
self.config = config
80+
self.embedding_dim = config.embedding_dim
8081
self.encoder_only = config.encoder_only
8182
self.linear_head = config.linear_head
8283
self.padding_idx = config.padding_idx
@@ -86,8 +87,7 @@ def __init__(
8687
self.dtype = dtype
8788

8889
self.token_embeddings = nn.Embedding(config.vocab_size, config.embedding_dim, config.padding_idx)
89-
self.encoder = T5Stack(
90-
is_decoder=False,
90+
self.encoder = T5Encoder(
9191
d_model=config.embedding_dim,
9292
nhead=config.num_attention_heads,
9393
num_layers=config.num_encoder_layers,
@@ -105,8 +105,7 @@ def __init__(
105105
self.dropout2 = nn.Dropout(self.dropout)
106106

107107
if not config.encoder_only:
108-
self.decoder = T5Stack(
109-
is_decoder=True,
108+
self.decoder = T5Decoder(
110109
d_model=config.embedding_dim,
111110
nhead=config.num_attention_heads,
112111
num_layers=config.num_decoder_layers,
@@ -122,9 +121,13 @@ def __init__(
122121
self.norm2 = T5LayerNorm(config.embedding_dim)
123122
self.dropout3 = nn.Dropout(self.dropout)
124123
self.dropout4 = nn.Dropout(self.dropout)
124+
else:
125+
self.decoder = None
125126

126127
if config.linear_head:
127128
self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
129+
else:
130+
self.lm_head = None
128131

129132
if freeze:
130133
for p in self.parameters():
@@ -133,10 +136,10 @@ def __init__(
133136
def forward(
134137
self,
135138
encoder_tokens: Tensor,
136-
decoder_tokens: Tensor = None,
139+
decoder_tokens: Optional[Tensor] = None,
137140
encoder_mask: Optional[Tensor] = None,
138141
decoder_mask: Optional[Tensor] = None,
139-
) -> Dict[str, Union[Tensor, Tuple[Tensor]]]:
142+
) -> Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]:
140143
r"""Pass the inputs (and mask) through the decoder layer in turn.
141144
Args:
142145
encoder_tokens: Tokenized input sequence to the encoder.
@@ -163,23 +166,27 @@ def forward(
163166
"""
164167
encoder_padding_mask = encoder_tokens.eq(self.padding_idx)
165168
encoder_embeddings = self.dropout1(self.token_embeddings(encoder_tokens))
166-
encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa, _ = self.encoder(
169+
encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa = self.encoder(
167170
encoder_embeddings, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask
168171
)
169172

170173
encoder_output = self.norm1(encoder_output)
171174
encoder_output = self.dropout2(encoder_output)
172-
encoder_hidden_states = encoder_hidden_states + (encoder_output,)
175+
encoder_hidden_states.append(encoder_output)
173176

174177
if not self.encoder_only:
175178

179+
assert self.decoder is not None
180+
176181
# decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx.
177182
if decoder_tokens is None:
178183
decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * self.padding_idx
179184

180185
if decoder_mask is None:
186+
assert decoder_tokens is not None and decoder_tokens.dim() == 2
181187
tgt_len = decoder_tokens.shape[1]
182-
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()
188+
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1)
189+
decoder_mask = decoder_mask.to(torch.bool)
183190

184191
decoder_padding_mask = decoder_tokens.eq(self.padding_idx)
185192
# T5 implemention uses padding idx to start sequence. Want to ignore this when masking
@@ -197,13 +204,14 @@ def forward(
197204

198205
decoder_output = self.norm2(decoder_output)
199206
decoder_output = self.dropout4(decoder_output)
200-
decoder_hidden_states = decoder_hidden_states + (decoder_output,)
207+
decoder_hidden_states.append(decoder_output)
201208

202209
if self.linear_head:
210+
assert self.lm_head is not None
203211
# Rescale output before projecting on vocab. This happens when the encoder and decoder share the
204212
# same word embeddings, which is always the case in our t5 implementation.
205213
# See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661
206-
decoder_output = decoder_output * (self.config.embedding_dim ** -0.5)
214+
decoder_output = decoder_output * (self.embedding_dim ** -0.5)
207215
decoder_output = self.lm_head(decoder_output)
208216

209217
t5_output = {
@@ -225,4 +233,8 @@ def forward(
225233
"encoder_sa_scores": encoder_sa,
226234
}
227235

236+
assert torch.jit.isinstance(
237+
t5_output, Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]
238+
)
239+
228240
return t5_output

0 commit comments

Comments
 (0)