diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py new file mode 100644 index 0000000000..ef10ee984a --- /dev/null +++ b/torchtext/prototype/t5/modules.py @@ -0,0 +1,102 @@ +# /* Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Parts of code are originally from +# https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py +# */ + +import math + +import torch +import torch.nn as nn +from torch import Tensor + + +class T5MultiheadAttention(nn.MultiheadAttention): + def __init__( + self, + embed_dim, + num_heads, + is_decoder=False, + dropout=0.0, + bias=False, + kdim=None, + vdim=None, + device=None, + dtype=None, + ) -> None: + r""" + Args: + embed_dim: Total dimension of the model. + num_heads: Parallel attention heads. + is_decoder: Whether or not multihead attention is being performed on a decoder layer. Default: `False` + dropout: Probability of an element to be zeroed. Default: 0.0 + bias: If specified, adds bias to input / output projection layers. Default: `False`. + kdim: Total number of features for keys. Default: `None` (uses `kdim=embed_dim`). + vdim: Total number of features for values. Default: `None` (uses `vdim=embed_dim`). + """ + super().__init__(embed_dim, num_heads, dropout, bias, False, False, kdim, vdim, True, device, dtype) + factory_kwargs = {"device": device, "dtype": dtype} + self.is_decoder = is_decoder + self.q_proj_weight = nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) + self.k_proj_weight = nn.Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) + self.v_proj_weight = nn.Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) + self.register_parameter("in_proj_weight", None) + + def forward(): + pass + + # NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L374 + def _relative_position_bucket( + self, relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128 + ) -> Tensor: + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # Ensure relative_position is in the range [0, inf) + + # Half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets