From 15321c77e8e42cb3f5a6e12de64f2597bf869513 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 Mar 2023 15:55:35 +0100 Subject: [PATCH 01/14] fix AttnProcessor2_0 Fix use of AttnProcessor2_0 for cross attention with mask --- src/diffusers/models/cross_attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 9f994064d08f..1e6283746d5e 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -472,7 +472,9 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, inner_dim = hidden_states.shape + batch_size, sequence_length, inner_dim = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) From 70e7c4e2e7f1d655453d5f0b67cf3742a89fc376 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 Mar 2023 20:07:43 +0100 Subject: [PATCH 02/14] added scale_qk and out_bias flags --- src/diffusers/models/cross_attention.py | 31 +++++++++++++++++-------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 1e6283746d5e..e06537f83428 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -59,6 +59,8 @@ def __init__( cross_attention_norm: bool = False, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, processor: Optional["AttnProcessor"] = None, ): super().__init__() @@ -68,7 +70,7 @@ def __init__( self.upcast_softmax = upcast_softmax self.cross_attention_norm = cross_attention_norm - self.scale = dim_head**-0.5 + self.scale = dim_head**-0.5 if scale_qk else 1.0 self.heads = heads # for slice_size > 0 the attention score computation @@ -95,7 +97,7 @@ def __init__( self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor @@ -295,7 +297,9 @@ def __call__( encoder_hidden_states=None, attention_mask=None, ): - batch_size, sequence_length, _ = hidden_states.shape + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) @@ -362,7 +366,9 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4): def __call__( self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 ): - batch_size, sequence_length, _ = hidden_states.shape + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) @@ -454,7 +460,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=self.attention_op + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) @@ -499,7 +505,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -529,7 +535,9 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio def __call__( self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 ): - batch_size, sequence_length, _ = hidden_states.shape + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) @@ -561,8 +569,9 @@ def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape - + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) @@ -617,7 +626,9 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - batch_size, sequence_length, _ = hidden_states.shape + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) From 0863561b0e90497e37cfce64d590945b8a822baf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 Mar 2023 21:03:49 +0100 Subject: [PATCH 03/14] fixed for xformers --- src/diffusers/models/cross_attention.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index e06537f83428..d37504d30913 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -398,7 +398,9 @@ class CrossAttnAddedKVProcessor: def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) - batch_size, sequence_length, _ = hidden_states.shape + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -441,7 +443,9 @@ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) From 8cdaf3ff0dbb949e351b74a5b37b51f34560cce6 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 Mar 2023 21:14:20 +0100 Subject: [PATCH 04/14] check if it has scale argument --- src/diffusers/models/cross_attention.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index d37504d30913..67c9e94d3f6e 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import Callable, Optional, Union import torch @@ -101,10 +102,17 @@ def __init__( self.to_out.append(nn.Dropout(dropout)) # set attention processor - # We use the AttnProcessor2_0 by default when torch2.x is used which uses + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the `scale` argument if processor is None: - processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() + if ( + hasattr(F, "scaled_dot_product_attention") + and inspect.signature(F.scaled_dot_product_attention).parameters.get("scale") is not None + ): + processor = AttnProcessor2_0() + else: + CrossAttnProcessor() self.set_processor(processor) def set_use_memory_efficient_attention_xformers( From 840af0af64f7541671a899bd93e8196ffb2b5cf2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 Mar 2023 21:33:12 +0100 Subject: [PATCH 05/14] Update cross_attention.py --- src/diffusers/models/cross_attention.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 67c9e94d3f6e..6319cf1cf1b3 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -105,6 +105,9 @@ def __init__( # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the `scale` argument + import pdb + + pdb.set_trace() if processor is None: if ( hasattr(F, "scaled_dot_product_attention") @@ -486,7 +489,10 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No class AttnProcessor2_0: def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): + if ( + not hasattr(F, "scaled_dot_product_attention") + or inspect.signature(F.scaled_dot_product_attention).parameters.get("scale") is None + ): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): From 19861b91f1857270cfcb09d0c6a923cfe93673aa Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 13 Mar 2023 21:41:47 +0100 Subject: [PATCH 06/14] check torch version --- src/diffusers/models/cross_attention.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 6319cf1cf1b3..8f262f7cd28c 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -105,17 +105,11 @@ def __init__( # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the `scale` argument - import pdb - - pdb.set_trace() if processor is None: - if ( - hasattr(F, "scaled_dot_product_attention") - and inspect.signature(F.scaled_dot_product_attention).parameters.get("scale") is not None - ): + if torch.torch_version.TorchVersion(torch.__version__) >= (2, 1, 0): processor = AttnProcessor2_0() else: - CrossAttnProcessor() + processor = CrossAttnProcessor() self.set_processor(processor) def set_use_memory_efficient_attention_xformers( @@ -489,10 +483,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No class AttnProcessor2_0: def __init__(self): - if ( - not hasattr(F, "scaled_dot_product_attention") - or inspect.signature(F.scaled_dot_product_attention).parameters.get("scale") is None - ): + if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): From afc92c4ef834d6a6d19e094ec7d897f2581f4422 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 14 Mar 2023 09:29:24 +0100 Subject: [PATCH 07/14] fix sliced attn --- src/diffusers/models/cross_attention.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 8f262f7cd28c..9cd33602eaee 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -597,12 +597,12 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) - batch_size_attention = query.shape[0] + batch_size_attention, query_tokens, _ = query.shape hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) - for i in range(hidden_states.shape[0] // self.slice_size): + for i in range(batch_size_attention // self.slice_size): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size @@ -660,12 +660,12 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) - batch_size_attention = query.shape[0] + batch_size_attention, query_tokens, _ = query.shape hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) - for i in range(hidden_states.shape[0] // self.slice_size): + for i in range(batch_size_attention // self.slice_size): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size From d044d2bfea3b1f68de8c9bb05453f00d5a450673 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 14 Mar 2023 09:34:27 +0100 Subject: [PATCH 08/14] style --- src/diffusers/models/cross_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 9cd33602eaee..ece360de3fd3 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from typing import Callable, Optional, Union import torch From 836bc8acd387cb67b94831a7a90167c4c36d74fd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 14 Mar 2023 09:46:05 +0100 Subject: [PATCH 09/14] set scale --- src/diffusers/models/cross_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index ece360de3fd3..a4af959aeefb 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -560,7 +560,7 @@ def __call__( value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=self.attention_op + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = attn.batch_to_head_dim(hidden_states) From 9631fa243bf961bff753f8dd496aceb6b682955e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 14 Mar 2023 09:57:34 +0100 Subject: [PATCH 10/14] fix test --- tests/models/test_models_unet_2d_condition.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index c1f3bc05d7c6..e313fcfb0b29 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -118,7 +118,8 @@ def test_xformers_enable_works(self): model.enable_xformers_memory_efficient_attention() assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersCrossAttnProcessor" ), "xformers is not enabled" @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") From 6e77ced765c81f6f1499264a98627d3c1d7eff33 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 14 Mar 2023 12:10:13 +0100 Subject: [PATCH 11/14] fixed addedKV processor --- src/diffusers/models/cross_attention.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index a4af959aeefb..c1db075fbd7f 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -402,9 +402,7 @@ class CrossAttnAddedKVProcessor: def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + batch_size, sequence_length, _ = hidden_states.shape encoder_hidden_states = encoder_hidden_states.transpose(1, 2) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -634,9 +632,7 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) From 0a96374f69c0a64885faf4e21f4cde8ec88afd54 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 14 Mar 2023 22:03:45 +0100 Subject: [PATCH 12/14] revert back AttnProcessor2_0 --- src/diffusers/models/cross_attention.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index c1db075fbd7f..170e5db1bb4a 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -103,12 +103,10 @@ def __init__( # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the `scale` argument - if processor is None: - if torch.torch_version.TorchVersion(torch.__version__) >= (2, 1, 0): - processor = AttnProcessor2_0() - else: - processor = CrossAttnProcessor() + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else CrossAttnProcessor() + ) self.set_processor(processor) def set_use_memory_efficient_attention_xformers( @@ -510,8 +508,9 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) From afaa112425ca51d1f8b8d57923cac7bdd7372b90 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 14 Mar 2023 22:48:50 +0100 Subject: [PATCH 13/14] if missing if --- src/diffusers/models/cross_attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 170e5db1bb4a..993f09027cf4 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -104,9 +104,10 @@ def __init__( # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - processor = ( - AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else CrossAttnProcessor() - ) + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else CrossAttnProcessor() + ) self.set_processor(processor) def set_use_memory_efficient_attention_xformers( From 9e34860d02ab51dba088df0e3824f56745f67612 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 14 Mar 2023 23:06:54 +0100 Subject: [PATCH 14/14] fix inner_dim --- src/diffusers/models/cross_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 993f09027cf4..a0ecfb0f406d 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -483,9 +483,10 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, inner_dim = ( + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) + inner_dim = hidden_states.shape[-1] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)