|
| 1 | +# /* Portions Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +# Parts of code are originally from |
| 13 | +# https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py |
| 14 | +# */ |
| 15 | + |
| 16 | +import torch |
| 17 | +import torch.nn as nn |
| 18 | + |
| 19 | + |
| 20 | +class T5MultiheadAttention(nn.MultiheadAttention): |
| 21 | + def __init__( |
| 22 | + self, |
| 23 | + embed_dim, |
| 24 | + num_heads, |
| 25 | + is_decoder=False, |
| 26 | + dropout=0.0, |
| 27 | + bias=False, |
| 28 | + kdim=None, |
| 29 | + vdim=None, |
| 30 | + device=None, |
| 31 | + dtype=None, |
| 32 | + ) -> None: |
| 33 | + r""" |
| 34 | + Args: |
| 35 | + embed_dim: Total dimension of the model. |
| 36 | + num_heads: Parallel attention heads. |
| 37 | + is_decoder: Whether or not multihead attention is being performed on a decoder layer. Default: `False` |
| 38 | + dropout: Probability of an element to be zeroed. Default: 0.0 |
| 39 | + bias: If specified, adds bias to input / output projection layers. Default: `False`. |
| 40 | + kdim: Total number of features for keys. Default: `None` (uses `kdim=embed_dim`). |
| 41 | + vdim: Total number of features for values. Default: `None` (uses `vdim=embed_dim`). |
| 42 | + """ |
| 43 | + super().__init__(embed_dim, num_heads, dropout, bias, False, False, kdim, vdim, True, device, dtype) |
| 44 | + factory_kwargs = {"device": device, "dtype": dtype} |
| 45 | + self.is_decoder = is_decoder |
| 46 | + self.q_proj_weight = nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) |
| 47 | + self.k_proj_weight = nn.Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) |
| 48 | + self.v_proj_weight = nn.Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) |
| 49 | + self.register_parameter("in_proj_weight", None) |
| 50 | + |
| 51 | + def forward(): |
| 52 | + pass |
0 commit comments