@@ -475,7 +475,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
475475 t_start = max (num_inference_steps - init_timestep + offset , 0 )
476476 timesteps = self .scheduler .timesteps [t_start :]
477477
478- return timesteps
478+ return timesteps , num_inference_steps - t_start
479479
480480 def prepare_latents (self , init_image , timestep , batch_size , num_images_per_prompt , dtype , device , generator = None ):
481481 init_image = init_image .to (device = device , dtype = dtype )
@@ -607,7 +607,7 @@ def __call__(
607607
608608 # 5. Prepare timesteps
609609 self .scheduler .set_timesteps (num_inference_steps , device = device )
610- timesteps = self .get_timesteps (num_inference_steps , strength , device )
610+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
611611 latent_timestep = timesteps [:1 ].repeat (batch_size * num_images_per_prompt )
612612
613613 # 6. Prepare latent variables
@@ -621,66 +621,70 @@ def __call__(
621621 generator = extra_step_kwargs .pop ("generator" , None )
622622
623623 # 8. Denoising loop
624- for i , t in enumerate (self .progress_bar (timesteps )):
625- # expand the latents if we are doing classifier free guidance
626- latent_model_input = torch .cat ([latents ] * 2 )
627- source_latent_model_input = torch .cat ([source_latents ] * 2 )
628- latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
629- source_latent_model_input = self .scheduler .scale_model_input (source_latent_model_input , t )
630-
631- # predict the noise residual
632- concat_latent_model_input = torch .stack (
633- [
634- source_latent_model_input [0 ],
635- latent_model_input [0 ],
636- source_latent_model_input [1 ],
637- latent_model_input [1 ],
638- ],
639- dim = 0 ,
640- )
641- concat_text_embeddings = torch .stack (
642- [
643- source_text_embeddings [0 ],
644- text_embeddings [0 ],
645- source_text_embeddings [1 ],
646- text_embeddings [1 ],
647- ],
648- dim = 0 ,
649- )
650- concat_noise_pred = self .unet (
651- concat_latent_model_input , t , encoder_hidden_states = concat_text_embeddings
652- ).sample
653-
654- # perform guidance
655- (
656- source_noise_pred_uncond ,
657- noise_pred_uncond ,
658- source_noise_pred_text ,
659- noise_pred_text ,
660- ) = concat_noise_pred .chunk (4 , dim = 0 )
661- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
662- source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
663- source_noise_pred_text - source_noise_pred_uncond
664- )
665-
666- # Sample source_latents from the posterior distribution.
667- prev_source_latents = posterior_sample (
668- self .scheduler , source_latents , t , clean_latents , generator = generator , ** extra_step_kwargs
669- )
670- # Compute noise.
671- noise = compute_noise (
672- self .scheduler , prev_source_latents , source_latents , t , source_noise_pred , ** extra_step_kwargs
673- )
674- source_latents = prev_source_latents
675-
676- # compute the previous noisy sample x_t -> x_t-1
677- latents = self .scheduler .step (
678- noise_pred , t , latents , variance_noise = noise , ** extra_step_kwargs
679- ).prev_sample
624+ num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
625+ with self .progress_bar (total = num_inference_steps ) as progress_bar :
626+ for i , t in enumerate (timesteps ):
627+ # expand the latents if we are doing classifier free guidance
628+ latent_model_input = torch .cat ([latents ] * 2 )
629+ source_latent_model_input = torch .cat ([source_latents ] * 2 )
630+ latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
631+ source_latent_model_input = self .scheduler .scale_model_input (source_latent_model_input , t )
632+
633+ # predict the noise residual
634+ concat_latent_model_input = torch .stack (
635+ [
636+ source_latent_model_input [0 ],
637+ latent_model_input [0 ],
638+ source_latent_model_input [1 ],
639+ latent_model_input [1 ],
640+ ],
641+ dim = 0 ,
642+ )
643+ concat_text_embeddings = torch .stack (
644+ [
645+ source_text_embeddings [0 ],
646+ text_embeddings [0 ],
647+ source_text_embeddings [1 ],
648+ text_embeddings [1 ],
649+ ],
650+ dim = 0 ,
651+ )
652+ concat_noise_pred = self .unet (
653+ concat_latent_model_input , t , encoder_hidden_states = concat_text_embeddings
654+ ).sample
655+
656+ # perform guidance
657+ (
658+ source_noise_pred_uncond ,
659+ noise_pred_uncond ,
660+ source_noise_pred_text ,
661+ noise_pred_text ,
662+ ) = concat_noise_pred .chunk (4 , dim = 0 )
663+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
664+ source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
665+ source_noise_pred_text - source_noise_pred_uncond
666+ )
680667
681- # call the callback, if provided
682- if callback is not None and i % callback_steps == 0 :
683- callback (i , t , latents )
668+ # Sample source_latents from the posterior distribution.
669+ prev_source_latents = posterior_sample (
670+ self .scheduler , source_latents , t , clean_latents , generator = generator , ** extra_step_kwargs
671+ )
672+ # Compute noise.
673+ noise = compute_noise (
674+ self .scheduler , prev_source_latents , source_latents , t , source_noise_pred , ** extra_step_kwargs
675+ )
676+ source_latents = prev_source_latents
677+
678+ # compute the previous noisy sample x_t -> x_t-1
679+ latents = self .scheduler .step (
680+ noise_pred , t , latents , variance_noise = noise , ** extra_step_kwargs
681+ ).prev_sample
682+
683+ # call the callback, if provided
684+ if (i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 :
685+ progress_bar .update ()
686+ if callback is not None and i % callback_steps == 0 :
687+ callback (i , t , latents )
684688
685689 # 9. Post-processing
686690 image = self .decode_latents (latents )
0 commit comments