File tree Expand file tree Collapse file tree 2 files changed +6
-3
lines changed
tests/pipelines/stable_diffusion Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -271,9 +271,10 @@ def __init__(
271271 def forward (
272272 self ,
273273 hidden_states ,
274+ attention_mask = None ,
274275 encoder_hidden_states = None ,
276+ encoder_attention_mask = None ,
275277 timestep = None ,
276- attention_mask = None ,
277278 cross_attention_kwargs = None ,
278279 class_labels = None ,
279280 ):
@@ -302,12 +303,14 @@ def forward(
302303 norm_hidden_states = (
303304 self .norm2 (hidden_states , timestep ) if self .use_ada_layer_norm else self .norm2 (hidden_states )
304305 )
306+ # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
307+ # prepare attention mask here
305308
306309 # 2. Cross-Attention
307310 attn_output = self .attn2 (
308311 norm_hidden_states ,
309312 encoder_hidden_states = encoder_hidden_states ,
310- attention_mask = attention_mask ,
313+ attention_mask = encoder_attention_mask ,
311314 ** cross_attention_kwargs ,
312315 )
313316 hidden_states = attn_output + hidden_states
Original file line number Diff line number Diff line change @@ -737,7 +737,7 @@ def test_stable_diffusion_vae_tiling(self):
737737
738738 # make sure that more than 4 GB is allocated
739739 mem_bytes = torch .cuda .max_memory_allocated ()
740- assert mem_bytes > 4e9
740+ assert mem_bytes > 5e9
741741 assert np .abs (image_chunked .flatten () - image .flatten ()).max () < 1e-2
742742
743743 def test_stable_diffusion_fp16_vs_autocast (self ):
You can’t perform that action at this time.
0 commit comments