@@ -622,31 +622,53 @@ class SlicedAttnProcessor:
622622 def __init__ (self , slice_size ):
623623 self .slice_size = slice_size
624624
625- def __call__ (self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
626- batch_size , sequence_length , _ = hidden_states .shape
625+ def __call__ (
626+ self ,
627+ attn : CrossAttention ,
628+ hidden_states : FloatTensor ,
629+ encoder_hidden_states : Optional [FloatTensor ] = None ,
630+ attention_mask : Optional [FloatTensor ] = None ,
631+ encoder_attention_bias : Optional [FloatTensor ] = None ,
632+ ):
633+ if encoder_hidden_states is None :
634+ encoder_hidden_states = hidden_states
635+ else :
636+ if encoder_attention_bias is not None :
637+ if attention_mask is not None :
638+ # it's not well-defined whether `attention_mask` should be passed to self-attention, cross-attention, neither* or both.
639+ # if two sources of bias (`attention_mask`, `encoder_attention_bias`) are provided: it's likely to be a mistake.
640+ raise ValueError (f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias." )
641+ attention_mask = encoder_attention_bias
642+ # make broadcastable over query tokens
643+ # TODO: see if there's a satisfactory way to unify how the `attention_mask`/`encoder_attention_bias` code paths
644+ # create this singleton dim. the way AttnProcessor2_0 does it could work.
645+ # here I'm trying to avoid interfering with the original `attention_mask` code path,
646+ # by limiting the unsqueeze() to just the `encoder_attention_bias` path, on the basis that
647+ # `attention_mask` is already working without this change.
648+ # maybe it's because UNet2DConditionModel#forward unsqueeze()s `attention_mask` earlier.
649+ attention_mask = attention_mask .unsqueeze (- 2 )
650+ if attn .cross_attention_norm :
651+ encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
627652
628- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
653+ batch_size , key_tokens , _ = encoder_hidden_states .shape
654+ attention_mask = attn .prepare_attention_mask (attention_mask , key_tokens , batch_size )
629655
630656 query = attn .to_q (hidden_states )
631- dim = query .shape [- 1 ]
632657 query = attn .head_to_batch_dim (query )
633658
634- if encoder_hidden_states is None :
635- encoder_hidden_states = hidden_states
636- elif attn .cross_attention_norm :
637- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
638-
639659 key = attn .to_k (encoder_hidden_states )
640660 value = attn .to_v (encoder_hidden_states )
641661 key = attn .head_to_batch_dim (key )
642662 value = attn .head_to_batch_dim (value )
643663
644- batch_size_attention = query .shape [0 ]
664+ batch_x_heads , query_tokens , _ = query .shape
665+ inner_dim = attn .to_q .out_features
666+ channels_per_head = inner_dim // attn .heads
645667 hidden_states = torch .zeros (
646- (batch_size_attention , sequence_length , dim // attn . heads ), device = query .device , dtype = query .dtype
668+ (batch_x_heads , query_tokens , channels_per_head ), device = query .device , dtype = query .dtype
647669 )
648670
649- for i in range (hidden_states . shape [ 0 ] // self .slice_size ):
671+ for i in range (batch_x_heads // self .slice_size ):
650672 start_idx = i * self .slice_size
651673 end_idx = (i + 1 ) * self .slice_size
652674
@@ -662,10 +684,10 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
662684
663685 hidden_states = attn .batch_to_head_dim (hidden_states )
664686
665- # linear proj
666- hidden_states = attn . to_out [ 0 ]( hidden_states )
667- # dropout
668- hidden_states = attn . to_out [ 1 ] (hidden_states )
687+ linear_proj , dropout = attn . to_out
688+
689+ hidden_states = linear_proj ( hidden_states )
690+ hidden_states = dropout (hidden_states )
669691
670692 return hidden_states
671693
0 commit comments