From a48b0266e170747f35d02be7d5e394f3ed796357 Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Thu, 3 Nov 2022 11:00:04 -0400 Subject: [PATCH 1/7] v diffusion support for ddpm --- src/diffusers/schedulers/scheduling_ddpm.py | 92 ++++++++++++++------- 1 file changed, 60 insertions(+), 32 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index ee4f608e09aa..19acf3dc2ff5 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -16,16 +16,25 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union - +from typing import Optional, Tuple, Union, Literal import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, deprecate +from ..utils import BaseOutput from .scheduling_utils import SchedulerMixin +def expand_to_shape(input, timesteps, shape, device): + """ + Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once. + """ + out = torch.gather(input.to(device), 0, timesteps.to(device)) + reshape = [shape[0]] + [1] * (len(shape) - 1) + out = out.reshape(*reshape) + return out + + @dataclass class DDPMSchedulerOutput(BaseOutput): """ @@ -102,6 +111,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): """ + _compatible_classes = [ + "DDIMScheduler", + "PNDMScheduler", + "LMSDiscreteScheduler", + "EulerDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + ] + @register_to_config def __init__( self, @@ -112,15 +129,7 @@ def __init__( trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, - **kwargs, ): - deprecate( - "tensor_format", - "0.6.0", - "If you're running your code in PyTorch, you can safely remove this argument.", - take_from=kwargs, - ) - if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": @@ -142,8 +151,8 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.sigmas = 1 - self.alphas**2 - self.one = torch.tensor(1.0) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -185,11 +194,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic def _get_variance(self, timestep, predicted_variance=None, variance_type=None): alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0) - # For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample - # x_{timestep-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] if variance_type is None: @@ -213,6 +222,8 @@ def _get_variance(self, timestep, predicted_variance=None, variance_type=None): max_log = self.betas[timestep] frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log + elif variance_type == "v_diffusion": + variance = torch.log(self.betas[timestep] * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) return variance @@ -221,9 +232,10 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", generator=None, return_dict: bool = True, + v_prediction: bool = True, ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion @@ -234,10 +246,8 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. - prediction_type (`str`): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample), or `v` (see section 2.4 - https://imagen.research.google/video/paper.pdf) + prediction_type (`Literal["epsilon", "sample", "v"]`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 https://imagen.research.google/video/paper.pdf) generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class @@ -254,23 +264,26 @@ def step( # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if prediction_type == "epsilon": + if prediction_type == "v": + # x_recon in p_mean_variance + pred_original_sample = ( + sample * self.sqrt_alphas_cumprod[timestep] + - model_output * self.sqrt_one_minus_alphas_cumprod[timestep] + ) + eps = ( + model_output * self.sqrt_alphas_cumprod[timestep] + - sample * self.sqrt_one_minus_alphas_cumprod[timestep] + ) + elif prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif prediction_type == "sample": - pred_original_sample = model_output - elif prediction_type == "v": - # v_t = alpha_t * epsilon - sigma_t * x - # need to merge the PRs for sigma to be available in DDPM - pred = sample * self.alphas[timestep] - model_output * self.sigmas[timestep] - eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep] else: - raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") + pred_original_sample = model_output # 3. Clip "predicted x_0" if self.config.clip_sample: @@ -291,7 +304,12 @@ def step( noise = torch.randn( model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator ).to(model_output.device) - variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise + if self.variance_type == "fixed_small_log": + variance = self._get_variance(t, predicted_variance=predicted_variance) * noise + elif self.variance_type == "v_diffusion": + variance = torch.exp(0.5 * self._get_variance(timestep, predicted_variance)) * noise + else: + variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance @@ -306,6 +324,11 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: + if self.variance_type == "v_diffusion": + alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device) + z_t = alpha * original_samples + sigma * noise + return z_t + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) @@ -325,3 +348,8 @@ def add_noise( def __len__(self): return self.config.num_train_timesteps + + def get_alpha_sigma(self, sample, timesteps, device): + alpha = expand_to_shape(self.sqrt_alphas_cumprod, timesteps, sample.shape, device) + sigma = expand_to_shape(self.sqrt_one_minus_alphas_cumprod, timesteps, sample.shape, device) + return alpha, sigma From 3d702c6d652cb4d7a50d9cea750ca7115b3c4375 Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Thu, 3 Nov 2022 11:03:26 -0400 Subject: [PATCH 2/7] quality and style --- src/diffusers/schedulers/scheduling_ddpm.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 19acf3dc2ff5..c3c72171f1cc 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -16,7 +16,8 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union, Literal +from typing import Literal, Optional, Tuple, Union + import numpy as np import torch @@ -27,7 +28,8 @@ def expand_to_shape(input, timesteps, shape, device): """ - Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once. + Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast + nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once. """ out = torch.gather(input.to(device), 0, timesteps.to(device)) reshape = [shape[0]] + [1] * (len(shape) - 1) @@ -247,7 +249,9 @@ def step( sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. prediction_type (`Literal["epsilon", "sample", "v"]`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 https://imagen.research.google/video/paper.pdf) + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 + https://imagen.research.google/video/paper.pdf) generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class @@ -276,10 +280,6 @@ def step( sample * self.sqrt_alphas_cumprod[timestep] - model_output * self.sqrt_one_minus_alphas_cumprod[timestep] ) - eps = ( - model_output * self.sqrt_alphas_cumprod[timestep] - - sample * self.sqrt_one_minus_alphas_cumprod[timestep] - ) elif prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) else: @@ -305,7 +305,7 @@ def step( model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator ).to(model_output.device) if self.variance_type == "fixed_small_log": - variance = self._get_variance(t, predicted_variance=predicted_variance) * noise + variance = self._get_variance(timestep, predicted_variance=predicted_variance) * noise elif self.variance_type == "v_diffusion": variance = torch.exp(0.5 * self._get_variance(timestep, predicted_variance)) * noise else: From 0889fd1d1178fc203321e5a095b737a111c9c9fb Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Thu, 3 Nov 2022 11:04:44 -0400 Subject: [PATCH 3/7] variable name consistency --- src/diffusers/schedulers/scheduling_ddpm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index c3c72171f1cc..25b081da87f1 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -198,9 +198,9 @@ def _get_variance(self, timestep, predicted_variance=None, variance_type=None): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0) - # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample - # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + # x_{timestep-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] if variance_type is None: From f7c709518fe1c036866ea60b09119fd3013b534b Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Thu, 3 Nov 2022 11:07:38 -0400 Subject: [PATCH 4/7] missing base case --- src/diffusers/schedulers/scheduling_ddpm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 25b081da87f1..878d67e817d1 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -282,8 +282,11 @@ def step( ) elif prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - else: + + elif prediction_type == "sample": pred_original_sample = model_output + else: + raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") # 3. Clip "predicted x_0" if self.config.clip_sample: From 0c23e1162ad31fee62b327ce6ce4ab5ba7a82188 Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Thu, 3 Nov 2022 11:23:35 -0400 Subject: [PATCH 5/7] pass prediction type along in the pipeline --- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 7 +++++-- src/diffusers/schedulers/scheduling_ddpm.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index aae29737aae3..3d5afc94e3df 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -14,7 +14,7 @@ # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Literal import torch @@ -44,6 +44,7 @@ def __call__( generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: r""" @@ -80,7 +81,9 @@ def __call__( model_output = self.unet(image, t).sample # 2. compute previous image: x_t -> t_t-1 - image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + image = self.scheduler.step( + model_output, t, image, generator=generator, prediction_type=prediction_type + ).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 878d67e817d1..1813592a069d 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -237,7 +237,6 @@ def step( prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", generator=None, return_dict: bool = True, - v_prediction: bool = True, ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion @@ -261,6 +260,8 @@ def step( returning a tuple, the first element is the sample tensor. """ + if self.variance_type == "v_diffusion": + assert prediction_type == "v", "Need to use v prediction with v_diffusion" if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: From b46327e89e079b2eaaa2351a27c86900d77c168d Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Mon, 7 Nov 2022 09:19:22 -0500 Subject: [PATCH 6/7] put prediction type in scheduler config --- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 7 ++----- src/diffusers/schedulers/scheduling_ddpm.py | 16 ++++++++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 3d5afc94e3df..a9284063e884 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -14,7 +14,7 @@ # limitations under the License. -from typing import Optional, Tuple, Union, Literal +from typing import Literal, Optional, Tuple, Union import torch @@ -44,7 +44,6 @@ def __call__( generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: r""" @@ -81,9 +80,7 @@ def __call__( model_output = self.unet(image, t).sample # 2. compute previous image: x_t -> t_t-1 - image = self.scheduler.step( - model_output, t, image, generator=generator, prediction_type=prediction_type - ).prev_sample + image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 1813592a069d..0327c44e3c4a 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -131,6 +131,7 @@ def __init__( trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, + prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -164,6 +165,7 @@ def __init__( self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type + self.prediction_type = prediction_type def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ @@ -234,7 +236,7 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, - prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", + # prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", generator=None, return_dict: bool = True, ) -> Union[DDPMSchedulerOutput, Tuple]: @@ -261,7 +263,7 @@ def step( """ if self.variance_type == "v_diffusion": - assert prediction_type == "v", "Need to use v prediction with v_diffusion" + assert self.prediction_type == "v", "Need to use v prediction with v_diffusion" if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: @@ -275,19 +277,21 @@ def step( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if prediction_type == "v": + if self.prediction_type == "v": # x_recon in p_mean_variance pred_original_sample = ( sample * self.sqrt_alphas_cumprod[timestep] - model_output * self.sqrt_one_minus_alphas_cumprod[timestep] ) - elif prediction_type == "epsilon": + elif self.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif prediction_type == "sample": + elif self.prediction_type == "sample": pred_original_sample = model_output else: - raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`" + ) # 3. Clip "predicted x_0" if self.config.clip_sample: From 45c36c85d71ae278ca7d7ad87263394128b30eb1 Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Mon, 7 Nov 2022 09:20:16 -0500 Subject: [PATCH 7/7] style --- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index a9284063e884..aae29737aae3 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -14,7 +14,7 @@ # limitations under the License. -from typing import Literal, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch