From 1f99650502524c5d54bfd11f58d219c0c1a33bb2 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 12 Jul 2022 17:35:42 -0400 Subject: [PATCH] add layer norm module for t5 model [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index 2e69e3ea19..fcc08e266d 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -531,3 +531,29 @@ def _relative_position_bucket( relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets + + +# NOTE: Taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +class T5LayerNorm(nn.Module): + def __init__(self, d_model, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(d_model)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states