Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 20180ff

Browse files
authored
create T5MultiheadAttention module (#1825)
1 parent c60704a commit 20180ff

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

torchtext/prototype/t5/modules.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# /* Portions Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
# Parts of code are originally from
13+
# https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py
14+
# */
15+
16+
import torch
17+
import torch.nn as nn
18+
19+
20+
class T5MultiheadAttention(nn.MultiheadAttention):
21+
def __init__(
22+
self,
23+
embed_dim,
24+
num_heads,
25+
is_decoder=False,
26+
dropout=0.0,
27+
bias=False,
28+
kdim=None,
29+
vdim=None,
30+
device=None,
31+
dtype=None,
32+
) -> None:
33+
r"""
34+
Args:
35+
embed_dim: Total dimension of the model.
36+
num_heads: Parallel attention heads.
37+
is_decoder: Whether or not multihead attention is being performed on a decoder layer. Default: `False`
38+
dropout: Probability of an element to be zeroed. Default: 0.0
39+
bias: If specified, adds bias to input / output projection layers. Default: `False`.
40+
kdim: Total number of features for keys. Default: `None` (uses `kdim=embed_dim`).
41+
vdim: Total number of features for values. Default: `None` (uses `vdim=embed_dim`).
42+
"""
43+
super().__init__(embed_dim, num_heads, dropout, bias, False, False, kdim, vdim, True, device, dtype)
44+
factory_kwargs = {"device": device, "dtype": dtype}
45+
self.is_decoder = is_decoder
46+
self.q_proj_weight = nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
47+
self.k_proj_weight = nn.Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
48+
self.v_proj_weight = nn.Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
49+
self.register_parameter("in_proj_weight", None)
50+
51+
def forward():
52+
pass

0 commit comments

Comments
 (0)