Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions torchtext/prototype/t5/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# */

import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


Expand Down Expand Up @@ -54,6 +56,58 @@ def __init__(
def forward():
pass

# NOTE: Modified from https://github.com/pytorch/pytorch/blob/5953fd9133c0bdcc0158acf1472fac403bc5f636/torch/nn/functional.py#L4814
def _t5_dot_product_attention(
self,
q: Tensor,
k: Tensor,
v: Tensor,
position_bias: Tensor,
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:
r"""
Computes scaled dot product attention on query, key and value tensors, using
an optional attention mask if passed, and applying dropout if a probability
greater than 0.0 is specified.
Returns a tensor pair containing attended values and attention weights.
Args:
q, k, v: Query, key and value tensors. See Shape section for shape details.
attn_mask: Optional tensor containing mask values to be added to calculated
attention. May be 2D or 3D; see Shape section for details.
dropout_p: Dropout probability. If greater than 0.0, dropout is applied.
position_bias: Position bias used to incorporate realtive attention bias in attention scors
Shape:
- q: :math:`(B, H, Nt, E)` where B is the batch size, H is the number of heads, Nt is the target sequence length,
and E is the head dimension.
- key: :math:`(B, H, Ns, E)` where B is the batch size, H is the number of heads, Ns is the source sequence length,
and E is the head dimension.
- value: :math:`(B, H, Ns, E)` where B is the batch size, H is the number of heads, Ns is the source sequence length,
and E is the head dimension.
- attn_mask: a 4D tensor of shape :math:`(B, H, Nt, Ns)`
- position_bias: :math:`(1, H, Nt, Ns)`
- Output: attention values have shape :math:`(B, Nt, H*E)`; attention weights
have shape :math:`(B, H, Nt, Ns)`
"""
B, H, _, E = q.shape
# NOTE: HF implementation does not perform this normalization. For the sake of matching test results, we have commented it out
# q = q / math.sqrt(E)

attn = torch.matmul(q, k.transpose(3, 2))

# NOTE: modification from torch.nn.functional._scaled_dot_product_attention to incorporate relative attention bias
position_bias = position_bias.repeat(B, 1, 1, 1)
if attn_mask is not None:
position_bias += attn_mask
attn += position_bias

attn = F.softmax(attn, dim=-1)
if dropout_p > 0.0:
attn = F.dropout(attn, p=dropout_p)
output = torch.matmul(attn, v)
output = output.transpose(1, 2).contiguous().view(B, -1, H * E)
return output, attn

# NOTE: modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421
def _compute_bias(
self,
Expand Down