1515
1616import torch
1717import torch .nn .functional as F
18- from torch import nn
18+ from torch import nn , FloatTensor
1919
2020from ..utils import deprecate , logging
2121from ..utils .import_utils import is_xformers_available
@@ -272,15 +272,22 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
272272 if attention_mask is None :
273273 return attention_mask
274274
275- if attention_mask .shape [- 1 ] != target_length :
275+ current_length : int = attention_mask .shape [- 1 ]
276+ if current_length > target_length :
277+ # we *could* trim the mask with:
278+ # attention_mask = attention_mask[:,:target_length]
279+ # but this is weird enough that it's more likely to be a mistake than a shortcut
280+ raise ValueError (f"mask's length ({ current_length } ) exceeds the sequence length ({ target_length } )." )
281+ elif current_length < target_length :
276282 if attention_mask .device .type == "mps" :
277283 # HACK: MPS: Does not support padding by greater than dimension of input tensor.
278284 # Instead, we can manually construct the padding tensor.
279285 padding_shape = (attention_mask .shape [0 ], attention_mask .shape [1 ], target_length )
280286 padding = torch .zeros (padding_shape , dtype = attention_mask .dtype , device = attention_mask .device )
281287 attention_mask = torch .cat ([attention_mask , padding ], dim = 2 )
282288 else :
283- attention_mask = F .pad (attention_mask , (0 , target_length ), value = 0.0 )
289+ remaining_length : int = target_length - current_length
290+ attention_mask = F .pad (attention_mask , (0 , remaining_length ), value = 0.0 )
284291
285292 if attention_mask .shape [0 ] < batch_size * head_size :
286293 attention_mask = attention_mask .repeat_interleave (head_size , dim = 0 )
@@ -291,19 +298,31 @@ class CrossAttnProcessor:
291298 def __call__ (
292299 self ,
293300 attn : CrossAttention ,
294- hidden_states ,
295- encoder_hidden_states = None ,
296- attention_mask = None ,
301+ hidden_states : FloatTensor ,
302+ encoder_hidden_states : Optional [FloatTensor ] = None ,
303+ attention_mask : Optional [FloatTensor ] = None ,
304+ encoder_attention_bias : Optional [FloatTensor ] = None ,
297305 ):
298- batch_size , sequence_length , _ = hidden_states .shape
299- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
300- query = attn .to_q (hidden_states )
301-
302306 if encoder_hidden_states is None :
303307 encoder_hidden_states = hidden_states
304- elif attn .cross_attention_norm :
305- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
308+ else :
309+ if encoder_attention_bias is not None :
310+ if attention_mask is not None :
311+ # it's not well-defined whether `attention_mask` should be passed to self-attention, cross-attention, neither* or both.
312+ # if two sources of bias (`attention_mask`, `encoder_attention_bias`) are provided: it's likely to be a mistake.
313+ raise ValueError (f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias." )
314+ attention_mask = encoder_attention_bias
315+ # make broadcastable over query tokens
316+ # TODO: consider aligning implementations such that AttnProcessor2_0 and CrossAttnProcessor do unsqueeze
317+ # in the same way/circumstances -- AttnProcessor2_0 does it for `attention_mask` **and** for `encoder_attention_bias`.
318+ attention_mask = attention_mask .unsqueeze (- 2 )
319+ if attn .cross_attention_norm :
320+ encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
321+
322+ batch_size , key_tokens , _ = encoder_hidden_states .shape
323+ attention_mask = attn .prepare_attention_mask (attention_mask , key_tokens , batch_size )
306324
325+ query = attn .to_q (hidden_states )
307326 key = attn .to_k (encoder_hidden_states )
308327 value = attn .to_v (encoder_hidden_states )
309328
@@ -315,10 +334,10 @@ def __call__(
315334 hidden_states = torch .bmm (attention_probs , value )
316335 hidden_states = attn .batch_to_head_dim (hidden_states )
317336
318- # linear proj
319- hidden_states = attn . to_out [ 0 ]( hidden_states )
320- # dropout
321- hidden_states = attn . to_out [ 1 ] (hidden_states )
337+ linear_proj , dropout = attn . to_out
338+
339+ hidden_states = linear_proj ( hidden_states )
340+ hidden_states = dropout (hidden_states )
322341
323342 return hidden_states
324343
@@ -471,25 +490,39 @@ def __init__(self):
471490 if not hasattr (F , "scaled_dot_product_attention" ):
472491 raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
473492
474- def __call__ (self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
475- batch_size , sequence_length , inner_dim = hidden_states .shape
493+ def __call__ (
494+ self ,
495+ attn : CrossAttention ,
496+ hidden_states : FloatTensor ,
497+ encoder_hidden_states : Optional [FloatTensor ] = None ,
498+ attention_mask : Optional [FloatTensor ] = None ,
499+ encoder_attention_bias : Optional [FloatTensor ] = None ,
500+ ):
501+ if encoder_hidden_states is None :
502+ encoder_hidden_states = hidden_states
503+ else :
504+ if encoder_attention_bias is not None :
505+ if attention_mask is not None :
506+ # it's not well-defined whether `attention_mask` should be passed to self-attention, cross-attention, neither* or both.
507+ # if two sources of bias (`attention_mask`, `encoder_attention_bias`) are provided: it's likely to be a mistake.
508+ raise ValueError (f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias." )
509+ attention_mask = encoder_attention_bias
510+ if attn .cross_attention_norm :
511+ encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
512+
513+ batch_size , key_tokens , _ = encoder_hidden_states .shape
476514
477515 if attention_mask is not None :
478- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
516+ attention_mask = attn .prepare_attention_mask (attention_mask , key_tokens , batch_size )
479517 # scaled_dot_product_attention expects attention_mask shape to be
480518 # (batch, heads, source_length, target_length)
481519 attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
482520
483521 query = attn .to_q (hidden_states )
484-
485- if encoder_hidden_states is None :
486- encoder_hidden_states = hidden_states
487- elif attn .cross_attention_norm :
488- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
489-
490522 key = attn .to_k (encoder_hidden_states )
491523 value = attn .to_v (encoder_hidden_states )
492524
525+ inner_dim = attn .to_q .out_features
493526 head_dim = inner_dim // attn .heads
494527 query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
495528 key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
@@ -503,10 +536,10 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
503536 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
504537 hidden_states = hidden_states .to (query .dtype )
505538
506- # linear proj
507- hidden_states = attn . to_out [ 0 ]( hidden_states )
508- # dropout
509- hidden_states = attn . to_out [ 1 ] (hidden_states )
539+ linear_proj , dropout = attn . to_out
540+
541+ hidden_states = linear_proj ( hidden_states )
542+ hidden_states = dropout (hidden_states )
510543 return hidden_states
511544
512545
0 commit comments