Skip to content

Commit b868bd6

Browse files
committed
add docs
1 parent 1196092 commit b868bd6

File tree

2 files changed

+155
-36
lines changed

2 files changed

+155
-36
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_discrete.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,24 @@ def alpha_bar(time_step):
5555

5656
class DPMSolverDiscreteScheduler(SchedulerMixin, ConfigMixin):
5757
"""
58-
DPM-Solver.
58+
DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
59+
the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
60+
samples, and it can generate quite good samples even in only 10 steps.
61+
62+
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
63+
64+
Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
65+
recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
66+
67+
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
68+
diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note
69+
that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion).
5970
6071
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
6172
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
6273
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
6374
[`~ConfigMixin.from_config`] functions.
6475
65-
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
66-
6776
Args:
6877
num_train_timesteps (`int`): number of diffusion steps used to train the model.
6978
beta_start (`float`): the starting `beta` value of inference.
@@ -73,17 +82,26 @@ class DPMSolverDiscreteScheduler(SchedulerMixin, ConfigMixin):
7382
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
7483
trained_betas (`np.ndarray`, optional):
7584
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
76-
skip_prk_steps (`bool`):
77-
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
78-
before plms steps; defaults to `False`.
79-
set_alpha_to_one (`bool`, default `False`):
80-
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
81-
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
82-
otherwise it uses the value of alpha at step 0.
83-
steps_offset (`int`, default `0`):
84-
an offset added to the inference steps. You can use a combination of `offset=1` and
85-
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
86-
stable diffusion.
85+
solver_order (`int`, default `2`):
86+
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
87+
sampling, and `solver_order=3` for unconditional sampling.
88+
predict_x0 (`bool`, default `True`):
89+
DPM-Solver is designed for both the noise prediction model (DPM-Solver, https://arxiv.org/abs/2206.00927)
90+
with `predict_x0=False` and the data prediction model (DPM-Solver++, https://arxiv.org/abs/2211.01095) with
91+
`predict_x0=True`. We recommend to use `predict_x0=True` and `solver_order=2` for guided sampling (e.g.
92+
stable-diffusion).
93+
thresholding (`bool`, default `False`):
94+
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
95+
For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the
96+
dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models
97+
(such as stable-diffusion).
98+
sample_max_value (`float`, default `1.0`):
99+
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
100+
solver_type (`str`, default `dpm_solver`):
101+
the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly
102+
affects the sample quality, especially for small number of steps.
103+
denoise_final (`bool`, default `False`):
104+
whether to use lower-order solvers in the final steps.
87105
88106
"""
89107

