Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions torchtext/prototype/t5/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,35 @@ def _relative_position_bucket(

relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets


# NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L239
class T5LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-6) -> None:
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.variance_epsilon = eps

def forward(self, hidden_states: Tensor) -> Tensor:
r"""
T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
half-precision inputs is done in fp32.
Args:
hidden_states: Tensor to be normalized. Final dimension must be model dimension (i.e. number of expected features in the input)
Returns:
a Tensor with the same shape as hidden_states after having been normalized
"""

variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

# Convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states