1717
1818import torch
1919import torch .nn .functional as F
20- from torch import nn
20+ from torch import nn , Tensor
21+ from einops import rearrange , repeat
2122
2223from ..configuration_utils import ConfigMixin , register_to_config
2324from ..modeling_utils import ModelMixin
@@ -175,7 +176,7 @@ def __init__(
175176 self .norm_out = nn .LayerNorm (inner_dim )
176177 self .out = nn .Linear (inner_dim , self .num_vector_embeds - 1 )
177178
178- def forward (self , hidden_states , encoder_hidden_states = None , timestep = None , return_dict : bool = True ):
179+ def forward (self , hidden_states , encoder_hidden_states = None , timestep = None , return_dict : bool = True , cross_attn_mask : Optional [ torch . Tensor ] = None ):
179180 """
180181 Args:
181182 hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
@@ -213,7 +214,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu
213214
214215 # 2. Blocks
215216 for block in self .transformer_blocks :
216- hidden_states = block (hidden_states , context = encoder_hidden_states , timestep = timestep )
217+ hidden_states = block (hidden_states , context = encoder_hidden_states , timestep = timestep , cross_attn_mask = cross_attn_mask )
217218
218219 # 3. Output
219220 if self .is_input_continuous :
@@ -472,14 +473,14 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten
472473 self .attn1 ._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
473474 self .attn2 ._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
474475
475- def forward (self , hidden_states , context = None , timestep = None ):
476+ def forward (self , hidden_states , context = None , timestep = None , cross_attn_mask : Optional [ torch . Tensor ] = None ):
476477 # 1. Self-Attention
477478 norm_hidden_states = (
478479 self .norm1 (hidden_states , timestep ) if self .use_ada_layer_norm else self .norm1 (hidden_states )
479480 )
480481
481482 if self .only_cross_attention :
482- hidden_states = self .attn1 (norm_hidden_states , context ) + hidden_states
483+ hidden_states = self .attn1 (norm_hidden_states , context , cross_attn_mask = cross_attn_mask ) + hidden_states
483484 else :
484485 hidden_states = self .attn1 (norm_hidden_states ) + hidden_states
485486
@@ -488,7 +489,7 @@ def forward(self, hidden_states, context=None, timestep=None):
488489 norm_hidden_states = (
489490 self .norm2 (hidden_states , timestep ) if self .use_ada_layer_norm else self .norm2 (hidden_states )
490491 )
491- hidden_states = self .attn2 (norm_hidden_states , context = context ) + hidden_states
492+ hidden_states = self .attn2 (norm_hidden_states , context = context , cross_attn_mask = cross_attn_mask ) + hidden_states
492493
493494 # 3. Feed-forward
494495 hidden_states = self .ff (self .norm3 (hidden_states )) + hidden_states
@@ -563,7 +564,7 @@ def set_attention_slice(self, slice_size):
563564
564565 self ._slice_size = slice_size
565566
566- def forward (self , hidden_states , context = None , mask = None ):
567+ def forward (self , hidden_states , context = None , mask = None , cross_attn_mask : Optional [ Tensor ] = None ):
567568 batch_size , sequence_length , _ = hidden_states .shape
568569
569570 query = self .to_q (hidden_states )
@@ -577,26 +578,29 @@ def forward(self, hidden_states, context=None, mask=None):
577578 key = self .reshape_heads_to_batch_dim (key )
578579 value = self .reshape_heads_to_batch_dim (value )
579580
580- # TODO(PVP) - mask is currently never used. Remember to re-implement when used
581+ # TODO AKB: `mask` param remains unimplemented. the parameter remains reserved
582+ # in case we should ever need a self-attention mask
583+ # (e.g. a pixel/latent-space mask to avoid self-attending to padding pixels, such as pillarboxing/letterboxing).
581584
582585 # attention, what we cannot get enough of
583586 if self ._use_memory_efficient_attention_xformers :
587+ assert cross_attn_mask is None , "cross-attention masking not implemented for xformers attention"
584588 hidden_states = self ._memory_efficient_attention_xformers (query , key , value )
585589 # Some versions of xformers return output in fp32, cast it back to the dtype of the input
586590 hidden_states = hidden_states .to (query .dtype )
587591 else :
588592 if self ._slice_size is None or query .shape [0 ] // self ._slice_size == 1 :
589- hidden_states = self ._attention (query , key , value )
593+ hidden_states = self ._attention (query , key , value , cross_attn_mask = cross_attn_mask )
590594 else :
591- hidden_states = self ._sliced_attention (query , key , value , sequence_length , dim )
595+ hidden_states = self ._sliced_attention (query , key , value , sequence_length , dim , cross_attn_mask = cross_attn_mask )
592596
593597 # linear proj
594598 hidden_states = self .to_out [0 ](hidden_states )
595599 # dropout
596600 hidden_states = self .to_out [1 ](hidden_states )
597601 return hidden_states
598602
599- def _attention (self , query , key , value ):
603+ def _attention (self , query , key , value , cross_attn_mask : Optional [ Tensor ] = None ):
600604 if self .upcast_attention :
601605 query = query .float ()
602606 key = key .float ()
@@ -608,6 +612,12 @@ def _attention(self, query, key, value):
608612 beta = 0 ,
609613 alpha = self .scale ,
610614 )
615+ if cross_attn_mask is not None :
616+ cross_attn_mask = rearrange (cross_attn_mask , 'b ... -> b (...)' )
617+ max_neg_value = - torch .finfo (attention_scores .dtype ).max
618+ cross_attn_mask = repeat (cross_attn_mask , 'b j -> (b h) () j' , h = self .heads )
619+ attention_scores .masked_fill_ (~ cross_attn_mask , max_neg_value )
620+ del cross_attn_mask
611621 attention_probs = attention_scores .softmax (dim = - 1 )
612622
613623 # cast back to the original dtype
@@ -620,11 +630,15 @@ def _attention(self, query, key, value):
620630 hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
621631 return hidden_states
622632
623- def _sliced_attention (self , query , key , value , sequence_length , dim ):
633+ def _sliced_attention (self , query , key , value , sequence_length , dim , cross_attn_mask : Optional [ Tensor ] = None ):
624634 batch_size_attention = query .shape [0 ]
625635 hidden_states = torch .zeros (
626636 (batch_size_attention , sequence_length , dim // self .heads ), device = query .device , dtype = query .dtype
627637 )
638+ if cross_attn_mask is not None :
639+ cross_attn_mask = rearrange (cross_attn_mask , 'b ... -> b (...)' )
640+ max_neg_value = - torch .finfo (query .dtype ).max
641+ cross_attn_mask = repeat (cross_attn_mask , 'b j -> (b h) () j' , h = self .heads )
628642 slice_size = self ._slice_size if self ._slice_size is not None else hidden_states .shape [0 ]
629643 for i in range (hidden_states .shape [0 ] // slice_size ):
630644 start_idx = i * slice_size
@@ -644,6 +658,10 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
644658 beta = 0 ,
645659 alpha = self .scale ,
646660 )
661+ if cross_attn_mask is not None :
662+ cross_attn_mask_slice = cross_attn_mask [start_idx :end_idx ]
663+ attn_slice .masked_fill_ (~ cross_attn_mask_slice , max_neg_value )
664+ del cross_attn_mask_slice
647665 attn_slice = attn_slice .softmax (dim = - 1 )
648666
649667 # cast back to the original dtype
0 commit comments