@@ -55,15 +55,24 @@ def alpha_bar(time_step):
5555
5656class DPMSolverDiscreteScheduler (SchedulerMixin , ConfigMixin ):
5757 """
58- DPM-Solver.
58+ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
59+ the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
60+ samples, and it can generate quite good samples even in only 10 steps.
61+
62+ For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
63+
64+ Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
65+ recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
66+
67+ We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
68+ diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note
69+ that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion).
5970
6071 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
6172 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
6273 [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
6374 [`~ConfigMixin.from_config`] functions.
6475
65- For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
66-
6776 Args:
6877 num_train_timesteps (`int`): number of diffusion steps used to train the model.
6978 beta_start (`float`): the starting `beta` value of inference.
@@ -73,17 +82,26 @@ class DPMSolverDiscreteScheduler(SchedulerMixin, ConfigMixin):
7382 `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
7483 trained_betas (`np.ndarray`, optional):
7584 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
76- skip_prk_steps (`bool`):
77- allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
78- before plms steps; defaults to `False`.
79- set_alpha_to_one (`bool`, default `False`):
80- each diffusion step uses the value of alphas product at that step and at the previous one. For the final
81- step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
82- otherwise it uses the value of alpha at step 0.
83- steps_offset (`int`, default `0`):
84- an offset added to the inference steps. You can use a combination of `offset=1` and
85- `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
86- stable diffusion.
85+ solver_order (`int`, default `2`):
86+ the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
87+ sampling, and `solver_order=3` for unconditional sampling.
88+ predict_x0 (`bool`, default `True`):
89+ DPM-Solver is designed for both the noise prediction model (DPM-Solver, https://arxiv.org/abs/2206.00927)
90+ with `predict_x0=False` and the data prediction model (DPM-Solver++, https://arxiv.org/abs/2211.01095) with
91+ `predict_x0=True`. We recommend to use `predict_x0=True` and `solver_order=2` for guided sampling (e.g.
92+ stable-diffusion).
93+ thresholding (`bool`, default `False`):
94+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
95+ For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the
96+ dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models
97+ (such as stable-diffusion).
98+ sample_max_value (`float`, default `1.0`):
99+ the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
100+ solver_type (`str`, default `dpm_solver`):
101+ the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly
102+ affects the sample quality, especially for small number of steps.
103+ denoise_final (`bool`, default `False`):
104+ whether to use lower-order solvers in the final steps.
87105
88106 """
89107
@@ -183,7 +201,16 @@ def convert_model_output(
183201 self , model_output : torch .FloatTensor , timestep : int , sample : torch .FloatTensor
184202 ) -> torch .FloatTensor :
185203 """
186- TODO
204+ Convert the noise prediction model to either the noise or the data prediction model.
205+
206+ Args:
207+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
208+ timestep (`int`): current discrete timestep in the diffusion chain.
209+ sample (`torch.FloatTensor`):
210+ current instance of sample being created by diffusion process.
211+
212+ Returns:
213+ `torch.FloatTensor`: the converted model output.
187214 """
188215 if self .predict_x0 :
189216 alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
@@ -208,7 +235,17 @@ def dpm_solver_first_order_update(
208235 sample : torch .FloatTensor ,
209236 ) -> torch .FloatTensor :
210237 """
211- TODO
238+ One step for the first-order DPM-Solver (equivalent to DDIM).
239+
240+ Args:
241+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
242+ timestep (`int`): current discrete timestep in the diffusion chain.
243+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
244+ sample (`torch.FloatTensor`):
245+ current instance of sample being created by diffusion process.
246+
247+ Returns:
248+ `torch.FloatTensor`: the sample tensor at the previous timestep.
212249 """
213250 lambda_t , lambda_s = self .lambda_t [prev_timestep ], self .lambda_t [timestep ]
214251 alpha_t , alpha_s = self .alpha_t [prev_timestep ], self .alpha_t [timestep ]
@@ -228,7 +265,18 @@ def multistep_dpm_solver_second_order_update(
228265 sample : torch .FloatTensor ,
229266 ) -> torch .FloatTensor :
230267 """
231- TODO
268+ One step for the second-order multistep DPM-Solver.
269+
270+ Args:
271+ model_output_list (`List[torch.FloatTensor]`):
272+ direct outputs from learned diffusion model at current and latter timesteps.
273+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
274+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
275+ sample (`torch.FloatTensor`):
276+ current instance of sample being created by diffusion process.
277+
278+ Returns:
279+ `torch.FloatTensor`: the sample tensor at the previous timestep.
232280 """
233281 t , s0 , s1 = prev_timestep , timestep_list [- 1 ], timestep_list [- 2 ]
234282 m0 , m1 = model_output_list [- 1 ], model_output_list [- 2 ]
@@ -274,7 +322,18 @@ def multistep_dpm_solver_third_order_update(
274322 sample : torch .FloatTensor ,
275323 ) -> torch .FloatTensor :
276324 """
277- TODO
325+ One step for the third-order multistep DPM-Solver.
326+
327+ Args:
328+ model_output_list (`List[torch.FloatTensor]`):
329+ direct outputs from learned diffusion model at current and latter timesteps.
330+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
331+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
332+ sample (`torch.FloatTensor`):
333+ current instance of sample being created by diffusion process.
334+
335+ Returns:
336+ `torch.FloatTensor`: the sample tensor at the previous timestep.
278337 """
279338 t , s0 , s1 , s2 = prev_timestep , timestep_list [- 1 ], timestep_list [- 2 ], timestep_list [- 3 ]
280339 m0 , m1 , m2 = model_output_list [- 1 ], model_output_list [- 2 ], model_output_list [- 3 ]
@@ -316,8 +375,7 @@ def step(
316375 return_dict : bool = True ,
317376 ) -> Union [SchedulerOutput , Tuple ]:
318377 """
319- Step function propagating the sample with the multistep DPM-Solver. This has one forward pass with multiple
320- times to approximate the solution.
378+ Step function propagating the sample with the multistep DPM-Solver.
321379
322380 Args:
323381 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
0 commit comments