1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from typing import Callable , Optional , Union
14+ from typing import Callable , Optional , Union , Dict , Any
1515
1616import torch
1717import torch .nn .functional as F
@@ -198,7 +198,13 @@ def set_processor(self, processor: "AttnProcessor"):
198198
199199 self .processor = processor
200200
201- def forward (self , hidden_states , encoder_hidden_states = None , attention_mask = None , ** cross_attention_kwargs ):
201+ def forward (
202+ self ,
203+ hidden_states : FloatTensor ,
204+ encoder_hidden_states : Optional [FloatTensor ] = None ,
205+ attention_mask : Optional [FloatTensor ] = None ,
206+ ** cross_attention_kwargs : Dict [str , Any ]
207+ ):
202208 # The `CrossAttention` class can call different attention processors / attention functions
203209 # here we simply pass along all tensors to the selected processor class
204210 # For standard processors that are defined here, `**cross_attention_kwargs` is empty
@@ -313,8 +319,12 @@ def __call__(
313319 raise ValueError (f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias." )
314320 attention_mask = encoder_attention_bias
315321 # make broadcastable over query tokens
316- # TODO: consider aligning implementations such that AttnProcessor2_0 and CrossAttnProcessor do unsqueeze
317- # in the same way/circumstances -- AttnProcessor2_0 does it for `attention_mask` **and** for `encoder_attention_bias`.
322+ # TODO: see if there's a satisfactory way to unify how the `attention_mask`/`encoder_attention_bias` code paths
323+ # create this singleton dim. the way AttnProcessor2_0 does it could work.
324+ # here I'm trying to avoid interfering with the original `attention_mask` code path,
325+ # by limiting the unsqueeze() to just the `encoder_attention_bias` path, on the basis that
326+ # `attention_mask` is already working without this change.
327+ # maybe it's because UNet2DConditionModel#forward unsqueeze()s `attention_mask` earlier.
318328 attention_mask = attention_mask .unsqueeze (- 2 )
319329 if attn .cross_attention_norm :
320330 encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
@@ -453,18 +463,39 @@ class XFormersCrossAttnProcessor:
453463 def __init__ (self , attention_op : Optional [Callable ] = None ):
454464 self .attention_op = attention_op
455465
456- def __call__ (self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
457- batch_size , sequence_length , _ = hidden_states .shape
458-
459- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
460-
461- query = attn .to_q (hidden_states )
462-
466+ def __call__ (
467+ self ,
468+ attn : CrossAttention ,
469+ hidden_states : FloatTensor ,
470+ encoder_hidden_states : Optional [FloatTensor ] = None ,
471+ attention_mask : Optional [FloatTensor ] = None ,
472+ encoder_attention_bias : Optional [FloatTensor ] = None ,
473+ ):
463474 if encoder_hidden_states is None :
464475 encoder_hidden_states = hidden_states
465- elif attn .cross_attention_norm :
466- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
476+ else :
477+ if encoder_attention_bias is not None :
478+ if attention_mask is not None :
479+ # it's not well-defined whether `attention_mask` should be passed to self-attention, cross-attention, neither* or both.
480+ # if two sources of bias (`attention_mask`, `encoder_attention_bias`) are provided: it's likely to be a mistake.
481+ raise ValueError (f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias." )
482+ attention_mask = encoder_attention_bias
483+
484+ # TODO: figure out why the original `attention_mask` code path didn't attempt broadcasting over query tokens.
485+ # it feels like this logic would be needed in that code path too.
467486
487+ # make broadcastable over query tokens
488+ attention_mask = attention_mask .unsqueeze (- 2 )
489+ _ , query_tokens , _ = hidden_states .shape
490+ # xformers doesn't broadcast for us, so we expand our singleton dimension manually
491+ attention_mask = attention_mask .expand (- 1 , query_tokens , - 1 )
492+ if attn .cross_attention_norm :
493+ encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
494+
495+ batch_size , key_tokens , _ = encoder_hidden_states .shape
496+ attention_mask = attn .prepare_attention_mask (attention_mask , key_tokens , batch_size )
497+
498+ query = attn .to_q (hidden_states )
468499 key = attn .to_k (encoder_hidden_states )
469500 value = attn .to_v (encoder_hidden_states )
470501
@@ -478,10 +509,10 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
478509 hidden_states = hidden_states .to (query .dtype )
479510 hidden_states = attn .batch_to_head_dim (hidden_states )
480511
481- # linear proj
482- hidden_states = attn . to_out [ 0 ]( hidden_states )
483- # dropout
484- hidden_states = attn . to_out [ 1 ] (hidden_states )
512+ linear_proj , dropout = attn . to_out
513+
514+ hidden_states = linear_proj ( hidden_states )
515+ hidden_states = dropout (hidden_states )
485516 return hidden_states
486517
487518
0 commit comments