@@ -183,7 +201,16 @@ def convert_model_output(
183201
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
184202
) -> torch.FloatTensor:
185203
"""
186-
TODO
204+
Convert the noise prediction model to either the noise or the data prediction model.
205+
206+
Args:
207+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
208+
timestep (`int`): current discrete timestep in the diffusion chain.
209+
sample (`torch.FloatTensor`):
210+
current instance of sample being created by diffusion process.
211+
212+
Returns:
213+
`torch.FloatTensor`: the converted model output.
187214
"""
188215
if self.predict_x0:
189216
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
@@ -208,7 +235,17 @@ def dpm_solver_first_order_update(
208235
sample: torch.FloatTensor,
209236
) -> torch.FloatTensor:
210237
"""
211-
TODO
238+
One step for the first-order DPM-Solver (equivalent to DDIM).
239+
240+
Args:
241+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
242+
timestep (`int`): current discrete timestep in the diffusion chain.
243+
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
244+
sample (`torch.FloatTensor`):
245+
current instance of sample being created by diffusion process.
246+
247+
Returns:
248+
`torch.FloatTensor`: the sample tensor at the previous timestep.
212249
"""
213250
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
214251
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
@@ -228,7 +265,18 @@ def multistep_dpm_solver_second_order_update(
228265
sample: torch.FloatTensor,
229266
) -> torch.FloatTensor:
230267
"""
231-
TODO
268+
One step for the second-order multistep DPM-Solver.
269+
270+
Args:
271+
model_output_list (`List[torch.FloatTensor]`):
272+
direct outputs from learned diffusion model at current and latter timesteps.
273+
timestep (`int`): current and latter discrete timestep in the diffusion chain.
274+
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
275+
sample (`torch.FloatTensor`):
276+
current instance of sample being created by diffusion process.
277+
278+
Returns:
279+
`torch.FloatTensor`: the sample tensor at the previous timestep.
232280
"""
233281
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
234282
m0, m1 = model_output_list[-1], model_output_list[-2]
@@ -274,7 +322,18 @@ def multistep_dpm_solver_third_order_update(
274322
sample: torch.FloatTensor,
275323
) -> torch.FloatTensor:
276324
"""
277-
TODO
325+
One step for the third-order multistep DPM-Solver.
326+
327+
Args:
328+
model_output_list (`List[torch.FloatTensor]`):
329+
direct outputs from learned diffusion model at current and latter timesteps.
330+
timestep (`int`): current and latter discrete timestep in the diffusion chain.
331+
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
332+
sample (`torch.FloatTensor`):
333+
current instance of sample being created by diffusion process.
334+
335+
Returns:
336+
`torch.FloatTensor`: the sample tensor at the previous timestep.
278337
"""
279338
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
280339
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
@@ -316,8 +375,7 @@ def step(
316375
return_dict: bool = True,
317376
) -> Union[SchedulerOutput, Tuple]:
318377
"""
319-
Step function propagating the sample with the multistep DPM-Solver. This has one forward pass with multiple
320-
times to approximate the solution.
378+
Step function propagating the sample with the multistep DPM-Solver.
321379
322380
Args:
323381
model_output (`torch.FloatTensor`): direct output from learned diffusion model.

src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,18 @@ class FlaxDPMSolverDiscreteSchedulerOutput(FlaxSchedulerOutput):
8080

8181
class FlaxDPMSolverDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
8282
"""
83-
DPM-Solver.
83+
DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
84+
the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
85+
samples, and it can generate quite good samples even in only 10 steps.
86+
87+
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
88+
89+
Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
90+
recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
91+
92+
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
93+
diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note
94+
that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion).
8495
8596
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
8697
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
@@ -98,17 +109,26 @@ class FlaxDPMSolverDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
98109
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
99110
trained_betas (`np.ndarray`, optional):
100111
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
101-
skip_prk_steps (`bool`):
102-
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
103-
before plms steps; defaults to `False`.
104-
set_alpha_to_one (`bool`, default `False`):
105-
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
106-
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
107-
otherwise it uses the value of alpha at step 0.
108-
steps_offset (`int`, default `0`):
109-
an offset added to the inference steps. You can use a combination of `offset=1` and
110-
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
111-
stable diffusion.
112+
solver_order (`int`, default `2`):
113+
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
114+
sampling, and `solver_order=3` for unconditional sampling.
115+
predict_x0 (`bool`, default `True`):
116+
DPM-Solver is designed for both the noise prediction model (DPM-Solver, https://arxiv.org/abs/2206.00927)
117+
with `predict_x0=False` and the data prediction model (DPM-Solver++, https://arxiv.org/abs/2211.01095) with
118+
`predict_x0=True`. We recommend to use `predict_x0=True` and `solver_order=2` for guided sampling (e.g.
119+
stable-diffusion).
120+
thresholding (`bool`, default `False`):
121+
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
122+
For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the
123+
dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models
124+
(such as stable-diffusion).
125+
sample_max_value (`float`, default `1.0`):
126+
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
127+
solver_type (`str`, default `dpm_solver`):
128+
the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly
129+
affects the sample quality, especially for small number of steps.
130+
denoise_final (`bool`, default `False`):
131+
whether to use lower-order solvers in the final steps.
112132
113133
"""
114134

@@ -205,7 +225,16 @@ def convert_model_output(
205225
sample: jnp.ndarray,
206226
) -> jnp.ndarray:
207227
"""
208-
TODO
228+
Convert the noise prediction model to either the noise or the data prediction model.
229+
230+
Args:
231+
model_output (`jnp.ndarray`): direct output from learned diffusion model.
232+
timestep (`int`): current discrete timestep in the diffusion chain.
233+
sample (`jnp.ndarray`):
234+
current instance of sample being created by diffusion process.
235+
236+
Returns:
237+
`jnp.ndarray`: the converted model output.
209238
"""
210239
if self.predict_x0:
211240
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
@@ -224,7 +253,17 @@ def dpm_solver_first_order_update(
224253
self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray
225254
) -> jnp.ndarray:
226255
"""
227-
TODO
256+
One step for the first-order DPM-Solver (equivalent to DDIM).
257+
258+
Args:
259+
model_output (`jnp.ndarray`): direct output from learned diffusion model.
260+
timestep (`int`): current discrete timestep in the diffusion chain.
261+
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
262+
sample (`jnp.ndarray`):
263+
current instance of sample being created by diffusion process.
264+
265+
Returns:
266+
`jnp.ndarray`: the sample tensor at the previous timestep.
228267
"""
229268
t, s0 = prev_timestep, timestep
230269
m0 = model_output
@@ -246,7 +285,18 @@ def multistep_dpm_solver_second_order_update(
246285
sample: jnp.ndarray,
247286
) -> jnp.ndarray:
248287
"""
249-
TODO
288+
One step for the second-order multistep DPM-Solver.
289+
290+
Args:
291+
model_output_list (`List[jnp.ndarray]`):
292+
direct outputs from learned diffusion model at current and latter timesteps.
293+
timestep (`int`): current and latter discrete timestep in the diffusion chain.
294+
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
295+
sample (`jnp.ndarray`):
296+
current instance of sample being created by diffusion process.
297+
298+
Returns:
299+
`jnp.ndarray`: the sample tensor at the previous timestep.
250300
"""
251301
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
252302
m0, m1 = model_output_list[-1], model_output_list[-2]
@@ -292,7 +342,18 @@ def multistep_dpm_solver_third_order_update(
292342
sample: jnp.ndarray,
293343
) -> jnp.ndarray:
294344
"""
295-
TODO
345+
One step for the third-order multistep DPM-Solver.
346+
347+
Args:
348+
model_output_list (`List[jnp.ndarray]`):
349+
direct outputs from learned diffusion model at current and latter timesteps.
350+
timestep (`int`): current and latter discrete timestep in the diffusion chain.
351+
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
352+
sample (`jnp.ndarray`):
353+
current instance of sample being created by diffusion process.
354+
355+
Returns:
356+
`jnp.ndarray`: the sample tensor at the previous timestep.
296357
"""
297358
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
298359
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]

0 commit comments

Comments
 (0)