diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index f4b3ffec62..720f21f399 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -418,3 +418,35 @@ def _relative_position_bucket( relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets + + +# NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L239 +class T5LayerNorm(nn.Module): + def __init__(self, d_model, eps=1e-6) -> None: + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(d_model)) + self.variance_epsilon = eps + + def forward(self, hidden_states: Tensor) -> Tensor: + r""" + T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + half-precision inputs is done in fp32. + Args: + hidden_states: Tensor to be normalized. Final dimension must be model dimension (i.e. number of expected features in the input) + Returns: + a Tensor with the same shape as hidden_states after having been normalized + """ + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # Convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states