@@ -59,6 +59,8 @@ def __init__(
5959 cross_attention_norm : bool = False ,
6060 added_kv_proj_dim : Optional [int ] = None ,
6161 norm_num_groups : Optional [int ] = None ,
62+ out_bias : bool = True ,
63+ scale_qk : bool = True ,
6264 processor : Optional ["AttnProcessor" ] = None ,
6365 ):
6466 super ().__init__ ()
@@ -68,7 +70,7 @@ def __init__(
6870 self .upcast_softmax = upcast_softmax
6971 self .cross_attention_norm = cross_attention_norm
7072
71- self .scale = dim_head ** - 0.5
73+ self .scale = dim_head ** - 0.5 if scale_qk else 1.0
7274
7375 self .heads = heads
7476 # for slice_size > 0 the attention score computation
@@ -95,14 +97,17 @@ def __init__(
9597 self .add_v_proj = nn .Linear (added_kv_proj_dim , cross_attention_dim )
9698
9799 self .to_out = nn .ModuleList ([])
98- self .to_out .append (nn .Linear (inner_dim , query_dim ))
100+ self .to_out .append (nn .Linear (inner_dim , query_dim , bias = out_bias ))
99101 self .to_out .append (nn .Dropout (dropout ))
100102
101103 # set attention processor
102- # We use the AttnProcessor2_0 by default when torch2 .x is used which uses
104+ # We use the AttnProcessor2_0 by default when torch 2 .x is used which uses
103105 # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
106+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
104107 if processor is None :
105- processor = AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) else CrossAttnProcessor ()
108+ processor = (
109+ AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) and scale_qk else CrossAttnProcessor ()
110+ )
106111 self .set_processor (processor )
107112
108113 def set_use_memory_efficient_attention_xformers (
@@ -295,7 +300,9 @@ def __call__(
295300 encoder_hidden_states = None ,
296301 attention_mask = None ,
297302 ):
298- batch_size , sequence_length , _ = hidden_states .shape
303+ batch_size , sequence_length , _ = (
304+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
305+ )
299306 attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
300307 query = attn .to_q (hidden_states )
301308
@@ -362,7 +369,9 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
362369 def __call__ (
363370 self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None , scale = 1.0
364371 ):
365- batch_size , sequence_length , _ = hidden_states .shape
372+ batch_size , sequence_length , _ = (
373+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
374+ )
366375 attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
367376
368377 query = attn .to_q (hidden_states ) + scale * self .to_q_lora (hidden_states )
@@ -435,7 +444,9 @@ def __init__(self, attention_op: Optional[Callable] = None):
435444 self .attention_op = attention_op
436445
437446 def __call__ (self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
438- batch_size , sequence_length , _ = hidden_states .shape
447+ batch_size , sequence_length , _ = (
448+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
449+ )
439450
440451 attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
441452
@@ -454,7 +465,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
454465 value = attn .head_to_batch_dim (value ).contiguous ()
455466
456467 hidden_states = xformers .ops .memory_efficient_attention (
457- query , key , value , attn_bias = attention_mask , op = self .attention_op
468+ query , key , value , attn_bias = attention_mask , op = self .attention_op , scale = attn . scale
458469 )
459470 hidden_states = hidden_states .to (query .dtype )
460471 hidden_states = attn .batch_to_head_dim (hidden_states )
@@ -472,7 +483,10 @@ def __init__(self):
472483 raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
473484
474485 def __call__ (self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
475- batch_size , sequence_length , inner_dim = hidden_states .shape
486+ batch_size , sequence_length , _ = (
487+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
488+ )
489+ inner_dim = hidden_states .shape [- 1 ]
476490
477491 if attention_mask is not None :
478492 attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
@@ -496,6 +510,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
496510 value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
497511
498512 # the output of sdp = (batch, num_heads, seq_len, head_dim)
513+ # TODO: add support for attn.scale when we move to Torch 2.1
499514 hidden_states = F .scaled_dot_product_attention (
500515 query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
501516 )
@@ -527,7 +542,9 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio
527542 def __call__ (
528543 self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None , scale = 1.0
529544 ):
530- batch_size , sequence_length , _ = hidden_states .shape
545+ batch_size , sequence_length , _ = (
546+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
547+ )
531548 attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
532549
533550 query = attn .to_q (hidden_states ) + scale * self .to_q_lora (hidden_states )
@@ -542,7 +559,7 @@ def __call__(
542559 value = attn .head_to_batch_dim (value ).contiguous ()
543560
544561 hidden_states = xformers .ops .memory_efficient_attention (
545- query , key , value , attn_bias = attention_mask , op = self .attention_op
562+ query , key , value , attn_bias = attention_mask , op = self .attention_op , scale = attn . scale
546563 )
547564 hidden_states = attn .batch_to_head_dim (hidden_states )
548565
@@ -559,8 +576,9 @@ def __init__(self, slice_size):
559576 self .slice_size = slice_size
560577
561578 def __call__ (self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
562- batch_size , sequence_length , _ = hidden_states .shape
563-
579+ batch_size , sequence_length , _ = (
580+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
581+ )
564582 attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
565583
566584 query = attn .to_q (hidden_states )
@@ -577,12 +595,12 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
577595 key = attn .head_to_batch_dim (key )
578596 value = attn .head_to_batch_dim (value )
579597
580- batch_size_attention = query .shape [ 0 ]
598+ batch_size_attention , query_tokens , _ = query .shape
581599 hidden_states = torch .zeros (
582- (batch_size_attention , sequence_length , dim // attn .heads ), device = query .device , dtype = query .dtype
600+ (batch_size_attention , query_tokens , dim // attn .heads ), device = query .device , dtype = query .dtype
583601 )
584602
585- for i in range (hidden_states . shape [ 0 ] // self .slice_size ):
603+ for i in range (batch_size_attention // self .slice_size ):
586604 start_idx = i * self .slice_size
587605 end_idx = (i + 1 ) * self .slice_size
588606
@@ -638,12 +656,12 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=
638656 key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
639657 value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 1 )
640658
641- batch_size_attention = query .shape [ 0 ]
659+ batch_size_attention , query_tokens , _ = query .shape
642660 hidden_states = torch .zeros (
643- (batch_size_attention , sequence_length , dim // attn .heads ), device = query .device , dtype = query .dtype
661+ (batch_size_attention , query_tokens , dim // attn .heads ), device = query .device , dtype = query .dtype
644662 )
645663
646- for i in range (hidden_states . shape [ 0 ] // self .slice_size ):
664+ for i in range (batch_size_attention // self .slice_size ):
647665 start_idx = i * self .slice_size
648666 end_idx = (i + 1 ) * self .slice_size
649667
0 commit comments