@@ -185,19 +185,24 @@ def forward(self, query, key, value, attn_mask=None):
185185 where L is the target length, S is the source length, H is the number
186186 of attention heads, N is the batch size, and E is the embedding dimension.
187187 """
188- tgt_len , batch_heads , head_dim = query .size ()
189- assert query .size (1 ) == key .size (1 ) == value .size (1 ), "Dimension 0 of query, key, value must be equal."
190- assert batch_heads % self .num_heads == 0 , "Dimension 0 of query, key, value must be divisible by num_heads"
188+ tgt_len , head_dim = query .size (- 3 ), query .size (- 1 )
189+ assert query .size (- 1 ) == key .size (- 1 ) == value .size (- 1 ), "The feature dim of query, key, value must be equal."
191190 assert key .size () == value .size (), "Shape of key, value must match"
192- assert query . size ( - 1 ) == key .size (- 1 ), "The head dimension of query must be equal to that of key"
193- src_len = key .size (0 )
191+ src_len = key .size (- 3 )
192+ batch_heads = max ( query . size ( - 2 ), key .size (- 2 ) )
194193
195194 # Scale query
196- query , key , value = query .transpose (0 , 1 ), key .transpose (0 , 1 ), value .transpose (0 , 1 )
195+ query , key , value = query .transpose (- 2 , - 3 ), key .transpose (- 2 , - 3 ), value .transpose (- 2 , - 3 )
197196 query = query * (float (head_dim ) ** - 0.5 )
198197 if attn_mask is not None :
199- if list (attn_mask .size ()) != [batch_heads , tgt_len , src_len ]:
200- raise RuntimeError ('The size of the 3D attn_mask is not correct.' )
198+ if attn_mask .dim () != 3 :
199+ raise RuntimeError ('attn_mask must be a 3D tensor.' )
200+ print (attn_mask .size (- 1 ), src_len )
201+ print (attn_mask .size (- 2 ), tgt_len )
202+ print (attn_mask .size (- 3 ), batch_heads )
203+ if (attn_mask .size (- 1 ) == src_len ) and (attn_mask .size (- 2 ) == tgt_len ) and \
204+ (attn_mask .size (- 3 ) == 1 or attn_mask .size (- 3 ) == batch_heads ):
205+ raise RuntimeError ('The size of the attn_mask is not correct.' )
201206 if attn_mask .dtype != torch .bool :
202207 raise RuntimeError ('Only bool tensor is supported for attn_mask' )
203208
@@ -211,4 +216,4 @@ def forward(self, query, key, value, attn_mask=None):
211216 attn_output_weights = torch .nn .functional .softmax (attn_output_weights , dim = - 1 )
212217 attn_output_weights = torch .nn .functional .dropout (attn_output_weights , p = self .dropout , training = self .training )
213218 attn_output = torch .matmul (attn_output_weights , value )
214- return attn_output .transpose (0 , 1 ), attn_output_weights
219+ return attn_output .transpose (- 2 , - 3 ), attn_output_weights
0 commit comments