Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
50 changes: 50 additions & 0 deletions torchtext/prototype/t5/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py
# */

import math

import torch
import torch.nn as nn
from torch import Tensor


class T5MultiheadAttention(nn.MultiheadAttention):
Expand Down Expand Up @@ -50,3 +53,50 @@ def __init__(

def forward():
pass

# NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L374
def _relative_position_bucket(
self, relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128
) -> Tensor:
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# Ensure relative_position is in the range [0, inf)

# Half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact

# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)

relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets