diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index a3db63b9ae..5ba410a593 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -14,11 +14,72 @@ # */ import math +from typing import Optional, Tuple import torch +import torch.nn.functional as F from torch import Tensor +def _t5_scaled_dot_product_attention( + q: Tensor, + k: Tensor, + v: Tensor, + position_bias: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, +) -> Tuple[Tensor, Tensor]: + r""" + Computes scaled dot product attention on query, key and value tensors, using + an optional attention mask if passed, and applying dropout if a probability + greater than 0.0 is specified. + Returns a tensor pair containing attended values and attention weights. + Args: + q, k, v: query, key and value tensors. See Shape section for shape details. + attn_mask: optional tensor containing mask values to be added to calculated + attention. May be 2D or 3D; see Shape section for details. + dropout_p: dropout probability. If greater than 0.0, dropout is applied. + position_bias: position bias used to incorporate realtive attention bias in attention scors + Shape: + - q: :math:`(B, Nt, E)` where B is (batch size*num_heads), Nt is the target sequence length, + and E is embedding dimension. + - key: :math:`(B, Ns, E)` where B is (batch size*num_heads), Ns is the source sequence length, + and E is embedding dimension. + - value: :math:`(B, Ns, E)` where B is (batch size*num_heads), Ns is the source sequence length, + and E is embedding dimension. + - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of + shape :math:`(Nt, Ns)`. + - position_bias: :math:`(1, num_heads, Nt, Ns)` + - Output: attention values have shape :math:`(B, Nt, E)`; attention weights + have shape :math:`(B, Nt, Ns)` + """ + B, Nt, E = q.shape + # NOTE: HF implementation does not perform this normalization. For the sake of matching test results, we have commented it out + # q = q / math.sqrt(E) + + n_heads, tgt_len, src_len = position_bias.size()[1:] + assert B % n_heads == 0 + assert tgt_len == Nt + + # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) + if attn_mask is not None: + attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1)) + else: + attn = torch.bmm(q, k.transpose(-2, -1)) + + # NOTE: modification from torch.nn.functional._scaled_dot_product_attention to incorporate relative attention bias + position_bias = position_bias.repeat(B // n_heads, 1, 1, 1) + position_bias = position_bias.view(B, tgt_len, src_len) + attn += position_bias + + attn = F.softmax(attn, dim=-1) + if dropout_p > 0.0: + attn = F.dropout(attn, p=dropout_p) + # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) + output = torch.bmm(attn, v) + return output, attn + + # NOTE: modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py def _compute_bias( query_length: int,