diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index ab0680431a..8e31e65ffc 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -14,9 +14,11 @@ # */ import math +from typing import Optional, Tuple import torch import torch.nn as nn +import torch.nn.functional as F from torch import Tensor @@ -54,6 +56,58 @@ def __init__( def forward(): pass + # NOTE: Modified from https://github.com/pytorch/pytorch/blob/5953fd9133c0bdcc0158acf1472fac403bc5f636/torch/nn/functional.py#L4814 + def _t5_dot_product_attention( + self, + q: Tensor, + k: Tensor, + v: Tensor, + position_bias: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + ) -> Tuple[Tensor, Tensor]: + r""" + Computes scaled dot product attention on query, key and value tensors, using + an optional attention mask if passed, and applying dropout if a probability + greater than 0.0 is specified. + Returns a tensor pair containing attended values and attention weights. + Args: + q, k, v: Query, key and value tensors. See Shape section for shape details. + attn_mask: Optional tensor containing mask values to be added to calculated + attention. May be 2D or 3D; see Shape section for details. + dropout_p: Dropout probability. If greater than 0.0, dropout is applied. + position_bias: Position bias used to incorporate realtive attention bias in attention scors + Shape: + - q: :math:`(B, H, Nt, E)` where B is the batch size, H is the number of heads, Nt is the target sequence length, + and E is the head dimension. + - key: :math:`(B, H, Ns, E)` where B is the batch size, H is the number of heads, Ns is the source sequence length, + and E is the head dimension. + - value: :math:`(B, H, Ns, E)` where B is the batch size, H is the number of heads, Ns is the source sequence length, + and E is the head dimension. + - attn_mask: a 4D tensor of shape :math:`(B, H, Nt, Ns)` + - position_bias: :math:`(1, H, Nt, Ns)` + - Output: attention values have shape :math:`(B, Nt, H*E)`; attention weights + have shape :math:`(B, H, Nt, Ns)` + """ + B, H, _, E = q.shape + # NOTE: HF implementation does not perform this normalization. For the sake of matching test results, we have commented it out + # q = q / math.sqrt(E) + + attn = torch.matmul(q, k.transpose(3, 2)) + + # NOTE: modification from torch.nn.functional._scaled_dot_product_attention to incorporate relative attention bias + position_bias = position_bias.repeat(B, 1, 1, 1) + if attn_mask is not None: + position_bias += attn_mask + attn += position_bias + + attn = F.softmax(attn, dim=-1) + if dropout_p > 0.0: + attn = F.dropout(attn, p=dropout_p) + output = torch.matmul(attn, v) + output = output.transpose(1, 2).contiguous().view(B, -1, H * E) + return output, attn + # NOTE: modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421 def _compute_bias( self,