Skip to content

Commit 769f0be

Browse files
Finalize 2nd order schedulers (#1503)
* up * up * finish * finish * up * up * finish
1 parent 4f59659 commit 769f0be

26 files changed

+1020
-36
lines changed

docs/source/api/schedulers.mdx

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,33 @@ Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [im
7676

7777
[[autodoc]] DPMSolverMultistepScheduler
7878

79+
#### Heun scheduler inspired by Karras et. al paper
80+
81+
Algorithm 1 of [Karras et. al](https://arxiv.org/abs/2206.00364).
82+
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
83+
84+
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
85+
86+
[[autodoc]] HeunDiscreteScheduler
87+
88+
#### DPM Discrete Scheduler inspired by Karras et. al paper
89+
90+
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
91+
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
92+
93+
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
94+
95+
[[autodoc]] KDPM2DiscreteScheduler
96+
97+
#### DPM Discrete Scheduler with ancestral sampling inspired by Karras et. al paper
98+
99+
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
100+
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
101+
102+
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
103+
104+
[[autodoc]] KDPM2AncestralDiscreteScheduler
105+
79106
#### Variance exploding, stochastic sampling from Karras et. al
80107

81108
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
@@ -86,7 +113,6 @@ Original paper can be found [here](https://arxiv.org/abs/2006.11239).
86113

87114
Original implementation can be found [here](https://arxiv.org/abs/2206.00364).
88115

89-
90116
[[autodoc]] LMSDiscreteScheduler
91117

92118
#### Pseudo numerical methods for diffusion models (PNDM)

examples/community/sd_text2img_k_diffusion.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from diffusers.pipeline_utils import DiffusionPipeline
2222
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
2323
from diffusers.utils import is_accelerate_available, logging
24-
from k_diffusion.external import CompVisDenoiser
24+
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
2525

2626

2727
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -33,7 +33,12 @@ def __init__(self, model, alphas_cumprod):
3333
self.alphas_cumprod = alphas_cumprod
3434

3535
def apply_model(self, *args, **kwargs):
36-
return self.model(*args, **kwargs).sample
36+
if len(args) == 3:
37+
encoder_hidden_states = args[-1]
38+
args = args[:2]
39+
if kwargs.get("cond", None) is not None:
40+
encoder_hidden_states = kwargs.pop("cond")
41+
return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
3742

3843

3944
class StableDiffusionPipeline(DiffusionPipeline):
@@ -63,6 +68,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
6368
feature_extractor ([`CLIPFeatureExtractor`]):
6469
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
6570
"""
71+
_optional_components = ["safety_checker", "feature_extractor"]
6672

6773
def __init__(
6874
self,
@@ -99,7 +105,10 @@ def __init__(
99105
)
100106

101107
model = ModelWrapper(unet, scheduler.alphas_cumprod)
102-
self.k_diffusion_model = CompVisDenoiser(model)
108+
if scheduler.prediction_type == "v_prediction":
109+
self.k_diffusion_model = CompVisVDenoiser(model)
110+
else:
111+
self.k_diffusion_model = CompVisDenoiser(model)
103112

104113
def set_sampler(self, scheduler_type: str):
105114
library = importlib.import_module("k_diffusion")
@@ -417,6 +426,7 @@ def __call__(
417426
# 4. Prepare timesteps
418427
self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)
419428
sigmas = self.scheduler.sigmas
429+
sigmas = sigmas.to(text_embeddings.dtype)
420430

421431
# 5. Prepare latent variables
422432
num_channels_latents = self.unet.in_channels
@@ -437,7 +447,7 @@ def __call__(
437447
def model_fn(x, t):
438448
latent_model_input = torch.cat([x] * 2)
439449

440-
noise_pred = self.k_diffusion_model(latent_model_input, t, encoder_hidden_states=text_embeddings)
450+
noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings)
441451

442452
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
443453
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
HeunDiscreteScheduler,
5050
IPNDMScheduler,
5151
KarrasVeScheduler,
52+
KDPM2AncestralDiscreteScheduler,
53+
KDPM2DiscreteScheduler,
5254
PNDMScheduler,
5355
RePaintScheduler,
5456
SchedulerMixin,

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def __call__(
558558
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
559559

560560
# call the callback, if provided
561-
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
561+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
562562
progress_bar.update()
563563
if callback is not None and i % callback_steps == 0:
564564
callback(i, t, latents)

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ def __call__(
580580
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
581581

582582
# call the callback, if provided
583-
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
583+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
584584
progress_bar.update()
585585
if callback is not None and i % callback_steps == 0:
586586
callback(i, t, latents)

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ def __call__(
666666
).prev_sample
667667

668668
# call the callback, if provided
669-
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
669+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
670670
progress_bar.update()
671671
if callback is not None and i % callback_steps == 0:
672672
callback(i, t, latents)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def __call__(
557557
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
558558

559559
# call the callback, if provided
560-
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
560+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
561561
progress_bar.update()
562562
if callback is not None and i % callback_steps == 0:
563563
callback(i, t, latents)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def __call__(
440440
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
441441

442442
# call the callback, if provided
443-
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
443+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
444444
progress_bar.update()
445445
if callback is not None and i % callback_steps == 0:
446446
callback(i, t, latents)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def __call__(
587587
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
588588

589589
# call the callback, if provided
590-
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
590+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
591591
progress_bar.update()
592592
if callback is not None and i % callback_steps == 0:
593593
callback(i, t, latents)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ def __call__(
701701
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
702702

703703
# call the callback, if provided
704-
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
704+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
705705
progress_bar.update()
706706
if callback is not None and i % callback_steps == 0:
707707
callback(i, t, latents)

0 commit comments

Comments
 (0)