2121import torch
2222
2323from ..configuration_utils import ConfigMixin , register_to_config
24+ from ..utils import randn_tensor
2425from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin , SchedulerOutput
2526
2627
@@ -70,6 +71,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
7071 thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
7172 stable-diffusion).
7273
74+ We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse
75+ diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the
76+ second-order `sde-dpmsolver++`.
77+
7378 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
7479 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
7580 [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
@@ -103,10 +108,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
103108 the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
104109 `algorithm_type="dpmsolver++`.
105110 algorithm_type (`str`, default `dpmsolver++`):
106- the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The ` dpmsolver` type implements the
107- algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
108- https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
109- sampling (e.g. stable-diffusion).
111+ the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde- dpmsolver` or
112+ `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and
113+ the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use
114+ `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion).
110115 solver_type (`str`, default `midpoint`):
111116 the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
112117 the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
@@ -180,7 +185,7 @@ def __init__(
180185 self .init_noise_sigma = 1.0
181186
182187 # settings for DPM-Solver
183- if algorithm_type not in ["dpmsolver" , "dpmsolver++" ]:
188+ if algorithm_type not in ["dpmsolver" , "dpmsolver++" , "sde-dpmsolver" , "sde-dpmsolver++" ]:
184189 if algorithm_type == "deis" :
185190 self .register_to_config (algorithm_type = "dpmsolver++" )
186191 else :
@@ -212,7 +217,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
212217 """
213218 # Clipping the minimum of all lambda(t) for numerical stability.
214219 # This is critical for cosine (squaredcos_cap_v2) noise schedule.
215- clipped_idx = torch .searchsorted (torch .flip (self .lambda_t , [0 ]), self .lambda_min_clipped )
220+ clipped_idx = torch .searchsorted (torch .flip (self .lambda_t , [0 ]), self .config . lambda_min_clipped )
216221 timesteps = (
217222 np .linspace (0 , self .config .num_train_timesteps - 1 - clipped_idx , num_inference_steps + 1 )
218223 .round ()[::- 1 ][:- 1 ]
@@ -338,10 +343,10 @@ def convert_model_output(
338343 """
339344
340345 # DPM-Solver++ needs to solve an integral of the data prediction model.
341- if self .config .algorithm_type == "dpmsolver++" :
346+ if self .config .algorithm_type in [ "dpmsolver++" , "sde-dpmsolver++" ] :
342347 if self .config .prediction_type == "epsilon" :
343348 # DPM-Solver and DPM-Solver++ only need the "mean" output.
344- if self .config .variance_type in ["learned_range" ]:
349+ if self .config .variance_type in ["learned" , " learned_range" ]:
345350 model_output = model_output [:, :3 ]
346351 alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
347352 x0_pred = (sample - sigma_t * model_output ) / alpha_t
@@ -360,33 +365,42 @@ def convert_model_output(
360365 x0_pred = self ._threshold_sample (x0_pred )
361366
362367 return x0_pred
368+
363369 # DPM-Solver needs to solve an integral of the noise prediction model.
364- elif self .config .algorithm_type == "dpmsolver" :
370+ elif self .config .algorithm_type in [ "dpmsolver" , "sde-dpmsolver" ] :
365371 if self .config .prediction_type == "epsilon" :
366372 # DPM-Solver and DPM-Solver++ only need the "mean" output.
367- if self .config .variance_type in ["learned_range" ]:
368- model_output = model_output [:, :3 ]
369- return model_output
373+ if self .config .variance_type in ["learned" , "learned_range" ]:
374+ epsilon = model_output [:, :3 ]
375+ else :
376+ epsilon = model_output
370377 elif self .config .prediction_type == "sample" :
371378 alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
372379 epsilon = (sample - alpha_t * model_output ) / sigma_t
373- return epsilon
374380 elif self .config .prediction_type == "v_prediction" :
375381 alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
376382 epsilon = alpha_t * model_output + sigma_t * sample
377- return epsilon
378383 else :
379384 raise ValueError (
380385 f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample`, or"
381386 " `v_prediction` for the DPMSolverMultistepScheduler."
382387 )
383388
389+ if self .config .thresholding :
390+ alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
391+ x0_pred = (sample - sigma_t * epsilon ) / alpha_t
392+ x0_pred = self ._threshold_sample (x0_pred )
393+ epsilon = (sample - alpha_t * x0_pred ) / sigma_t
394+
395+ return epsilon
396+
384397 def dpm_solver_first_order_update (
385398 self ,
386399 model_output : torch .FloatTensor ,
387400 timestep : int ,
388401 prev_timestep : int ,
389402 sample : torch .FloatTensor ,
403+ noise : Optional [torch .FloatTensor ] = None ,
390404 ) -> torch .FloatTensor :
391405 """
392406 One step for the first-order DPM-Solver (equivalent to DDIM).
@@ -411,6 +425,20 @@ def dpm_solver_first_order_update(
411425 x_t = (sigma_t / sigma_s ) * sample - (alpha_t * (torch .exp (- h ) - 1.0 )) * model_output
412426 elif self .config .algorithm_type == "dpmsolver" :
413427 x_t = (alpha_t / alpha_s ) * sample - (sigma_t * (torch .exp (h ) - 1.0 )) * model_output
428+ elif self .config .algorithm_type == "sde-dpmsolver++" :
429+ assert noise is not None
430+ x_t = (
431+ (sigma_t / sigma_s * torch .exp (- h )) * sample
432+ + (alpha_t * (1 - torch .exp (- 2.0 * h ))) * model_output
433+ + sigma_t * torch .sqrt (1.0 - torch .exp (- 2 * h )) * noise
434+ )
435+ elif self .config .algorithm_type == "sde-dpmsolver" :
436+ assert noise is not None
437+ x_t = (
438+ (alpha_t / alpha_s ) * sample
439+ - 2.0 * (sigma_t * (torch .exp (h ) - 1.0 )) * model_output
440+ + sigma_t * torch .sqrt (torch .exp (2 * h ) - 1.0 ) * noise
441+ )
414442 return x_t
415443
416444 def multistep_dpm_solver_second_order_update (
@@ -419,6 +447,7 @@ def multistep_dpm_solver_second_order_update(
419447 timestep_list : List [int ],
420448 prev_timestep : int ,
421449 sample : torch .FloatTensor ,
450+ noise : Optional [torch .FloatTensor ] = None ,
422451 ) -> torch .FloatTensor :
423452 """
424453 One step for the second-order multistep DPM-Solver.
@@ -470,6 +499,38 @@ def multistep_dpm_solver_second_order_update(
470499 - (sigma_t * (torch .exp (h ) - 1.0 )) * D0
471500 - (sigma_t * ((torch .exp (h ) - 1.0 ) / h - 1.0 )) * D1
472501 )
502+ elif self .config .algorithm_type == "sde-dpmsolver++" :
503+ assert noise is not None
504+ if self .config .solver_type == "midpoint" :
505+ x_t = (
506+ (sigma_t / sigma_s0 * torch .exp (- h )) * sample
507+ + (alpha_t * (1 - torch .exp (- 2.0 * h ))) * D0
508+ + 0.5 * (alpha_t * (1 - torch .exp (- 2.0 * h ))) * D1
509+ + sigma_t * torch .sqrt (1.0 - torch .exp (- 2 * h )) * noise
510+ )
511+ elif self .config .solver_type == "heun" :
512+ x_t = (
513+ (sigma_t / sigma_s0 * torch .exp (- h )) * sample
514+ + (alpha_t * (1 - torch .exp (- 2.0 * h ))) * D0
515+ + (alpha_t * ((1.0 - torch .exp (- 2.0 * h )) / (- 2.0 * h ) + 1.0 )) * D1
516+ + sigma_t * torch .sqrt (1.0 - torch .exp (- 2 * h )) * noise
517+ )
518+ elif self .config .algorithm_type == "sde-dpmsolver" :
519+ assert noise is not None
520+ if self .config .solver_type == "midpoint" :
521+ x_t = (
522+ (alpha_t / alpha_s0 ) * sample
523+ - 2.0 * (sigma_t * (torch .exp (h ) - 1.0 )) * D0
524+ - (sigma_t * (torch .exp (h ) - 1.0 )) * D1
525+ + sigma_t * torch .sqrt (torch .exp (2 * h ) - 1.0 ) * noise
526+ )
527+ elif self .config .solver_type == "heun" :
528+ x_t = (
529+ (alpha_t / alpha_s0 ) * sample
530+ - 2.0 * (sigma_t * (torch .exp (h ) - 1.0 )) * D0
531+ - 2.0 * (sigma_t * ((torch .exp (h ) - 1.0 ) / h - 1.0 )) * D1
532+ + sigma_t * torch .sqrt (torch .exp (2 * h ) - 1.0 ) * noise
533+ )
473534 return x_t
474535
475536 def multistep_dpm_solver_third_order_update (
@@ -532,6 +593,7 @@ def step(
532593 model_output : torch .FloatTensor ,
533594 timestep : int ,
534595 sample : torch .FloatTensor ,
596+ generator = None ,
535597 return_dict : bool = True ,
536598 ) -> Union [SchedulerOutput , Tuple ]:
537599 """
@@ -574,12 +636,21 @@ def step(
574636 self .model_outputs [i ] = self .model_outputs [i + 1 ]
575637 self .model_outputs [- 1 ] = model_output
576638
639+ if self .config .algorithm_type in ["sde-dpmsolver" , "sde-dpmsolver++" ]:
640+ noise = randn_tensor (
641+ model_output .shape , generator = generator , device = model_output .device , dtype = model_output .dtype
642+ )
643+ else :
644+ noise = None
645+
577646 if self .config .solver_order == 1 or self .lower_order_nums < 1 or lower_order_final :
578- prev_sample = self .dpm_solver_first_order_update (model_output , timestep , prev_timestep , sample )
647+ prev_sample = self .dpm_solver_first_order_update (
648+ model_output , timestep , prev_timestep , sample , noise = noise
649+ )
579650 elif self .config .solver_order == 2 or self .lower_order_nums < 2 or lower_order_second :
580651 timestep_list = [self .timesteps [step_index - 1 ], timestep ]
581652 prev_sample = self .multistep_dpm_solver_second_order_update (
582- self .model_outputs , timestep_list , prev_timestep , sample
653+ self .model_outputs , timestep_list , prev_timestep , sample , noise = noise
583654 )
584655 else :
585656 timestep_list = [self .timesteps [step_index - 2 ], self .timesteps [step_index - 1 ], timestep ]
0 commit comments