Skip to content

Commit 3543f39

Browse files
Add 2nd order heun scheduler (huggingface#1336)
* Add heun * Finish first version of heun * remove bogus * finish * finish * improve * up * up * fix more * change progress bar * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py * finish * up * up * up
1 parent 2b7e440 commit 3543f39

29 files changed

+624
-291
lines changed

__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
DPMSolverMultistepScheduler,
4747
EulerAncestralDiscreteScheduler,
4848
EulerDiscreteScheduler,
49+
HeunDiscreteScheduler,
4950
IPNDMScheduler,
5051
KarrasVeScheduler,
5152
PNDMScheduler,

pipeline_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def is_safetensors_compatible(info) -> bool:
129129
sf_filename = os.path.join(prefix, "model.safetensors")
130130
else:
131131
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
132-
if sf_filename not in filenames:
133-
logger.warning("{sf_filename} not found")
132+
if is_safetensors_compatible and sf_filename not in filenames:
133+
logger.warning(f"{sf_filename} not found")
134134
is_safetensors_compatible = False
135135
return is_safetensors_compatible
136136

@@ -767,15 +767,20 @@ def numpy_to_pil(images):
767767

768768
return pil_images
769769

770-
def progress_bar(self, iterable):
770+
def progress_bar(self, iterable=None, total=None):
771771
if not hasattr(self, "_progress_bar_config"):
772772
self._progress_bar_config = {}
773773
elif not isinstance(self._progress_bar_config, dict):
774774
raise ValueError(
775775
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
776776
)
777777

778-
return tqdm(iterable, **self._progress_bar_config)
778+
if iterable is not None:
779+
return tqdm(iterable, **self._progress_bar_config)
780+
elif total is not None:
781+
return tqdm(total=total, **self._progress_bar_config)
782+
else:
783+
raise ValueError("Either `total` or `iterable` has to be defined.")
779784

780785
def set_progress_bar_config(self, **kwargs):
781786
self._progress_bar_config = kwargs

pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -541,25 +541,29 @@ def __call__(
541541
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
542542

543543
# 7. Denoising loop
544-
for i, t in enumerate(self.progress_bar(timesteps)):
545-
# expand the latents if we are doing classifier free guidance
546-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
547-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
548-
549-
# predict the noise residual
550-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
551-
552-
# perform guidance
553-
if do_classifier_free_guidance:
554-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
555-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
556-
557-
# compute the previous noisy sample x_t -> x_t-1
558-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
559-
560-
# call the callback, if provided
561-
if callback is not None and i % callback_steps == 0:
562-
callback(i, t, latents)
544+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
545+
with self.progress_bar(total=num_inference_steps) as progress_bar:
546+
for i, t in enumerate(timesteps):
547+
# expand the latents if we are doing classifier free guidance
548+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
549+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
550+
551+
# predict the noise residual
552+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
553+
554+
# perform guidance
555+
if do_classifier_free_guidance:
556+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
557+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
558+
559+
# compute the previous noisy sample x_t -> x_t-1
560+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
561+
562+
# call the callback, if provided
563+
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
564+
progress_bar.update()
565+
if callback is not None and i % callback_steps == 0:
566+
callback(i, t, latents)
563567

564568
# 8. Post-processing
565569
image = self.decode_latents(latents)

pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
433433
t_start = max(num_inference_steps - init_timestep + offset, 0)
434434
timesteps = self.scheduler.timesteps[t_start:]
435435

436-
return timesteps
436+
return timesteps, num_inference_steps - t_start
437437

438438
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
439439
init_image = init_image.to(device=device, dtype=dtype)
@@ -562,7 +562,7 @@ def __call__(
562562

563563
# 5. set timesteps
564564
self.scheduler.set_timesteps(num_inference_steps, device=device)
565-
timesteps = self.get_timesteps(num_inference_steps, strength, device)
565+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
566566
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
567567

568568
# 6. Prepare latent variables
@@ -574,25 +574,29 @@ def __call__(
574574
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
575575

576576
# 8. Denoising loop
577-
for i, t in enumerate(self.progress_bar(timesteps)):
578-
# expand the latents if we are doing classifier free guidance
579-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
580-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
581-
582-
# predict the noise residual
583-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
584-
585-
# perform guidance
586-
if do_classifier_free_guidance:
587-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
588-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
589-
590-
# compute the previous noisy sample x_t -> x_t-1
591-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
592-
593-
# call the callback, if provided
594-
if callback is not None and i % callback_steps == 0:
595-
callback(i, t, latents)
577+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
578+
with self.progress_bar(total=num_inference_steps) as progress_bar:
579+
for i, t in enumerate(timesteps):
580+
# expand the latents if we are doing classifier free guidance
581+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
582+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
583+
584+
# predict the noise residual
585+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
586+
587+
# perform guidance
588+
if do_classifier_free_guidance:
589+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
590+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
591+
592+
# compute the previous noisy sample x_t -> x_t-1
593+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
594+
595+
# call the callback, if provided
596+
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
597+
progress_bar.update()
598+
if callback is not None and i % callback_steps == 0:
599+
callback(i, t, latents)
596600

597601
# 9. Post-processing
598602
image = self.decode_latents(latents)

pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 65 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
475475
t_start = max(num_inference_steps - init_timestep + offset, 0)
476476
timesteps = self.scheduler.timesteps[t_start:]
477477

478-
return timesteps
478+
return timesteps, num_inference_steps - t_start
479479

480480
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
481481
init_image = init_image.to(device=device, dtype=dtype)
@@ -607,7 +607,7 @@ def __call__(
607607

608608
# 5. Prepare timesteps
609609
self.scheduler.set_timesteps(num_inference_steps, device=device)
610-
timesteps = self.get_timesteps(num_inference_steps, strength, device)
610+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
611611
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
612612

613613
# 6. Prepare latent variables
@@ -621,66 +621,70 @@ def __call__(
621621
generator = extra_step_kwargs.pop("generator", None)
622622

623623
# 8. Denoising loop
624-
for i, t in enumerate(self.progress_bar(timesteps)):
625-
# expand the latents if we are doing classifier free guidance
626-
latent_model_input = torch.cat([latents] * 2)
627-
source_latent_model_input = torch.cat([source_latents] * 2)
628-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
629-
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
630-
631-
# predict the noise residual
632-
concat_latent_model_input = torch.stack(
633-
[
634-
source_latent_model_input[0],
635-
latent_model_input[0],
636-
source_latent_model_input[1],
637-
latent_model_input[1],
638-
],
639-
dim=0,
640-
)
641-
concat_text_embeddings = torch.stack(
642-
[
643-
source_text_embeddings[0],
644-
text_embeddings[0],
645-
source_text_embeddings[1],
646-
text_embeddings[1],
647-
],
648-
dim=0,
649-
)
650-
concat_noise_pred = self.unet(
651-
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
652-
).sample
653-
654-
# perform guidance
655-
(
656-
source_noise_pred_uncond,
657-
noise_pred_uncond,
658-
source_noise_pred_text,
659-
noise_pred_text,
660-
) = concat_noise_pred.chunk(4, dim=0)
661-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
662-
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
663-
source_noise_pred_text - source_noise_pred_uncond
664-
)
665-
666-
# Sample source_latents from the posterior distribution.
667-
prev_source_latents = posterior_sample(
668-
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
669-
)
670-
# Compute noise.
671-
noise = compute_noise(
672-
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
673-
)
674-
source_latents = prev_source_latents
675-
676-
# compute the previous noisy sample x_t -> x_t-1
677-
latents = self.scheduler.step(
678-
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
679-
).prev_sample
624+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
625+
with self.progress_bar(total=num_inference_steps) as progress_bar:
626+
for i, t in enumerate(timesteps):
627+
# expand the latents if we are doing classifier free guidance
628+
latent_model_input = torch.cat([latents] * 2)
629+
source_latent_model_input = torch.cat([source_latents] * 2)
630+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
631+
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
632+
633+
# predict the noise residual
634+
concat_latent_model_input = torch.stack(
635+
[
636+
source_latent_model_input[0],
637+
latent_model_input[0],
638+
source_latent_model_input[1],
639+
latent_model_input[1],
640+
],
641+
dim=0,
642+
)
643+
concat_text_embeddings = torch.stack(
644+
[
645+
source_text_embeddings[0],
646+
text_embeddings[0],
647+
source_text_embeddings[1],
648+
text_embeddings[1],
649+
],
650+
dim=0,
651+
)
652+
concat_noise_pred = self.unet(
653+
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
654+
).sample
655+
656+
# perform guidance
657+
(
658+
source_noise_pred_uncond,
659+
noise_pred_uncond,
660+
source_noise_pred_text,
661+
noise_pred_text,
662+
) = concat_noise_pred.chunk(4, dim=0)
663+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
664+
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
665+
source_noise_pred_text - source_noise_pred_uncond
666+
)
680667

681-
# call the callback, if provided
682-
if callback is not None and i % callback_steps == 0:
683-
callback(i, t, latents)
668+
# Sample source_latents from the posterior distribution.
669+
prev_source_latents = posterior_sample(
670+
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
671+
)
672+
# Compute noise.
673+
noise = compute_noise(
674+
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
675+
)
676+
source_latents = prev_source_latents
677+
678+
# compute the previous noisy sample x_t -> x_t-1
679+
latents = self.scheduler.step(
680+
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
681+
).prev_sample
682+
683+
# call the callback, if provided
684+
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
685+
progress_bar.update()
686+
if callback is not None and i % callback_steps == 0:
687+
callback(i, t, latents)
684688

685689
# 9. Post-processing
686690
image = self.decode_latents(latents)

pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -540,25 +540,29 @@ def __call__(
540540
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
541541

542542
# 7. Denoising loop
543-
for i, t in enumerate(self.progress_bar(timesteps)):
544-
# expand the latents if we are doing classifier free guidance
545-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
546-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
547-
548-
# predict the noise residual
549-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
550-
551-
# perform guidance
552-
if do_classifier_free_guidance:
553-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
554-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
555-
556-
# compute the previous noisy sample x_t -> x_t-1
557-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
558-
559-
# call the callback, if provided
560-
if callback is not None and i % callback_steps == 0:
561-
callback(i, t, latents)
543+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
544+
with self.progress_bar(total=num_inference_steps) as progress_bar:
545+
for i, t in enumerate(timesteps):
546+
# expand the latents if we are doing classifier free guidance
547+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
548+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
549+
550+
# predict the noise residual
551+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
552+
553+
# perform guidance
554+
if do_classifier_free_guidance:
555+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
556+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
557+
558+
# compute the previous noisy sample x_t -> x_t-1
559+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
560+
561+
# call the callback, if provided
562+
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
563+
progress_bar.update()
564+
if callback is not None and i % callback_steps == 0:
565+
callback(i, t, latents)
562566

563567
# 8. Post-processing
564568
image = self.decode_latents(latents)

0 commit comments

Comments
 (0)