Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
174 changes: 173 additions & 1 deletion torchtext/prototype/t5/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import math
import warnings
from typing import Optional, Tuple
from typing import Optional, Tuple, Union, Callable

import torch
import torch.nn as nn
Expand Down Expand Up @@ -450,3 +450,175 @@ def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states


# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L622
class T5Layer(nn.Module):
r"""T5Layer is made up of self-attn, cross-attn (decoder only) and feedforward network.
This T5 layer is based on the paper:
"Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer".
Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena,
Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research.
Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html
Users may modify or implement in a different way during application.
Args:
is_decoder: Whether or not the layer belongs to the decoder. (required)
d_model: Number of expected features in the input (required).
nhead: Number of heads in the multihead attention models (required).
dim_feedforward: Dimension of the feedforward network model (default=3072).
dropout: Dropout value (default=0.1).
activation: Activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. (default: relu)
layer_norm_eps: The eps value in layer normalization components (default=1e-6).
relative_attention_num_buckets: Number of relative position buckets (default: 32)
relative_attention_max_distance: Maximum threshold on the relative distance used to
allocate buckets. Anything larger gets placed in the same bucket (default: 128)
compute_relative_attention_bias: Whether or not the relative position embeddings
need to be computed. Typically occurs in the first layer of encoder/decoder
and resulting position embeddings are returned to be passed up to higher layers. (default: False)
relative_attention_bias: nn.Embeding object used to compute relative position embeddings. (default: None)

Examples::
>>> decoder_layer = T5Layer(is_decoder=True, d_model=768, nhead=12)
>>> memory = torch.rand(32, 10, 768)
>>> tgt = torch.rand(32, 20, 768)
>>> out = deoder_layer(tgt, memory)
"""

def __init__(
self,
is_decoder: bool,
d_model: int,
nhead: int,
dim_feedforward: int = 3072,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-6,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
compute_relative_attention_bias: bool = False,
relative_attention_bias: Optional[Tensor] = None,
device=None,
dtype=None,
) -> None:
super().__init__()

self.is_decoder = is_decoder
self.compute_relative_attention_bias = compute_relative_attention_bias
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.relative_attention_bias = relative_attention_bias

self.self_attn = T5MultiheadAttention(
d_model, nhead, is_decoder=is_decoder, dropout=dropout, device=device, dtype=dtype
)
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
self.norm1 = T5LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = T5LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)

if is_decoder:
self.cross_attn = T5MultiheadAttention(
d_model, nhead, is_decoder=is_decoder, dropout=dropout, device=device, dtype=dtype
)
self.norm3 = T5LayerNorm(d_model, eps=layer_norm_eps)
self.dropout4 = nn.Dropout(dropout)

if isinstance(activation, str):
assert activation in (
"relu",
"gelu",
), f"Do not support '{activation}' activation. Use either 'relu' or 'gelu'"
if activation == "relu":
self.activation = F.relu
elif activation == "gelu":
self.activation = F.gelu
else:
self.activation = activation

def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
position_bias: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]:
r"""Pass the inputs (and mask) through the encoder/decoder layer.
Args:
tgt: Input sequence to the encoder/decoder layer. (required).
Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence
length, and E is the model dimension.
memory: Sequence from the last layer of the encoder (used for decoder only). (required).
Must have shape (B, Nts, E) where B is the batch size, Ns is the source sequence
length, and E is the model dimension.
tgt_mask: Attention mask for self-attention. (optional).
Must have shape (Nt, Nt).
memory_mask: Attention mask for cross-attention (decoder-only) (optional).
Must have shape (Nt, Ns).
tgt_key_padding_mask: Mask for the tgt keys per batch (optional).
Must have shape (B, Nt).
memory_key_padding_mask: Mask for the memory keys per batch (decoder-only) (optional).
Must have shape (B, Ns).
position_bias: Relative attention bias to be used when computing self-attention scores (optional)
Must have shape (B, H, Nt, Nt) where H is the number of heads.
"""

# See Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
x = tgt
sa_out, position_bias, sa_scores = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, position_bias)
x = x + sa_out
if self.is_decoder:
ca_out, ca_scores = self._ca_block(self.norm3(x), memory, memory_mask, memory_key_padding_mask)
x = x + ca_out
x = x + self._ff_block(self.norm2(x))

return x, position_bias, sa_scores, ca_scores if self.is_decoder else None

# Self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
position_bias: Optional[Tensor],
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
attn = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=True,
compute_relative_attention_bias=self.compute_relative_attention_bias,
relative_attention_num_buckets=self.relative_attention_num_buckets,
relative_attention_max_distance=self.relative_attention_max_distance,
relative_attention_bias=self.relative_attention_bias,
position_bias=position_bias,
)

x = attn[0]
scores = attn[2]
if self.compute_relative_attention_bias and position_bias is None:
position_bias = attn[1]

return self.dropout1(x), position_bias, scores

# Cross attention block
def _ca_block(
self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
) -> Tuple[Tensor, Optional[Tensor]]:
attn = self.cross_attn(x, mem, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=True)
x = attn[0]
scores = attn[2]
return self.dropout4(x), scores

# Feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout2(self.activation(self.linear1(x))))
return self.dropout3(x)