2121import torch
2222
2323from ..configuration_utils import ConfigMixin , register_to_config
24+ from ..utils import randn_tensor
2425from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin , SchedulerOutput
2526
2627
@@ -96,10 +97,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
9697 the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
9798 `algorithm_type="dpmsolver++`.
9899 algorithm_type (`str`, default `dpmsolver++`):
99- the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The ` dpmsolver` type implements the
100- algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
101- https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
102- sampling (e.g. stable-diffusion).
100+ the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde- dpmsolver` or
101+ `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and
102+ the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use
103+ `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion).
103104 solver_type (`str`, default `midpoint`):
104105 the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
105106 the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
@@ -173,7 +174,7 @@ def __init__(
173174 self .init_noise_sigma = 1.0
174175
175176 # settings for DPM-Solver
176- if algorithm_type not in ["dpmsolver" , "dpmsolver++" ]:
177+ if algorithm_type not in ["dpmsolver" , "dpmsolver++" , "sde-dpmsolver" , "sde-dpmsolver++" ]:
177178 if algorithm_type == "deis" :
178179 self .register_to_config (algorithm_type = "dpmsolver++" )
179180 else :
@@ -380,6 +381,7 @@ def dpm_solver_first_order_update(
380381 timestep : int ,
381382 prev_timestep : int ,
382383 sample : torch .FloatTensor ,
384+ noise : Optional [torch .FloatTensor ] = None ,
383385 ) -> torch .FloatTensor :
384386 """
385387 One step for the first-order DPM-Solver (equivalent to DDIM).
@@ -404,6 +406,20 @@ def dpm_solver_first_order_update(
404406 x_t = (sigma_t / sigma_s ) * sample - (alpha_t * (torch .exp (- h ) - 1.0 )) * model_output
405407 elif self .config .algorithm_type == "dpmsolver" :
406408 x_t = (alpha_t / alpha_s ) * sample - (sigma_t * (torch .exp (h ) - 1.0 )) * model_output
409+ elif self .config .algorithm_type == "sde-dpmsolver++" :
410+ assert noise is not None
411+ x_t = (
412+ (sigma_t / sigma_s * torch .exp (- h )) * sample
413+ + (alpha_t * (1 - torch .exp (- 2.0 * h ))) * model_output
414+ + sigma_t * torch .sqrt (1.0 - torch .exp (- 2 * h )) * noise
415+ )
416+ elif self .config .algorithm_type == "sde-dpmsolver" :
417+ assert noise is not None
418+ x_t = (
419+ (alpha_t / alpha_s ) * sample
420+ - 2.0 * (sigma_t * (torch .exp (h ) - 1.0 )) * model_output
421+ + sigma_t * torch .sqrt (torch .exp (2 * h ) - 1.0 ) * noise
422+ )
407423 return x_t
408424
409425 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
@@ -413,6 +429,7 @@ def multistep_dpm_solver_second_order_update(
413429 timestep_list : List [int ],
414430 prev_timestep : int ,
415431 sample : torch .FloatTensor ,
432+ noise : Optional [torch .FloatTensor ] = None ,
416433 ) -> torch .FloatTensor :
417434 """
418435 One step for the second-order multistep DPM-Solver.
@@ -464,6 +481,38 @@ def multistep_dpm_solver_second_order_update(
464481 - (sigma_t * (torch .exp (h ) - 1.0 )) * D0
465482 - (sigma_t * ((torch .exp (h ) - 1.0 ) / h - 1.0 )) * D1
466483 )
484+ elif self .config .algorithm_type == "sde-dpmsolver++" :
485+ assert noise is not None
486+ if self .config .solver_type == "midpoint" :
487+ x_t = (
488+ (sigma_t / sigma_s0 * torch .exp (- h )) * sample
489+ + (alpha_t * (1 - torch .exp (- 2.0 * h ))) * D0
490+ + 0.5 * (alpha_t * (1 - torch .exp (- 2.0 * h ))) * D1
491+ + sigma_t * torch .sqrt (1.0 - torch .exp (- 2 * h )) * noise
492+ )
493+ elif self .config .solver_type == "heun" :
494+ x_t = (
495+ (sigma_t / sigma_s0 * torch .exp (- h )) * sample
496+ + (alpha_t * (1 - torch .exp (- 2.0 * h ))) * D0
497+ + (alpha_t * ((1.0 - torch .exp (- 2.0 * h )) / (- 2.0 * h ) + 1.0 )) * D1
498+ + sigma_t * torch .sqrt (1.0 - torch .exp (- 2 * h )) * noise
499+ )
500+ elif self .config .algorithm_type == "sde-dpmsolver" :
501+ assert noise is not None
502+ if self .config .solver_type == "midpoint" :
503+ x_t = (
504+ (alpha_t / alpha_s0 ) * sample
505+ - 2.0 * (sigma_t * (torch .exp (h ) - 1.0 )) * D0
506+ - (sigma_t * (torch .exp (h ) - 1.0 )) * D1
507+ + sigma_t * torch .sqrt (torch .exp (2 * h ) - 1.0 ) * noise
508+ )
509+ elif self .config .solver_type == "heun" :
510+ x_t = (
511+ (alpha_t / alpha_s0 ) * sample
512+ - 2.0 * (sigma_t * (torch .exp (h ) - 1.0 )) * D0
513+ - 2.0 * (sigma_t * ((torch .exp (h ) - 1.0 ) / h - 1.0 )) * D1
514+ + sigma_t * torch .sqrt (torch .exp (2 * h ) - 1.0 ) * noise
515+ )
467516 return x_t
468517
469518 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
@@ -527,6 +576,7 @@ def step(
527576 model_output : torch .FloatTensor ,
528577 timestep : int ,
529578 sample : torch .FloatTensor ,
579+ generator = None ,
530580 return_dict : bool = True ,
531581 ) -> Union [SchedulerOutput , Tuple ]:
532582 """
@@ -571,12 +621,21 @@ def step(
571621 self .model_outputs [i ] = self .model_outputs [i + 1 ]
572622 self .model_outputs [- 1 ] = model_output
573623
624+ if self .config .algorithm_type in ["sde-dpmsolver" , "sde-dpmsolver++" ]:
625+ noise = randn_tensor (
626+ model_output .shape , generator = generator , device = model_output .device , dtype = model_output .dtype
627+ )
628+ else :
629+ noise = None
630+
574631 if self .config .solver_order == 1 or self .lower_order_nums < 1 or lower_order_final :
575- prev_sample = self .dpm_solver_first_order_update (model_output , timestep , prev_timestep , sample )
632+ prev_sample = self .dpm_solver_first_order_update (
633+ model_output , timestep , prev_timestep , sample , noise = noise
634+ )
576635 elif self .config .solver_order == 2 or self .lower_order_nums < 2 or lower_order_second :
577636 timestep_list = [self .timesteps [step_index - 1 ], timestep ]
578637 prev_sample = self .multistep_dpm_solver_second_order_update (
579- self .model_outputs , timestep_list , prev_timestep , sample
638+ self .model_outputs , timestep_list , prev_timestep , sample , noise = noise
580639 )
581640 else :
582641 timestep_list = [self .timesteps [step_index - 2 ], self .timesteps [step_index - 1 ], timestep ]
0 commit comments