Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 568259f

Browse files
authored
add layer norm module for t5 model (#1826)
1 parent 6f5aa67 commit 568259f

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

torchtext/prototype/t5/modules.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,35 @@ def _relative_position_bucket(
418418

419419
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
420420
return relative_buckets
421+
422+
423+
# NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L239
424+
class T5LayerNorm(nn.Module):
425+
def __init__(self, d_model, eps=1e-6) -> None:
426+
"""
427+
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
428+
"""
429+
super().__init__()
430+
self.weight = nn.Parameter(torch.ones(d_model))
431+
self.variance_epsilon = eps
432+
433+
def forward(self, hidden_states: Tensor) -> Tensor:
434+
r"""
435+
T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
436+
Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
437+
w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
438+
half-precision inputs is done in fp32.
439+
Args:
440+
hidden_states: Tensor to be normalized. Final dimension must be model dimension (i.e. number of expected features in the input)
441+
Returns:
442+
a Tensor with the same shape as hidden_states after having been normalized
443+
"""
444+
445+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
446+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
447+
448+
# Convert into half-precision if necessary
449+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
450+
hidden_states = hidden_states.to(self.weight.dtype)
451+
452+
return self.weight * hidden_states

0 commit comments

Comments
 (0)