Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions torchtext/prototype/models/t5/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(
else:
self.relative_attention_bias = None

self.device = device

def forward(
self,
query: Tensor,
Expand Down Expand Up @@ -258,7 +260,6 @@ def _t5_multi_head_attention_forward(
tgt_len,
src_len,
bidirectional=(not self.is_decoder),
device=k.device,
)

# Calculate attention and out projection
Expand Down Expand Up @@ -402,20 +403,17 @@ def _t5_dot_product_attention(
output = output.transpose(1, 2).contiguous().view(B, -1, H * E)
return output, attn

# NOTE: modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421
# NOTE: Modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421
def _compute_bias(
self,
query_length: int,
key_length: int,
bidirectional: bool = True,
device: Optional[torch.device] = None,
) -> Tensor:
"""Compute binned relative position bias"""
assert self.relative_attention_bias is not None
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
context_position = torch.arange(query_length, dtype=torch.long, device=self.device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=self.device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
Expand All @@ -427,7 +425,7 @@ def _compute_bias(
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values

# NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L374
# NOTE: Modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L374
def _relative_position_bucket(
self, relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128
) -> Tensor:
Expand All @@ -448,7 +446,7 @@ def _relative_position_bucket(
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long)
relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long, device=self.device)
if bidirectional:
num_buckets = num_buckets // 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
Expand Down