From 798263f6292db82b586fd14dcd5f5e665eb29004 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 12 Oct 2022 17:24:36 -0700 Subject: [PATCH 01/17] init v-pred pr --- src/diffusers/schedulers/scheduling_ddpm.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 04c92904a660..020850d680fe 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -220,7 +220,7 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, - predict_epsilon=True, + prediction_type: str = "epsilon", generator=None, return_dict: bool = True, ) -> Union[DDPMSchedulerOutput, Tuple]: @@ -233,8 +233,10 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. - predict_epsilon (`bool`): - optional flag to use when model predicts the samples directly instead of the noise, epsilon. + prediction_type (`str`): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample), or `v` (see section 2.4 + https://imagen.research.google/video/paper.pdf) generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class @@ -259,10 +261,15 @@ def step( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if predict_epsilon: + if prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - else: + elif prediction_type == "sample": pred_original_sample = model_output + elif prediction_type == "v": + # v_t = alpha_t * epsilon - sigma_t * x + raise NotImplementedError(f"v prediction not yet implemented") + else: + raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") # 3. Clip "predicted x_0" if self.config.clip_sample: From b7d0c1e84aa9be151f577a31de17cbb8c15b65d8 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 12 Oct 2022 17:32:52 -0700 Subject: [PATCH 02/17] placeholder code --- src/diffusers/schedulers/scheduling_ddpm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 020850d680fe..74dfc57bd017 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -267,7 +267,10 @@ def step( pred_original_sample = model_output elif prediction_type == "v": # v_t = alpha_t * epsilon - sigma_t * x - raise NotImplementedError(f"v prediction not yet implemented") + # need to merge the PRs for sigma to be available in DDPM + # pred_original_sample = sample*self.alphas[t] - model_output * self.sigmas[t] + # eps = model_output*self.alphas[t] - sample * self.sigmas[t] + raise NotImplementedError(f"v prediction not yet implemented for DDPM") else: raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") From 7eb4bfae6c663ac3974406a65fd8160a809ace8e Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 12 Oct 2022 17:39:48 -0700 Subject: [PATCH 03/17] up --- src/diffusers/schedulers/scheduling_ddpm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 74dfc57bd017..ded5fea168e5 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -142,6 +142,7 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.sigmas = 1 - self.alphas ** 2 self.one = torch.tensor(1.0) # standard deviation of the initial noise distribution @@ -268,8 +269,8 @@ def step( elif prediction_type == "v": # v_t = alpha_t * epsilon - sigma_t * x # need to merge the PRs for sigma to be available in DDPM - # pred_original_sample = sample*self.alphas[t] - model_output * self.sigmas[t] - # eps = model_output*self.alphas[t] - sample * self.sigmas[t] + pred_original_sample = sample*self.alphas[t] - model_output * self.sigmas[t] + eps = model_output*self.alphas[t] - sample * self.sigmas[t] raise NotImplementedError(f"v prediction not yet implemented for DDPM") else: raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") From 3eb2593d9a48f0c0861bd9d4b1089ef9843e57be Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 12 Oct 2022 20:10:03 -0700 Subject: [PATCH 04/17] a few more additions --- src/diffusers/schedulers/scheduling_ddpm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index ded5fea168e5..6ab2498956d7 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -142,7 +142,7 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.sigmas = 1 - self.alphas ** 2 + self.sigmas = 1 - self.alphas**2 self.one = torch.tensor(1.0) # standard deviation of the initial noise distribution @@ -269,8 +269,8 @@ def step( elif prediction_type == "v": # v_t = alpha_t * epsilon - sigma_t * x # need to merge the PRs for sigma to be available in DDPM - pred_original_sample = sample*self.alphas[t] - model_output * self.sigmas[t] - eps = model_output*self.alphas[t] - sample * self.sigmas[t] + pred_original_sample = sample * self.alphas[t] - model_output * self.sigmas[t] + eps = model_output * self.alphas[t] - sample * self.sigmas[t] raise NotImplementedError(f"v prediction not yet implemented for DDPM") else: raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") From 4c6850473dfbdb7b6a792141d6683669e8a2e793 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 18 Oct 2022 11:22:46 -0700 Subject: [PATCH 05/17] add ddim --- src/diffusers/schedulers/scheduling_ddim.py | 29 ++++++++++++++--- src/diffusers/schedulers/scheduling_ddpm.py | 36 ++++++++++----------- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 33d9bafb8aed..b12a565ee196 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -145,6 +145,7 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.sigmas = 1 - self.alphas**2 # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 @@ -209,6 +210,7 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, + prediction_type: str = "epsilon", eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, @@ -223,6 +225,10 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. + prediction_type (`str`): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample), or `v` (see section 2.4 + https://imagen.research.google/video/paper.pdf) eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): TODO generator: random number generator. @@ -243,14 +249,14 @@ def step( # Ideally, read DDIM paper in-detail understanding # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - pred_noise_t -> e_theta(x_t, timestep) + # - pred_original_sample -> f_theta(x_t, timestep) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" - # 1. get previous step value (=t-1) + # 1. get previous step value (=timestep-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas @@ -261,7 +267,20 @@ def step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + if prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + eps = torch.tensor(1) + elif prediction_type == "sample": + pred_original_sample = model_output + eps = torch.tensor(1) + elif prediction_type == "v": + # v_t = alpha_t * epsilon - sigma_t * x + # need to merge the PRs for sigma to be available in DDPM + pred_original_sample = sample * self.alphas[timestep] - model_output * self.sigmas[timestep] + eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep] + raise NotImplementedError(f"v prediction not yet implemented for DDPM") + else: + raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") # 4. Clip "predicted x_0" if self.config.clip_sample: @@ -280,7 +299,7 @@ def step( pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 6ab2498956d7..daf17a2de3e2 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -183,14 +183,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic )[::-1].copy() self.timesteps = torch.from_numpy(timesteps).to(device) - def _get_variance(self, t, predicted_variance=None, variance_type=None): - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + def _get_variance(self, timestep, predicted_variance=None, variance_type=None): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one - # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample - # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample - variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + # x_{timestep-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] if variance_type is None: variance_type = self.config.variance_type @@ -202,15 +202,15 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): elif variance_type == "fixed_small_log": variance = torch.log(torch.clamp(variance, min=1e-20)) elif variance_type == "fixed_large": - variance = self.betas[t] + variance = self.betas[timestep] elif variance_type == "fixed_large_log": # Glide max_log - variance = torch.log(self.betas[t]) + variance = torch.log(self.betas[timestep]) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = variance - max_log = self.betas[t] + max_log = self.betas[timestep] frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log @@ -247,16 +247,14 @@ def step( returning a tuple, the first element is the sample tensor. """ - t = timestep - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -269,8 +267,8 @@ def step( elif prediction_type == "v": # v_t = alpha_t * epsilon - sigma_t * x # need to merge the PRs for sigma to be available in DDPM - pred_original_sample = sample * self.alphas[t] - model_output * self.sigmas[t] - eps = model_output * self.alphas[t] - sample * self.sigmas[t] + pred = sample * self.alphas[timestep] - model_output * self.sigmas[timestep] + eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep] raise NotImplementedError(f"v prediction not yet implemented for DDPM") else: raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") @@ -281,8 +279,8 @@ def step( # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t - current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t + current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf @@ -290,11 +288,11 @@ def step( # 6. Add noise variance = 0 - if t > 0: + if timestep > 0: noise = torch.randn( model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator ).to(model_output.device) - variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance From ac6be90a718d1dc3fe5c57e411082577a19454ac Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 18 Oct 2022 11:42:51 -0700 Subject: [PATCH 06/17] style --- src/diffusers/schedulers/scheduling_ddim.py | 1 - src/diffusers/schedulers/scheduling_ddpm.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index b12a565ee196..abdcb3e81a58 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -278,7 +278,6 @@ def step( # need to merge the PRs for sigma to be available in DDPM pred_original_sample = sample * self.alphas[timestep] - model_output * self.sigmas[timestep] eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep] - raise NotImplementedError(f"v prediction not yet implemented for DDPM") else: raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index daf17a2de3e2..ee4f608e09aa 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -269,7 +269,6 @@ def step( # need to merge the PRs for sigma to be available in DDPM pred = sample * self.alphas[timestep] - model_output * self.sigmas[timestep] eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep] - raise NotImplementedError(f"v prediction not yet implemented for DDPM") else: raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") From f00d896a1e693b371858f120f311b4bc536105c3 Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Wed, 9 Nov 2022 14:33:15 -0500 Subject: [PATCH 07/17] DDPM changes to support v diffusion (#1121) * v diffusion support for ddpm * quality and style * variable name consistency * missing base case * pass prediction type along in the pipeline * put prediction type in scheduler config * style --- src/diffusers/schedulers/scheduling_ddpm.py | 88 +++++++++++++++------ 1 file changed, 62 insertions(+), 26 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index ee4f608e09aa..0327c44e3c4a 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -16,16 +16,27 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, deprecate +from ..utils import BaseOutput from .scheduling_utils import SchedulerMixin +def expand_to_shape(input, timesteps, shape, device): + """ + Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast + nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once. + """ + out = torch.gather(input.to(device), 0, timesteps.to(device)) + reshape = [shape[0]] + [1] * (len(shape) - 1) + out = out.reshape(*reshape) + return out + + @dataclass class DDPMSchedulerOutput(BaseOutput): """ @@ -102,6 +113,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): """ + _compatible_classes = [ + "DDIMScheduler", + "PNDMScheduler", + "LMSDiscreteScheduler", + "EulerDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + ] + @register_to_config def __init__( self, @@ -112,15 +131,8 @@ def __init__( trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, - **kwargs, + prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", ): - deprecate( - "tensor_format", - "0.6.0", - "If you're running your code in PyTorch, you can safely remove this argument.", - take_from=kwargs, - ) - if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": @@ -142,8 +154,8 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.sigmas = 1 - self.alphas**2 - self.one = torch.tensor(1.0) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -153,6 +165,7 @@ def __init__( self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type + self.prediction_type = prediction_type def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ @@ -185,7 +198,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic def _get_variance(self, timestep, predicted_variance=None, variance_type=None): alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0) # For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample @@ -213,6 +226,8 @@ def _get_variance(self, timestep, predicted_variance=None, variance_type=None): max_log = self.betas[timestep] frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log + elif variance_type == "v_diffusion": + variance = torch.log(self.betas[timestep] * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) return variance @@ -221,7 +236,7 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, - prediction_type: str = "epsilon", + # prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", generator=None, return_dict: bool = True, ) -> Union[DDPMSchedulerOutput, Tuple]: @@ -234,9 +249,9 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. - prediction_type (`str`): + prediction_type (`Literal["epsilon", "sample", "v"]`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample), or `v` (see section 2.4 + process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 https://imagen.research.google/video/paper.pdf) generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class @@ -247,6 +262,8 @@ def step( returning a tuple, the first element is the sample tensor. """ + if self.variance_type == "v_diffusion": + assert self.prediction_type == "v", "Need to use v prediction with v_diffusion" if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: @@ -254,23 +271,27 @@ def step( # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if prediction_type == "epsilon": + if self.prediction_type == "v": + # x_recon in p_mean_variance + pred_original_sample = ( + sample * self.sqrt_alphas_cumprod[timestep] + - model_output * self.sqrt_one_minus_alphas_cumprod[timestep] + ) + elif self.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif prediction_type == "sample": + + elif self.prediction_type == "sample": pred_original_sample = model_output - elif prediction_type == "v": - # v_t = alpha_t * epsilon - sigma_t * x - # need to merge the PRs for sigma to be available in DDPM - pred = sample * self.alphas[timestep] - model_output * self.sigmas[timestep] - eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep] else: - raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`" + ) # 3. Clip "predicted x_0" if self.config.clip_sample: @@ -291,7 +312,12 @@ def step( noise = torch.randn( model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator ).to(model_output.device) - variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise + if self.variance_type == "fixed_small_log": + variance = self._get_variance(timestep, predicted_variance=predicted_variance) * noise + elif self.variance_type == "v_diffusion": + variance = torch.exp(0.5 * self._get_variance(timestep, predicted_variance)) * noise + else: + variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance @@ -306,6 +332,11 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: + if self.variance_type == "v_diffusion": + alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device) + z_t = alpha * original_samples + sigma * noise + return z_t + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) @@ -325,3 +356,8 @@ def add_noise( def __len__(self): return self.config.num_train_timesteps + + def get_alpha_sigma(self, sample, timesteps, device): + alpha = expand_to_shape(self.sqrt_alphas_cumprod, timesteps, sample.shape, device) + sigma = expand_to_shape(self.sqrt_one_minus_alphas_cumprod, timesteps, sample.shape, device) + return alpha, sigma From 56164f56fb0ab92a212335b0f112a47766559b41 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 9 Nov 2022 11:53:25 -0800 Subject: [PATCH 08/17] quality --- src/diffusers/schedulers/scheduling_ddpm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 03659040d769..d403c4f5959c 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -274,8 +274,6 @@ def step( new_config["predict_epsilon"] = predict_epsilon self._internal_dict = FrozenDict(new_config) - t = timestep - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: From 11362ae5d2ad756c4762c86bfc7ab58d5ca4877a Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Thu, 17 Nov 2022 13:26:19 -0500 Subject: [PATCH 09/17] V prediction ddim (#1313) * v diffusion support for ddpm * quality and style * variable name consistency * missing base case * pass prediction type along in the pipeline * put prediction type in scheduler config * style * try to train on ddim * changes to ddim * ddim v prediction works to train butterflies example * fix bad merge, style and quality * try to fix broken doc strings * second pass * one more * white space * Update src/diffusers/schedulers/scheduling_ddim.py * remove extra lines * Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Ben Glickenhaus Co-authored-by: Nathan Lambert --- examples/v_prediction/train_butterflies.py | 227 ++++++++++++++++++++ src/diffusers/schedulers/scheduling_ddim.py | 89 ++++++-- 2 files changed, 299 insertions(+), 17 deletions(-) create mode 100644 examples/v_prediction/train_butterflies.py diff --git a/examples/v_prediction/train_butterflies.py b/examples/v_prediction/train_butterflies.py new file mode 100644 index 000000000000..5074ece86a98 --- /dev/null +++ b/examples/v_prediction/train_butterflies.py @@ -0,0 +1,227 @@ +import glob +import os +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +from accelerate import Accelerator +from datasets import load_dataset +from diffusers import DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, UNet2DModel +from diffusers.hub_utils import init_git_repo, push_to_hub +from diffusers.optimization import get_cosine_schedule_with_warmup +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm + + +@dataclass +class TrainingConfig: + image_size = 128 # the generated image resolution + train_batch_size = 16 + eval_batch_size = 16 # how many images to sample during evaluation + num_epochs = 50 + gradient_accumulation_steps = 1 + learning_rate = 5e-5 + lr_warmup_steps = 500 + save_image_epochs = 10 + save_model_epochs = 30 + mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision + output_dir = "ddim-butterflies-128-v-diffusion" # the model namy locally and on the HF Hub + + push_to_hub = False # whether to upload the saved model to the HF Hub + hub_private_repo = False + overwrite_output_dir = True # overwrite the old model when re-running the notebook + seed = 0 + + +config = TrainingConfig() + + +config.dataset_name = "huggan/smithsonian_butterflies_subset" +dataset = load_dataset(config.dataset_name, split="train") + + +preprocess = transforms.Compose( + [ + transforms.Resize((config.image_size, config.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + + +def transform(examples): + images = [preprocess(image.convert("RGB")) for image in examples["image"]] + return {"images": images} + + +dataset.set_transform(transform) + + +train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True) + + +model = UNet2DModel( + sample_size=config.image_size, # the target image resolution + in_channels=3, # the number of input channels, 3 for RGB images + out_channels=3, # the number of output channels + layers_per_block=2, # how many ResNet layers to use per UNet block + block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block + down_block_types=( + "DownBlock2D", # a regular ResNet downsampling block + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", # a regular ResNet upsampling block + "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), +) + + +if config.output_dir.startswith("ddpm"): + noise_scheduler = DDPMScheduler( + num_train_timesteps=1000, + beta_schedule="squaredcos_cap_v2", + variance_type="v_diffusion", + prediction_type="v", + ) +else: + noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_schedule="squaredcos_cap_v2", + variance_type="v_diffusion", + prediction_type="v", + ) + + +optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) + + +lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=config.lr_warmup_steps, + num_training_steps=(len(train_dataloader) * config.num_epochs), +) + + +def make_grid(images, rows, cols): + w, h = images[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + for i, image in enumerate(images): + grid.paste(image, box=(i % cols * w, i // cols * h)) + return grid + + +def evaluate(config, epoch, pipeline): + # Sample some images from random noise (this is the backward diffusion process). + # The default pipeline output type is `List[PIL.Image]` + images = pipeline( + batch_size=config.eval_batch_size, + generator=torch.manual_seed(config.seed), + ).images + + # Make a grid out of the images + image_grid = make_grid(images, rows=4, cols=4) + + # Save the images + test_dir = os.path.join(config.output_dir, "samples") + os.makedirs(test_dir, exist_ok=True) + image_grid.save(f"{test_dir}/{epoch:04d}.png") + + +def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): + # Initialize accelerator and tensorboard logging + accelerator = Accelerator( + mixed_precision=config.mixed_precision, + gradient_accumulation_steps=config.gradient_accumulation_steps, + log_with="tensorboard", + logging_dir=os.path.join(config.output_dir, "logs"), + ) + if accelerator.is_main_process: + if config.push_to_hub: + repo = init_git_repo(config, at_init=True) + accelerator.init_trackers("train_example") + + # Prepare everything + # There is no specific order to remember, you just need to unpack the + # objects in the same order you gave them to the prepare method. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + + if config.output_dir.startswith("ddpm"): + pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + else: + pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + + evaluate(config, 0, pipeline) + + # Now you train the model + for epoch in range(config.num_epochs): + progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Epoch {epoch}") + + for step, batch in enumerate(train_dataloader): + clean_images = batch["images"] + # Sample noise to add to the images + noise = torch.randn(clean_images.shape).to(clean_images.device) + bs = clean_images.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long() + + with accelerator.accumulate(model): + # Predict the noise residual + alpha_t, sigma_t = noise_scheduler.get_alpha_sigma(clean_images, timesteps, accelerator.device) + z_t = alpha_t * clean_images + sigma_t * noise + noise_pred = model(z_t, timesteps).sample + v = alpha_t * noise - sigma_t * clean_images + loss = F.mse_loss(noise_pred, v) + accelerator.backward(loss) + + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + progress_bar.update(1) + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + global_step += 1 + + # After each epoch you optionally sample some demo images with evaluate() and save the model + if accelerator.is_main_process: + if config.output_dir.startswith("ddpm"): + pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + else: + pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + + if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: + evaluate(config, epoch, pipeline) + + if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: + if config.push_to_hub: + push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True) + else: + pipeline.save_pretrained(config.output_dir) + + +args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler) + +train_loop(*args) + +sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png")) +Image.open(sample_images[-1]) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 972e4d45b079..89d90ba60ad4 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union import numpy as np import torch @@ -27,6 +27,17 @@ from .scheduling_utils import SchedulerMixin +def expand_to_shape(input, timesteps, shape, device): + """ + Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast + nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once. + """ + out = torch.gather(input.to(device), 0, timesteps.to(device)) + reshape = [shape[0]] + [1] * (len(shape) - 1) + out = out.reshape(*reshape) + return out + + @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): @@ -75,6 +86,18 @@ def alpha_bar(time_step): return torch.tensor(betas) +def t_to_alpha_sigma(num_diffusion_timesteps): + """Returns the scaling factors for the clean image and for the noise, given + a timestep.""" + alphas = torch.cos( + torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)]) + ) + sigmas = torch.sin( + torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)]) + ) + return alphas, sigmas + + class DDIMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising @@ -128,7 +151,10 @@ def __init__( trained_betas: Optional[np.ndarray] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, + variance_type: str = "fixed", steps_offset: int = 0, + prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", + **kwargs, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -145,15 +171,18 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + self.variance_type = variance_type self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.sigmas = 1 - self.alphas**2 + if prediction_type == "v": + self.alphas, self.sigmas = t_to_alpha_sigma(num_train_timesteps) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.final_sigma = torch.tensor(0.0) if set_alpha_to_one else self.sigmas[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -161,6 +190,8 @@ def __init__( # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.variance_type = variance_type + self.prediction_type = prediction_type def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ @@ -170,20 +201,31 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep - Returns: `torch.FloatTensor`: scaled input sample """ return sample - def _get_variance(self, timestep, prev_timestep): + def _get_variance(self, timestep, prev_timestep, eta=0): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev - variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) - + if self.variance_type == "fixed": + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + elif self.variance_type == "v_diffusion": + # If eta > 0, adjust the scaling factor for the predicted noise + # downward according to the amount of additional noise to add + alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + sigma_prev = self.sigmas[prev_timestep] if prev_timestep >= 0 else self.final_sigma + if eta: + numerator = eta * (sigma_prev**2 / self.sigmas[timestep] ** 2).clamp(min=1.0e-7).sqrt() + else: + numerator = 0 + denominator = (1 - self.alphas[timestep] ** 2 / alpha_prev**2).clamp(min=1.0e-7).sqrt() + ddim_sigma = (numerator * denominator).clamp(min=1.0e-7) + variance = (sigma_prev**2 - ddim_sigma**2).clamp(min=1.0e-7).sqrt() return variance def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): @@ -207,7 +249,6 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, - prediction_type: str = "epsilon", eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, @@ -271,19 +312,21 @@ def step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - if prediction_type == "epsilon": + if self.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) eps = torch.tensor(1) - elif prediction_type == "sample": + elif self.prediction_type == "sample": pred_original_sample = model_output eps = torch.tensor(1) - elif prediction_type == "v": + elif self.prediction_type == "v": # v_t = alpha_t * epsilon - sigma_t * x # need to merge the PRs for sigma to be available in DDPM pred_original_sample = sample * self.alphas[timestep] - model_output * self.sigmas[timestep] - eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep] + eps = model_output * self.alphas[timestep] + sample * self.sigmas[timestep] else: - raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`" + ) # 4. Clip "predicted x_0" if self.config.clip_sample: @@ -291,7 +334,7 @@ def step( # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep) + variance = self._get_variance(timestep, prev_timestep, eta) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: @@ -299,10 +342,14 @@ def step( model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + if self.prediction_type == "epsilon": + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output - # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction + else: + alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + prev_sample = pred_original_sample * alpha_prev + eps * variance if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 @@ -325,7 +372,6 @@ def step( variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise prev_sample = prev_sample + variance - if not return_dict: return (prev_sample,) @@ -337,6 +383,10 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: + if self.variance_type == "v_diffusion": + alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device) + z_t = alpha * original_samples + sigma * noise + return z_t # Make sure alphas_cumprod and timestep have same device and dtype as original_samples self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) @@ -356,3 +406,8 @@ def add_noise( def __len__(self): return self.config.num_train_timesteps + + def get_alpha_sigma(self, sample, timesteps, device): + alpha = expand_to_shape(self.alphas, timesteps, sample.shape, device) + sigma = expand_to_shape(self.sigmas, timesteps, sample.shape, device) + return alpha, sigma From e39198306bf58a36dcd224aa98e713b7f95eb054 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 17 Nov 2022 14:43:14 -0800 Subject: [PATCH 10/17] fix tests --- src/diffusers/schedulers/scheduling_ddim.py | 12 ++++++++++-- src/diffusers/schedulers/scheduling_ddpm.py | 14 ++++++++++---- tests/test_scheduler.py | 8 ++++---- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 89d90ba60ad4..5945b0c1eaa4 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -129,6 +129,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. + prediction_type (`Literal["epsilon", "sample", "v"]`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 + https://imagen.research.google/video/paper.pdf) """ @@ -181,8 +185,12 @@ def __init__( # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - self.final_sigma = torch.tensor(0.0) if set_alpha_to_one else self.sigmas[0] + if set_alpha_to_one: + self.final_alpha_cumprod = torch.tensor(1.0) + self.final_sigma = torch.tensor(0.0) # TODO rename set_alpha_to_one for something general with sigma=0 + else: + self.final_alpha_cumprod = self.alphas_cumprod[0] + self.final_sigma = self.sigmas[0] if prediction_type == "v" else None # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d403c4f5959c..9931d8c14382 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -114,6 +114,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 https://imagen.research.google/video/paper.pdf) + predict_epsilon (`bool`, default `True`): + depreciated flag (removing v0.10.0) for epsilon vs. direct sample prediction. """ _compatible_classes = [ @@ -136,6 +138,7 @@ def __init__( variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", + predict_epsilon: bool = True, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -265,8 +268,8 @@ def step( if self.variance_type == "v_diffusion": assert self.prediction_type == "v", "Need to use v prediction with v_diffusion" message = ( - "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" - " DDPMScheduler.from_config(, predict_epsilon=True)`." + "Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler =" + " DDPMScheduler.from_config(, prediction_type=epsilon)`." ) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: @@ -293,11 +296,14 @@ def step( sample * self.sqrt_alphas_cumprod[timestep] - model_output * self.sqrt_one_minus_alphas_cumprod[timestep] ) + + # not check on predict_epsilon for depreciation flag above + elif self.prediction_type == "sample" or not self.config.predict_epsilon: + pred_original_sample = model_output + elif self.prediction_type == "epsilon" or self.config.predict_epsilon: pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.prediction_type == "sample": - pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`" diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index ab5217151125..5a19d1059cf8 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -393,9 +393,9 @@ def test_clip_sample(self): for clip_sample in [True, False]: self.check_over_configs(clip_sample=clip_sample) - def test_predict_epsilon(self): - for predict_epsilon in [True, False]: - self.check_over_configs(predict_epsilon=predict_epsilon) + def test_prediction_type(self): + for prediction_type in ["epsilon", "sample", "v"]: + self.check_over_configs(prediction_type=prediction_type) def test_deprecated_epsilon(self): deprecate("remove this test", "0.10.0", "remove") @@ -407,7 +407,7 @@ def test_deprecated_epsilon(self): time_step = 4 scheduler = scheduler_class(**scheduler_config) - scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config) + scheduler_eps = scheduler_class(prediction_type="sample", **scheduler_config) kwargs = {} if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): From 3adf87b2d9030ffc4778809ea35381ef574efb2f Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 17 Nov 2022 14:49:55 -0800 Subject: [PATCH 11/17] add ddim pred type test --- tests/test_scheduler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 6ea5ee6e9216..075759c4fdc2 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -728,6 +728,10 @@ def test_schedules(self): for schedule in ["linear", "squaredcos_cap_v2"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "sample", "v"]: + self.check_over_configs(prediction_type=prediction_type) + def test_clip_sample(self): for clip_sample in [True, False]: self.check_over_configs(clip_sample=clip_sample) From c1a05842134fe90944913c07af24989fe5220f64 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 17 Nov 2022 14:51:59 -0800 Subject: [PATCH 12/17] style --- src/diffusers/schedulers/scheduling_ddpm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 351651660d3e..b7a2db6ab9e7 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -333,7 +333,9 @@ def step( elif self.variance_type == "v_diffusion": variance = torch.exp(0.5 * self._get_variance(timestep, predicted_variance)) * variance_noise else: - variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * variance_noise + variance = ( + self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5 + ) * variance_noise pred_prev_sample = pred_prev_sample + variance From e701a97838f0e20f3e0af162f2f48d043c9dd629 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 17 Nov 2022 14:56:19 -0800 Subject: [PATCH 13/17] change name from v to velocity --- examples/v_prediction/train_butterflies.py | 4 ++-- src/diffusers/schedulers/scheduling_ddim.py | 10 +++++----- src/diffusers/schedulers/scheduling_ddpm.py | 8 ++++---- tests/test_scheduler.py | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/v_prediction/train_butterflies.py b/examples/v_prediction/train_butterflies.py index 5074ece86a98..ea2bc71a2b6f 100644 --- a/examples/v_prediction/train_butterflies.py +++ b/examples/v_prediction/train_butterflies.py @@ -93,14 +93,14 @@ def transform(examples): num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2", variance_type="v_diffusion", - prediction_type="v", + prediction_type="velocity", ) else: noise_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2", variance_type="v_diffusion", - prediction_type="v", + prediction_type="velocity", ) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index f6d2eb81398d..a49b279daa35 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -129,7 +129,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - prediction_type (`Literal["epsilon", "sample", "v"]`, optional): + prediction_type (`Literal["epsilon", "sample", "velocity"]`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 https://imagen.research.google/video/paper.pdf) @@ -150,7 +150,7 @@ def __init__( set_alpha_to_one: bool = True, variance_type: str = "fixed", steps_offset: int = 0, - prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", + prediction_type: Literal["epsilon", "sample", "velocity"] = "epsilon", **kwargs, ): if trained_betas is not None: @@ -171,7 +171,7 @@ def __init__( self.variance_type = variance_type self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - if prediction_type == "v": + if prediction_type == "velocity": self.alphas, self.sigmas = t_to_alpha_sigma(num_train_timesteps) # At every step in ddim, we are looking into the previous alphas_cumprod @@ -183,7 +183,7 @@ def __init__( self.final_sigma = torch.tensor(0.0) # TODO rename set_alpha_to_one for something general with sigma=0 else: self.final_alpha_cumprod = self.alphas_cumprod[0] - self.final_sigma = self.sigmas[0] if prediction_type == "v" else None + self.final_sigma = self.sigmas[0] if prediction_type == "velocity" else None # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -319,7 +319,7 @@ def step( elif self.prediction_type == "sample": pred_original_sample = model_output eps = torch.tensor(1) - elif self.prediction_type == "v": + elif self.prediction_type == "velocity": # v_t = alpha_t * epsilon - sigma_t * x # need to merge the PRs for sigma to be available in DDPM pred_original_sample = sample * self.alphas[timestep] - model_output * self.sigmas[timestep] diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index b7a2db6ab9e7..e6c204cf63df 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -110,7 +110,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - prediction_type (`Literal["epsilon", "sample", "v"]`, optional): + prediction_type (`Literal["epsilon", "sample", "velocity"]`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 https://imagen.research.google/video/paper.pdf) @@ -130,7 +130,7 @@ def __init__( trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, - prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", + prediction_type: Literal["epsilon", "sample", "velocity"] = "epsilon", predict_epsilon: bool = True, ): if trained_betas is not None: @@ -260,7 +260,7 @@ def step( """ if self.variance_type == "v_diffusion": - assert self.prediction_type == "v", "Need to use v prediction with v_diffusion" + assert self.prediction_type == "velocity", "Need to use v prediction with v_diffusion" message = ( "Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler =" " DDPMScheduler.from_config(, prediction_type=epsilon)`." @@ -284,7 +284,7 @@ def step( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if self.prediction_type == "v": + if self.prediction_type == "velocity": # x_recon in p_mean_variance pred_original_sample = ( sample * self.sqrt_alphas_cumprod[timestep] diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 075759c4fdc2..c76823527a41 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -600,7 +600,7 @@ def test_clip_sample(self): self.check_over_configs(clip_sample=clip_sample) def test_prediction_type(self): - for prediction_type in ["epsilon", "sample", "v"]: + for prediction_type in ["epsilon", "sample", "velocity"]: self.check_over_configs(prediction_type=prediction_type) def test_deprecated_epsilon(self): @@ -729,7 +729,7 @@ def test_schedules(self): self.check_over_configs(beta_schedule=schedule) def test_prediction_type(self): - for prediction_type in ["epsilon", "sample", "v"]: + for prediction_type in ["epsilon", "sample", "velocity"]: self.check_over_configs(prediction_type=prediction_type) def test_clip_sample(self): From 172b242c2aa2582afbacee35b09d039846fe9931 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 17 Nov 2022 14:58:52 -0800 Subject: [PATCH 14/17] fix loose comments --- src/diffusers/schedulers/scheduling_ddim.py | 10 +++------- src/diffusers/schedulers/scheduling_ddpm.py | 6 +++--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index a49b279daa35..6df8c0905131 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -30,7 +30,7 @@ def expand_to_shape(input, timesteps, shape, device): """ Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast - nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once. + nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once. """ out = torch.gather(input.to(device), 0, timesteps.to(device)) reshape = [shape[0]] + [1] * (len(shape) - 1) @@ -131,7 +131,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): stable diffusion. prediction_type (`Literal["epsilon", "sample", "velocity"]`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 + process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4 https://imagen.research.google/video/paper.pdf) """ @@ -265,10 +265,6 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. - prediction_type (`str`): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample), or `v` (see section 2.4 - https://imagen.research.google/video/paper.pdf) eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when @@ -326,7 +322,7 @@ def step( eps = model_output * self.alphas[timestep] + sample * self.sigmas[timestep] else: raise ValueError( - f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`" + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `velocity`" ) # 4. Clip "predicted x_0" diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index e6c204cf63df..ed0371e95f86 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -29,7 +29,7 @@ def expand_to_shape(input, timesteps, shape, device): """ Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast - nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once. + nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once. """ out = torch.gather(input.to(device), 0, timesteps.to(device)) reshape = [shape[0]] + [1] * (len(shape) - 1) @@ -112,7 +112,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): option to clip predicted sample between -1 and 1 for numerical stability. prediction_type (`Literal["epsilon", "sample", "velocity"]`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 + process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4 https://imagen.research.google/video/paper.pdf) predict_epsilon (`bool`, default `True`): depreciated flag (removing v0.10.0) for epsilon vs. direct sample prediction. @@ -300,7 +300,7 @@ def step( else: raise ValueError( - f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`" + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `velocity`" ) # 3. Clip "predicted x_0" From 66951ec084d79888c399b05d6679dee4f2d63eb8 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 23 Nov 2022 07:42:21 -0800 Subject: [PATCH 15/17] Update src/diffusers/schedulers/scheduling_ddpm.py Co-authored-by: Pedro Cuenca --- src/diffusers/schedulers/scheduling_ddpm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index ed0371e95f86..4b0ae8f74a03 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -115,7 +115,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4 https://imagen.research.google/video/paper.pdf) predict_epsilon (`bool`, default `True`): - depreciated flag (removing v0.10.0) for epsilon vs. direct sample prediction. + deprecated flag (removing v0.10.0) for epsilon vs. direct sample prediction. """ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() From b70f6cd5e0412aeb63b1dafe6b10e87f66be5f17 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 23 Nov 2022 11:59:15 -0800 Subject: [PATCH 16/17] move expand_to_shape --- src/diffusers/schedulers/scheduling_ddim.py | 13 +------------ src/diffusers/schedulers/scheduling_ddpm.py | 13 +------------ src/diffusers/schedulers/scheduling_utils.py | 11 +++++++++++ 3 files changed, 13 insertions(+), 24 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 6df8c0905131..f94b4486035c 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -24,18 +24,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput -from .scheduling_utils import SchedulerMixin - - -def expand_to_shape(input, timesteps, shape, device): - """ - Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast - nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once. - """ - out = torch.gather(input.to(device), 0, timesteps.to(device)) - reshape = [shape[0]] + [1] * (len(shape) - 1) - out = out.reshape(*reshape) - return out +from .scheduling_utils import SchedulerMixin, expand_to_shape @dataclass diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 4b0ae8f74a03..26ce386f77a1 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -23,18 +23,7 @@ from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate -from .scheduling_utils import SchedulerMixin - - -def expand_to_shape(input, timesteps, shape, device): - """ - Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast - nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once. - """ - out = torch.gather(input.to(device), 0, timesteps.to(device)) - reshape = [shape[0]] + [1] * (len(shape) - 1) - out = out.reshape(*reshape) - return out +from .scheduling_utils import SchedulerMixin, expand_to_shape @dataclass diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 90ab674e38a4..973b1298fcd2 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -152,3 +152,14 @@ def _get_compatibles(cls): getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes + + +def expand_to_shape(input, timesteps, shape, device): + """ + Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast + nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once. + """ + out = torch.gather(input.to(device), 0, timesteps.to(device)) + reshape = [shape[0]] + [1] * (len(shape) - 1) + out = out.reshape(*reshape) + return out From da5e677c18f11710983ac54f03c3ab00f2408167 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 23 Nov 2022 12:20:54 -0800 Subject: [PATCH 17/17] remove Literal, add deprecates --- .../train_unconditional.py | 28 +++++++++++----- src/diffusers/schedulers/scheduling_ddim.py | 12 +++---- src/diffusers/schedulers/scheduling_ddpm.py | 12 +++++-- .../scheduling_dpmsolver_multistep.py | 32 +++++++++++++++---- 4 files changed, 61 insertions(+), 23 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 54a94d98b578..f3359478355d 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -194,16 +194,28 @@ def parse_args(): ) parser.add_argument( - "--predict_epsilon", - action="store_true", - default=True, - help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.", + "--prediction_type", + type=str, + default="epsilon", + help=( + "Whether the model should predict the 'epsilon'/noise error, directly the reconstructed image 'x0', or the" + " velocity of the ODE 'velocity'." + ), ) parser.add_argument("--ddpm_num_steps", type=int, default=1000) parser.add_argument("--ddpm_beta_schedule", type=str, default="linear") args = parser.parse_args() + + message = ( + "Please make sure to instantiate your training with `--prediction_type=epsilon` instead. E.g. `scheduler =" + " DDPMScheduler.from_config(, prediction_type=epsilon)`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=args) + if predict_epsilon: + args.prediction_type = "epsilon" + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -256,13 +268,13 @@ def main(args): "UpBlock2D", ), ) - accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) + accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) - if accepts_predict_epsilon: + if accepts_prediction_type: noise_scheduler = DDPMScheduler( num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule, - predict_epsilon=args.predict_epsilon, + prediction_type=args.prediction_type, ) else: noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) @@ -365,7 +377,7 @@ def transforms(examples): # Predict the noise residual model_output = model(noisy_images, timesteps).sample - if args.predict_epsilon: + if args.prediction_type == "epsilon": loss = F.mse_loss(model_output, noise) # this could have different weights! else: alpha_t = _extract_into_tensor( diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index f94b4486035c..314e8aa22454 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import Literal, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -118,10 +118,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - prediction_type (`Literal["epsilon", "sample", "velocity"]`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4 - https://imagen.research.google/video/paper.pdf) + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4 + https://imagen.research.google/video/paper.pdf) """ @@ -139,7 +139,7 @@ def __init__( set_alpha_to_one: bool = True, variance_type: str = "fixed", steps_offset: int = 0, - prediction_type: Literal["epsilon", "sample", "velocity"] = "epsilon", + prediction_type: str = "epsilon", **kwargs, ): if trained_betas is not None: diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 26ce386f77a1..8264dd540bbb 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Literal, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -99,7 +99,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - prediction_type (`Literal["epsilon", "sample", "velocity"]`, optional): + prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4 https://imagen.research.google/video/paper.pdf) @@ -119,7 +119,7 @@ def __init__( trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, - prediction_type: Literal["epsilon", "sample", "velocity"] = "epsilon", + prediction_type: str = "epsilon", predict_epsilon: bool = True, ): if trained_betas is not None: @@ -156,6 +156,12 @@ def __init__( self.variance_type = variance_type self.prediction_type = prediction_type + message = ( + "Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler =" + " DDPMScheduler.from_config(, prediction_type='epsilon')`." + ) + deprecate("predict_epsilon", "0.10.0", message) + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 472b24637dcf..bcc3d4c7a293 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -88,9 +88,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. predict_epsilon (`bool`, default `True`): - we currently support both the noise prediction model and the data prediction model. If the model predicts - the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set - `predict_epsilon` to `False`. + deprecated flag (removing v0.10.0); we currently support both the noise prediction model and the data + prediction model. If the model predicts the noise / epsilon, set `predict_epsilon` to `True`. If the model + predicts the data / x0 directly, set `predict_epsilon` to `False`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4 + https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to @@ -128,6 +132,7 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, solver_order: int = 2, + prediction_type: str = "epsilon", predict_epsilon: bool = True, thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, @@ -174,6 +179,17 @@ def __init__( self.model_outputs = [None] * solver_order self.lower_order_nums = 0 + if prediction_type not in ["epsilon", "sample"]: + raise ValueError( + f"Prediction type {self.config.prediction_type} not supported by DPMSolverMultistepScheduler" + ) + + message = ( + "Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler =" + " DDPMScheduler.from_config(, prediction_type='epsilon')`." + ) + deprecate("predict_epsilon", "0.10.0", message) + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -221,11 +237,15 @@ def convert_model_output( """ # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": - if self.config.predict_epsilon: + if self.config.prediction_type == "epsilon": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t - else: + elif self.config.prediction_type == "sample": x0_pred = model_output + else: + raise ValueError( + f"Prediction type {self.config.prediction_type} not supported by DPMSolverMultistepScheduler" + ) if self.config.thresholding: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 dynamic_max_val = torch.quantile(