@@ -296,23 +296,31 @@ def forward(
296
296
cu_seqlens : torch .Tensor ,
297
297
rotary_pos_emb : Optional [torch .Tensor ] = None ,
298
298
position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] = None ,
299
- ** kwargs : Unpack [FlashAttentionKwargs ],
299
+ attention_mask : Optional [torch .Tensor ] = None ,
300
+ ** kwargs ,
300
301
) -> torch .Tensor :
301
302
seq_length = hidden_states .shape [0 ]
302
303
query_states , key_states , value_states = (
303
304
self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
304
305
)
305
-
306
- cos , sin = position_embeddings
306
+ if position_embeddings is None :
307
+ logger .warning_once (
308
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
309
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
310
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
311
+ "removed and `position_embeddings` will be mandatory."
312
+ )
313
+ emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
314
+ cos = emb .cos ()
315
+ sin = emb .sin ()
316
+ else :
317
+ cos , sin = position_embeddings
307
318
query_states , key_states = apply_rotary_pos_emb_vision (query_states , key_states , cos , sin )
308
319
309
320
query_states = query_states .transpose (0 , 1 ).unsqueeze (0 )
310
321
key_states = key_states .transpose (0 , 1 ).unsqueeze (0 )
311
322
value_states = value_states .transpose (0 , 1 ).unsqueeze (0 )
312
-
313
- attention_mask = torch .zeros ([1 , 1 , seq_length , seq_length ], device = query_states .device , dtype = torch .bool )
314
- for i in range (1 , len (cu_seqlens )):
315
- attention_mask [..., cu_seqlens [i - 1 ] : cu_seqlens [i ], cu_seqlens [i - 1 ] : cu_seqlens [i ]] = True
323
+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
316
324
317
325
attention_interface : Callable = eager_attention_forward
318
326
if self .config ._attn_implementation != "eager" :
@@ -323,13 +331,17 @@ def forward(
323
331
query_states ,
324
332
key_states ,
325
333
value_states ,
326
- attention_mask ,
334
+ attention_mask = attention_mask ,
327
335
dropout = 0.0 if not self .training else self .attention_dropout ,
328
336
scaling = self .scaling ,
329
- is_causal = self .is_causal ,
337
+ cu_seq_lens_q = cu_seqlens , # pass cu seq lens for FA2
338
+ cu_seq_lens_k = cu_seqlens ,
339
+ max_length_q = max_seqlen ,
340
+ max_length_k = max_seqlen ,
341
+ is_causal = False ,
330
342
** kwargs ,
331
343
)
332
- attn_output = attn_output . squeeze ( 0 )
344
+
333
345
attn_output = attn_output .reshape (seq_length , - 1 ).contiguous ()
334
346
attn_output = self .proj (attn_output )
335
347
return attn_output
0 commit comments