Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 22f537b

Browse files
authored
computing relative attention bias (#1831)
1 parent 36bf7d5 commit 22f537b

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

torchtext/prototype/t5/modules.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,33 @@ def __init__(
5454
def forward():
5555
pass
5656

57+
# NOTE: modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421
58+
def _compute_bias(
59+
self,
60+
query_length: int,
61+
key_length: int,
62+
relative_attention_bias: Tensor,
63+
relative_attention_num_buckets: int = 32,
64+
relative_attention_max_distance: int = 128,
65+
bidirectional: bool = True,
66+
device=None,
67+
) -> Tensor:
68+
"""Compute binned relative position bias"""
69+
if device is None:
70+
device = relative_attention_bias.weight.device
71+
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
72+
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
73+
relative_position = memory_position - context_position # shape (query_length, key_length)
74+
relative_position_bucket = self._relative_position_bucket(
75+
relative_position, # shape (query_length, key_length)
76+
bidirectional=bidirectional,
77+
num_buckets=relative_attention_num_buckets,
78+
max_distance=relative_attention_max_distance,
79+
)
80+
values = relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
81+
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
82+
return values
83+
5784
# NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L374
5885
def _relative_position_bucket(
5986
self, relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128

0 commit comments

Comments
 (0)