From 70dad25b5e4c738c2e92a3e138a5987dfda0d8f6 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 12 Jul 2022 17:35:25 -0400 Subject: [PATCH 01/16] compute relative position buckets for relative attention bias [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 67 +++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 torchtext/prototype/t5/modules.py diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py new file mode 100644 index 0000000000..de8d438e37 --- /dev/null +++ b/torchtext/prototype/t5/modules.py @@ -0,0 +1,67 @@ +# /* Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Original code is taken from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +# */ + +import math + +import torch +from torch import Tensor + + +# NOTE: taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +def _relative_position_bucket( + relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128 +): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets From 88743d6b4e327d2a85e58d34ed7b2e12524222fd Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 12 Jul 2022 17:35:29 -0400 Subject: [PATCH 02/16] compute relative position bias for t5 attention [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index de8d438e37..a3db63b9ae 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -19,6 +19,33 @@ from torch import Tensor +# NOTE: modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +def _compute_bias( + query_length: int, + key_length: int, + relative_attention_bias: Tensor, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + bidirectional: bool = True, + device=None, +): + """Compute binned relative position bias""" + if device is None: + device = relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=bidirectional, + num_buckets=relative_attention_num_buckets, + max_distance=relative_attention_max_distance, + ) + values = relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + # NOTE: taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py def _relative_position_bucket( relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128 From 7b67d563cdbd1b69ac5a40bbaf8efdbf813e796a Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 12 Jul 2022 17:35:33 -0400 Subject: [PATCH 03/16] compute attention scores for t5 model using relative attention bias [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 61 +++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index a3db63b9ae..5ba410a593 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -14,11 +14,72 @@ # */ import math +from typing import Optional, Tuple import torch +import torch.nn.functional as F from torch import Tensor +def _t5_scaled_dot_product_attention( + q: Tensor, + k: Tensor, + v: Tensor, + position_bias: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, +) -> Tuple[Tensor, Tensor]: + r""" + Computes scaled dot product attention on query, key and value tensors, using + an optional attention mask if passed, and applying dropout if a probability + greater than 0.0 is specified. + Returns a tensor pair containing attended values and attention weights. + Args: + q, k, v: query, key and value tensors. See Shape section for shape details. + attn_mask: optional tensor containing mask values to be added to calculated + attention. May be 2D or 3D; see Shape section for details. + dropout_p: dropout probability. If greater than 0.0, dropout is applied. + position_bias: position bias used to incorporate realtive attention bias in attention scors + Shape: + - q: :math:`(B, Nt, E)` where B is (batch size*num_heads), Nt is the target sequence length, + and E is embedding dimension. + - key: :math:`(B, Ns, E)` where B is (batch size*num_heads), Ns is the source sequence length, + and E is embedding dimension. + - value: :math:`(B, Ns, E)` where B is (batch size*num_heads), Ns is the source sequence length, + and E is embedding dimension. + - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of + shape :math:`(Nt, Ns)`. + - position_bias: :math:`(1, num_heads, Nt, Ns)` + - Output: attention values have shape :math:`(B, Nt, E)`; attention weights + have shape :math:`(B, Nt, Ns)` + """ + B, Nt, E = q.shape + # NOTE: HF implementation does not perform this normalization. For the sake of matching test results, we have commented it out + # q = q / math.sqrt(E) + + n_heads, tgt_len, src_len = position_bias.size()[1:] + assert B % n_heads == 0 + assert tgt_len == Nt + + # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) + if attn_mask is not None: + attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1)) + else: + attn = torch.bmm(q, k.transpose(-2, -1)) + + # NOTE: modification from torch.nn.functional._scaled_dot_product_attention to incorporate relative attention bias + position_bias = position_bias.repeat(B // n_heads, 1, 1, 1) + position_bias = position_bias.view(B, tgt_len, src_len) + attn += position_bias + + attn = F.softmax(attn, dim=-1) + if dropout_p > 0.0: + attn = F.dropout(attn, p=dropout_p) + # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) + output = torch.bmm(attn, v) + return output, attn + + # NOTE: modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py def _compute_bias( query_length: int, From f3fac0e791a78d3f84d0b86467fdeed486996a4c Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 12 Jul 2022 17:35:36 -0400 Subject: [PATCH 04/16] perform multihead attention using relative attention bias for t5 model [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 220 +++++++++++++++++++++++++++++- 1 file changed, 219 insertions(+), 1 deletion(-) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index 5ba410a593..3500f3b4e2 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -12,8 +12,8 @@ # Original code is taken from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py # */ - import math +import warnings from typing import Optional, Tuple import torch @@ -21,6 +21,224 @@ from torch import Tensor +def t5_multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + compute_relative_attention_bias: bool, + relative_attention_num_buckets: Optional[int], + relative_attention_max_distance: Optional[int], + relative_attention_bias: Optional[Tensor], + position_bias: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = False, +) -> Tuple[Tensor, Optional[Tensor]]: + is_batched = F._mha_shape_check(query, key, value, key_padding_mask, attn_mask, self.num_heads) + + # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input + # is batched, run the computation and before returning squeeze the + # batch dimension so that the output doesn't carry this temporary batch dimension. + if not is_batched: + # unsqueeze if the input is unbatched + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + + assert embed_dim == self.embed_dim, f"was expecting embedding dimension of {self.embed_dim}, but got {embed_dim}" + if isinstance(embed_dim, Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(self.num_heads, rounding_mode="trunc") + else: + head_dim = embed_dim // self.num_heads + assert head_dim * self.num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {self.num_heads}" + # allow MHA to have different embedding dimensions when separate projection weights are used + assert ( + key.shape[:2] == value.shape[:2] + ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + + # + # compute in-projection + # + assert self.q_proj_weight is not None, "q_proj_weight is None" + assert self.k_proj_weight is not None, "k_proj_weight is None" + assert self.v_proj_weight is not None, "v_proj_weight is None" + if self.in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = self.in_proj_bias.chunk(3) + q, k, v = F._in_projection( + query, key, value, self.q_proj_weight, self.k_proj_weight, self.v_proj_weight, b_q, b_k, b_v + ) + + # prep attention mask + if attn_mask is not None: + if attn_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + else: + assert ( + attn_mask.is_floating_point() or attn_mask.dtype == torch.bool + ), f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * self.num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + + # prep key padding mask + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + # add bias along batch dimension (currently second) + if self.bias_k is not None and self.bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + else: + assert self.bias_k is None + assert self.bias_v is None + + # + # reshape q, k, v for multihead attention and make em batch first + # + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1) + if static_k is None: + k = k.contiguous().view(k.shape[0], bsz * self.num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert ( + static_k.size(0) == bsz * self.num_heads + ), f"expecting static_k.size(0) of {bsz * self.num_heads}, but got {static_k.size(0)}" + assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + k = static_k + if static_v is None: + v = v.contiguous().view(v.shape[0], bsz * self.num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert ( + static_v.size(0) == bsz * self.num_heads + ), f"expecting static_v.size(0) of {bsz * self.num_heads}, but got {static_v.size(0)}" + assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + v = static_v + + # add zero attention along batch dimension (now first) + if self.add_zero_attn: + zero_attn_shape = (bsz * self.num_heads, 1, head_dim) + k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.size(1) + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + bsz, + src_len, + ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = ( + key_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, self.num_heads, -1, -1) + .reshape(bsz * self.num_heads, 1, src_len) + ) + if attn_mask is None: + attn_mask = key_padding_mask + elif attn_mask.dtype == torch.bool: + attn_mask = attn_mask.logical_or(key_padding_mask) + else: + attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) + + # convert mask to float + if attn_mask is not None and attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + + # adjust dropout probability + if not self.training: + dropout_p = 0.0 + else: + dropout_p = self.dropout + + # NOTE: modification from torch.nn.functional._multi_head_attention_forward to incorporate relative attention bias + if position_bias is None: + if not compute_relative_attention_bias: + position_bias = torch.zeros((self.num_heads, tgt_len, src_len), device=k.device, dtype=k.dtype).unsqueeze(0) + else: + position_bias = self._compute_bias( + tgt_len, + src_len, + relative_attention_bias, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + bidirectional=(not self.is_decoder), + device=k.device, + ) + + # calculate attention and out projection + attn_output, attn_output_weights = self._t5_scaled_dot_product_attention( + q, k, v, position_bias, attn_mask, dropout_p + ) + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + + if need_weights: + # optionally average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads + + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) + + return attn_output, attn_output_weights, position_bias + + else: + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + + return attn_output, None, position_bias + + def _t5_scaled_dot_product_attention( q: Tensor, k: Tensor, From 13f9c2209520b28be51d219b43b8e3d22e34dba9 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 12 Jul 2022 17:35:39 -0400 Subject: [PATCH 05/16] create T5MultiheadAttention module [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 832 ++++++++++++++++++------------ 1 file changed, 496 insertions(+), 336 deletions(-) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index 3500f3b4e2..2e69e3ea19 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -12,362 +12,522 @@ # Original code is taken from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py # */ + import math import warnings from typing import Optional, Tuple import torch +import torch.nn as nn import torch.nn.functional as F from torch import Tensor -def t5_multi_head_attention_forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - compute_relative_attention_bias: bool, - relative_attention_num_buckets: Optional[int], - relative_attention_max_distance: Optional[int], - relative_attention_bias: Optional[Tensor], - position_bias: Optional[Tensor], - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - static_k: Optional[Tensor] = None, - static_v: Optional[Tensor] = None, - average_attn_weights: bool = False, -) -> Tuple[Tensor, Optional[Tensor]]: - is_batched = F._mha_shape_check(query, key, value, key_padding_mask, attn_mask, self.num_heads) - - # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input - # is batched, run the computation and before returning squeeze the - # batch dimension so that the output doesn't carry this temporary batch dimension. - if not is_batched: - # unsqueeze if the input is unbatched - query = query.unsqueeze(1) - key = key.unsqueeze(1) - value = value.unsqueeze(1) - if key_padding_mask is not None: - key_padding_mask = key_padding_mask.unsqueeze(0) - - # set up shape vars - tgt_len, bsz, embed_dim = query.shape - src_len, _, _ = key.shape - - assert embed_dim == self.embed_dim, f"was expecting embedding dimension of {self.embed_dim}, but got {embed_dim}" - if isinstance(embed_dim, Tensor): - # embed_dim can be a tensor when JIT tracing - head_dim = embed_dim.div(self.num_heads, rounding_mode="trunc") - else: - head_dim = embed_dim // self.num_heads - assert head_dim * self.num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {self.num_heads}" - # allow MHA to have different embedding dimensions when separate projection weights are used - assert ( - key.shape[:2] == value.shape[:2] - ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" - - # - # compute in-projection - # - assert self.q_proj_weight is not None, "q_proj_weight is None" - assert self.k_proj_weight is not None, "k_proj_weight is None" - assert self.v_proj_weight is not None, "v_proj_weight is None" - if self.in_proj_bias is None: - b_q = b_k = b_v = None - else: - b_q, b_k, b_v = self.in_proj_bias.chunk(3) - q, k, v = F._in_projection( - query, key, value, self.q_proj_weight, self.k_proj_weight, self.v_proj_weight, b_q, b_k, b_v - ) - - # prep attention mask - if attn_mask is not None: - if attn_mask.dtype == torch.uint8: - warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") - attn_mask = attn_mask.to(torch.bool) - else: - assert ( - attn_mask.is_floating_point() or attn_mask.dtype == torch.bool - ), f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" - # ensure attn_mask's dim is 3 - if attn_mask.dim() == 2: - correct_2d_size = (tgt_len, src_len) - if attn_mask.shape != correct_2d_size: - raise RuntimeError( - f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." - ) - attn_mask = attn_mask.unsqueeze(0) - elif attn_mask.dim() == 3: - correct_3d_size = (bsz * self.num_heads, tgt_len, src_len) - if attn_mask.shape != correct_3d_size: - raise RuntimeError( - f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." - ) +class T5MultiheadAttention(nn.MultiheadAttention): + def __init__( + self, + embed_dim, + num_heads, + is_decoder=False, + dropout=0.0, + bias=False, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None, + ) -> None: + r""" + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + is_decoder: whether or not multihead attention is being performed on a decoder layer. Default: ``False`` + dropout: probability of an element to be zeroed. Default: 0.0 + bias: If specified, adds bias to input / output projection layers. Default: ``False``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + """ + super().__init__( + embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype + ) + factory_kwargs = {"device": device, "dtype": dtype} + self.is_decoder = is_decoder + self.q_proj_weight = nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) + self.k_proj_weight = nn.Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) + self.v_proj_weight = nn.Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) + self.register_parameter("in_proj_weight", None) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = False, + compute_relative_attention_bias=False, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + relative_attention_bias: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Allows the model to jointly attend to information from different representation subspaces + as described in the paper: + `Attention Is All You Need `_. + Also incorporates relative attention bias when computing attention scores as descripted in the paper: + `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `_. + + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and byte masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key`` + value will be ignored. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + compute_relative_attention_bias: whether or not the relative position embeddings + need to be computed. typically occurs in the first layer of encoder/decoder + and resulting position embeddings are returned to be passed up to higher layers. (defualt: False) + relative_attention_num_buckets: the number of relative position buckets (default: 32) + relative_attention_max_distance: maximum threshold on the relative distance used to + allocate buckets. anything larger than that gets placed in the same bucket (default: 128) + relative_attention_bias: tensor of weights to compute relative position embeddings. (default: None) + position_bias: tensor of position bias used if to add relative attention bias to attention scores. (default: None) + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + - **position_bias** - used in attention scoring. Only computed when ``compute_relative_attention_bias=True`` + and ``position_bias=None``. Has shape :math:`(1, num_heads, L, S)`. + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ + is_batched = query.dim() == 3 + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + attn_output, attn_output_weights, position_bias = self.t5_multi_head_attention_forward( + query, + key, + value, + compute_relative_attention_bias=compute_relative_attention_bias, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + relative_attention_bias=relative_attention_bias, + position_bias=position_bias, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights, position_bias else: - raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + return attn_output, attn_output_weights, position_bias + + def t5_multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + compute_relative_attention_bias: bool, + relative_attention_num_buckets: Optional[int], + relative_attention_max_distance: Optional[int], + relative_attention_bias: Optional[Tensor], + position_bias: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + is_batched = F._mha_shape_check(query, key, value, key_padding_mask, attn_mask, self.num_heads) + + # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input + # is batched, run the computation and before returning squeeze the + # batch dimension so that the output doesn't carry this temporary batch dimension. + if not is_batched: + # unsqueeze if the input is unbatched + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape - # prep key padding mask - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - # add bias along batch dimension (currently second) - if self.bias_k is not None and self.bias_v is not None: - assert static_k is None, "bias cannot be added to static key." - assert static_v is None, "bias cannot be added to static value." - k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) - v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) - if attn_mask is not None: - attn_mask = F.pad(attn_mask, (0, 1)) - if key_padding_mask is not None: - key_padding_mask = F.pad(key_padding_mask, (0, 1)) - else: - assert self.bias_k is None - assert self.bias_v is None - - # - # reshape q, k, v for multihead attention and make em batch first - # - - q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1) - if static_k is None: - k = k.contiguous().view(k.shape[0], bsz * self.num_heads, head_dim).transpose(0, 1) - else: - # TODO finish disentangling control flow so we don't do in-projections when statics are passed assert ( - static_k.size(0) == bsz * self.num_heads - ), f"expecting static_k.size(0) of {bsz * self.num_heads}, but got {static_k.size(0)}" - assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" - k = static_k - if static_v is None: - v = v.contiguous().view(v.shape[0], bsz * self.num_heads, head_dim).transpose(0, 1) - else: - # TODO finish disentangling control flow so we don't do in-projections when statics are passed + embed_dim == self.embed_dim + ), f"was expecting embedding dimension of {self.embed_dim}, but got {embed_dim}" + if isinstance(embed_dim, Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(self.num_heads, rounding_mode="trunc") + else: + head_dim = embed_dim // self.num_heads assert ( - static_v.size(0) == bsz * self.num_heads - ), f"expecting static_v.size(0) of {bsz * self.num_heads}, but got {static_v.size(0)}" - assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" - v = static_v - - # add zero attention along batch dimension (now first) - if self.add_zero_attn: - zero_attn_shape = (bsz * self.num_heads, 1, head_dim) - k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) - v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) - if attn_mask is not None: - attn_mask = F.pad(attn_mask, (0, 1)) - if key_padding_mask is not None: - key_padding_mask = F.pad(key_padding_mask, (0, 1)) - - # update source sequence length after adjustments - src_len = k.size(1) - - # merge key padding and attention masks - if key_padding_mask is not None: - assert key_padding_mask.shape == ( - bsz, - src_len, - ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" - key_padding_mask = ( - key_padding_mask.view(bsz, 1, 1, src_len) - .expand(-1, self.num_heads, -1, -1) - .reshape(bsz * self.num_heads, 1, src_len) + head_dim * self.num_heads == embed_dim + ), f"embed_dim {embed_dim} not divisible by num_heads {self.num_heads}" + # allow MHA to have different embedding dimensions when separate projection weights are used + assert ( + key.shape[:2] == value.shape[:2] + ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + + # + # compute in-projection + # + assert self.q_proj_weight is not None, "q_proj_weight is None" + assert self.k_proj_weight is not None, "k_proj_weight is None" + assert self.v_proj_weight is not None, "v_proj_weight is None" + if self.in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = self.in_proj_bias.chunk(3) + q, k, v = F._in_projection( + query, key, value, self.q_proj_weight, self.k_proj_weight, self.v_proj_weight, b_q, b_k, b_v ) - if attn_mask is None: - attn_mask = key_padding_mask - elif attn_mask.dtype == torch.bool: - attn_mask = attn_mask.logical_or(key_padding_mask) + + # prep attention mask + if attn_mask is not None: + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + else: + assert ( + attn_mask.is_floating_point() or attn_mask.dtype == torch.bool + ), f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * self.num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + + # prep key padding mask + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + # add bias along batch dimension (currently second) + if self.bias_k is not None and self.bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + else: + assert self.bias_k is None + assert self.bias_v is None + + # + # reshape q, k, v for multihead attention and make em batch first + # + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1) + if static_k is None: + k = k.contiguous().view(k.shape[0], bsz * self.num_heads, head_dim).transpose(0, 1) else: - attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) - - # convert mask to float - if attn_mask is not None and attn_mask.dtype == torch.bool: - new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) - new_attn_mask.masked_fill_(attn_mask, float("-inf")) - attn_mask = new_attn_mask - - # adjust dropout probability - if not self.training: - dropout_p = 0.0 - else: - dropout_p = self.dropout - - # NOTE: modification from torch.nn.functional._multi_head_attention_forward to incorporate relative attention bias - if position_bias is None: - if not compute_relative_attention_bias: - position_bias = torch.zeros((self.num_heads, tgt_len, src_len), device=k.device, dtype=k.dtype).unsqueeze(0) + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert ( + static_k.size(0) == bsz * self.num_heads + ), f"expecting static_k.size(0) of {bsz * self.num_heads}, but got {static_k.size(0)}" + assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + k = static_k + if static_v is None: + v = v.contiguous().view(v.shape[0], bsz * self.num_heads, head_dim).transpose(0, 1) else: - position_bias = self._compute_bias( - tgt_len, + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert ( + static_v.size(0) == bsz * self.num_heads + ), f"expecting static_v.size(0) of {bsz * self.num_heads}, but got {static_v.size(0)}" + assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + v = static_v + + # add zero attention along batch dimension (now first) + if self.add_zero_attn: + zero_attn_shape = (bsz * self.num_heads, 1, head_dim) + k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.size(1) + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + bsz, src_len, - relative_attention_bias, - relative_attention_num_buckets=relative_attention_num_buckets, - relative_attention_max_distance=relative_attention_max_distance, - bidirectional=(not self.is_decoder), - device=k.device, + ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = ( + key_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, self.num_heads, -1, -1) + .reshape(bsz * self.num_heads, 1, src_len) ) + if attn_mask is None: + attn_mask = key_padding_mask + elif attn_mask.dtype == torch.bool: + attn_mask = attn_mask.logical_or(key_padding_mask) + else: + attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) + + # convert mask to float + if attn_mask is not None and attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + + # adjust dropout probability + if not self.training: + dropout_p = 0.0 + else: + dropout_p = self.dropout + + # NOTE: modification from torch.nn.functional._multi_head_attention_forward to incorporate relative attention bias + if position_bias is None: + if not compute_relative_attention_bias: + position_bias = torch.zeros( + (self.num_heads, tgt_len, src_len), device=k.device, dtype=k.dtype + ).unsqueeze(0) + else: + position_bias = self._compute_bias( + tgt_len, + src_len, + relative_attention_bias, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + bidirectional=(not self.is_decoder), + device=k.device, + ) - # calculate attention and out projection - attn_output, attn_output_weights = self._t5_scaled_dot_product_attention( - q, k, v, position_bias, attn_mask, dropout_p - ) - attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) - attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias) - attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + # calculate attention and out projection + attn_output, attn_output_weights = self._t5_scaled_dot_product_attention( + q, k, v, position_bias, attn_mask, dropout_p + ) + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) - if need_weights: - # optionally average attention weights over heads - attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) - if average_attn_weights: - attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads + if need_weights: + # optionally average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads - if not is_batched: - # squeeze the output if input was unbatched - attn_output = attn_output.squeeze(1) - attn_output_weights = attn_output_weights.squeeze(0) + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) - return attn_output, attn_output_weights, position_bias + return attn_output, attn_output_weights, position_bias - else: - if not is_batched: - # squeeze the output if input was unbatched - attn_output = attn_output.squeeze(1) - - return attn_output, None, position_bias - - -def _t5_scaled_dot_product_attention( - q: Tensor, - k: Tensor, - v: Tensor, - position_bias: Tensor, - attn_mask: Optional[Tensor] = None, - dropout_p: float = 0.0, -) -> Tuple[Tensor, Tensor]: - r""" - Computes scaled dot product attention on query, key and value tensors, using - an optional attention mask if passed, and applying dropout if a probability - greater than 0.0 is specified. - Returns a tensor pair containing attended values and attention weights. - Args: - q, k, v: query, key and value tensors. See Shape section for shape details. - attn_mask: optional tensor containing mask values to be added to calculated - attention. May be 2D or 3D; see Shape section for details. - dropout_p: dropout probability. If greater than 0.0, dropout is applied. - position_bias: position bias used to incorporate realtive attention bias in attention scors - Shape: - - q: :math:`(B, Nt, E)` where B is (batch size*num_heads), Nt is the target sequence length, - and E is embedding dimension. - - key: :math:`(B, Ns, E)` where B is (batch size*num_heads), Ns is the source sequence length, - and E is embedding dimension. - - value: :math:`(B, Ns, E)` where B is (batch size*num_heads), Ns is the source sequence length, - and E is embedding dimension. - - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of - shape :math:`(Nt, Ns)`. - - position_bias: :math:`(1, num_heads, Nt, Ns)` - - Output: attention values have shape :math:`(B, Nt, E)`; attention weights - have shape :math:`(B, Nt, Ns)` - """ - B, Nt, E = q.shape - # NOTE: HF implementation does not perform this normalization. For the sake of matching test results, we have commented it out - # q = q / math.sqrt(E) - - n_heads, tgt_len, src_len = position_bias.size()[1:] - assert B % n_heads == 0 - assert tgt_len == Nt - - # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) - if attn_mask is not None: - attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1)) - else: - attn = torch.bmm(q, k.transpose(-2, -1)) - - # NOTE: modification from torch.nn.functional._scaled_dot_product_attention to incorporate relative attention bias - position_bias = position_bias.repeat(B // n_heads, 1, 1, 1) - position_bias = position_bias.view(B, tgt_len, src_len) - attn += position_bias - - attn = F.softmax(attn, dim=-1) - if dropout_p > 0.0: - attn = F.dropout(attn, p=dropout_p) - # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) - output = torch.bmm(attn, v) - return output, attn - - -# NOTE: modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py -def _compute_bias( - query_length: int, - key_length: int, - relative_attention_bias: Tensor, - relative_attention_num_buckets: int = 32, - relative_attention_max_distance: int = 128, - bidirectional: bool = True, - device=None, -): - """Compute binned relative position bias""" - if device is None: - device = relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] - relative_position = memory_position - context_position # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=bidirectional, - num_buckets=relative_attention_num_buckets, - max_distance=relative_attention_max_distance, - ) - values = relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) - return values - - -# NOTE: taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py -def _relative_position_bucket( - relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128 -): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) - ) - - relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) - return relative_buckets + else: + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + + return attn_output, None, position_bias + + def _t5_scaled_dot_product_attention( + self, + q: Tensor, + k: Tensor, + v: Tensor, + position_bias: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + ) -> Tuple[Tensor, Tensor]: + r""" + Computes scaled dot product attention on query, key and value tensors, using + an optional attention mask if passed, and applying dropout if a probability + greater than 0.0 is specified. + Returns a tensor pair containing attended values and attention weights. + Args: + q, k, v: query, key and value tensors. See Shape section for shape details. + attn_mask: optional tensor containing mask values to be added to calculated + attention. May be 2D or 3D; see Shape section for details. + dropout_p: dropout probability. If greater than 0.0, dropout is applied. + position_bias: position bias used to incorporate realtive attention bias in attention scors + Shape: + - q: :math:`(B, Nt, E)` where B is (batch size*num_heads), Nt is the target sequence length, + and E is embedding dimension. + - key: :math:`(B, Ns, E)` where B is (batch size*num_heads), Ns is the source sequence length, + and E is embedding dimension. + - value: :math:`(B, Ns, E)` where B is (batch size*num_heads), Ns is the source sequence length, + and E is embedding dimension. + - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of + shape :math:`(Nt, Ns)`. + - position_bias: :math:`(1, num_heads, Nt, Ns)` + - Output: attention values have shape :math:`(B, Nt, E)`; attention weights + have shape :math:`(B, Nt, Ns)` + """ + B, Nt, E = q.shape + # NOTE: HF implementation does not perform this normalization. For the sake of matching test results, we have commented it out + # q = q / math.sqrt(E) + + n_heads, tgt_len, src_len = position_bias.size()[1:] + assert B % n_heads == 0 + assert tgt_len == Nt + + # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) + if attn_mask is not None: + attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1)) + else: + attn = torch.bmm(q, k.transpose(-2, -1)) + + # NOTE: modification from torch.nn.functional._scaled_dot_product_attention to incorporate relative attention bias + position_bias = position_bias.repeat(B // n_heads, 1, 1, 1) + position_bias = position_bias.view(B, tgt_len, src_len) + attn += position_bias + + attn = F.softmax(attn, dim=-1) + if dropout_p > 0.0: + attn = F.dropout(attn, p=dropout_p) + # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) + output = torch.bmm(attn, v) + return output, attn + + # NOTE: modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + def _compute_bias( + self, + query_length: int, + key_length: int, + relative_attention_bias: Tensor, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + bidirectional: bool = True, + device=None, + ): + """Compute binned relative position bias""" + if device is None: + device = relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=bidirectional, + num_buckets=relative_attention_num_buckets, + max_distance=relative_attention_max_distance, + ) + values = relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + # NOTE: taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + def _relative_position_bucket( + self, relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets From 1f99650502524c5d54bfd11f58d219c0c1a33bb2 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 12 Jul 2022 17:35:42 -0400 Subject: [PATCH 06/16] add layer norm module for t5 model [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index 2e69e3ea19..fcc08e266d 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -531,3 +531,29 @@ def _relative_position_bucket( relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets + + +# NOTE: Taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +class T5LayerNorm(nn.Module): + def __init__(self, d_model, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(d_model)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states From 12e48e1ce08123287c1c3289032e94e8e2c171e8 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 12 Jul 2022 17:35:44 -0400 Subject: [PATCH 07/16] add t5 layer module that can be used for both encoder or decoder stack [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 169 +++++++++++++++++++++++++++++- 1 file changed, 168 insertions(+), 1 deletion(-) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index fcc08e266d..356f346dd7 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -15,7 +15,7 @@ import math import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple, Union, Callable import torch import torch.nn as nn @@ -557,3 +557,170 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states + + +class T5Layer(nn.Module): + r"""T5Layer is made up of self-attn, cross-attn (decoder only) and feedforward network. + This T5 layer is based on the paper: + "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer". + Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, + Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research. + Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html + Users may modify or implement in a different way during application. + Args: + is_decoder: whether or not the layer belongs to the decoder. (required) + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. (default: relu) + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). (default: ``False``) (seq, batch, feature). + relative_attention_num_buckets: the number of relative position buckets (default: 32) + relative_attention_max_distance: maximum threshold on the relative distance used to + allocate buckets. anything larger than that gets placed in the same bucket (default: 128) + compute_relative_attention_bias: whether or not the relative position embeddings + need to be computed. typically occurs in the first layer of encoder/decoder (default: False) + and resulting position embeddings are returned to be passed up to higher layers. + relative_attention_bias: tensor of weights to compute relative position embeddings. (default: None) + + Examples:: + >>> decoder_layer = T5Layer(is_decoder=True, d_model=768, nhead=12, batch_first=True) + >>> memory = torch.rand(32, 10, 768) + >>> tgt = torch.rand(32, 20, 768) + >>> out = deoder_layer(tgt, memory) + """ + + def __init__( + self, + is_decoder: bool, + d_model: int, + nhead: int, + dim_feedforward: int = 3072, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-6, + batch_first: bool = False, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + compute_relative_attention_bias: bool = False, + relative_attention_bias: Optional[Tensor] = None, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.is_decoder = is_decoder + self.compute_relative_attention_bias = compute_relative_attention_bias + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.relative_attention_bias = relative_attention_bias + + self.self_attn = T5MultiheadAttention( + d_model, nhead, is_decoder=is_decoder, dropout=dropout, batch_first=batch_first, device=device, dtype=dtype + ) + self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) + self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) + self.norm1 = T5LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = T5LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + if is_decoder: + self.cross_attn = T5MultiheadAttention( + d_model, + nhead, + is_decoder=is_decoder, + dropout=dropout, + batch_first=batch_first, + device=device, + dtype=dtype, + ) + self.norm3 = T5LayerNorm(d_model, eps=layer_norm_eps) + self.dropout4 = nn.Dropout(dropout) + + if isinstance(activation, str): + if activation == "relu": + self.activation = F.relu + elif activation == "gelu": + self.activation = F.gelu + else: + self.activation = activation + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer. + Args: + tgt: the input sequence to the encoder/decoder layer (required). + memory: the sequence from the last layer of the encoder (used for decoder only). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + position_bias: position embeddings to be used when computing attention scores (optional) + """ + + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + x = tgt + sa_out, position_bias, sa_scores = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, position_bias) + x = x + sa_out + if self.is_decoder: + ca_out, ca_scores = self._ca_block(self.norm3(x), memory, memory_mask, memory_key_padding_mask) + x = x + ca_out + x = x + self._ff_block(self.norm2(x)) + + return x, position_bias, sa_scores, ca_scores if self.is_decoder else None + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + position_bias: Optional[Tensor], + ) -> Tensor: + attn = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=True, + compute_relative_attention_bias=self.compute_relative_attention_bias, + relative_attention_num_buckets=self.relative_attention_num_buckets, + relative_attention_max_distance=self.relative_attention_max_distance, + relative_attention_bias=self.relative_attention_bias, + position_bias=position_bias, + ) + + x = attn[0] + scores = attn[1] + if self.compute_relative_attention_bias and position_bias is None: + position_bias = attn[2] + + return self.dropout1(x), position_bias, scores + + # cross attention block + def _ca_block( + self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] + ) -> Tensor: + attn = self.cross_attn(x, mem, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=True) + x = attn[0] + scores = attn[1] + return self.dropout4(x), scores + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout2(self.activation(self.linear1(x)))) + return self.dropout3(x) From be54aef696a8b84d68eabbcb3c1cbd71f0ac202d Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 12 Jul 2022 17:35:47 -0400 Subject: [PATCH 08/16] add t5 stack that can function as either the encoder or decoder of a t5 model [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 105 ++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index 356f346dd7..0b6d1c8630 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -724,3 +724,108 @@ def _ca_block( def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout2(self.activation(self.linear1(x)))) return self.dropout3(x) + + +class T5Stack(nn.Module): + r"""T5 is a stack of N encoder/decoder layers + Args: + is_decoder: whether or not the layer belongs to the decoder. (required) + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + num_layers: the number of encoder/decoder layers in the stack (required) + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. (default: relu) + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). (default: ``False``) (seq, batch, feature). + relative_attention_num_buckets: the number of relative position buckets (default: 32) + relative_attention_max_distance: maximum threshold on the relative distance used to + allocate buckets. anything larger than that gets placed in the same bucket (defulat: 128) + Examples:: + >>> decoder = nn.T5Stack(is_decoder=True, d_model=768, nhead=12, num_layers=12) + >>> memory = torch.rand(32, 10, 512) + >>> tgt = torch.rand(32, 10, 512) + >>> out = decoder(tgt, memory) + """ + + def __init__( + self, + is_decoder: bool, + d_model: int, + nhead: int, + num_layers: int, + dim_feedforward: int = 3072, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-6, + batch_first: bool = False, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + device=None, + dtype=None, + ): + super().__init__() + + self.layers = nn.ModuleList( + [ + T5Layer( + is_decoder, + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + batch_first, + relative_attention_num_buckets, + relative_attention_max_distance, + compute_relative_attention_bias=True if i == 0 else False, + relative_attention_bias=nn.Embedding(relative_attention_num_buckets, nhead) if i == 0 else None, + device=device, + dtype=dtype, + ) + for i in range(num_layers) + ] + ) + self.num_layers = num_layers + + def forward( + self, + tgt: Tensor, + memory: Tensor = None, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer in turn. + Args: + tgt: the input sequence to the encoder/decoder (required). + memory: the sequence from the last layer of the encoder (for decoder only). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + """ + output = tgt + position_bias = None + all_outputs = () + sa_scores = () + ca_scores = () + for mod in self.layers: + all_outputs = all_outputs + (output,) + output, position_bias, sa_score, ca_score = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + position_bias=position_bias, + ) + sa_scores = sa_scores + (sa_score,) + ca_scores = ca_scores + (ca_score,) + + return output, all_outputs, position_bias, sa_scores, ca_scores From dc55ec9d02687680ab31b1556c54e955b8096d30 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 13 Jul 2022 13:24:27 -0400 Subject: [PATCH 09/16] Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index 0b6d1c8630..5dbe01c79d 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -15,7 +15,7 @@ import math import warnings -from typing import Optional, Tuple, Union, Callable +from typing import Optional, Union, Tuple, Callable import torch import torch.nn as nn From a213d8cc912261bab55df1de7a1416d9e13311dc Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 13 Jul 2022 13:40:46 -0400 Subject: [PATCH 10/16] Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index 5dbe01c79d..0b6d1c8630 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -15,7 +15,7 @@ import math import warnings -from typing import Optional, Union, Tuple, Callable +from typing import Optional, Tuple, Union, Callable import torch import torch.nn as nn From c48d2927762f82a21bcb825edc0f34a5e85cf309 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 13 Jul 2022 13:44:43 -0400 Subject: [PATCH 11/16] Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] From 4e661a0d6518ff563d28e562a2100cccc33199bc Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 13 Jul 2022 13:52:07 -0400 Subject: [PATCH 12/16] Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] From ff1f0e8c078a87f2f48218e0e67698536cb684e1 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Thu, 14 Jul 2022 16:22:50 -0400 Subject: [PATCH 13/16] Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 455 ++++++++++++------------------ 1 file changed, 174 insertions(+), 281 deletions(-) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index 0b6d1c8630..d9c5615d2a 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -9,8 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Original code is taken from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +# Parts of code are originally from +# https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py # */ import math @@ -31,32 +31,22 @@ def __init__( is_decoder=False, dropout=0.0, bias=False, - add_bias_kv=False, - add_zero_attn=False, kdim=None, vdim=None, - batch_first=False, device=None, dtype=None, ) -> None: r""" Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - is_decoder: whether or not multihead attention is being performed on a decoder layer. Default: ``False`` - dropout: probability of an element to be zeroed. Default: 0.0 - bias: If specified, adds bias to input / output projection layers. Default: ``False``. - add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. - add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. - Default: ``False``. - kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). - vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). - batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + embed_dim: Total dimension of the model. + num_heads: Parallel attention heads. + is_decoder: Whether or not multihead attention is being performed on a decoder layer. Default: `False` + dropout: Probability of an element to be zeroed. Default: 0.0 + bias: If specified, adds bias to input / output projection layers. Default: `False`. + kdim: Total number of features for keys. Default: `None` (uses `kdim=embed_dim`). + vdim: Total number of features for values. Default: `None` (uses `vdim=embed_dim`). """ - super().__init__( - embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype - ) + super().__init__(embed_dim, num_heads, dropout, bias, False, False, kdim, vdim, True, device, dtype) factory_kwargs = {"device": device, "dtype": dtype} self.is_decoder = is_decoder self.q_proj_weight = nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) @@ -87,74 +77,49 @@ def forward( `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `_. Args: - query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` - or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, - :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + query: Query embeddings of shape :math:`(N, L, E_q)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, + and :math:`E_q` is the query embedding dimension `embed_dim`. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details. - key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` - or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, - :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + key: Key embeddings of shape :math:`(N, S, E_k)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, + and :math:`E_k` is the key embedding dimension `kdim`. See "Attention Is All You Need" for more details. - value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when - ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source - sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + value: Value embeddings of shape :math:`(N, S, E_v)`, where :math:`N` is the batch size, :math:`S` is the source + sequence length, and :math:`E_v` is the value embedding dimension `vdim`. See "Attention Is All You Need" for more details. - key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` - to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. - Binary and byte masks are supported. - For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for - the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key`` - value will be ignored. - need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. - Default: ``True``. - attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape - :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, - :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be - broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. - Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the - corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the - corresponding position is not allowed to attend. For a float mask, the mask values will be added to - the attention weight. - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across - heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an - effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) - compute_relative_attention_bias: whether or not the relative position embeddings - need to be computed. typically occurs in the first layer of encoder/decoder - and resulting position embeddings are returned to be passed up to higher layers. (defualt: False) - relative_attention_num_buckets: the number of relative position buckets (default: 32) - relative_attention_max_distance: maximum threshold on the relative distance used to - allocate buckets. anything larger than that gets placed in the same bucket (default: 128) - relative_attention_bias: tensor of weights to compute relative position embeddings. (default: None) - position_bias: tensor of position bias used if to add relative attention bias to attention scores. (default: None) + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within `key` + to ignore for the purpose of attention (i.e. treat as "padding"). + Binary masks are supported. For a binary mask, a `True` value indicates that the corresponding `key` + value will be ignored for the purpose of attention. + need_weights: If specified, returns `attn_output_weights` in addition to `attn_outputs`. + Default: `True`. + attn_mask: If specified, a 2D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)`, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch. Binary, and float masks are supported. + For a binary mask, a `True` value indicates that the corresponding position is not allowed to attend. + For a float mask, the mask values will be added to the attention weight. Default: `None` + average_attn_weights: If true, indicates that the returned `attn_weights` should be averaged across + heads. Otherwise, `attn_weights` are provided separately per head. Note that this flag only has an + effect when `need_weights=True`. Default: `False` (i.e. average weights across heads) + compute_relative_attention_bias: Whether or not the relative position embeddings + need to be computed. Wypically occurs in the first layer of the encoder/decoder + and the resulting position embeddings are returned to be passed up to higher layers. (defualt: False) + relative_attention_num_buckets: Number of relative position buckets. Default: `32` + relative_attention_max_distance: Maximum threshold on the relative distance used to + allocate buckets. Anything larger gets placed in the same bucket. Default: `128` + relative_attention_bias: nn.Embeding object used to compute relative position embeddings. Default: `None` + position_bias: Position bias tensor used if to add relative attention bias to attention scores. Default: `None` Outputs: - - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, - :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, - where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the - embedding dimension ``embed_dim``. - - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`E` is the embedding dimension `embed_dim`. + - **attn_output_weights** - Only returned when `need_weights=True`. If `average_attn_weights=True`, returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and - :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per + :math:`S` is the source sequence length. If `average_weights=False`, returns attention weights per head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. - - **position_bias** - used in attention scoring. Only computed when ``compute_relative_attention_bias=True`` - and ``position_bias=None``. Has shape :math:`(1, num_heads, L, S)`. - .. note:: - `batch_first` argument is ignored for unbatched inputs. + - **position_bias** - Used in attention scoring. Only computed when `compute_relative_attention_bias=True` + and `position_bias=None`. Has shape :math:`(1, num_heads, L, S)`. """ - is_batched = query.dim() == 3 - - if self.batch_first and is_batched: - # make sure that the transpose op does not affect the "is" property - if key is value: - if query is key: - query = key = value = query.transpose(1, 0) - else: - query, key = [x.transpose(1, 0) for x in (query, key)] - value = key - else: - query, key, value = [x.transpose(1, 0) for x in (query, key, value)] - attn_output, attn_output_weights, position_bias = self.t5_multi_head_attention_forward( query, key, @@ -169,11 +134,9 @@ def forward( attn_mask=attn_mask, average_attn_weights=average_attn_weights, ) - if self.batch_first and is_batched: - return attn_output.transpose(1, 0), attn_output_weights, position_bias - else: - return attn_output, attn_output_weights, position_bias + return attn_output, attn_output_weights, position_bias + # NOTE: Modified from https://github.com/pytorch/pytorch/blob/5953fd9133c0bdcc0158acf1472fac403bc5f636/torch/nn/functional.py#L4909 def t5_multi_head_attention_forward( self, query: Tensor, @@ -187,8 +150,6 @@ def t5_multi_head_attention_forward( key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, - static_k: Optional[Tensor] = None, - static_v: Optional[Tensor] = None, average_attn_weights: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: is_batched = F._mha_shape_check(query, key, value, key_padding_mask, attn_mask, self.num_heads) @@ -197,36 +158,34 @@ def t5_multi_head_attention_forward( # is batched, run the computation and before returning squeeze the # batch dimension so that the output doesn't carry this temporary batch dimension. if not is_batched: - # unsqueeze if the input is unbatched + # Unsqueeze if the input is unbatched query = query.unsqueeze(1) key = key.unsqueeze(1) value = value.unsqueeze(1) if key_padding_mask is not None: key_padding_mask = key_padding_mask.unsqueeze(0) - # set up shape vars - tgt_len, bsz, embed_dim = query.shape - src_len, _, _ = key.shape + # Set up shape vars + bsz, tgt_len, embed_dim = query.shape + _, src_len, _ = key.shape assert ( embed_dim == self.embed_dim ), f"was expecting embedding dimension of {self.embed_dim}, but got {embed_dim}" if isinstance(embed_dim, Tensor): - # embed_dim can be a tensor when JIT tracing + # Embed_dim can be a tensor when JIT tracing head_dim = embed_dim.div(self.num_heads, rounding_mode="trunc") else: head_dim = embed_dim // self.num_heads assert ( head_dim * self.num_heads == embed_dim ), f"embed_dim {embed_dim} not divisible by num_heads {self.num_heads}" - # allow MHA to have different embedding dimensions when separate projection weights are used + # Allow MHA to have different embedding dimensions when separate projection weights are used assert ( key.shape[:2] == value.shape[:2] ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" - # - # compute in-projection - # + # Compute in-projection assert self.q_proj_weight is not None, "q_proj_weight is None" assert self.k_proj_weight is not None, "k_proj_weight is None" assert self.v_proj_weight is not None, "v_proj_weight is None" @@ -238,103 +197,43 @@ def t5_multi_head_attention_forward( query, key, value, self.q_proj_weight, self.k_proj_weight, self.v_proj_weight, b_q, b_k, b_v ) - # prep attention mask + # Prep attention mask if attn_mask is not None: if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." - ) + warnings.warn("Byte tensor for attn_mask is not supported. Using bool tensor instead.") attn_mask = attn_mask.to(torch.bool) else: assert ( attn_mask.is_floating_point() or attn_mask.dtype == torch.bool - ), f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" - # ensure attn_mask's dim is 3 + ), f"Only float and bool types are supported for attn_mask, not {attn_mask.dtype}" + # Ensure attn_mask's dim is 3 if attn_mask.dim() == 2: correct_2d_size = (tgt_len, src_len) if attn_mask.shape != correct_2d_size: raise RuntimeError( f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." ) - attn_mask = attn_mask.unsqueeze(0) - elif attn_mask.dim() == 3: - correct_3d_size = (bsz * self.num_heads, tgt_len, src_len) - if attn_mask.shape != correct_3d_size: - raise RuntimeError( - f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." - ) + attn_mask = attn_mask.view(1, 1, tgt_len, tgt_len).expand(bsz, self.num_heads, -1, -1) else: raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") - # prep key padding mask + # Prep key padding mask if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." - ) + warnings.warn("Byte tensor for key_padding_mask is not supported. Using bool tensor instead.") key_padding_mask = key_padding_mask.to(torch.bool) - # add bias along batch dimension (currently second) - if self.bias_k is not None and self.bias_v is not None: - assert static_k is None, "bias cannot be added to static key." - assert static_v is None, "bias cannot be added to static value." - k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) - v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) - if attn_mask is not None: - attn_mask = F.pad(attn_mask, (0, 1)) - if key_padding_mask is not None: - key_padding_mask = F.pad(key_padding_mask, (0, 1)) - else: - assert self.bias_k is None - assert self.bias_v is None - - # - # reshape q, k, v for multihead attention and make em batch first - # + # Reshape q, k, v for multihead attention and make em batch first + q = q.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2) + k = k.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2) + v = v.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2) + src_len = k.size(2) - q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1) - if static_k is None: - k = k.contiguous().view(k.shape[0], bsz * self.num_heads, head_dim).transpose(0, 1) - else: - # TODO finish disentangling control flow so we don't do in-projections when statics are passed - assert ( - static_k.size(0) == bsz * self.num_heads - ), f"expecting static_k.size(0) of {bsz * self.num_heads}, but got {static_k.size(0)}" - assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" - k = static_k - if static_v is None: - v = v.contiguous().view(v.shape[0], bsz * self.num_heads, head_dim).transpose(0, 1) - else: - # TODO finish disentangling control flow so we don't do in-projections when statics are passed - assert ( - static_v.size(0) == bsz * self.num_heads - ), f"expecting static_v.size(0) of {bsz * self.num_heads}, but got {static_v.size(0)}" - assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" - v = static_v - - # add zero attention along batch dimension (now first) - if self.add_zero_attn: - zero_attn_shape = (bsz * self.num_heads, 1, head_dim) - k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) - v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) - if attn_mask is not None: - attn_mask = F.pad(attn_mask, (0, 1)) - if key_padding_mask is not None: - key_padding_mask = F.pad(key_padding_mask, (0, 1)) - - # update source sequence length after adjustments - src_len = k.size(1) - - # merge key padding and attention masks if key_padding_mask is not None: assert key_padding_mask.shape == ( bsz, src_len, ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" - key_padding_mask = ( - key_padding_mask.view(bsz, 1, 1, src_len) - .expand(-1, self.num_heads, -1, -1) - .reshape(bsz * self.num_heads, 1, src_len) - ) + key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_heads, tgt_len, -1) if attn_mask is None: attn_mask = key_padding_mask elif attn_mask.dtype == torch.bool: @@ -342,19 +241,19 @@ def t5_multi_head_attention_forward( else: attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) - # convert mask to float + # Convert mask to float if attn_mask is not None and attn_mask.dtype == torch.bool: new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) new_attn_mask.masked_fill_(attn_mask, float("-inf")) attn_mask = new_attn_mask - # adjust dropout probability + # Adjust dropout probability if not self.training: dropout_p = 0.0 else: dropout_p = self.dropout - # NOTE: modification from torch.nn.functional._multi_head_attention_forward to incorporate relative attention bias + # NOTE: Modification to torch.nn.functional._multi_head_attention_forward to incorporate relative attention bias if position_bias is None: if not compute_relative_attention_bias: position_bias = torch.zeros( @@ -371,22 +270,17 @@ def t5_multi_head_attention_forward( device=k.device, ) - # calculate attention and out projection - attn_output, attn_output_weights = self._t5_scaled_dot_product_attention( - q, k, v, position_bias, attn_mask, dropout_p - ) - attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + # Calculate attention and out projection + attn_output, attn_output_weights = self._t5_dot_product_attention(q, k, v, position_bias, attn_mask, dropout_p) attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias) - attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) if need_weights: - # optionally average attention weights over heads - attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) + # Optionally average attention weights over heads if average_attn_weights: attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads if not is_batched: - # squeeze the output if input was unbatched + # Aqueeze the output if input was unbatched attn_output = attn_output.squeeze(1) attn_output_weights = attn_output_weights.squeeze(0) @@ -394,12 +288,13 @@ def t5_multi_head_attention_forward( else: if not is_batched: - # squeeze the output if input was unbatched + # Aqueeze the output if input was unbatched attn_output = attn_output.squeeze(1) return attn_output, None, position_bias - def _t5_scaled_dot_product_attention( + # NOTE: Modified from https://github.com/pytorch/pytorch/blob/5953fd9133c0bdcc0158acf1472fac403bc5f636/torch/nn/functional.py#L4814 + def _t5_dot_product_attention( self, q: Tensor, k: Tensor, @@ -414,51 +309,43 @@ def _t5_scaled_dot_product_attention( greater than 0.0 is specified. Returns a tensor pair containing attended values and attention weights. Args: - q, k, v: query, key and value tensors. See Shape section for shape details. - attn_mask: optional tensor containing mask values to be added to calculated + q, k, v: Query, key and value tensors. See Shape section for shape details. + attn_mask: Optional tensor containing mask values to be added to calculated attention. May be 2D or 3D; see Shape section for details. - dropout_p: dropout probability. If greater than 0.0, dropout is applied. - position_bias: position bias used to incorporate realtive attention bias in attention scors + dropout_p: Dropout probability. If greater than 0.0, dropout is applied. + position_bias: Position bias used to incorporate realtive attention bias in attention scors Shape: - - q: :math:`(B, Nt, E)` where B is (batch size*num_heads), Nt is the target sequence length, - and E is embedding dimension. - - key: :math:`(B, Ns, E)` where B is (batch size*num_heads), Ns is the source sequence length, - and E is embedding dimension. - - value: :math:`(B, Ns, E)` where B is (batch size*num_heads), Ns is the source sequence length, - and E is embedding dimension. - - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of - shape :math:`(Nt, Ns)`. - - position_bias: :math:`(1, num_heads, Nt, Ns)` - - Output: attention values have shape :math:`(B, Nt, E)`; attention weights - have shape :math:`(B, Nt, Ns)` + - q: :math:`(B, H, Nt, E)` where B is the batch size, H is the number of heads, Nt is the target sequence length, + and E is the head dimension. + - key: :math:`(B, H, Ns, E)` where B is the batch size, H is the number of heads, Ns is the source sequence length, + and E is the head dimension. + - value: :math:`(B, H, Ns, E)` where B is the batch size, H is the number of heads, Ns is the source sequence length, + and E is the head dimension. + - attn_mask: a 4D tensor of shape :math:`(B, H, Nt, Ns)` + - position_bias: :math:`(1, H, Nt, Ns)` + - Output: attention values have shape :math:`(B, Nt, H*E)`; attention weights + have shape :math:`(B, H, Nt, Ns)` """ - B, Nt, E = q.shape + B, H, _, E = q.shape # NOTE: HF implementation does not perform this normalization. For the sake of matching test results, we have commented it out # q = q / math.sqrt(E) - n_heads, tgt_len, src_len = position_bias.size()[1:] - assert B % n_heads == 0 - assert tgt_len == Nt - - # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) - if attn_mask is not None: - attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1)) - else: - attn = torch.bmm(q, k.transpose(-2, -1)) + attn = torch.matmul(q, k.transpose(3, 2)) # NOTE: modification from torch.nn.functional._scaled_dot_product_attention to incorporate relative attention bias - position_bias = position_bias.repeat(B // n_heads, 1, 1, 1) - position_bias = position_bias.view(B, tgt_len, src_len) + position_bias = position_bias.repeat(B, 1, 1, 1) + if attn_mask is not None: + position_bias += attn_mask attn += position_bias attn = F.softmax(attn, dim=-1) if dropout_p > 0.0: attn = F.dropout(attn, p=dropout_p) - # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) - output = torch.bmm(attn, v) + output = torch.matmul(attn, v) + output = output.transpose(1, 2).contiguous().view(B, -1, H * E) return output, attn - # NOTE: modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + # NOTE: modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421 def _compute_bias( self, query_length: int, @@ -485,7 +372,7 @@ def _compute_bias( values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values - # NOTE: taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + # NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L374 def _relative_position_bucket( self, relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128 ): @@ -513,9 +400,9 @@ def _relative_position_bucket( relative_position = torch.abs(relative_position) else: relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) - # now relative_position is in the range [0, inf) + # Ensure relative_position is in the range [0, inf) - # half of the buckets are for exact increments in positions + # Half of the buckets are for exact increments in positions max_exact = num_buckets // 2 is_small = relative_position < max_exact @@ -533,7 +420,7 @@ def _relative_position_bucket( return relative_buckets -# NOTE: Taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +# NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L239 class T5LayerNorm(nn.Module): def __init__(self, d_model, eps=1e-6): """ @@ -552,13 +439,14 @@ def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # convert into half-precision if necessary + # Convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states +# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L622 class T5Layer(nn.Module): r"""T5Layer is made up of self-attn, cross-attn (decoder only) and feedforward network. This T5 layer is based on the paper: @@ -568,26 +456,24 @@ class T5Layer(nn.Module): Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html Users may modify or implement in a different way during application. Args: - is_decoder: whether or not the layer belongs to the decoder. (required) - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - activation: the activation function of the intermediate layer, can be a string + is_decoder: Whether or not the layer belongs to the decoder. (required) + d_model: Number of expected features in the input (required). + nhead: Number of heads in the multihead attention models (required). + dim_feedforward: Dimension of the feedforward network model (default=3072). + dropout: Dropout value (default=0.1). + activation: Activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable. (default: relu) - layer_norm_eps: the eps value in layer normalization components (default=1e-5). - batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). (default: ``False``) (seq, batch, feature). - relative_attention_num_buckets: the number of relative position buckets (default: 32) - relative_attention_max_distance: maximum threshold on the relative distance used to - allocate buckets. anything larger than that gets placed in the same bucket (default: 128) - compute_relative_attention_bias: whether or not the relative position embeddings - need to be computed. typically occurs in the first layer of encoder/decoder (default: False) - and resulting position embeddings are returned to be passed up to higher layers. - relative_attention_bias: tensor of weights to compute relative position embeddings. (default: None) + layer_norm_eps: The eps value in layer normalization components (default=1e-6). + relative_attention_num_buckets: Number of relative position buckets (default: 32) + relative_attention_max_distance: Maximum threshold on the relative distance used to + allocate buckets. Anything larger gets placed in the same bucket (default: 128) + compute_relative_attention_bias: Whether or not the relative position embeddings + need to be computed. Typically occurs in the first layer of encoder/decoder + and resulting position embeddings are returned to be passed up to higher layers. (default: False) + relative_attention_bias: nn.Embeding object used to compute relative position embeddings. (default: None) Examples:: - >>> decoder_layer = T5Layer(is_decoder=True, d_model=768, nhead=12, batch_first=True) + >>> decoder_layer = T5Layer(is_decoder=True, d_model=768, nhead=12) >>> memory = torch.rand(32, 10, 768) >>> tgt = torch.rand(32, 20, 768) >>> out = deoder_layer(tgt, memory) @@ -602,7 +488,6 @@ def __init__( dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-6, - batch_first: bool = False, relative_attention_num_buckets: int = 32, relative_attention_max_distance: int = 128, compute_relative_attention_bias: bool = False, @@ -619,7 +504,7 @@ def __init__( self.relative_attention_bias = relative_attention_bias self.self_attn = T5MultiheadAttention( - d_model, nhead, is_decoder=is_decoder, dropout=dropout, batch_first=batch_first, device=device, dtype=dtype + d_model, nhead, is_decoder=is_decoder, dropout=dropout, device=device, dtype=dtype ) self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) @@ -631,13 +516,7 @@ def __init__( if is_decoder: self.cross_attn = T5MultiheadAttention( - d_model, - nhead, - is_decoder=is_decoder, - dropout=dropout, - batch_first=batch_first, - device=device, - dtype=dtype, + d_model, nhead, is_decoder=is_decoder, dropout=dropout, device=device, dtype=dtype ) self.norm3 = T5LayerNorm(d_model, eps=layer_norm_eps) self.dropout4 = nn.Dropout(dropout) @@ -660,18 +539,27 @@ def forward( memory_key_padding_mask: Optional[Tensor] = None, position_bias: Optional[Tensor] = None, ) -> Tensor: - r"""Pass the inputs (and mask) through the decoder layer. + r"""Pass the inputs (and mask) through the encoder/decoder layer. Args: - tgt: the input sequence to the encoder/decoder layer (required). - memory: the sequence from the last layer of the encoder (used for decoder only). - tgt_mask: the mask for the tgt sequence (optional). - memory_mask: the mask for the memory sequence (optional). - tgt_key_padding_mask: the mask for the tgt keys per batch (optional). - memory_key_padding_mask: the mask for the memory keys per batch (optional). - position_bias: position embeddings to be used when computing attention scores (optional) + tgt: Input sequence to the encoder/decoder layer. (required). + Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence + length, and E is the model dimension. + memory: Sequence from the last layer of the encoder (used for decoder only). (required). + Must have shape (B, Nts, E) where B is the batch size, Ns is the source sequence + length, and E is the model dimension. + tgt_mask: Attention mask for self-attention. (optional). + Must have shape (Nt, Nt). + memory_mask: Attention mask for cross-attention (decoder-only) (optional). + Must have shape (Nt, Ns). + tgt_key_padding_mask: Mask for the tgt keys per batch (optional). + Must have shape (B, Nt). + memory_key_padding_mask: Mask for the memory keys per batch (decoder-only) (optional). + Must have shape (B, Ns). + position_bias: Relative attention bias to be used when computing self-attention scores (optional) + Must have shape (B, H, Nt, Nt) where H is the number of heads. """ - # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + # See Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf x = tgt sa_out, position_bias, sa_scores = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, position_bias) x = x + sa_out @@ -682,7 +570,7 @@ def forward( return x, position_bias, sa_scores, ca_scores if self.is_decoder else None - # self-attention block + # Self-attention block def _sa_block( self, x: Tensor, @@ -711,7 +599,7 @@ def _sa_block( return self.dropout1(x), position_bias, scores - # cross attention block + # Cross attention block def _ca_block( self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] ) -> Tensor: @@ -720,29 +608,28 @@ def _ca_block( scores = attn[1] return self.dropout4(x), scores - # feed forward block + # Feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout2(self.activation(self.linear1(x)))) return self.dropout3(x) +# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L835 class T5Stack(nn.Module): r"""T5 is a stack of N encoder/decoder layers Args: - is_decoder: whether or not the layer belongs to the decoder. (required) - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - num_layers: the number of encoder/decoder layers in the stack (required) - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - activation: the activation function of the intermediate layer, can be a string + is_decoder: Whether or not the layer belongs to the decoder. (required) + d_model: Number of expected features in the input (required). + nhead: Number of heads in the multihead attention models (required). + num_layers: Number of encoder/decoder layers in the stack (required) + dim_feedforward: Dimension of the feedforward network model (default=3072). + dropout: Dropout value (default=0.1). + activation: Activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable. (default: relu) - layer_norm_eps: the eps value in layer normalization components (default=1e-5). - batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). (default: ``False``) (seq, batch, feature). - relative_attention_num_buckets: the number of relative position buckets (default: 32) - relative_attention_max_distance: maximum threshold on the relative distance used to - allocate buckets. anything larger than that gets placed in the same bucket (defulat: 128) + layer_norm_eps: The eps value in layer normalization components (default=1e-6). + relative_attention_num_buckets: Number of relative position buckets (default: 32) + relative_attention_max_distance: Maximum threshold on the relative distance used to + allocate buckets. Anything larger gets placed in the same bucket (defulat: 128) Examples:: >>> decoder = nn.T5Stack(is_decoder=True, d_model=768, nhead=12, num_layers=12) >>> memory = torch.rand(32, 10, 512) @@ -760,7 +647,6 @@ def __init__( dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-6, - batch_first: bool = False, relative_attention_num_buckets: int = 32, relative_attention_max_distance: int = 128, device=None, @@ -778,7 +664,6 @@ def __init__( dropout, activation, layer_norm_eps, - batch_first, relative_attention_num_buckets, relative_attention_max_distance, compute_relative_attention_bias=True if i == 0 else False, @@ -800,20 +685,28 @@ def forward( tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: - r"""Pass the inputs (and mask) through the decoder layer in turn. + r"""Pass the inputs (and mask) through the stack of encoder/decoder layers. Args: - tgt: the input sequence to the encoder/decoder (required). - memory: the sequence from the last layer of the encoder (for decoder only). - tgt_mask: the mask for the tgt sequence (optional). - memory_mask: the mask for the memory sequence (optional). - tgt_key_padding_mask: the mask for the tgt keys per batch (optional). - memory_key_padding_mask: the mask for the memory keys per batch (optional). + tgt: Input sequence to the encoder/decoder layer. (required). + Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence + length, and E is the model dimension. + memory: Sequence from the last layer of the encoder (used for decoder only). (required). + Must have shape (B, Nts, E) where B is the batch size, Ns is the source sequence + length, and E is the model dimension. + tgt_mask: Attention mask for self-attention. (optional). + Must have shape (Nt, Nt). + memory_mask: Attention mask for cross-attention (decoder-only) (optional). + Must have shape (Nt, Ns). + tgt_key_padding_mask: Mask for the tgt keys per batch (optional). + Must have shape (B, Nt). + memory_key_padding_mask: Mask for the memory keys per batch (decoder-only) (optional). + Must have shape (B, Ns). """ output = tgt position_bias = None all_outputs = () - sa_scores = () - ca_scores = () + all_sa_scores = () + all_ca_scores = () for mod in self.layers: all_outputs = all_outputs + (output,) output, position_bias, sa_score, ca_score = mod( @@ -825,7 +718,7 @@ def forward( memory_key_padding_mask=memory_key_padding_mask, position_bias=position_bias, ) - sa_scores = sa_scores + (sa_score,) - ca_scores = ca_scores + (ca_score,) + all_sa_scores = all_sa_scores + (sa_score,) + all_ca_scores = all_ca_scores + (ca_score,) - return output, all_outputs, position_bias, sa_scores, ca_scores + return output, all_outputs, position_bias, all_sa_scores, all_ca_scores From 6b7746082cc12a4244a4f8d36c1f79263b0b8830 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 15 Jul 2022 12:36:15 -0400 Subject: [PATCH 14/16] Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] From eef9916fbd14ff1bac12893495ba4cbf9fd946e4 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 15 Jul 2022 16:58:26 -0400 Subject: [PATCH 15/16] Update base for Update on "add t5 model that can function as both encodery-only or encoder-decoder model" [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 60 ++++++++++++++++++------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index d9c5615d2a..77b6733de4 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -68,7 +68,7 @@ def forward( relative_attention_max_distance=128, relative_attention_bias: Optional[Tensor] = None, position_bias: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: + ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: r""" Allows the model to jointly attend to information from different representation subspaces as described in the paper: @@ -120,7 +120,7 @@ def forward( - **position_bias** - Used in attention scoring. Only computed when `compute_relative_attention_bias=True` and `position_bias=None`. Has shape :math:`(1, num_heads, L, S)`. """ - attn_output, attn_output_weights, position_bias = self.t5_multi_head_attention_forward( + attn_output, position_bias, attn_output_weights = self._t5_multi_head_attention_forward( query, key, value, @@ -134,10 +134,10 @@ def forward( attn_mask=attn_mask, average_attn_weights=average_attn_weights, ) - return attn_output, attn_output_weights, position_bias + return attn_output, position_bias, attn_output_weights # NOTE: Modified from https://github.com/pytorch/pytorch/blob/5953fd9133c0bdcc0158acf1472fac403bc5f636/torch/nn/functional.py#L4909 - def t5_multi_head_attention_forward( + def _t5_multi_head_attention_forward( self, query: Tensor, key: Tensor, @@ -151,7 +151,7 @@ def t5_multi_head_attention_forward( need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = False, - ) -> Tuple[Tensor, Optional[Tensor]]: + ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: is_batched = F._mha_shape_check(query, key, value, key_padding_mask, attn_mask, self.num_heads) # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input @@ -280,18 +280,18 @@ def t5_multi_head_attention_forward( attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads if not is_batched: - # Aqueeze the output if input was unbatched + # Squeeze the output if input was unbatched attn_output = attn_output.squeeze(1) attn_output_weights = attn_output_weights.squeeze(0) - return attn_output, attn_output_weights, position_bias + return attn_output, position_bias, attn_output_weights else: if not is_batched: - # Aqueeze the output if input was unbatched + # Squeeze the output if input was unbatched attn_output = attn_output.squeeze(1) - return attn_output, None, position_bias + return attn_output, position_bias, None # NOTE: Modified from https://github.com/pytorch/pytorch/blob/5953fd9133c0bdcc0158acf1472fac403bc5f636/torch/nn/functional.py#L4814 def _t5_dot_product_attention( @@ -355,7 +355,7 @@ def _compute_bias( relative_attention_max_distance: int = 128, bidirectional: bool = True, device=None, - ): + ) -> Tensor: """Compute binned relative position bias""" if device is None: device = relative_attention_bias.weight.device @@ -375,7 +375,7 @@ def _compute_bias( # NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L374 def _relative_position_bucket( self, relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128 - ): + ) -> Tensor: """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 @@ -422,7 +422,7 @@ def _relative_position_bucket( # NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L239 class T5LayerNorm(nn.Module): - def __init__(self, d_model, eps=1e-6): + def __init__(self, d_model, eps=1e-6) -> None: """ Construct a layernorm module in the T5 style. No bias and no subtraction of mean. """ @@ -430,11 +430,17 @@ def __init__(self, d_model, eps=1e-6): self.weight = nn.Parameter(torch.ones(d_model)) self.variance_epsilon = eps - def forward(self, hidden_states): - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 + def forward(self, hidden_states: Tensor) -> Tensor: + r""" + T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + half-precision inputs is done in fp32. + Args: + hidden_states: Tensor to be normalized. Final dimension must be model dimension (i.e. number of expected features in the input) + Returns: + a Tensor with the same shape as hidden_states after having been normalized + """ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) @@ -522,6 +528,10 @@ def __init__( self.dropout4 = nn.Dropout(dropout) if isinstance(activation, str): + assert activation in ( + "relu", + "gelu", + ), f"Do not support '{activation}' activation. Use either 'relu' or 'gelu'" if activation == "relu": self.activation = F.relu elif activation == "gelu": @@ -538,7 +548,7 @@ def forward( tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, position_bias: Optional[Tensor] = None, - ) -> Tensor: + ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]: r"""Pass the inputs (and mask) through the encoder/decoder layer. Args: tgt: Input sequence to the encoder/decoder layer. (required). @@ -577,7 +587,7 @@ def _sa_block( attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], position_bias: Optional[Tensor], - ) -> Tensor: + ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: attn = self.self_attn( x, x, @@ -593,19 +603,19 @@ def _sa_block( ) x = attn[0] - scores = attn[1] + scores = attn[2] if self.compute_relative_attention_bias and position_bias is None: - position_bias = attn[2] + position_bias = attn[1] return self.dropout1(x), position_bias, scores # Cross attention block def _ca_block( self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] - ) -> Tensor: + ) -> Tuple[Tensor, Optional[Tensor]]: attn = self.cross_attn(x, mem, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=True) x = attn[0] - scores = attn[1] + scores = attn[2] return self.dropout4(x), scores # Feed forward block @@ -651,7 +661,7 @@ def __init__( relative_attention_max_distance: int = 128, device=None, dtype=None, - ): + ) -> None: super().__init__() self.layers = nn.ModuleList( @@ -684,7 +694,7 @@ def forward( memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: + ) -> Tuple[Tensor, Tuple[Tensor], Tensor, Tuple[Tensor], Tuple[Tensor]]: r"""Pass the inputs (and mask) through the stack of encoder/decoder layers. Args: tgt: Input sequence to the encoder/decoder layer. (required). From adbc51174d9e8d20a1ae52e3f1adeec5416a3828 Mon Sep 17 00:00:00 2001 From: pmabbo13 <88948596+pmabbo13@users.noreply.github.com> Date: Mon, 18 Jul 2022 14:17:49 -0400 Subject: [PATCH 16/16] add t5 model that can function as both encodery-only or encoder-decoder model (#1829) --- torchtext/prototype/t5/model.py | 200 ++++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 torchtext/prototype/t5/model.py diff --git a/torchtext/prototype/t5/model.py b/torchtext/prototype/t5/model.py new file mode 100644 index 0000000000..cd9ebf2367 --- /dev/null +++ b/torchtext/prototype/t5/model.py @@ -0,0 +1,200 @@ +from typing import Dict, Optional, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch import Tensor + +from .modules import T5Stack, T5LayerNorm + + +# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L1269 +class T5Model(nn.Module): + r"""A T5 model. User is able to modify the attributes as needed. The architecture + is based on the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer". + Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, + Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research. + Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html + Args: + encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (required) + d_model: Number of expected features in the encoder/decoder inputs (default=768). + nhead: Number of heads in the multiheadattention models (default=12). + num_encoder_layers: Number of encoder layers in the encoder (default=12). + num_decoder_layers: Number of decoder layers in the decoder (default=12). + dim_feedforward: Dimension of the feedforward network model (default=3072). + dropout: Dropout value (default=0.1). + activation: Activation function of encoder/decoder intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + layer_norm_eps: The eps value in layer normalization components (default=1e-6). + relative_attention_num_buckets: Number of relative position buckets (default: 32) + relative_attention_max_distance: Maximum threshold on the relative distance used to + allocate buckets. Anything larger gets placed in the same bucket (default: 128) + padding_idx: Index assigned to padding token in vocabulary (default: 0) + max_seq_len: Maximum sequence length (default: 512) + vocab_size: Size of vocabulary (default: 32128) + Examples:: + >>> t5_model = T5Model(encoder_only=False) + >>> src = torch.rand((32, 10, 512)) + >>> tgt = torch.rand((32, 20, 512)) + >>> out = t5_model(src, tgt) + """ + + def __init__( + self, + encoder_only: bool, + d_model: int = 768, + nhead: int = 12, + num_encoder_layers: int = 12, + num_decoder_layers: int = 12, + dim_feedforward: int = 3072, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = "relu", + layer_norm_eps: float = 1e-6, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + padding_idx: int = 0, + max_seq_len: int = 512, + vocab_size: int = 32128, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.encoder_only = encoder_only + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout = dropout + self.activation = activation + self.layer_norm_eps = layer_norm_eps + self.nhead = nhead + self.num_encoder_layers = num_encoder_layers + self.num_decoder_layers = num_decoder_layers + self.relative_attention_num_buckets = relative_attention_num_buckets + self.realtive_attention_max_distance = relative_attention_max_distance + self.padding_idx = padding_idx + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.device = device + self.dtype = dtype + + self.token_embeddings = nn.Embedding(vocab_size, d_model, padding_idx) + self.encoder = T5Stack( + is_decoder=False, + d_model=d_model, + nhead=nhead, + num_layers=num_encoder_layers, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + device=device, + dtype=dtype, + ) + self.norm1 = T5LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + if not encoder_only: + self.decoder = T5Stack( + is_decoder=True, + d_model=d_model, + nhead=nhead, + num_layers=num_decoder_layers, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + device=device, + dtype=dtype, + ) + self.norm2 = T5LayerNorm(d_model) + self.dropout3 = nn.Dropout(dropout) + self.dropout4 = nn.Dropout(dropout) + + def forward( + self, + encoder_tokens: Tensor, + decoder_tokens: Tensor = None, + encoder_mask: Optional[Tensor] = None, + decoder_mask: Optional[Tensor] = None, + ) -> Dict[str, Union[Tensor, Tuple[Tensor]]]: + r"""Pass the inputs (and mask) through the decoder layer in turn. + Args: + encoder_tokens: Tokenized input sequence to the encoder. + Must be batch first with shape (B, Ne) where B is the batch size and Ne is the + encoder input sequence length. (required). + decoder_tokens: Tokenized input sequence to the decoder. + Must be batch first with shape (B, Nd) where B is the batch size and Nd is the + decoder input sequence length. (required). + encoder_mask: Self-attention mask for the encoder input sequence. + Must have shape (Ne, Ne) (optional). + decoder_mask: Self-attention mask for the decoder input sequence. + Must have shape (Nd, Nd) (optional). + Returns: + encoder_output: Output Tensor from the final layer of the encoder + encoder_hidden_states: Tuple of output Tensors from each layer of the encoder + encoder_position_bias: Tensor of relative attention bias computed for input sequence to encoder + encoder_sa_scores: Tuple of self-attention scores computed at each layer of the encoder + decoder_output: Output Tensor from the final layer of the decoder + decoder_hidden_states: Tuple of output Tensors from each layer of the decoder + decoder_position_bias: Tensor of relative attention bias computed for input sequence to decoder + encoder_sa_scores: Tuple of self-attention scores computed at each layer of the decoder + encoder_ca_scores: Tuple of cross-attention scores computed at each layer of the decoder + """ + encoder_padding_mask = encoder_tokens.eq(self.padding_idx) + encoder_embeddings = self.dropout1(self.token_embeddings(encoder_tokens)) + encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa, _ = self.encoder( + encoder_embeddings, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask + ) + + encoder_output = self.norm1(encoder_output) + encoder_output = self.dropout2(encoder_output) + encoder_hidden_states = encoder_hidden_states + (encoder_output,) + + if not self.encoder_only: + assert decoder_tokens is not None + if decoder_mask is None: + tgt_len = decoder_tokens.shape[1] + decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool() + + decoder_padding_mask = decoder_tokens.eq(self.padding_idx) + # T5 implemention uses padding idx to start sequence. Want to ignore this when masking + decoder_padding_mask[:, 0] = False + + decoder_embeddings = self.dropout3(self.token_embeddings(decoder_tokens)) + decoder_output, decoder_hidden_states, decoder_position_bias, decoder_sa, decoder_ca = self.decoder( + decoder_embeddings, + memory=encoder_output, + tgt_mask=decoder_mask, + memory_mask=encoder_mask, + tgt_key_padding_mask=decoder_padding_mask, + memory_key_padding_mask=encoder_padding_mask, + ) + + decoder_output = self.norm2(decoder_output) + decoder_output = self.dropout4(decoder_output) + decoder_hidden_states = decoder_hidden_states + (decoder_output,) + + t5_output = { + "encoder_output": encoder_output, + "encoder_hidden_states": encoder_hidden_states, + "encoder_position_bias": encoder_position_bias, + "encoder_sa_scores": encoder_sa, + "decoder_output": decoder_output, + "decoder_hidden_states": decoder_hidden_states, + "decoder_position_bias": decoder_position_bias, + "decoder_sa_scores": decoder_sa, + "decoder_ca_scores": decoder_ca, + } + else: + t5_output = { + "encoder_output": encoder_output, + "encoder_hidden_states": encoder_hidden_states, + "encoder_position_bias": encoder_position_bias, + "encoder_sa_scores": encoder_sa, + } + + return t5_output