Skip to content

Commit 2dd17e7

Browse files
committed
Add inverse sde-dpmsolver steps to tune image diversity from inverted latents
1 parent 1ba5739 commit 2dd17e7

File tree

1 file changed

+66
-7
lines changed

1 file changed

+66
-7
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222

2323
from ..configuration_utils import ConfigMixin, register_to_config
24+
from ..utils import randn_tensor
2425
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
2526

2627

@@ -96,10 +97,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
9697
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
9798
`algorithm_type="dpmsolver++`.
9899
algorithm_type (`str`, default `dpmsolver++`):
99-
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
100-
algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
101-
https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
102-
sampling (e.g. stable-diffusion).
100+
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or
101+
`sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and
102+
the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use
103+
`dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion).
103104
solver_type (`str`, default `midpoint`):
104105
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
105106
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
@@ -173,7 +174,7 @@ def __init__(
173174
self.init_noise_sigma = 1.0
174175

175176
# settings for DPM-Solver
176-
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
177+
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
177178
if algorithm_type == "deis":
178179
self.register_to_config(algorithm_type="dpmsolver++")
179180
else:
@@ -380,6 +381,7 @@ def dpm_solver_first_order_update(
380381
timestep: int,
381382
prev_timestep: int,
382383
sample: torch.FloatTensor,
384+
noise: Optional[torch.FloatTensor] = None,
383385
) -> torch.FloatTensor:
384386
"""
385387
One step for the first-order DPM-Solver (equivalent to DDIM).
@@ -404,6 +406,20 @@ def dpm_solver_first_order_update(
404406
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
405407
elif self.config.algorithm_type == "dpmsolver":
406408
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
409+
elif self.config.algorithm_type == "sde-dpmsolver++":
410+
assert noise is not None
411+
x_t = (
412+
(sigma_t / sigma_s * torch.exp(-h)) * sample
413+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
414+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
415+
)
416+
elif self.config.algorithm_type == "sde-dpmsolver":
417+
assert noise is not None
418+
x_t = (
419+
(alpha_t / alpha_s) * sample
420+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
421+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
422+
)
407423
return x_t
408424

409425
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
@@ -413,6 +429,7 @@ def multistep_dpm_solver_second_order_update(
413429
timestep_list: List[int],
414430
prev_timestep: int,
415431
sample: torch.FloatTensor,
432+
noise: Optional[torch.FloatTensor] = None,
416433
) -> torch.FloatTensor:
417434
"""
418435
One step for the second-order multistep DPM-Solver.
@@ -464,6 +481,38 @@ def multistep_dpm_solver_second_order_update(
464481
- (sigma_t * (torch.exp(h) - 1.0)) * D0
465482
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
466483
)
484+
elif self.config.algorithm_type == "sde-dpmsolver++":
485+
assert noise is not None
486+
if self.config.solver_type == "midpoint":
487+
x_t = (
488+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
489+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
490+
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
491+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
492+
)
493+
elif self.config.solver_type == "heun":
494+
x_t = (
495+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
496+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
497+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
498+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
499+
)
500+
elif self.config.algorithm_type == "sde-dpmsolver":
501+
assert noise is not None
502+
if self.config.solver_type == "midpoint":
503+
x_t = (
504+
(alpha_t / alpha_s0) * sample
505+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
506+
- (sigma_t * (torch.exp(h) - 1.0)) * D1
507+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
508+
)
509+
elif self.config.solver_type == "heun":
510+
x_t = (
511+
(alpha_t / alpha_s0) * sample
512+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
513+
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
514+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
515+
)
467516
return x_t
468517

469518
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
@@ -527,6 +576,7 @@ def step(
527576
model_output: torch.FloatTensor,
528577
timestep: int,
529578
sample: torch.FloatTensor,
579+
generator=None,
530580
return_dict: bool = True,
531581
) -> Union[SchedulerOutput, Tuple]:
532582
"""
@@ -571,12 +621,21 @@ def step(
571621
self.model_outputs[i] = self.model_outputs[i + 1]
572622
self.model_outputs[-1] = model_output
573623

624+
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
625+
noise = randn_tensor(
626+
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
627+
)
628+
else:
629+
noise = None
630+
574631
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
575-
prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
632+
prev_sample = self.dpm_solver_first_order_update(
633+
model_output, timestep, prev_timestep, sample, noise=noise
634+
)
576635
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
577636
timestep_list = [self.timesteps[step_index - 1], timestep]
578637
prev_sample = self.multistep_dpm_solver_second_order_update(
579-
self.model_outputs, timestep_list, prev_timestep, sample
638+
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
580639
)
581640
else:
582641
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]

0 commit comments

Comments
 (0)