@@ -118,6 +118,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
118118 This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
119119 noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
120120 of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
121+ lambda_min_clipped (`float`, default `-inf`):
122+ the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
123+ cosine (squaredcos_cap_v2) noise schedule.
124+ variance_type (`str`, *optional*):
125+ Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
126+ guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
127+ Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
128+ diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
129+ guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
130+ Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
131+ diffusion ODEs.
121132 """
122133
123134 _compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -140,6 +151,8 @@ def __init__(
140151 solver_type : str = "midpoint" ,
141152 lower_order_final : bool = True ,
142153 use_karras_sigmas : Optional [bool ] = False ,
154+ lambda_min_clipped : float = - float ("inf" ),
155+ variance_type : Optional [str ] = None ,
143156 ):
144157 if trained_betas is not None :
145158 self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -187,7 +200,7 @@ def __init__(
187200 self .lower_order_nums = 0
188201 self .use_karras_sigmas = use_karras_sigmas
189202
190- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
203+ def set_timesteps (self , num_inference_steps : int = None , device : Union [str , torch .device ] = None ):
191204 """
192205 Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
193206
@@ -197,8 +210,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
197210 device (`str` or `torch.device`, optional):
198211 the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
199212 """
213+ # Clipping the minimum of all lambda(t) for numerical stability.
214+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
215+ clipped_idx = torch .searchsorted (torch .flip (self .lambda_t , [0 ]), self .lambda_min_clipped )
200216 timesteps = (
201- np .linspace (0 , self .config .num_train_timesteps - 1 , num_inference_steps + 1 )
217+ np .linspace (0 , self .config .num_train_timesteps - 1 - clipped_idx , num_inference_steps + 1 )
202218 .round ()[::- 1 ][:- 1 ]
203219 .copy ()
204220 .astype (np .int64 )
@@ -320,9 +336,13 @@ def convert_model_output(
320336 Returns:
321337 `torch.FloatTensor`: the converted model output.
322338 """
339+
323340 # DPM-Solver++ needs to solve an integral of the data prediction model.
324341 if self .config .algorithm_type == "dpmsolver++" :
325342 if self .config .prediction_type == "epsilon" :
343+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
344+ if self .config .variance_type in ["learned_range" ]:
345+ model_output = model_output [:, :3 ]
326346 alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
327347 x0_pred = (sample - sigma_t * model_output ) / alpha_t
328348 elif self .config .prediction_type == "sample" :
@@ -343,6 +363,9 @@ def convert_model_output(
343363 # DPM-Solver needs to solve an integral of the noise prediction model.
344364 elif self .config .algorithm_type == "dpmsolver" :
345365 if self .config .prediction_type == "epsilon" :
366+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
367+ if self .config .variance_type in ["learned_range" ]:
368+ model_output = model_output [:, :3 ]
346369 return model_output
347370 elif self .config .prediction_type == "sample" :
348371 alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
0 commit comments