Skip to content

Commit a1ea8c0

Browse files
hlkypatil-surajanton-lpcuencapatrickvonplaten
authored
k-diffusion-euler (#1019)
* k-diffusion-euler * make style make quality * make fix-copies * fix tests for euler a * Update src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py Co-authored-by: Anton Lozhkov <[email protected]> * Update src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py Co-authored-by: Anton Lozhkov <[email protected]> * Update src/diffusers/schedulers/scheduling_euler_discrete.py Co-authored-by: Anton Lozhkov <[email protected]> * Update src/diffusers/schedulers/scheduling_euler_discrete.py Co-authored-by: Anton Lozhkov <[email protected]> * remove unused arg and method * update doc * quality * make flake happy * use logger instead of warn * raise error instead of deprication * don't require scipy * pass generator in step * fix tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Update tests/test_scheduler.py Co-authored-by: Patrick von Platen <[email protected]> * remove unused generator * pass generator as extra_step_kwargs * update tests * pass generator as kwarg * pass generator as kwarg * quality * fix test for lms * fix tests Co-authored-by: patil-suraj <[email protected]> Co-authored-by: Anton Lozhkov <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent bf7b0bc commit a1ea8c0

File tree

11 files changed

+858
-12
lines changed

11 files changed

+858
-12
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from .schedulers import (
4242
DDIMScheduler,
4343
DDPMScheduler,
44+
EulerAncestralDiscreteScheduler,
45+
EulerDiscreteScheduler,
4446
IPNDMScheduler,
4547
KarrasVeScheduler,
4648
PNDMScheduler,

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
from ...configuration_utils import FrozenDict
1010
from ...models import AutoencoderKL, UNet2DConditionModel
1111
from ...pipeline_utils import DiffusionPipeline
12-
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
12+
from ...schedulers import (
13+
DDIMScheduler,
14+
EulerAncestralDiscreteScheduler,
15+
EulerDiscreteScheduler,
16+
LMSDiscreteScheduler,
17+
PNDMScheduler,
18+
)
1319
from ...utils import deprecate, logging
1420
from . import StableDiffusionPipelineOutput
1521
from .safety_checker import StableDiffusionSafetyChecker
@@ -52,7 +58,9 @@ def __init__(
5258
text_encoder: CLIPTextModel,
5359
tokenizer: CLIPTokenizer,
5460
unet: UNet2DConditionModel,
55-
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
61+
scheduler: Union[
62+
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
63+
],
5664
safety_checker: StableDiffusionSafetyChecker,
5765
feature_extractor: CLIPFeatureExtractor,
5866
):
@@ -334,6 +342,11 @@ def __call__(
334342
if accepts_eta:
335343
extra_step_kwargs["eta"] = eta
336344

345+
# check if the scheduler accepts generator
346+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
347+
if accepts_generator:
348+
extra_step_kwargs["generator"] = generator
349+
337350
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
338351
# expand the latents if we are doing classifier free guidance
339352
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
from ...configuration_utils import FrozenDict
1111
from ...models import AutoencoderKL, UNet2DConditionModel
1212
from ...pipeline_utils import DiffusionPipeline
13-
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
13+
from ...schedulers import (
14+
DDIMScheduler,
15+
EulerAncestralDiscreteScheduler,
16+
EulerDiscreteScheduler,
17+
LMSDiscreteScheduler,
18+
PNDMScheduler,
19+
)
1420
from ...utils import deprecate, logging
1521
from . import StableDiffusionPipelineOutput
1622
from .safety_checker import StableDiffusionSafetyChecker
@@ -63,7 +69,9 @@ def __init__(
6369
text_encoder: CLIPTextModel,
6470
tokenizer: CLIPTokenizer,
6571
unet: UNet2DConditionModel,
66-
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
72+
scheduler: Union[
73+
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
74+
],
6775
safety_checker: StableDiffusionSafetyChecker,
6876
feature_extractor: CLIPFeatureExtractor,
6977
):
@@ -335,6 +343,11 @@ def __call__(
335343
if accepts_eta:
336344
extra_step_kwargs["eta"] = eta
337345

346+
# check if the scheduler accepts generator
347+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
348+
if accepts_generator:
349+
extra_step_kwargs["generator"] = generator
350+
338351
latents = init_latents
339352

340353
t_start = max(num_inference_steps - init_timestep + offset, 0)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,11 @@ def __call__(
379379
if accepts_eta:
380380
extra_step_kwargs["eta"] = eta
381381

382+
# check if the scheduler accepts generator
383+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
384+
if accepts_generator:
385+
extra_step_kwargs["generator"] = generator
386+
382387
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
383388
# expand the latents if we are doing classifier free guidance
384389
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,11 @@ def __call__(
352352
if accepts_eta:
353353
extra_step_kwargs["eta"] = eta
354354

355+
# check if the scheduler accepts generator
356+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
357+
if accepts_generator:
358+
extra_step_kwargs["generator"] = generator
359+
355360
latents = init_latents
356361

357362
t_start = max(num_inference_steps - init_timestep + offset, 0)

src/diffusers/schedulers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
if is_torch_available():
2020
from .scheduling_ddim import DDIMScheduler
2121
from .scheduling_ddpm import DDPMScheduler
22+
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
23+
from .scheduling_euler_discrete import EulerDiscreteScheduler
2224
from .scheduling_ipndm import IPNDMScheduler
2325
from .scheduling_karras_ve import KarrasVeScheduler
2426
from .scheduling_pndm import PNDMScheduler
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from typing import Optional, Tuple, Union
17+
18+
import numpy as np
19+
import torch
20+
21+
from ..configuration_utils import ConfigMixin, register_to_config
22+
from ..utils import BaseOutput, deprecate, logging
23+
from .scheduling_utils import SchedulerMixin
24+
25+
26+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27+
28+
29+
@dataclass
30+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
31+
class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
32+
"""
33+
Output class for the scheduler's step function output.
34+
35+
Args:
36+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
38+
denoising loop.
39+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40+
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
41+
`pred_original_sample` can be used to preview progress or for guidance.
42+
"""
43+
44+
prev_sample: torch.FloatTensor
45+
pred_original_sample: Optional[torch.FloatTensor] = None
46+
47+
48+
class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
49+
"""
50+
Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
51+
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
52+
53+
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
54+
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
55+
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
56+
[`~ConfigMixin.from_config`] functions.
57+
58+
Args:
59+
num_train_timesteps (`int`): number of diffusion steps used to train the model.
60+
beta_start (`float`): the starting `beta` value of inference.
61+
beta_end (`float`): the final `beta` value.
62+
beta_schedule (`str`):
63+
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
64+
`linear` or `scaled_linear`.
65+
trained_betas (`np.ndarray`, optional):
66+
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
67+
68+
"""
69+
70+
@register_to_config
71+
def __init__(
72+
self,
73+
num_train_timesteps: int = 1000,
74+
beta_start: float = 0.0001,
75+
beta_end: float = 0.02,
76+
beta_schedule: str = "linear",
77+
trained_betas: Optional[np.ndarray] = None,
78+
):
79+
if trained_betas is not None:
80+
self.betas = torch.from_numpy(trained_betas)
81+
elif beta_schedule == "linear":
82+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
83+
elif beta_schedule == "scaled_linear":
84+
# this schedule is very specific to the latent diffusion model.
85+
self.betas = (
86+
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
87+
)
88+
else:
89+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
90+
91+
self.alphas = 1.0 - self.betas
92+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
93+
94+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
95+
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
96+
self.sigmas = torch.from_numpy(sigmas)
97+
98+
# standard deviation of the initial noise distribution
99+
self.init_noise_sigma = self.sigmas.max()
100+
101+
# setable values
102+
self.num_inference_steps = None
103+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
104+
self.timesteps = torch.from_numpy(timesteps)
105+
self.is_scale_input_called = False
106+
107+
def scale_model_input(
108+
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
109+
) -> torch.FloatTensor:
110+
"""
111+
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
112+
113+
Args:
114+
sample (`torch.FloatTensor`): input sample
115+
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
116+
117+
Returns:
118+
`torch.FloatTensor`: scaled input sample
119+
"""
120+
if isinstance(timestep, torch.Tensor):
121+
timestep = timestep.to(self.timesteps.device)
122+
step_index = (self.timesteps == timestep).nonzero().item()
123+
sigma = self.sigmas[step_index]
124+
sample = sample / ((sigma**2 + 1) ** 0.5)
125+
self.is_scale_input_called = True
126+
return sample
127+
128+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
129+
"""
130+
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
131+
132+
Args:
133+
num_inference_steps (`int`):
134+
the number of diffusion steps used when generating samples with a pre-trained model.
135+
device (`str` or `torch.device`, optional):
136+
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
137+
"""
138+
self.num_inference_steps = num_inference_steps
139+
140+
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
141+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
142+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
143+
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
144+
self.sigmas = torch.from_numpy(sigmas).to(device=device)
145+
self.timesteps = torch.from_numpy(timesteps).to(device=device)
146+
147+
def step(
148+
self,
149+
model_output: torch.FloatTensor,
150+
timestep: Union[float, torch.FloatTensor],
151+
sample: torch.FloatTensor,
152+
generator: Optional[torch.Generator] = None,
153+
return_dict: bool = True,
154+
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
155+
"""
156+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
157+
process from the learned model outputs (most often the predicted noise).
158+
159+
Args:
160+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
161+
timestep (`float`): current timestep in the diffusion chain.
162+
sample (`torch.FloatTensor`):
163+
current instance of sample being created by diffusion process.
164+
generator (`torch.Generator`, optional): Random number generator.
165+
return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
166+
167+
Returns:
168+
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
169+
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
170+
a `tuple`. When returning a tuple, the first element is the sample tensor.
171+
172+
"""
173+
174+
if (
175+
isinstance(timestep, int)
176+
or isinstance(timestep, torch.IntTensor)
177+
or isinstance(timestep, torch.LongTensor)
178+
):
179+
raise ValueError(
180+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
181+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
182+
" one of the `scheduler.timesteps` as a timestep.",
183+
)
184+
185+
if not self.is_scale_input_called:
186+
logger.warn(
187+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
188+
"See `StableDiffusionPipeline` for a usage example."
189+
)
190+
191+
if isinstance(timestep, torch.Tensor):
192+
timestep = timestep.to(self.timesteps.device)
193+
194+
step_index = (self.timesteps == timestep).nonzero().item()
195+
sigma = self.sigmas[step_index]
196+
197+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
198+
pred_original_sample = sample - sigma * model_output
199+
sigma_from = self.sigmas[step_index]
200+
sigma_to = self.sigmas[step_index + 1]
201+
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
202+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
203+
204+
# 2. Convert to an ODE derivative
205+
derivative = (sample - pred_original_sample) / sigma
206+
207+
dt = sigma_down - sigma
208+
209+
prev_sample = sample + derivative * dt
210+
211+
device = model_output.device if torch.is_tensor(model_output) else "cpu"
212+
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
213+
prev_sample = prev_sample + noise * sigma_up
214+
215+
if not return_dict:
216+
return (prev_sample,)
217+
218+
return EulerAncestralDiscreteSchedulerOutput(
219+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
220+
)
221+
222+
def add_noise(
223+
self,
224+
original_samples: torch.FloatTensor,
225+
noise: torch.FloatTensor,
226+
timesteps: torch.FloatTensor,
227+
) -> torch.FloatTensor:
228+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
229+
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
230+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
231+
# mps does not support float64
232+
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
233+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
234+
else:
235+
self.timesteps = self.timesteps.to(original_samples.device)
236+
timesteps = timesteps.to(original_samples.device)
237+
238+
schedule_timesteps = self.timesteps
239+
240+
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
241+
deprecate(
242+
"timesteps as indices",
243+
"0.8.0",
244+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
245+
" `EulerAncestralDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
246+
" pass values from `scheduler.timesteps` as timesteps.",
247+
standard_warn=False,
248+
)
249+
step_indices = timesteps
250+
else:
251+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
252+
253+
sigma = self.sigmas[step_indices].flatten()
254+
while len(sigma.shape) < len(original_samples.shape):
255+
sigma = sigma.unsqueeze(-1)
256+
257+
noisy_samples = original_samples + noise * sigma
258+
return noisy_samples
259+
260+
def __len__(self):
261+
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)