Skip to content

Commit cf4227c

Browse files
T5Attention support for cross-attention (#2654)
* fix AttnProcessor2_0 Fix use of AttnProcessor2_0 for cross attention with mask * added scale_qk and out_bias flags * fixed for xformers * check if it has scale argument * Update cross_attention.py * check torch version * fix sliced attn * style * set scale * fix test * fixed addedKV processor * revert back AttnProcessor2_0 * if missing if * fix inner_dim --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 9d1341d commit cf4227c

File tree

2 files changed

+39
-20
lines changed

2 files changed

+39
-20
lines changed

src/diffusers/models/cross_attention.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def __init__(
5959
cross_attention_norm: bool = False,
6060
added_kv_proj_dim: Optional[int] = None,
6161
norm_num_groups: Optional[int] = None,
62+
out_bias: bool = True,
63+
scale_qk: bool = True,
6264
processor: Optional["AttnProcessor"] = None,
6365
):
6466
super().__init__()
@@ -68,7 +70,7 @@ def __init__(
6870
self.upcast_softmax = upcast_softmax
6971
self.cross_attention_norm = cross_attention_norm
7072

71-
self.scale = dim_head**-0.5
73+
self.scale = dim_head**-0.5 if scale_qk else 1.0
7274

7375
self.heads = heads
7476
# for slice_size > 0 the attention score computation
@@ -95,14 +97,17 @@ def __init__(
9597
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
9698

9799
self.to_out = nn.ModuleList([])
98-
self.to_out.append(nn.Linear(inner_dim, query_dim))
100+
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
99101
self.to_out.append(nn.Dropout(dropout))
100102

101103
# set attention processor
102-
# We use the AttnProcessor2_0 by default when torch2.x is used which uses
104+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
103105
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
106+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
104107
if processor is None:
105-
processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
108+
processor = (
109+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else CrossAttnProcessor()
110+
)
106111
self.set_processor(processor)
107112

108113
def set_use_memory_efficient_attention_xformers(
@@ -295,7 +300,9 @@ def __call__(
295300
encoder_hidden_states=None,
296301
attention_mask=None,
297302
):
298-
batch_size, sequence_length, _ = hidden_states.shape
303+
batch_size, sequence_length, _ = (
304+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
305+
)
299306
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
300307
query = attn.to_q(hidden_states)
301308

@@ -362,7 +369,9 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
362369
def __call__(
363370
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
364371
):
365-
batch_size, sequence_length, _ = hidden_states.shape
372+
batch_size, sequence_length, _ = (
373+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
374+
)
366375
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
367376

368377
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
@@ -435,7 +444,9 @@ def __init__(self, attention_op: Optional[Callable] = None):
435444
self.attention_op = attention_op
436445

437446
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
438-
batch_size, sequence_length, _ = hidden_states.shape
447+
batch_size, sequence_length, _ = (
448+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
449+
)
439450

440451
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
441452

@@ -454,7 +465,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
454465
value = attn.head_to_batch_dim(value).contiguous()
455466

456467
hidden_states = xformers.ops.memory_efficient_attention(
457-
query, key, value, attn_bias=attention_mask, op=self.attention_op
468+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
458469
)
459470
hidden_states = hidden_states.to(query.dtype)
460471
hidden_states = attn.batch_to_head_dim(hidden_states)
@@ -472,7 +483,10 @@ def __init__(self):
472483
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
473484

474485
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
475-
batch_size, sequence_length, inner_dim = hidden_states.shape
486+
batch_size, sequence_length, _ = (
487+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
488+
)
489+
inner_dim = hidden_states.shape[-1]
476490

477491
if attention_mask is not None:
478492
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
@@ -496,6 +510,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
496510
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
497511

498512
# the output of sdp = (batch, num_heads, seq_len, head_dim)
513+
# TODO: add support for attn.scale when we move to Torch 2.1
499514
hidden_states = F.scaled_dot_product_attention(
500515
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
501516
)
@@ -527,7 +542,9 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio
527542
def __call__(
528543
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
529544
):
530-
batch_size, sequence_length, _ = hidden_states.shape
545+
batch_size, sequence_length, _ = (
546+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
547+
)
531548
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
532549

533550
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
@@ -542,7 +559,7 @@ def __call__(
542559
value = attn.head_to_batch_dim(value).contiguous()
543560

544561
hidden_states = xformers.ops.memory_efficient_attention(
545-
query, key, value, attn_bias=attention_mask, op=self.attention_op
562+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
546563
)
547564
hidden_states = attn.batch_to_head_dim(hidden_states)
548565

@@ -559,8 +576,9 @@ def __init__(self, slice_size):
559576
self.slice_size = slice_size
560577

561578
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
562-
batch_size, sequence_length, _ = hidden_states.shape
563-
579+
batch_size, sequence_length, _ = (
580+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
581+
)
564582
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
565583

566584
query = attn.to_q(hidden_states)
@@ -577,12 +595,12 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
577595
key = attn.head_to_batch_dim(key)
578596
value = attn.head_to_batch_dim(value)
579597

580-
batch_size_attention = query.shape[0]
598+
batch_size_attention, query_tokens, _ = query.shape
581599
hidden_states = torch.zeros(
582-
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
600+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
583601
)
584602

585-
for i in range(hidden_states.shape[0] // self.slice_size):
603+
for i in range(batch_size_attention // self.slice_size):
586604
start_idx = i * self.slice_size
587605
end_idx = (i + 1) * self.slice_size
588606

@@ -638,12 +656,12 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=
638656
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
639657
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
640658

641-
batch_size_attention = query.shape[0]
659+
batch_size_attention, query_tokens, _ = query.shape
642660
hidden_states = torch.zeros(
643-
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
661+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
644662
)
645663

646-
for i in range(hidden_states.shape[0] // self.slice_size):
664+
for i in range(batch_size_attention // self.slice_size):
647665
start_idx = i * self.slice_size
648666
end_idx = (i + 1) * self.slice_size
649667

tests/models/test_models_unet_2d_condition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def test_xformers_enable_works(self):
118118
model.enable_xformers_memory_efficient_attention()
119119

120120
assert (
121-
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
121+
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
122+
== "XFormersCrossAttnProcessor"
122123
), "xformers is not enabled"
123124

124125
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")

0 commit comments

Comments
 (0)