@@ -205,6 +205,110 @@ def disable_attention_slicing(self):
205205 # set slice_size = `None` to disable `set_attention_slice`
206206 self .enable_attention_slicing (None )
207207
208+ @property
209+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
210+ def _execution_device (self ):
211+ r"""
212+ Returns the device on which the pipeline's models will be executed. After calling
213+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
214+ hooks.
215+ """
216+ if self .device != torch .device ("meta" ) or not hasattr (self .unet , "_hf_hook" ):
217+ return self .device
218+ for module in self .unet .modules ():
219+ if (
220+ hasattr (module , "_hf_hook" )
221+ and hasattr (module ._hf_hook , "execution_device" )
222+ and module ._hf_hook .execution_device is not None
223+ ):
224+ return torch .device (module ._hf_hook .execution_device )
225+ return self .device
226+
227+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
228+ def _encode_prompt (self , prompt , device , num_images_per_prompt , do_classifier_free_guidance , negative_prompt ):
229+ r"""
230+ Encodes the prompt into text encoder hidden states.
231+
232+ Args:
233+ prompt (`str` or `list(int)`):
234+ prompt to be encoded
235+ device: (`torch.device`):
236+ torch device
237+ num_images_per_prompt (`int`):
238+ number of images that should be generated per prompt
239+ do_classifier_free_guidance (`bool`):
240+ whether to use classifier free guidance or not
241+ negative_prompt (`str` or `List[str]`):
242+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
243+ if `guidance_scale` is less than `1`).
244+ """
245+ batch_size = len (prompt ) if isinstance (prompt , list ) else 1
246+
247+ text_inputs = self .tokenizer (
248+ prompt ,
249+ padding = "max_length" ,
250+ max_length = self .tokenizer .model_max_length ,
251+ return_tensors = "pt" ,
252+ )
253+ text_input_ids = text_inputs .input_ids
254+
255+ if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
256+ removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
257+ logger .warning (
258+ "The following part of your input was truncated because CLIP can only handle sequences up to"
259+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
260+ )
261+ text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
262+ text_embeddings = self .text_encoder (text_input_ids .to (device ))[0 ]
263+
264+ # duplicate text embeddings for each generation per prompt, using mps friendly method
265+ bs_embed , seq_len , _ = text_embeddings .shape
266+ text_embeddings = text_embeddings .repeat (1 , num_images_per_prompt , 1 )
267+ text_embeddings = text_embeddings .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
268+
269+ # get unconditional embeddings for classifier free guidance
270+ if do_classifier_free_guidance :
271+ uncond_tokens : List [str ]
272+ if negative_prompt is None :
273+ uncond_tokens = ["" ] * batch_size
274+ elif type (prompt ) is not type (negative_prompt ):
275+ raise TypeError (
276+ f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
277+ f" { type (prompt )} ."
278+ )
279+ elif isinstance (negative_prompt , str ):
280+ uncond_tokens = [negative_prompt ]
281+ elif batch_size != len (negative_prompt ):
282+ raise ValueError (
283+ f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
284+ f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
285+ " the batch size of `prompt`."
286+ )
287+ else :
288+ uncond_tokens = negative_prompt
289+
290+ max_length = text_input_ids .shape [- 1 ]
291+ uncond_input = self .tokenizer (
292+ uncond_tokens ,
293+ padding = "max_length" ,
294+ max_length = max_length ,
295+ truncation = True ,
296+ return_tensors = "pt" ,
297+ )
298+ uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (device ))[0 ]
299+
300+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
301+ seq_len = uncond_embeddings .shape [1 ]
302+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
303+ uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
304+
305+ # For classifier free guidance, we need to do two forward passes.
306+ # Here we concatenate the unconditional and text embeddings into a single batch
307+ # to avoid doing two forward passes
308+ text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
309+
310+ return text_embeddings
311+
208312 @torch .no_grad ()
209313 def __call__ (
210314 self ,
@@ -309,89 +413,17 @@ def __call__(
309413 if isinstance (init_image , PIL .Image .Image ):
310414 init_image = preprocess (init_image )
311415
312- # get prompt text embeddings
313- text_inputs = self .tokenizer (
314- prompt ,
315- padding = "max_length" ,
316- max_length = self .tokenizer .model_max_length ,
317- return_tensors = "pt" ,
318- )
319- source_text_inputs = self .tokenizer (
320- source_prompt ,
321- padding = "max_length" ,
322- max_length = self .tokenizer .model_max_length ,
323- return_tensors = "pt" ,
324- )
325- text_input_ids = text_inputs .input_ids
326- source_text_input_ids = source_text_inputs .input_ids
327-
328- if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
329- removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
330- logger .warning (
331- "The following part of your input was truncated because CLIP can only handle sequences up to"
332- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
333- )
334- text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
335- if source_text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
336- removed_text = self .tokenizer .batch_decode (source_text_input_ids [:, self .tokenizer .model_max_length :])
337- logger .warning (
338- "The following part of your input was truncated because CLIP can only handle sequences up to"
339- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
340- )
341- source_text_input_ids = source_text_input_ids [:, : self .tokenizer .model_max_length ]
342- text_embeddings = self .text_encoder (text_input_ids .to (self .device ))[0 ]
343- source_text_embeddings = self .text_encoder (source_text_input_ids .to (self .device ))[0 ]
344-
345- # duplicate text embeddings for each generation per prompt
346- text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
347- source_text_embeddings = source_text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
416+ device = self ._execution_device
348417
349418 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
350419 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
351420 # corresponds to doing no classifier free guidance.
421+ do_classifier_free_guidance = guidance_scale > 1.0
352422
353- # get unconditional embeddings for classifier free guidance
354- uncond_tokens = ["" ]
355-
356- max_length = text_input_ids .shape [- 1 ]
357- uncond_input = self .tokenizer (
358- uncond_tokens ,
359- padding = "max_length" ,
360- max_length = max_length ,
361- truncation = True ,
362- return_tensors = "pt" ,
423+ text_embeddings = self ._encode_prompt (prompt , device , num_images_per_prompt , do_classifier_free_guidance , None )
424+ source_text_embeddings = self ._encode_prompt (
425+ source_prompt , device , num_images_per_prompt , do_classifier_free_guidance , None
363426 )
364- uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
365-
366- # duplicate unconditional embeddings for each generation per prompt
367- uncond_embeddings = uncond_embeddings .repeat_interleave (batch_size * num_images_per_prompt , dim = 0 )
368-
369- # For classifier free guidance, we need to do two forward passes.
370- # Here we concatenate the unconditional and text embeddings into a single batch
371- # to avoid doing two forward passes
372- text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
373-
374- source_uncond_tokens = ["" ]
375-
376- max_length = source_text_input_ids .shape [- 1 ]
377- source_uncond_input = self .tokenizer (
378- source_uncond_tokens ,
379- padding = "max_length" ,
380- max_length = max_length ,
381- truncation = True ,
382- return_tensors = "pt" ,
383- )
384- source_uncond_embeddings = self .text_encoder (source_uncond_input .input_ids .to (self .device ))[0 ]
385-
386- # duplicate unconditional embeddings for each generation per prompt
387- source_uncond_embeddings = source_uncond_embeddings .repeat_interleave (
388- batch_size * num_images_per_prompt , dim = 0
389- )
390-
391- # For classifier free guidance, we need to do two forward passes.
392- # Here we concatenate the unconditional and text embeddings into a single batch
393- # to avoid doing two forward passes
394- source_text_embeddings = torch .cat ([source_uncond_embeddings , source_text_embeddings ])
395427
396428 # encode the init image into latents and scale the latents
397429 latents_dtype = text_embeddings .dtype
0 commit comments