From 79b40750b4ab63aed2b497cc4416b57a8347e916 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 13 Mar 2023 16:59:38 +0000 Subject: [PATCH 1/3] [attention] Fix attention --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b476e762675c..ede7e13ca3bf 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -291,7 +291,7 @@ def forward( attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, + attention_mask=attention_mask if self.only_cross_attention else None, **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: From ad106be11949782e76fd0c0bb74493db60b63350 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 13 Mar 2023 17:15:32 +0000 Subject: [PATCH 2/3] fix --- src/diffusers/models/attention.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index ede7e13ca3bf..6da318e65593 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -271,9 +271,10 @@ def __init__( def forward( self, hidden_states, + attention_mask=None, encoder_hidden_states=None, + encoder_attention_mask=None, timestep=None, - attention_mask=None, cross_attention_kwargs=None, class_labels=None, ): @@ -291,7 +292,7 @@ def forward( attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask if self.only_cross_attention else None, + attention_mask=attention_mask, **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: @@ -302,12 +303,14 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here # 2. Cross-Attention attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states From 665322e28f14f332b3e14afb9122c1c5b8f1a519 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 13 Mar 2023 17:46:11 +0000 Subject: [PATCH 3/3] correct --- tests/pipelines/stable_diffusion/test_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index dfb14617f885..d4fd30458373 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -737,7 +737,7 @@ def test_stable_diffusion_vae_tiling(self): # make sure that more than 4 GB is allocated mem_bytes = torch.cuda.max_memory_allocated() - assert mem_bytes > 4e9 + assert mem_bytes > 5e9 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2 def test_stable_diffusion_fp16_vs_autocast(self):