99# See the License for the specific language governing permissions and
1010# limitations under the License.
1111
12- # Original code is taken from
13- # https://github.com/huggingface/transformers/blob/main /src/transformers/models/t5/modeling_t5.py
12+ # Parts of code are originally from
13+ # https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e /src/transformers/models/t5/modeling_t5.py
1414# */
1515
1616import math
@@ -28,32 +28,22 @@ def __init__(
2828 is_decoder = False ,
2929 dropout = 0.0 ,
3030 bias = False ,
31- add_bias_kv = False ,
32- add_zero_attn = False ,
3331 kdim = None ,
3432 vdim = None ,
35- batch_first = False ,
3633 device = None ,
3734 dtype = None ,
3835 ) -> None :
3936 r"""
4037 Args:
41- embed_dim: total dimension of the model.
42- num_heads: parallel attention heads.
43- is_decoder: whether or not multihead attention is being performed on a decoder layer. Default: ``False``
44- dropout: probability of an element to be zeroed. Default: 0.0
45- bias: If specified, adds bias to input / output projection layers. Default: ``False``.
46- add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
47- add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
48- Default: ``False``.
49- kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
50- vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
51- batch_first: If ``True``, then the input and output tensors are provided
52- as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
38+ embed_dim: Total dimension of the model.
39+ num_heads: Parallel attention heads.
40+ is_decoder: Whether or not multihead attention is being performed on a decoder layer. Default: `False`
41+ dropout: Probability of an element to be zeroed. Default: 0.0
42+ bias: If specified, adds bias to input / output projection layers. Default: `False`.
43+ kdim: Total number of features for keys. Default: `None` (uses `kdim=embed_dim`).
44+ vdim: Total number of features for values. Default: `None` (uses `vdim=embed_dim`).
5345 """
54- super ().__init__ (
55- embed_dim , num_heads , dropout , bias , add_bias_kv , add_zero_attn , kdim , vdim , batch_first , device , dtype
56- )
46+ super ().__init__ (embed_dim , num_heads , dropout , bias , False , False , kdim , vdim , True , device , dtype )
5747 factory_kwargs = {"device" : device , "dtype" : dtype }
5848 self .is_decoder = is_decoder
5949 self .q_proj_weight = nn .Parameter (torch .empty ((embed_dim , embed_dim ), ** factory_kwargs ))
@@ -64,7 +54,7 @@ def __init__(
6454 def forward ():
6555 pass
6656
67- # NOTE: modified from https://github.com/huggingface/transformers/blob/main /src/transformers/models/t5/modeling_t5.py
57+ # NOTE: modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e /src/transformers/models/t5/modeling_t5.py#L421
6858 def _compute_bias (
6959 self ,
7060 query_length : int ,
@@ -91,7 +81,7 @@ def _compute_bias(
9181 values = values .permute ([2 , 0 , 1 ]).unsqueeze (0 ) # shape (1, num_heads, query_length, key_length)
9282 return values
9383
94- # NOTE: taken from https://github.com/huggingface/transformers/blob/main /src/transformers/models/t5/modeling_t5.py
84+ # NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e /src/transformers/models/t5/modeling_t5.py#L374
9585 def _relative_position_bucket (
9686 self , relative_position : Tensor , bidirectional : bool = True , num_buckets : int = 32 , max_distance : int = 128
9787 ):
@@ -119,9 +109,9 @@ def _relative_position_bucket(
119109 relative_position = torch .abs (relative_position )
120110 else :
121111 relative_position = - torch .min (relative_position , torch .zeros_like (relative_position ))
122- # now relative_position is in the range [0, inf)
112+ # Ensure relative_position is in the range [0, inf)
123113
124- # half of the buckets are for exact increments in positions
114+ # Half of the buckets are for exact increments in positions
125115 max_exact = num_buckets // 2
126116 is_small = relative_position < max_exact
127117
0 commit comments