Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8d56ed2

Browse files
authored
Add T5 Model to TorchText (#1845)
* compute relative position buckets for relative attention bias [ghstack-poisoned] * compute relative position bias for t5 attention [ghstack-poisoned] * compute attention scores for t5 model using relative attention bias [ghstack-poisoned] * perform multihead attention using relative attention bias for t5 model [ghstack-poisoned] * create T5MultiheadAttention module [ghstack-poisoned] * add layer norm module for t5 model [ghstack-poisoned] * add t5 layer module that can be used for both encoder or decoder stack [ghstack-poisoned] * add t5 stack that can function as either the encoder or decoder of a t5 model [ghstack-poisoned] * Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] * Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] * Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] * Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] * Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] * Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] * Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] * add t5 model that can function as both encodery-only or encoder-decoder model (#1829)
1 parent bb58f6e commit 8d56ed2

File tree

2 files changed

+934
-0
lines changed

2 files changed

+934
-0
lines changed

torchtext/prototype/t5/model.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
from typing import Dict, Optional, Tuple, Union, Callable
2+
3+
import torch
4+
import torch.nn as nn
5+
from torch import Tensor
6+
7+
from .modules import T5Stack, T5LayerNorm
8+
9+
10+
# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L1269
11+
class T5Model(nn.Module):
12+
r"""A T5 model. User is able to modify the attributes as needed. The architecture
13+
is based on the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer".
14+
Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena,
15+
Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research.
16+
Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html
17+
Args:
18+
encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (required)
19+
d_model: Number of expected features in the encoder/decoder inputs (default=768).
20+
nhead: Number of heads in the multiheadattention models (default=12).
21+
num_encoder_layers: Number of encoder layers in the encoder (default=12).
22+
num_decoder_layers: Number of decoder layers in the decoder (default=12).
23+
dim_feedforward: Dimension of the feedforward network model (default=3072).
24+
dropout: Dropout value (default=0.1).
25+
activation: Activation function of encoder/decoder intermediate layer, can be a string
26+
("relu" or "gelu") or a unary callable. Default: relu
27+
layer_norm_eps: The eps value in layer normalization components (default=1e-6).
28+
relative_attention_num_buckets: Number of relative position buckets (default: 32)
29+
relative_attention_max_distance: Maximum threshold on the relative distance used to
30+
allocate buckets. Anything larger gets placed in the same bucket (default: 128)
31+
padding_idx: Index assigned to padding token in vocabulary (default: 0)
32+
max_seq_len: Maximum sequence length (default: 512)
33+
vocab_size: Size of vocabulary (default: 32128)
34+
Examples::
35+
>>> t5_model = T5Model(encoder_only=False)
36+
>>> src = torch.rand((32, 10, 512))
37+
>>> tgt = torch.rand((32, 20, 512))
38+
>>> out = t5_model(src, tgt)
39+
"""
40+
41+
def __init__(
42+
self,
43+
encoder_only: bool,
44+
d_model: int = 768,
45+
nhead: int = 12,
46+
num_encoder_layers: int = 12,
47+
num_decoder_layers: int = 12,
48+
dim_feedforward: int = 3072,
49+
dropout: float = 0.1,
50+
activation: Union[str, Callable[[Tensor], Tensor]] = "relu",
51+
layer_norm_eps: float = 1e-6,
52+
relative_attention_num_buckets: int = 32,
53+
relative_attention_max_distance: int = 128,
54+
padding_idx: int = 0,
55+
max_seq_len: int = 512,
56+
vocab_size: int = 32128,
57+
device=None,
58+
dtype=None,
59+
) -> None:
60+
super().__init__()
61+
62+
self.encoder_only = encoder_only
63+
self.d_model = d_model
64+
self.dim_feedforward = dim_feedforward
65+
self.dropout = dropout
66+
self.activation = activation
67+
self.layer_norm_eps = layer_norm_eps
68+
self.nhead = nhead
69+
self.num_encoder_layers = num_encoder_layers
70+
self.num_decoder_layers = num_decoder_layers
71+
self.relative_attention_num_buckets = relative_attention_num_buckets
72+
self.realtive_attention_max_distance = relative_attention_max_distance
73+
self.padding_idx = padding_idx
74+
self.max_seq_len = max_seq_len
75+
self.vocab_size = vocab_size
76+
self.device = device
77+
self.dtype = dtype
78+
79+
self.token_embeddings = nn.Embedding(vocab_size, d_model, padding_idx)
80+
self.encoder = T5Stack(
81+
is_decoder=False,
82+
d_model=d_model,
83+
nhead=nhead,
84+
num_layers=num_encoder_layers,
85+
dim_feedforward=dim_feedforward,
86+
dropout=dropout,
87+
activation=activation,
88+
layer_norm_eps=layer_norm_eps,
89+
relative_attention_num_buckets=relative_attention_num_buckets,
90+
relative_attention_max_distance=relative_attention_max_distance,
91+
device=device,
92+
dtype=dtype,
93+
)
94+
self.norm1 = T5LayerNorm(d_model)
95+
self.dropout1 = nn.Dropout(dropout)
96+
self.dropout2 = nn.Dropout(dropout)
97+
98+
if not encoder_only:
99+
self.decoder = T5Stack(
100+
is_decoder=True,
101+
d_model=d_model,
102+
nhead=nhead,
103+
num_layers=num_decoder_layers,
104+
dim_feedforward=dim_feedforward,
105+
dropout=dropout,
106+
activation=activation,
107+
layer_norm_eps=layer_norm_eps,
108+
relative_attention_num_buckets=relative_attention_num_buckets,
109+
relative_attention_max_distance=relative_attention_max_distance,
110+
device=device,
111+
dtype=dtype,
112+
)
113+
self.norm2 = T5LayerNorm(d_model)
114+
self.dropout3 = nn.Dropout(dropout)
115+
self.dropout4 = nn.Dropout(dropout)
116+
117+
def forward(
118+
self,
119+
encoder_tokens: Tensor,
120+
decoder_tokens: Tensor = None,
121+
encoder_mask: Optional[Tensor] = None,
122+
decoder_mask: Optional[Tensor] = None,
123+
) -> Dict[str, Union[Tensor, Tuple[Tensor]]]:
124+
r"""Pass the inputs (and mask) through the decoder layer in turn.
125+
Args:
126+
encoder_tokens: Tokenized input sequence to the encoder.
127+
Must be batch first with shape (B, Ne) where B is the batch size and Ne is the
128+
encoder input sequence length. (required).
129+
decoder_tokens: Tokenized input sequence to the decoder.
130+
Must be batch first with shape (B, Nd) where B is the batch size and Nd is the
131+
decoder input sequence length. (required).
132+
encoder_mask: Self-attention mask for the encoder input sequence.
133+
Must have shape (Ne, Ne) (optional).
134+
decoder_mask: Self-attention mask for the decoder input sequence.
135+
Must have shape (Nd, Nd) (optional).
136+
Returns:
137+
encoder_output: Output Tensor from the final layer of the encoder
138+
encoder_hidden_states: Tuple of output Tensors from each layer of the encoder
139+
encoder_position_bias: Tensor of relative attention bias computed for input sequence to encoder
140+
encoder_sa_scores: Tuple of self-attention scores computed at each layer of the encoder
141+
decoder_output: Output Tensor from the final layer of the decoder
142+
decoder_hidden_states: Tuple of output Tensors from each layer of the decoder
143+
decoder_position_bias: Tensor of relative attention bias computed for input sequence to decoder
144+
encoder_sa_scores: Tuple of self-attention scores computed at each layer of the decoder
145+
encoder_ca_scores: Tuple of cross-attention scores computed at each layer of the decoder
146+
"""
147+
encoder_padding_mask = encoder_tokens.eq(self.padding_idx)
148+
encoder_embeddings = self.dropout1(self.token_embeddings(encoder_tokens))
149+
encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa, _ = self.encoder(
150+
encoder_embeddings, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask
151+
)
152+
153+
encoder_output = self.norm1(encoder_output)
154+
encoder_output = self.dropout2(encoder_output)
155+
encoder_hidden_states = encoder_hidden_states + (encoder_output,)
156+
157+
if not self.encoder_only:
158+
assert decoder_tokens is not None
159+
if decoder_mask is None:
160+
tgt_len = decoder_tokens.shape[1]
161+
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()
162+
163+
decoder_padding_mask = decoder_tokens.eq(self.padding_idx)
164+
# T5 implemention uses padding idx to start sequence. Want to ignore this when masking
165+
decoder_padding_mask[:, 0] = False
166+
167+
decoder_embeddings = self.dropout3(self.token_embeddings(decoder_tokens))
168+
decoder_output, decoder_hidden_states, decoder_position_bias, decoder_sa, decoder_ca = self.decoder(
169+
decoder_embeddings,
170+
memory=encoder_output,
171+
tgt_mask=decoder_mask,
172+
memory_mask=encoder_mask,
173+
tgt_key_padding_mask=decoder_padding_mask,
174+
memory_key_padding_mask=encoder_padding_mask,
175+
)
176+
177+
decoder_output = self.norm2(decoder_output)
178+
decoder_output = self.dropout4(decoder_output)
179+
decoder_hidden_states = decoder_hidden_states + (decoder_output,)
180+
181+
t5_output = {
182+
"encoder_output": encoder_output,
183+
"encoder_hidden_states": encoder_hidden_states,
184+
"encoder_position_bias": encoder_position_bias,
185+
"encoder_sa_scores": encoder_sa,
186+
"decoder_output": decoder_output,
187+
"decoder_hidden_states": decoder_hidden_states,
188+
"decoder_position_bias": decoder_position_bias,
189+
"decoder_sa_scores": decoder_sa,
190+
"decoder_ca_scores": decoder_ca,
191+
}
192+
else:
193+
t5_output = {
194+
"encoder_output": encoder_output,
195+
"encoder_hidden_states": encoder_hidden_states,
196+
"encoder_position_bias": encoder_position_bias,
197+
"encoder_sa_scores": encoder_sa,
198+
}
199+
200+
return t5_output

0 commit comments

Comments
 (0)