Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 62 additions & 26 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,27 @@

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import Literal, Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, deprecate
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin


def expand_to_shape(input, timesteps, shape, device):
"""
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once.
"""
out = torch.gather(input.to(device), 0, timesteps.to(device))
reshape = [shape[0]] + [1] * (len(shape) - 1)
out = out.reshape(*reshape)
return out


@dataclass
class DDPMSchedulerOutput(BaseOutput):
"""
Expand Down Expand Up @@ -102,6 +113,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):

"""

_compatible_classes = [
"DDIMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]

@register_to_config
def __init__(
self,
Expand All @@ -112,15 +131,8 @@ def __init__(
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
**kwargs,
prediction_type: Literal["epsilon", "sample", "v"] = "epsilon",
):
deprecate(
"tensor_format",
"0.6.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
Expand All @@ -142,8 +154,8 @@ def __init__(

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sigmas = 1 - self.alphas**2
self.one = torch.tensor(1.0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
Expand All @@ -153,6 +165,7 @@ def __init__(
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())

self.variance_type = variance_type
self.prediction_type = prediction_type

def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Expand Down Expand Up @@ -185,7 +198,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

def _get_variance(self, timestep, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0)

# For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
Expand Down Expand Up @@ -213,6 +226,8 @@ def _get_variance(self, timestep, predicted_variance=None, variance_type=None):
max_log = self.betas[timestep]
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log
elif variance_type == "v_diffusion":
variance = torch.log(self.betas[timestep] * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t))

return variance

Expand All @@ -221,7 +236,7 @@ def step(
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
prediction_type: str = "epsilon",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think we should deprecate this argument and just have it in the config instead - I don't it's something that differs for every scheduler call. @anton-l what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense! That means the pipeline doesn’t have to change at all either

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, the prediction type doesn't need to change during sampling

# prediction_type: Literal["epsilon", "sample", "v"] = "epsilon",
generator=None,
return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]:
Expand All @@ -234,9 +249,9 @@ def step(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
prediction_type (`str`):
prediction_type (`Literal["epsilon", "sample", "v"]`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample), or `v` (see section 2.4
process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4
https://imagen.research.google/video/paper.pdf)
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
Expand All @@ -247,30 +262,36 @@ def step(
returning a tuple, the first element is the sample tensor.

"""
if self.variance_type == "v_diffusion":
assert self.prediction_type == "v", "Need to use v prediction with v_diffusion"
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
predicted_variance = None

# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if prediction_type == "epsilon":
if self.prediction_type == "v":
# x_recon in p_mean_variance
pred_original_sample = (
sample * self.sqrt_alphas_cumprod[timestep]
- model_output * self.sqrt_one_minus_alphas_cumprod[timestep]
)
elif self.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif prediction_type == "sample":

elif self.prediction_type == "sample":
pred_original_sample = model_output
elif prediction_type == "v":
# v_t = alpha_t * epsilon - sigma_t * x
# need to merge the PRs for sigma to be available in DDPM
pred = sample * self.alphas[timestep] - model_output * self.sigmas[timestep]
eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep]
else:
raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`")
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`"
)

# 3. Clip "predicted x_0"
if self.config.clip_sample:
Expand All @@ -291,7 +312,12 @@ def step(
noise = torch.randn(
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
).to(model_output.device)
variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise
if self.variance_type == "fixed_small_log":
variance = self._get_variance(timestep, predicted_variance=predicted_variance) * noise
elif self.variance_type == "v_diffusion":
variance = torch.exp(0.5 * self._get_variance(timestep, predicted_variance)) * noise
else:
variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise

pred_prev_sample = pred_prev_sample + variance

Expand All @@ -306,6 +332,11 @@ def add_noise(
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
if self.variance_type == "v_diffusion":
alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device)
z_t = alpha * original_samples + sigma * noise
return z_t

# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
Expand All @@ -325,3 +356,8 @@ def add_noise(

def __len__(self):
return self.config.num_train_timesteps

def get_alpha_sigma(self, sample, timesteps, device):
alpha = expand_to_shape(self.sqrt_alphas_cumprod, timesteps, sample.shape, device)
sigma = expand_to_shape(self.sqrt_one_minus_alphas_cumprod, timesteps, sample.shape, device)
return alpha, sigma