@@ -103,8 +103,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
103103 solver_type (`str`, default `dpm_solver`):
104104 the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly
105105 affects the sample quality, especially for small number of steps.
106- denoise_final (`bool`, default `True`):
107- whether to use lower-order solvers in the final steps.
106+ lower_order_final (`bool`, default `True`):
107+ whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
108+ find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
108109
109110 """
110111
@@ -131,7 +132,7 @@ def __init__(
131132 dynamic_thresholding_ratio : float = 0.995 ,
132133 sample_max_value : float = 1.0 ,
133134 solver_type : str = "dpm_solver" ,
134- denoise_final : bool = True ,
135+ lower_order_final : bool = True ,
135136 ):
136137 if trained_betas is not None :
137138 self .betas = torch .from_numpy (trained_betas )
@@ -405,17 +406,21 @@ def step(
405406 else :
406407 step_index = step_index .item ()
407408 prev_timestep = 0 if step_index == len (self .timesteps ) - 1 else self .timesteps [step_index + 1 ]
408- denoise_final = (step_index == len (self .timesteps ) - 1 ) and self .config .denoise_final
409- denoise_second = (step_index == len (self .timesteps ) - 2 ) and self .config .denoise_final
409+ lower_order_final = (
410+ (step_index == len (self .timesteps ) - 1 ) and self .config .lower_order_final and len (self .timesteps ) < 15
411+ )
412+ lower_order_second = (
413+ (step_index == len (self .timesteps ) - 2 ) and self .config .lower_order_final and len (self .timesteps ) < 15
414+ )
410415
411416 model_output = self .convert_model_output (model_output , timestep , sample )
412417 for i in range (self .config .solver_order - 1 ):
413418 self .model_outputs [i ] = self .model_outputs [i + 1 ]
414419 self .model_outputs [- 1 ] = model_output
415420
416- if self .config .solver_order == 1 or self .lower_order_nums < 1 or denoise_final :
421+ if self .config .solver_order == 1 or self .lower_order_nums < 1 or lower_order_final :
417422 prev_sample = self .dpm_solver_first_order_update (model_output , timestep , prev_timestep , sample )
418- elif self .config .solver_order == 2 or self .lower_order_nums < 2 or denoise_second :
423+ elif self .config .solver_order == 2 or self .lower_order_nums < 2 or lower_order_second :
419424 timestep_list = [self .timesteps [step_index - 1 ], timestep ]
420425 prev_sample = self .multistep_dpm_solver_second_order_update (
421426 self .model_outputs , timestep_list , prev_timestep , sample
0 commit comments