Skip to content

Commit be99201

Browse files
feat : add log-rho deis multistep scheduler (#1432)
* feat : add log-rho deis multistep deis * docs :fix typo * docs : add docs for impl algo * docs : remove duplicate ref * finish deis * add docs * fix Co-authored-by: Patrick von Platen <[email protected]>
1 parent 9b63854 commit be99201

File tree

10 files changed

+743
-4
lines changed

10 files changed

+743
-4
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@
155155
title: "DDIM"
156156
- local: api/schedulers/ddpm
157157
title: "DDPM"
158+
- local: api/schedulers/deis
159+
title: "DEIS"
158160
- local: api/schedulers/singlestep_dpm_solver
159161
title: "Singlestep DPM-Solver"
160162
- local: api/schedulers/multistep_dpm_solver
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# DEIS
14+
15+
Fast Sampling of Diffusion Models with Exponential Integrator.
16+
17+
## Overview
18+
19+
Original paper can be found [here](https://arxiv.org/abs/2204.13902). The original implementation can be found [here](https://github.com/qsh-zh/deis).
20+
21+
## DEISMultistepScheduler
22+
[[autodoc]] DEISMultistepScheduler

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from .schedulers import (
6868
DDIMScheduler,
6969
DDPMScheduler,
70+
DEISMultistepScheduler,
7071
DPMSolverMultistepScheduler,
7172
DPMSolverSinglestepScheduler,
7273
EulerAncestralDiscreteScheduler,

src/diffusers/schedulers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
else:
2525
from .scheduling_ddim import DDIMScheduler
2626
from .scheduling_ddpm import DDPMScheduler
27+
from .scheduling_deis_multistep import DEISMultistepScheduler
2728
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
2829
from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
2930
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 481 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,15 @@ def __init__(
174174

175175
# settings for DPM-Solver
176176
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
177-
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
177+
if algorithm_type == "deis":
178+
algorithm_type = "dpmsolver++"
179+
else:
180+
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
178181
if solver_type not in ["midpoint", "heun"]:
179-
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
182+
if solver_type == "logrho":
183+
solver_type = "midpoint"
184+
else:
185+
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
180186

181187
# setable values
182188
self.num_inference_steps = None

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,15 @@ def __init__(
163163

164164
# settings for DPM-Solver
165165
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
166-
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
166+
if algorithm_type == "deis":
167+
algorithm_type = "dpmsolver++"
168+
else:
169+
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
167170
if solver_type not in ["midpoint", "heun"]:
168-
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
171+
if solver_type == "logrho":
172+
solver_type = "midpoint"
173+
else:
174+
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
169175

170176
# setable values
171177
self.num_inference_steps = None

src/diffusers/utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,7 @@
4141
"EulerAncestralDiscreteScheduler",
4242
"DPMSolverMultistepScheduler",
4343
"DPMSolverSinglestepScheduler",
44+
"KDPM2DiscreteScheduler",
45+
"KDPM2AncestralDiscreteScheduler",
46+
"DEISMultistepScheduler",
4447
]

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,21 @@ def from_pretrained(cls, *args, **kwargs):
362362
requires_backends(cls, ["torch"])
363363

364364

365+
class DEISMultistepScheduler(metaclass=DummyObject):
366+
_backends = ["torch"]
367+
368+
def __init__(self, *args, **kwargs):
369+
requires_backends(self, ["torch"])
370+
371+
@classmethod
372+
def from_config(cls, *args, **kwargs):
373+
requires_backends(cls, ["torch"])
374+
375+
@classmethod
376+
def from_pretrained(cls, *args, **kwargs):
377+
requires_backends(cls, ["torch"])
378+
379+
365380
class DPMSolverMultistepScheduler(metaclass=DummyObject):
366381
_backends = ["torch"]
367382

tests/test_scheduler.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from diffusers import (
2828
DDIMScheduler,
2929
DDPMScheduler,
30+
DEISMultistepScheduler,
3031
DPMSolverMultistepScheduler,
3132
DPMSolverSinglestepScheduler,
3233
EulerAncestralDiscreteScheduler,
@@ -2505,6 +2506,207 @@ def test_full_loop_device(self):
25052506
assert abs(result_mean.item() - 0.0266) < 1e-3
25062507

25072508

2509+
class DEISMultistepSchedulerTest(SchedulerCommonTest):
2510+
scheduler_classes = (DEISMultistepScheduler,)
2511+
forward_default_kwargs = (("num_inference_steps", 25),)
2512+
2513+
def get_scheduler_config(self, **kwargs):
2514+
config = {
2515+
"num_train_timesteps": 1000,
2516+
"beta_start": 0.0001,
2517+
"beta_end": 0.02,
2518+
"beta_schedule": "linear",
2519+
"solver_order": 2,
2520+
}
2521+
2522+
config.update(**kwargs)
2523+
return config
2524+
2525+
def check_over_configs(self, time_step=0, **config):
2526+
kwargs = dict(self.forward_default_kwargs)
2527+
num_inference_steps = kwargs.pop("num_inference_steps", None)
2528+
sample = self.dummy_sample
2529+
residual = 0.1 * sample
2530+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
2531+
2532+
for scheduler_class in self.scheduler_classes:
2533+
scheduler_config = self.get_scheduler_config(**config)
2534+
scheduler = scheduler_class(**scheduler_config)
2535+
scheduler.set_timesteps(num_inference_steps)
2536+
# copy over dummy past residuals
2537+
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
2538+
2539+
with tempfile.TemporaryDirectory() as tmpdirname:
2540+
scheduler.save_config(tmpdirname)
2541+
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
2542+
new_scheduler.set_timesteps(num_inference_steps)
2543+
# copy over dummy past residuals
2544+
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
2545+
2546+
output, new_output = sample, sample
2547+
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
2548+
output = scheduler.step(residual, t, output, **kwargs).prev_sample
2549+
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
2550+
2551+
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
2552+
2553+
def test_from_save_pretrained(self):
2554+
pass
2555+
2556+
def check_over_forward(self, time_step=0, **forward_kwargs):
2557+
kwargs = dict(self.forward_default_kwargs)
2558+
num_inference_steps = kwargs.pop("num_inference_steps", None)
2559+
sample = self.dummy_sample
2560+
residual = 0.1 * sample
2561+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
2562+
2563+
for scheduler_class in self.scheduler_classes:
2564+
scheduler_config = self.get_scheduler_config()
2565+
scheduler = scheduler_class(**scheduler_config)
2566+
scheduler.set_timesteps(num_inference_steps)
2567+
2568+
# copy over dummy past residuals (must be after setting timesteps)
2569+
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
2570+
2571+
with tempfile.TemporaryDirectory() as tmpdirname:
2572+
scheduler.save_config(tmpdirname)
2573+
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
2574+
# copy over dummy past residuals
2575+
new_scheduler.set_timesteps(num_inference_steps)
2576+
2577+
# copy over dummy past residual (must be after setting timesteps)
2578+
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
2579+
2580+
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
2581+
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
2582+
2583+
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
2584+
2585+
def full_loop(self, **config):
2586+
scheduler_class = self.scheduler_classes[0]
2587+
scheduler_config = self.get_scheduler_config(**config)
2588+
scheduler = scheduler_class(**scheduler_config)
2589+
2590+
num_inference_steps = 10
2591+
model = self.dummy_model()
2592+
sample = self.dummy_sample_deter
2593+
scheduler.set_timesteps(num_inference_steps)
2594+
2595+
for i, t in enumerate(scheduler.timesteps):
2596+
residual = model(sample, t)
2597+
sample = scheduler.step(residual, t, sample).prev_sample
2598+
2599+
return sample
2600+
2601+
def test_step_shape(self):
2602+
kwargs = dict(self.forward_default_kwargs)
2603+
2604+
num_inference_steps = kwargs.pop("num_inference_steps", None)
2605+
2606+
for scheduler_class in self.scheduler_classes:
2607+
scheduler_config = self.get_scheduler_config()
2608+
scheduler = scheduler_class(**scheduler_config)
2609+
2610+
sample = self.dummy_sample
2611+
residual = 0.1 * sample
2612+
2613+
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
2614+
scheduler.set_timesteps(num_inference_steps)
2615+
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
2616+
kwargs["num_inference_steps"] = num_inference_steps
2617+
2618+
# copy over dummy past residuals (must be done after set_timesteps)
2619+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
2620+
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
2621+
2622+
time_step_0 = scheduler.timesteps[5]
2623+
time_step_1 = scheduler.timesteps[6]
2624+
2625+
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
2626+
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
2627+
2628+
self.assertEqual(output_0.shape, sample.shape)
2629+
self.assertEqual(output_0.shape, output_1.shape)
2630+
2631+
def test_timesteps(self):
2632+
for timesteps in [25, 50, 100, 999, 1000]:
2633+
self.check_over_configs(num_train_timesteps=timesteps)
2634+
2635+
def test_thresholding(self):
2636+
self.check_over_configs(thresholding=False)
2637+
for order in [1, 2, 3]:
2638+
for solver_type in ["logrho"]:
2639+
for threshold in [0.5, 1.0, 2.0]:
2640+
for prediction_type in ["epsilon", "sample"]:
2641+
self.check_over_configs(
2642+
thresholding=True,
2643+
prediction_type=prediction_type,
2644+
sample_max_value=threshold,
2645+
algorithm_type="deis",
2646+
solver_order=order,
2647+
solver_type=solver_type,
2648+
)
2649+
2650+
def test_prediction_type(self):
2651+
for prediction_type in ["epsilon", "v_prediction"]:
2652+
self.check_over_configs(prediction_type=prediction_type)
2653+
2654+
def test_solver_order_and_type(self):
2655+
for algorithm_type in ["deis"]:
2656+
for solver_type in ["logrho"]:
2657+
for order in [1, 2, 3]:
2658+
for prediction_type in ["epsilon", "sample"]:
2659+
self.check_over_configs(
2660+
solver_order=order,
2661+
solver_type=solver_type,
2662+
prediction_type=prediction_type,
2663+
algorithm_type=algorithm_type,
2664+
)
2665+
sample = self.full_loop(
2666+
solver_order=order,
2667+
solver_type=solver_type,
2668+
prediction_type=prediction_type,
2669+
algorithm_type=algorithm_type,
2670+
)
2671+
assert not torch.isnan(sample).any(), "Samples have nan numbers"
2672+
2673+
def test_lower_order_final(self):
2674+
self.check_over_configs(lower_order_final=True)
2675+
self.check_over_configs(lower_order_final=False)
2676+
2677+
def test_inference_steps(self):
2678+
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
2679+
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
2680+
2681+
def test_full_loop_no_noise(self):
2682+
sample = self.full_loop()
2683+
result_mean = torch.mean(torch.abs(sample))
2684+
2685+
assert abs(result_mean.item() - 0.23916) < 1e-3
2686+
2687+
def test_full_loop_with_v_prediction(self):
2688+
sample = self.full_loop(prediction_type="v_prediction")
2689+
result_mean = torch.mean(torch.abs(sample))
2690+
2691+
assert abs(result_mean.item() - 0.091) < 1e-3
2692+
2693+
def test_fp16_support(self):
2694+
scheduler_class = self.scheduler_classes[0]
2695+
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
2696+
scheduler = scheduler_class(**scheduler_config)
2697+
2698+
num_inference_steps = 10
2699+
model = self.dummy_model()
2700+
sample = self.dummy_sample_deter.half()
2701+
scheduler.set_timesteps(num_inference_steps)
2702+
2703+
for i, t in enumerate(scheduler.timesteps):
2704+
residual = model(sample, t)
2705+
sample = scheduler.step(residual, t, sample).prev_sample
2706+
2707+
assert sample.dtype == torch.float16
2708+
2709+
25082710
class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
25092711
scheduler_classes = (KDPM2AncestralDiscreteScheduler,)
25102712
num_inference_steps = 10

0 commit comments

Comments
 (0)