File tree Expand file tree Collapse file tree 7 files changed +18
-14
lines changed
src/diffusers/pipelines/stable_diffusion Expand file tree Collapse file tree 7 files changed +18
-14
lines changed Original file line number Diff line number Diff line change @@ -278,7 +278,7 @@ def __call__(
278278 if do_classifier_free_guidance :
279279 uncond_tokens : List [str ]
280280 if negative_prompt is None :
281- uncond_tokens = ["" ]
281+ uncond_tokens = ["" ] * batch_size
282282 elif type (prompt ) is not type (negative_prompt ):
283283 raise TypeError (
284284 f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -307,7 +307,7 @@ def __call__(
307307
308308 # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
309309 seq_len = uncond_embeddings .shape [1 ]
310- uncond_embeddings = uncond_embeddings .repeat (batch_size , num_images_per_prompt , 1 )
310+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
311311 uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
312312
313313 # For classifier free guidance, we need to do two forward passes.
Original file line number Diff line number Diff line change @@ -148,7 +148,7 @@ def __call__(
148148 if do_classifier_free_guidance :
149149 uncond_tokens : List [str ]
150150 if negative_prompt is None :
151- uncond_tokens = ["" ]
151+ uncond_tokens = ["" ] * batch_size
152152 elif type (prompt ) is not type (negative_prompt ):
153153 raise TypeError (
154154 f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -177,7 +177,7 @@ def __call__(
177177
178178 # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
179179 seq_len = uncond_embeddings .shape [1 ]
180- uncond_embeddings = uncond_embeddings .repeat (batch_size , num_images_per_prompt , 1 )
180+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
181181 uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
182182
183183 # For classifier free guidance, we need to do two forward passes.
Original file line number Diff line number Diff line change @@ -295,7 +295,7 @@ def __call__(
295295 if do_classifier_free_guidance :
296296 uncond_tokens : List [str ]
297297 if negative_prompt is None :
298- uncond_tokens = ["" ]
298+ uncond_tokens = ["" ] * batch_size
299299 elif type (prompt ) is not type (negative_prompt ):
300300 raise TypeError (
301301 f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -324,7 +324,7 @@ def __call__(
324324
325325 # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
326326 seq_len = uncond_embeddings .shape [1 ]
327- uncond_embeddings = uncond_embeddings .repeat (batch_size , num_images_per_prompt , 1 )
327+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
328328 uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
329329
330330 # For classifier free guidance, we need to do two forward passes.
Original file line number Diff line number Diff line change @@ -297,7 +297,7 @@ def __call__(
297297 if do_classifier_free_guidance :
298298 uncond_tokens : List [str ]
299299 if negative_prompt is None :
300- uncond_tokens = ["" ]
300+ uncond_tokens = ["" ] * batch_size
301301 elif type (prompt ) is not type (negative_prompt ):
302302 raise TypeError (
303303 f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -326,7 +326,7 @@ def __call__(
326326
327327 # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
328328 seq_len = uncond_embeddings .shape [1 ]
329- uncond_embeddings = uncond_embeddings .repeat (batch_size , num_images_per_prompt , 1 )
329+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
330330 uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
331331
332332 # For classifier free guidance, we need to do two forward passes.
Original file line number Diff line number Diff line change @@ -295,7 +295,7 @@ def __call__(
295295 if do_classifier_free_guidance :
296296 uncond_tokens : List [str ]
297297 if negative_prompt is None :
298- uncond_tokens = ["" ]
298+ uncond_tokens = ["" ] * batch_size
299299 elif type (prompt ) is not type (negative_prompt ):
300300 raise TypeError (
301301 f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -319,7 +319,9 @@ def __call__(
319319 uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
320320
321321 # duplicate unconditional embeddings for each generation per prompt
322- uncond_embeddings = uncond_embeddings .repeat_interleave (batch_size * num_images_per_prompt , dim = 0 )
322+ seq_len = uncond_embeddings .shape [1 ]
323+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
324+ uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
323325
324326 # For classifier free guidance, we need to do two forward passes.
325327 # Here we concatenate the unconditional and text embeddings into a single batch
Original file line number Diff line number Diff line change @@ -302,7 +302,7 @@ def __call__(
302302 if do_classifier_free_guidance :
303303 uncond_tokens : List [str ]
304304 if negative_prompt is None :
305- uncond_tokens = ["" ]
305+ uncond_tokens = ["" ] * batch_size
306306 elif type (prompt ) is not type (negative_prompt ):
307307 raise TypeError (
308308 f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -331,7 +331,7 @@ def __call__(
331331
332332 # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
333333 seq_len = uncond_embeddings .shape [1 ]
334- uncond_embeddings = uncond_embeddings .repeat (batch_size , num_images_per_prompt , 1 )
334+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
335335 uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
336336
337337 # For classifier free guidance, we need to do two forward passes.
Original file line number Diff line number Diff line change @@ -284,7 +284,7 @@ def __call__(
284284 if do_classifier_free_guidance :
285285 uncond_tokens : List [str ]
286286 if negative_prompt is None :
287- uncond_tokens = ["" ]
287+ uncond_tokens = ["" ] * batch_size
288288 elif type (prompt ) is not type (negative_prompt ):
289289 raise TypeError (
290290 f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -312,7 +312,9 @@ def __call__(
312312 uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
313313
314314 # duplicate unconditional embeddings for each generation per prompt
315- uncond_embeddings = uncond_embeddings .repeat_interleave (batch_size * num_images_per_prompt , dim = 0 )
315+ seq_len = uncond_embeddings .shape [1 ]
316+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
317+ uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
316318
317319 # For classifier free guidance, we need to do two forward passes.
318320 # Here we concatenate the unconditional and text embeddings into a single batch
You can’t perform that action at this time.
0 commit comments