Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit c8c6dbb

Browse files
committed
Update base for Update on "computing attention scores using relative attention bias"
WIP PR to workshop implementation: #1812 [ghstack-poisoned]
1 parent 448e615 commit c8c6dbb

File tree

1 file changed

+14
-24
lines changed

1 file changed

+14
-24
lines changed

torchtext/prototype/t5/modules.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
# Original code is taken from
13-
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
12+
# Parts of code are originally from
13+
# https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py
1414
# */
1515

1616
import math
@@ -28,32 +28,22 @@ def __init__(
2828
is_decoder=False,
2929
dropout=0.0,
3030
bias=False,
31-
add_bias_kv=False,
32-
add_zero_attn=False,
3331
kdim=None,
3432
vdim=None,
35-
batch_first=False,
3633
device=None,
3734
dtype=None,
3835
) -> None:
3936
r"""
4037
Args:
41-
embed_dim: total dimension of the model.
42-
num_heads: parallel attention heads.
43-
is_decoder: whether or not multihead attention is being performed on a decoder layer. Default: ``False``
44-
dropout: probability of an element to be zeroed. Default: 0.0
45-
bias: If specified, adds bias to input / output projection layers. Default: ``False``.
46-
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
47-
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
48-
Default: ``False``.
49-
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
50-
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
51-
batch_first: If ``True``, then the input and output tensors are provided
52-
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
38+
embed_dim: Total dimension of the model.
39+
num_heads: Parallel attention heads.
40+
is_decoder: Whether or not multihead attention is being performed on a decoder layer. Default: `False`
41+
dropout: Probability of an element to be zeroed. Default: 0.0
42+
bias: If specified, adds bias to input / output projection layers. Default: `False`.
43+
kdim: Total number of features for keys. Default: `None` (uses `kdim=embed_dim`).
44+
vdim: Total number of features for values. Default: `None` (uses `vdim=embed_dim`).
5345
"""
54-
super().__init__(
55-
embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype
56-
)
46+
super().__init__(embed_dim, num_heads, dropout, bias, False, False, kdim, vdim, True, device, dtype)
5747
factory_kwargs = {"device": device, "dtype": dtype}
5848
self.is_decoder = is_decoder
5949
self.q_proj_weight = nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
@@ -64,7 +54,7 @@ def __init__(
6454
def forward():
6555
pass
6656

67-
# NOTE: modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
57+
# NOTE: modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421
6858
def _compute_bias(
6959
self,
7060
query_length: int,
@@ -91,7 +81,7 @@ def _compute_bias(
9181
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
9282
return values
9383

94-
# NOTE: taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
84+
# NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L374
9585
def _relative_position_bucket(
9686
self, relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128
9787
):
@@ -119,9 +109,9 @@ def _relative_position_bucket(
119109
relative_position = torch.abs(relative_position)
120110
else:
121111
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
122-
# now relative_position is in the range [0, inf)
112+
# Ensure relative_position is in the range [0, inf)
123113

124-
# half of the buckets are for exact increments in positions
114+
# Half of the buckets are for exact increments in positions
125115
max_exact = num_buckets // 2
126116
is_small = relative_position < max_exact
127117

0 commit comments

Comments
 (0)