Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 6b37a7f

Browse files
committed
add t5 model that can function as both encodery-only or encoder-decoder model
ghstack-source-id: b5a8a5e Pull Request resolved: #1829
1 parent f7e328b commit 6b37a7f

File tree

1 file changed

+189
-0
lines changed

1 file changed

+189
-0
lines changed

torchtext/prototype/t5/model.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
from typing import Optional, Union, Callable
2+
3+
import torch
4+
import torch.nn as nn
5+
from torch import Tensor
6+
7+
from .modules import T5Stack, T5LayerNorm
8+
9+
10+
# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L1269
11+
class T5Model(nn.Module):
12+
r"""A T5 model. User is able to modify the attributes as needed. The architecture
13+
is based on the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer".
14+
Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena,
15+
Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research.
16+
Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html
17+
Args:
18+
encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (required)
19+
d_model: Number of expected features in the encoder/decoder inputs (default=768).
20+
nhead: Number of heads in the multiheadattention models (default=12).
21+
num_encoder_layers: Number of encoder layers in the encoder (default=12).
22+
num_decoder_layers: Number of decoder layers in the decoder (default=12).
23+
dim_feedforward: Dimension of the feedforward network model (default=3072).
24+
dropout: Dropout value (default=0.1).
25+
activation: Activation function of encoder/decoder intermediate layer, can be a string
26+
("relu" or "gelu") or a unary callable. Default: relu
27+
layer_norm_eps: The eps value in layer normalization components (default=1e-6).
28+
relative_attention_num_buckets: Number of relative position buckets (default: 32)
29+
relative_attention_max_distance: Maximum threshold on the relative distance used to
30+
allocate buckets. Anything larger gets placed in the same bucket (default: 128)
31+
padding_idx: Index assigned to padding token in vocabulary (default: 0)
32+
max_seq_len: Maximum sequence length (default: 512)
33+
vocab_size: Size of vocabulary (default: 32128)
34+
Examples::
35+
>>> t5_model = T5Model(encoder_only=False)
36+
>>> src = torch.rand((32, 10, 512))
37+
>>> tgt = torch.rand((32, 20, 512))
38+
>>> out = t5_model(src, tgt)
39+
"""
40+
41+
def __init__(
42+
self,
43+
encoder_only: bool,
44+
d_model: int = 768,
45+
nhead: int = 12,
46+
num_encoder_layers: int = 12,
47+
num_decoder_layers: int = 12,
48+
dim_feedforward: int = 3072,
49+
dropout: float = 0.1,
50+
activation: Union[str, Callable[[Tensor], Tensor]] = "relu",
51+
layer_norm_eps: float = 1e-6,
52+
relative_attention_num_buckets: int = 32,
53+
relative_attention_max_distance: int = 128,
54+
padding_idx: int = 0,
55+
max_seq_len: int = 512,
56+
vocab_size: int = 32128,
57+
device=None,
58+
dtype=None,
59+
) -> None:
60+
super().__init__()
61+
62+
self.encoder_only = encoder_only
63+
self.d_model = d_model
64+
self.dim_feedforward = dim_feedforward
65+
self.dropout = dropout
66+
self.activation = activation
67+
self.layer_norm_eps = layer_norm_eps
68+
self.nhead = nhead
69+
self.num_encoder_layers = num_encoder_layers
70+
self.num_decoder_layers = num_decoder_layers
71+
self.relative_attention_num_buckets = relative_attention_num_buckets
72+
self.realtive_attention_max_distance = relative_attention_max_distance
73+
self.padding_idx = padding_idx
74+
self.max_seq_len = max_seq_len
75+
self.vocab_size = vocab_size
76+
self.device = device
77+
self.dtype = dtype
78+
79+
self.token_embeddings = nn.Embedding(vocab_size, d_model, padding_idx)
80+
self.encoder = T5Stack(
81+
is_decoder=False,
82+
d_model=d_model,
83+
nhead=nhead,
84+
num_layers=num_encoder_layers,
85+
dim_feedforward=dim_feedforward,
86+
dropout=dropout,
87+
activation=activation,
88+
layer_norm_eps=layer_norm_eps,
89+
relative_attention_num_buckets=relative_attention_num_buckets,
90+
relative_attention_max_distance=relative_attention_max_distance,
91+
device=device,
92+
dtype=dtype,
93+
)
94+
self.norm1 = T5LayerNorm(d_model)
95+
self.dropout1 = nn.Dropout(dropout)
96+
self.dropout2 = nn.Dropout(dropout)
97+
98+
if not encoder_only:
99+
self.decoder = T5Stack(
100+
is_decoder=True,
101+
d_model=d_model,
102+
nhead=nhead,
103+
num_layers=num_decoder_layers,
104+
dim_feedforward=dim_feedforward,
105+
dropout=dropout,
106+
activation=activation,
107+
layer_norm_eps=layer_norm_eps,
108+
relative_attention_num_buckets=relative_attention_num_buckets,
109+
relative_attention_max_distance=relative_attention_max_distance,
110+
device=device,
111+
dtype=dtype,
112+
)
113+
self.norm2 = T5LayerNorm(d_model)
114+
self.dropout3 = nn.Dropout(dropout)
115+
self.dropout4 = nn.Dropout(dropout)
116+
117+
def forward(
118+
self,
119+
encoder_tokens: Tensor,
120+
decoder_tokens: Tensor = None,
121+
encoder_mask: Optional[Tensor] = None,
122+
decoder_mask: Optional[Tensor] = None,
123+
) -> Tensor:
124+
r"""Pass the inputs (and mask) through the decoder layer in turn.
125+
Args:
126+
encoder_tokens: Tokenized input sequence to the encoder.
127+
Must be batch first with shape (B, Ne) where B is the batch size and Ne is the
128+
encoder input sequence length. (required).
129+
decoder_tokens: Tokenized input sequence to the decoder.
130+
Must be batch first with shape (B, Nd) where B is the batch size and Nd is the
131+
decoder input sequence length. (required).
132+
encoder_mask: Self-attention mask for the encoder input sequence.
133+
Must have shape (Ne, Ne) (optional).
134+
decoder_mask: Self-attention mask for the decoder input sequence.
135+
Must have shape (Nd, Nd) (optional).
136+
"""
137+
encoder_padding_mask = encoder_tokens.eq(self.padding_idx)
138+
encoder_embeddings = self.dropout1(self.token_embeddings(encoder_tokens))
139+
encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa, _ = self.encoder(
140+
encoder_embeddings, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask
141+
)
142+
143+
encoder_output = self.norm1(encoder_output)
144+
encoder_output = self.dropout2(encoder_output)
145+
encoder_hidden_states = encoder_hidden_states + (encoder_output,)
146+
147+
decoder_output = None
148+
decoder_hidden_states = None
149+
decoder_position_bias = None
150+
decoder_sa = None
151+
decoder_ca = None
152+
153+
if not self.encoder_only:
154+
assert decoder_tokens is not None
155+
if decoder_mask is None:
156+
tgt_len = decoder_tokens.shape[1]
157+
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()
158+
159+
decoder_padding_mask = decoder_tokens.eq(self.padding_idx)
160+
# T5 implemention uses padding idx to start sequence. Want to ignore this when masking
161+
decoder_padding_mask[:, 0] = False
162+
163+
decoder_embeddings = self.dropout3(self.token_embeddings(decoder_tokens))
164+
decoder_output, decoder_hidden_states, decoder_position_bias, decoder_sa, decoder_ca = self.decoder(
165+
decoder_embeddings,
166+
memory=encoder_output,
167+
tgt_mask=decoder_mask,
168+
memory_mask=encoder_mask,
169+
tgt_key_padding_mask=decoder_padding_mask,
170+
memory_key_padding_mask=encoder_padding_mask,
171+
)
172+
173+
decoder_output = self.norm2(decoder_output)
174+
decoder_output = self.dropout4(decoder_output)
175+
decoder_hidden_states = decoder_hidden_states + (decoder_output,)
176+
177+
t5_output = {
178+
"encoder_output": encoder_output,
179+
"encoder_hidden_states": encoder_hidden_states,
180+
"encoder_position_bias": encoder_position_bias,
181+
"encoder_sa_scores": encoder_sa,
182+
"decoder_output": decoder_output,
183+
"decoder_hidden_states": decoder_hidden_states,
184+
"decoder_position_bias": decoder_position_bias,
185+
"decoder_sa_scores": decoder_sa,
186+
"decoder_ca_scores": decoder_ca,
187+
}
188+
189+
return t5_output

0 commit comments

Comments
 (0)