Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 6b6003d

Browse files
committed
add t5 model that can function as both encodery-only or encoder-decoder model
ghstack-source-id: 946c573 Pull Request resolved: #1829
1 parent c8d1725 commit 6b6003d

File tree

1 file changed

+178
-0
lines changed

1 file changed

+178
-0
lines changed

torchtext/prototype/t5/model.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
from typing import Optional, Union, Callable
2+
3+
import torch
4+
import torch.nn as nn
5+
from torch import Tensor
6+
7+
from .modules import T5Stack, T5LayerNorm
8+
9+
10+
class T5Model(nn.Module):
11+
r"""A T5 model. User is able to modify the attributes as needed. The architecture
12+
is based on the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer".
13+
Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena,
14+
Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research.
15+
Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html
16+
Args:
17+
encoder_only: whether or not model should consist of only the encoder as opposed to encoder-decoder (required)
18+
d_model: the number of expected features in the encoder/decoder inputs (default=768.
19+
nhead: the number of heads in the multiheadattention models (default=12).
20+
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=12).
21+
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=12).
22+
dim_feedforward: the dimension of the feedforward network model (default=3072).
23+
dropout: the dropout value (default=0.1).
24+
activation: the activation function of encoder/decoder intermediate layer, can be a string
25+
("relu" or "gelu") or a unary callable. Default: relu
26+
custom_encoder: custom encoder (default=None).
27+
custom_decoder: custom decoder (default=None).
28+
layer_norm_eps: the eps value in layer normalization components (default=1e-6).
29+
batch_first: If ``True``, then the input and output tensors are provided
30+
as (batch, seq, feature). Default: ``True`` (seq, batch, feature).
31+
relative_attention_num_buckets: the number of relative position buckets (default: 32)
32+
relative_attention_max_distance: maximum threshold on the relative distance used to
33+
allocate buckets. anything larger than that gets placed in the same bucket (default: 128)
34+
padding_idx: index assigned to padding token in vocabulary (default: 0)
35+
max_seq_len: maximum sequence length (default: 512)
36+
vocab_size: size of vocabulary (default: 32128)
37+
Examples::
38+
>>> t5_model = T5Model(encoder_only=False)
39+
>>> src = torch.rand((32, 10, 512))
40+
>>> tgt = torch.rand((32, 20, 512))
41+
>>> out = t5_model(src, tgt)
42+
"""
43+
44+
def __init__(
45+
self,
46+
encoder_only: bool,
47+
d_model: int = 768,
48+
nhead: int = 12,
49+
num_encoder_layers: int = 12,
50+
num_decoder_layers: int = 12,
51+
dim_feedforward: int = 3072,
52+
dropout: float = 0.1,
53+
activation: Union[str, Callable[[Tensor], Tensor]] = "relu",
54+
layer_norm_eps: float = 1e-6,
55+
batch_first: bool = True,
56+
relative_attention_num_buckets: int = 32,
57+
relative_attention_max_distance: int = 128,
58+
padding_idx: int = 0,
59+
max_seq_len: int = 512,
60+
vocab_size: int = 32128,
61+
device=None,
62+
dtype=None,
63+
) -> None:
64+
super().__init__()
65+
66+
self.encoder_only = encoder_only
67+
self.d_model = d_model
68+
self.dim_feedforward = dim_feedforward
69+
self.dropout = dropout
70+
self.activation = activation
71+
self.layer_norm_eps = layer_norm_eps
72+
self.nhead = nhead
73+
self.num_encoder_layers = num_encoder_layers
74+
self.num_decoder_layers = num_decoder_layers
75+
self.batch_first = batch_first
76+
self.relative_attention_num_buckets = relative_attention_num_buckets
77+
self.realtive_attention_max_distance = relative_attention_max_distance
78+
self.padding_idx = padding_idx
79+
self.max_seq_len = max_seq_len
80+
self.vocab_size = vocab_size
81+
self.device = device
82+
self.dtype = dtype
83+
84+
self.token_embeddings = nn.Embedding(vocab_size, d_model, padding_idx)
85+
self.encoder = T5Stack(
86+
is_decoder=False,
87+
d_model=d_model,
88+
nhead=nhead,
89+
num_layers=num_encoder_layers,
90+
dim_feedforward=dim_feedforward,
91+
dropout=dropout,
92+
activation=activation,
93+
layer_norm_eps=layer_norm_eps,
94+
batch_first=batch_first,
95+
relative_attention_num_buckets=relative_attention_num_buckets,
96+
relative_attention_max_distance=relative_attention_max_distance,
97+
device=device,
98+
dtype=dtype,
99+
)
100+
self.norm1 = T5LayerNorm(d_model)
101+
self.dropout1 = nn.Dropout(dropout)
102+
self.dropout2 = nn.Dropout(dropout)
103+
104+
if not encoder_only:
105+
self.decoder = T5Stack(
106+
is_decoder=True,
107+
d_model=d_model,
108+
nhead=nhead,
109+
num_layers=num_decoder_layers,
110+
dim_feedforward=dim_feedforward,
111+
dropout=dropout,
112+
activation=activation,
113+
layer_norm_eps=layer_norm_eps,
114+
batch_first=batch_first,
115+
relative_attention_num_buckets=relative_attention_num_buckets,
116+
relative_attention_max_distance=relative_attention_max_distance,
117+
device=device,
118+
dtype=dtype,
119+
)
120+
self.norm2 = T5LayerNorm(d_model)
121+
self.dropout3 = nn.Dropout(dropout)
122+
self.dropout4 = nn.Dropout(dropout)
123+
124+
def forward(
125+
self,
126+
encoder_tokens: Tensor,
127+
decoder_tokens: Tensor = None,
128+
encoder_mask: Optional[Tensor] = None,
129+
decoder_mask: Optional[Tensor] = None,
130+
) -> Tensor:
131+
encoder_padding_mask = encoder_tokens.eq(self.padding_idx)
132+
encoder_embeddings = self.dropout1(self.token_embeddings(encoder_tokens))
133+
encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa, _ = self.encoder(
134+
encoder_embeddings, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask
135+
)
136+
137+
encoder_output = self.norm1(encoder_output)
138+
encoder_output = self.dropout2(encoder_output)
139+
encoder_hidden_states = encoder_hidden_states + (encoder_output,)
140+
141+
decoder_output = None
142+
decoder_hidden_states = None
143+
decoder_position_bias = None
144+
145+
if not self.encoder_only:
146+
assert decoder_tokens is not None
147+
if decoder_mask is None:
148+
tgt_len = decoder_tokens.shape[1]
149+
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()
150+
151+
decoder_padding_mask = decoder_tokens.eq(self.padding_idx)
152+
decoder_embeddings = self.dropout3(self.token_embeddings(decoder_tokens))
153+
decoder_output, decoder_hidden_states, decoder_position_bias, decoder_sa, decoder_ca = self.decoder(
154+
decoder_embeddings,
155+
memory=encoder_output,
156+
tgt_mask=decoder_mask,
157+
memory_mask=encoder_mask,
158+
tgt_key_padding_mask=decoder_padding_mask,
159+
memory_key_padding_mask=encoder_padding_mask,
160+
)
161+
162+
decoder_output = self.norm2(decoder_output)
163+
decoder_output = self.dropout4(decoder_output)
164+
decoder_hidden_states = decoder_hidden_states + (decoder_output,)
165+
166+
t5_output = {
167+
"encoder_output": encoder_output,
168+
"encoder_hidden_states": encoder_hidden_states,
169+
"encoder_position_bias": encoder_position_bias,
170+
"encoder_sa_scores": encoder_sa,
171+
"decoder_output": decoder_output,
172+
"decoder_hidden_states": decoder_hidden_states,
173+
"decoder_position_bias": decoder_position_bias,
174+
"decoder_sa_scores": decoder_sa,
175+
"decoder_ca_scores": decoder_ca,
176+
}
177+
178+
return t5_output

0 commit comments

Comments
 (0)