Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions torchtext/prototype/t5/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,72 @@
# */

import math
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor


def _t5_scaled_dot_product_attention(
q: Tensor,
k: Tensor,
v: Tensor,
position_bias: Tensor,
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:
r"""
Computes scaled dot product attention on query, key and value tensors, using
an optional attention mask if passed, and applying dropout if a probability
greater than 0.0 is specified.
Returns a tensor pair containing attended values and attention weights.
Args:
q, k, v: query, key and value tensors. See Shape section for shape details.
attn_mask: optional tensor containing mask values to be added to calculated
attention. May be 2D or 3D; see Shape section for details.
dropout_p: dropout probability. If greater than 0.0, dropout is applied.
position_bias: position bias used to incorporate realtive attention bias in attention scors
Shape:
- q: :math:`(B, Nt, E)` where B is (batch size*num_heads), Nt is the target sequence length,
and E is embedding dimension.
- key: :math:`(B, Ns, E)` where B is (batch size*num_heads), Ns is the source sequence length,
and E is embedding dimension.
- value: :math:`(B, Ns, E)` where B is (batch size*num_heads), Ns is the source sequence length,
and E is embedding dimension.
- attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
shape :math:`(Nt, Ns)`.
- position_bias: :math:`(1, num_heads, Nt, Ns)`
- Output: attention values have shape :math:`(B, Nt, E)`; attention weights
have shape :math:`(B, Nt, Ns)`
"""
B, Nt, E = q.shape
# NOTE: HF implementation does not perform this normalization. For the sake of matching test results, we have commented it out
# q = q / math.sqrt(E)

n_heads, tgt_len, src_len = position_bias.size()[1:]
assert B % n_heads == 0
assert tgt_len == Nt

# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
if attn_mask is not None:
attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
else:
attn = torch.bmm(q, k.transpose(-2, -1))

# NOTE: modification from torch.nn.functional._scaled_dot_product_attention to incorporate relative attention bias
position_bias = position_bias.repeat(B // n_heads, 1, 1, 1)
position_bias = position_bias.view(B, tgt_len, src_len)
attn += position_bias

attn = F.softmax(attn, dim=-1)
if dropout_p > 0.0:
attn = F.dropout(attn, p=dropout_p)
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
output = torch.bmm(attn, v)
return output, attn


# NOTE: modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
def _compute_bias(
query_length: int,
Expand Down