Skip to content

Commit f832493

Browse files
committed
support cross-attention masking for CrossAttnProcessor, AttnProcessor2_0
1 parent 186689a commit f832493

File tree

3 files changed

+72
-30
lines changed

3 files changed

+72
-30
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from torch import BoolTensor, FloatTensor
2+
import torch
3+
4+
def mask_to_bias(mask: BoolTensor, dtype: torch.dtype) -> FloatTensor:
5+
bias: FloatTensor = (1 - mask.to(dtype=dtype)) * -10000.0
6+
return bias
7+

src/diffusers/models/cross_attention.py

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import torch
1717
import torch.nn.functional as F
18-
from torch import nn
18+
from torch import nn, FloatTensor
1919

2020
from ..utils import deprecate, logging
2121
from ..utils.import_utils import is_xformers_available
@@ -272,15 +272,22 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
272272
if attention_mask is None:
273273
return attention_mask
274274

275-
if attention_mask.shape[-1] != target_length:
275+
current_length: int = attention_mask.shape[-1]
276+
if current_length > target_length:
277+
# we *could* trim the mask with:
278+
# attention_mask = attention_mask[:,:target_length]
279+
# but this is weird enough that it's more likely to be a mistake than a shortcut
280+
raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).")
281+
elif current_length < target_length:
276282
if attention_mask.device.type == "mps":
277283
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
278284
# Instead, we can manually construct the padding tensor.
279285
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
280286
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
281287
attention_mask = torch.cat([attention_mask, padding], dim=2)
282288
else:
283-
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
289+
remaining_length: int = target_length-current_length
290+
attention_mask = F.pad(attention_mask, (0, remaining_length), value=0.0)
284291

285292
if attention_mask.shape[0] < batch_size * head_size:
286293
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
@@ -291,19 +298,31 @@ class CrossAttnProcessor:
291298
def __call__(
292299
self,
293300
attn: CrossAttention,
294-
hidden_states,
295-
encoder_hidden_states=None,
296-
attention_mask=None,
301+
hidden_states: FloatTensor,
302+
encoder_hidden_states: Optional[FloatTensor] = None,
303+
attention_mask: Optional[FloatTensor] = None,
304+
encoder_attention_bias: Optional[FloatTensor] = None,
297305
):
298-
batch_size, sequence_length, _ = hidden_states.shape
299-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
300-
query = attn.to_q(hidden_states)
301-
302306
if encoder_hidden_states is None:
303307
encoder_hidden_states = hidden_states
304-
elif attn.cross_attention_norm:
305-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
308+
else:
309+
if encoder_attention_bias is not None:
310+
if attention_mask is not None:
311+
# it's not well-defined whether `attention_mask` should be passed to self-attention, cross-attention, neither* or both.
312+
# if two sources of bias (`attention_mask`, `encoder_attention_bias`) are provided: it's likely to be a mistake.
313+
raise ValueError(f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias.")
314+
attention_mask = encoder_attention_bias
315+
# make broadcastable over query tokens
316+
# TODO: consider aligning implementations such that AttnProcessor2_0 and CrossAttnProcessor do unsqueeze
317+
# in the same way/circumstances -- AttnProcessor2_0 does it for `attention_mask` **and** for `encoder_attention_bias`.
318+
attention_mask = attention_mask.unsqueeze(-2)
319+
if attn.cross_attention_norm:
320+
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
321+
322+
batch_size, key_tokens, _ = encoder_hidden_states.shape
323+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
306324

325+
query = attn.to_q(hidden_states)
307326
key = attn.to_k(encoder_hidden_states)
308327
value = attn.to_v(encoder_hidden_states)
309328

@@ -315,10 +334,10 @@ def __call__(
315334
hidden_states = torch.bmm(attention_probs, value)
316335
hidden_states = attn.batch_to_head_dim(hidden_states)
317336

318-
# linear proj
319-
hidden_states = attn.to_out[0](hidden_states)
320-
# dropout
321-
hidden_states = attn.to_out[1](hidden_states)
337+
linear_proj, dropout = attn.to_out
338+
339+
hidden_states = linear_proj(hidden_states)
340+
hidden_states = dropout(hidden_states)
322341

323342
return hidden_states
324343

@@ -471,25 +490,39 @@ def __init__(self):
471490
if not hasattr(F, "scaled_dot_product_attention"):
472491
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
473492

474-
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
475-
batch_size, sequence_length, inner_dim = hidden_states.shape
493+
def __call__(
494+
self,
495+
attn: CrossAttention,
496+
hidden_states: FloatTensor,
497+
encoder_hidden_states: Optional[FloatTensor] = None,
498+
attention_mask: Optional[FloatTensor] = None,
499+
encoder_attention_bias: Optional[FloatTensor] = None,
500+
):
501+
if encoder_hidden_states is None:
502+
encoder_hidden_states = hidden_states
503+
else:
504+
if encoder_attention_bias is not None:
505+
if attention_mask is not None:
506+
# it's not well-defined whether `attention_mask` should be passed to self-attention, cross-attention, neither* or both.
507+
# if two sources of bias (`attention_mask`, `encoder_attention_bias`) are provided: it's likely to be a mistake.
508+
raise ValueError(f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias.")
509+
attention_mask = encoder_attention_bias
510+
if attn.cross_attention_norm:
511+
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
512+
513+
batch_size, key_tokens, _ = encoder_hidden_states.shape
476514

477515
if attention_mask is not None:
478-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
516+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
479517
# scaled_dot_product_attention expects attention_mask shape to be
480518
# (batch, heads, source_length, target_length)
481519
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
482520

483521
query = attn.to_q(hidden_states)
484-
485-
if encoder_hidden_states is None:
486-
encoder_hidden_states = hidden_states
487-
elif attn.cross_attention_norm:
488-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
489-
490522
key = attn.to_k(encoder_hidden_states)
491523
value = attn.to_v(encoder_hidden_states)
492524

525+
inner_dim = attn.to_q.out_features
493526
head_dim = inner_dim // attn.heads
494527
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
495528
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
@@ -503,10 +536,10 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
503536
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
504537
hidden_states = hidden_states.to(query.dtype)
505538

506-
# linear proj
507-
hidden_states = attn.to_out[0](hidden_states)
508-
# dropout
509-
hidden_states = attn.to_out[1](hidden_states)
539+
linear_proj, dropout = attn.to_out
540+
541+
hidden_states = linear_proj(hidden_states)
542+
hidden_states = dropout(hidden_states)
510543
return hidden_states
511544

512545

src/diffusers/models/unet_2d_condition.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..configuration_utils import ConfigMixin, register_to_config
2222
from ..loaders import UNet2DConditionLoadersMixin
2323
from ..utils import BaseOutput, logging
24+
from .attention_utils import mask_to_bias
2425
from .cross_attention import AttnProcessor
2526
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
2627
from .modeling_utils import ModelMixin
@@ -530,7 +531,8 @@ def forward(
530531

531532
# prepare attention_mask
532533
if attention_mask is not None:
533-
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
534+
attention_mask = mask_to_bias(attention_mask, sample.dtype)
535+
# create singleton dimension for broadcasting bias over query tokens
534536
attention_mask = attention_mask.unsqueeze(1)
535537

536538
# 0. center input if necessary

0 commit comments

Comments
 (0)