diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index ee4f608e09aa..0327c44e3c4a 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -16,16 +16,27 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, deprecate +from ..utils import BaseOutput from .scheduling_utils import SchedulerMixin +def expand_to_shape(input, timesteps, shape, device): + """ + Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast + nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once. + """ + out = torch.gather(input.to(device), 0, timesteps.to(device)) + reshape = [shape[0]] + [1] * (len(shape) - 1) + out = out.reshape(*reshape) + return out + + @dataclass class DDPMSchedulerOutput(BaseOutput): """ @@ -102,6 +113,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): """ + _compatible_classes = [ + "DDIMScheduler", + "PNDMScheduler", + "LMSDiscreteScheduler", + "EulerDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + ] + @register_to_config def __init__( self, @@ -112,15 +131,8 @@ def __init__( trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, - **kwargs, + prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", ): - deprecate( - "tensor_format", - "0.6.0", - "If you're running your code in PyTorch, you can safely remove this argument.", - take_from=kwargs, - ) - if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": @@ -142,8 +154,8 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.sigmas = 1 - self.alphas**2 - self.one = torch.tensor(1.0) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -153,6 +165,7 @@ def __init__( self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type + self.prediction_type = prediction_type def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ @@ -185,7 +198,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic def _get_variance(self, timestep, predicted_variance=None, variance_type=None): alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0) # For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample @@ -213,6 +226,8 @@ def _get_variance(self, timestep, predicted_variance=None, variance_type=None): max_log = self.betas[timestep] frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log + elif variance_type == "v_diffusion": + variance = torch.log(self.betas[timestep] * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) return variance @@ -221,7 +236,7 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, - prediction_type: str = "epsilon", + # prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", generator=None, return_dict: bool = True, ) -> Union[DDPMSchedulerOutput, Tuple]: @@ -234,9 +249,9 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. - prediction_type (`str`): + prediction_type (`Literal["epsilon", "sample", "v"]`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample), or `v` (see section 2.4 + process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 https://imagen.research.google/video/paper.pdf) generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class @@ -247,6 +262,8 @@ def step( returning a tuple, the first element is the sample tensor. """ + if self.variance_type == "v_diffusion": + assert self.prediction_type == "v", "Need to use v prediction with v_diffusion" if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: @@ -254,23 +271,27 @@ def step( # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if prediction_type == "epsilon": + if self.prediction_type == "v": + # x_recon in p_mean_variance + pred_original_sample = ( + sample * self.sqrt_alphas_cumprod[timestep] + - model_output * self.sqrt_one_minus_alphas_cumprod[timestep] + ) + elif self.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif prediction_type == "sample": + + elif self.prediction_type == "sample": pred_original_sample = model_output - elif prediction_type == "v": - # v_t = alpha_t * epsilon - sigma_t * x - # need to merge the PRs for sigma to be available in DDPM - pred = sample * self.alphas[timestep] - model_output * self.sigmas[timestep] - eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep] else: - raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`" + ) # 3. Clip "predicted x_0" if self.config.clip_sample: @@ -291,7 +312,12 @@ def step( noise = torch.randn( model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator ).to(model_output.device) - variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise + if self.variance_type == "fixed_small_log": + variance = self._get_variance(timestep, predicted_variance=predicted_variance) * noise + elif self.variance_type == "v_diffusion": + variance = torch.exp(0.5 * self._get_variance(timestep, predicted_variance)) * noise + else: + variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance @@ -306,6 +332,11 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: + if self.variance_type == "v_diffusion": + alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device) + z_t = alpha * original_samples + sigma * noise + return z_t + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) @@ -325,3 +356,8 @@ def add_noise( def __len__(self): return self.config.num_train_timesteps + + def get_alpha_sigma(self, sample, timesteps, device): + alpha = expand_to_shape(self.sqrt_alphas_cumprod, timesteps, sample.shape, device) + sigma = expand_to_shape(self.sqrt_one_minus_alphas_cumprod, timesteps, sample.shape, device) + return alpha, sigma