Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
PNDMScheduler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,6 @@ def __call__(

# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps

# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
Expand All @@ -496,7 +495,9 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 7. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
i = 0
t = self.scheduler.timesteps[0]
while t > 0:
Comment on lines +498 to +500
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep i and t, which is very ugly. i is only required for the callback.

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand All @@ -510,12 +511,14 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents, t = self.scheduler.step(noise_pred, t, latents, return_dict=False, **extra_step_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this makes the step function a bit more black-box - not suuuper happy about it at this point of the library, but maybe at some point we have to let the scheduler adapt t anyways. However would still be happier with exposing the whole timestamp array @patil-suraj @anton-l


# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

i += 1

# 8. Post-processing
image = self.decode_latents(latents)

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from .scheduling_euler_discrete import EulerDiscreteScheduler
from .scheduling_heun import HeunDiscreteScheduler
from .scheduling_ipndm import IPNDMScheduler
from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_pndm import PNDMScheduler
Expand Down
218 changes: 218 additions & 0 deletions src/diffusers/schedulers/scheduling_heun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput


class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Args:
Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original
k-diffusion implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""

@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

# set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)

def scale_model_input(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
) -> torch.FloatTensor:
"""
Args:
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample

def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps

num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps

timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)

timesteps = torch.from_numpy(timesteps)

# standard deviation of the initial noise distribution
self.init_noise_sigma = sigmas[0]

if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = timesteps.to(device, dtype=torch.float32)
else:
self.timesteps = timesteps.to(device=device)

# empty dt and derivative
self.prev_derivative = None
self.dt = None

@property
def state_in_first_order(self):
return self.dt is None

def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: Union[float, torch.FloatTensor],
sample: Union[torch.FloatTensor, np.ndarray],
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Args:
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
step_index = (self.timesteps == timestep).nonzero().item()

if self.state_in_first_order:
sigma = self.sigmas[step_index]
step_index += 1
sigma_next = self.sigmas[step_index]
sigma_hat = sigma
else:
# 2nd order / Heun's method
sigma = self.sigmas[step_index - 1]
sigma_next = self.sigmas[step_index]
sigma_hat = sigma_next

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma_hat * model_output

# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat
if self.state_in_first_order:
# 3. 1st order derivative
dt = sigma_next - sigma_hat

# store for 2nd order step
self.sample = sample
self.prev_derivative = derivative
self.dt = dt
else:
# 2. 2nd order / Heun's method
derivative = (self.prev_derivative + derivative) / 2

# 3. Retrieve 1st order derivative
dt = self.dt

# free dt and derivative
# Note, this puts the scheduler in "first order mode"
self.prev_derivative = None
self.dt = None

prev_sample = self.sample + derivative * dt
print(f"step_index: {step_index}, state_in_first_order: {self.state_in_first_order}, sigma: {sigma}, sigma_next: {sigma_next}, sigma_hat: {sigma_hat}, dt: {dt}")

if not return_dict:
return (prev_sample, self.timesteps[step_index])

return SchedulerOutput(prev_sample=prev_sample, timestep=self.timesteps[step_index])
Comment on lines +196 to +199
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar implementation as in #1336 except that we return the next timestep.


def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
self.timesteps = self.timesteps.to(original_samples.device)
sigma = self.sigmas[timesteps].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

noisy_samples = original_samples + noise * sigma
return noisy_samples

def __len__(self):
return self.config.num_train_timesteps
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class SchedulerOutput(BaseOutput):
"""

prev_sample: torch.FloatTensor
timestep: Union[float, torch.FloatTensor]



class SchedulerMixin:
Expand Down