@@ -234,43 +234,6 @@ def __call__(
234234 # set timesteps
235235 self .scheduler .set_timesteps (num_inference_steps )
236236
237- # preprocess image
238- if not isinstance (init_image , torch .FloatTensor ):
239- init_image = preprocess_image (init_image )
240- init_image = init_image .to (self .device )
241-
242- # encode the init image into latents and scale the latents
243- init_latent_dist = self .vae .encode (init_image ).latent_dist
244- init_latents = init_latent_dist .sample (generator = generator )
245-
246- init_latents = 0.18215 * init_latents
247-
248- # Expand init_latents for batch_size and num_images_per_prompt
249- init_latents = torch .cat ([init_latents ] * batch_size * num_images_per_prompt , dim = 0 )
250- init_latents_orig = init_latents
251-
252- # preprocess mask
253- if not isinstance (mask_image , torch .FloatTensor ):
254- mask_image = preprocess_mask (mask_image )
255- mask_image = mask_image .to (self .device )
256- mask = torch .cat ([mask_image ] * batch_size * num_images_per_prompt )
257-
258- # check sizes
259- if not mask .shape == init_latents .shape :
260- raise ValueError ("The mask and init_image should be the same size!" )
261-
262- # get the original timestep using init_timestep
263- offset = self .scheduler .config .get ("steps_offset" , 0 )
264- init_timestep = int (num_inference_steps * strength ) + offset
265- init_timestep = min (init_timestep , num_inference_steps )
266-
267- timesteps = self .scheduler .timesteps [- init_timestep ]
268- timesteps = torch .tensor ([timesteps ] * batch_size * num_images_per_prompt , device = self .device )
269-
270- # add noise to latents using the timesteps
271- noise = torch .randn (init_latents .shape , generator = generator , device = self .device )
272- init_latents = self .scheduler .add_noise (init_latents , noise , timesteps )
273-
274237 # get prompt text embeddings
275238 text_inputs = self .tokenizer (
276239 prompt ,
@@ -335,6 +298,43 @@ def __call__(
335298 # to avoid doing two forward passes
336299 text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
337300
301+ # preprocess image
302+ if not isinstance (init_image , torch .FloatTensor ):
303+ init_image = preprocess_image (init_image )
304+
305+ # encode the init image into latents and scale the latents
306+ latents_dtype = text_embeddings .dtype
307+ init_image = init_image .to (device = self .device , dtype = latents_dtype )
308+ init_latent_dist = self .vae .encode (init_image ).latent_dist
309+ init_latents = init_latent_dist .sample (generator = generator )
310+ init_latents = 0.18215 * init_latents
311+
312+ # Expand init_latents for batch_size and num_images_per_prompt
313+ init_latents = torch .cat ([init_latents ] * batch_size * num_images_per_prompt , dim = 0 )
314+ init_latents_orig = init_latents
315+
316+ # preprocess mask
317+ if not isinstance (mask_image , torch .FloatTensor ):
318+ mask_image = preprocess_mask (mask_image )
319+ mask_image = mask_image .to (device = self .device , dtype = latents_dtype )
320+ mask = torch .cat ([mask_image ] * batch_size * num_images_per_prompt )
321+
322+ # check sizes
323+ if not mask .shape == init_latents .shape :
324+ raise ValueError ("The mask and init_image should be the same size!" )
325+
326+ # get the original timestep using init_timestep
327+ offset = self .scheduler .config .get ("steps_offset" , 0 )
328+ init_timestep = int (num_inference_steps * strength ) + offset
329+ init_timestep = min (init_timestep , num_inference_steps )
330+
331+ timesteps = self .scheduler .timesteps [- init_timestep ]
332+ timesteps = torch .tensor ([timesteps ] * batch_size * num_images_per_prompt , device = self .device )
333+
334+ # add noise to latents using the timesteps
335+ noise = torch .randn (init_latents .shape , generator = generator , device = self .device , dtype = latents_dtype )
336+ init_latents = self .scheduler .add_noise (init_latents , noise , timesteps )
337+
338338 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
339339 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
340340 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
0 commit comments