From c63e6e85c8ded726622c06bc1eb9df59c40f4e66 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 26 Sep 2022 03:07:38 +0200 Subject: [PATCH 1/8] [Schedulers Refactoring] Phase 1: timesteps and scaling --- src/diffusers/dependency_versions_table.py | 1 - src/diffusers/pipelines/ddim/pipeline_ddim.py | 5 +-- .../pipeline_latent_diffusion.py | 7 ++-- .../pipeline_latent_diffusion_uncond.py | 7 ++-- .../pipeline_stable_diffusion.py | 26 +++++++------- .../pipeline_stable_diffusion_inpaint.py | 28 ++++----------- src/diffusers/schedulers/scheduling_ddim.py | 22 +++++++++--- .../schedulers/scheduling_lms_discrete.py | 34 +++++++++++++++---- src/diffusers/schedulers/scheduling_pndm.py | 18 +++++++--- src/diffusers/schedulers/scheduling_utils.py | 23 +++++++++++++ 10 files changed, 114 insertions(+), 57 deletions(-) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 09a7baad560d..82ca5dbb6f56 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -17,7 +17,6 @@ "jaxlib": "jaxlib>=0.1.65,<=0.3.6", "modelcards": "modelcards>=0.1.4", "numpy": "numpy", - "onnxruntime": "onnxruntime", "onnxruntime-gpu": "onnxruntime-gpu", "pytest": "pytest", "pytest-timeout": "pytest-timeout", diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 95b49e045f67..0815f386e053 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -98,13 +98,14 @@ def __call__( # set step values self.scheduler.set_timesteps(num_inference_steps) - for t in self.progress_bar(self.scheduler.timesteps): + for step in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output + t = self.scheduler.get_noise_condition(step) model_output = self.unet(image, t).sample # 2. predict previous mean of image x_t-1 and add variance depending on eta # do x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, eta).prev_sample + image = self.scheduler.step(model_output, step, image, eta).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 4a4f29be7f75..6fc33c577ef7 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -130,6 +130,7 @@ def __call__( generator=generator, ) latents = latents.to(self.device) + latents = self.scheduler.scale_initial_noise(latents) self.scheduler.set_timesteps(num_inference_steps) @@ -140,7 +141,7 @@ def __call__( if accepts_eta: extra_kwargs["eta"] = eta - for t in self.progress_bar(self.scheduler.timesteps): + for step in self.progress_bar(self.scheduler.timesteps): if guidance_scale == 1.0: # guidance_scale of 1 means no guidance latents_input = latents @@ -153,6 +154,8 @@ def __call__( context = torch.cat([uncond_embeddings, text_embeddings]) # predict the noise residual + latents_input = self.scheduler.scale_model_input(latents_input, step) + t = self.scheduler.get_noise_condition(step) noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample # perform guidance if guidance_scale != 1.0: @@ -160,7 +163,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, step, latents, **extra_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 5574b65df9f8..141a185a5770 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -77,6 +77,7 @@ def __call__( generator=generator, ) latents = latents.to(self.device) + latents = self.scheduler.scale_initial_noise(latents) self.scheduler.set_timesteps(num_inference_steps) @@ -87,11 +88,13 @@ def __call__( if accepts_eta: extra_kwargs["eta"] = eta - for t in self.progress_bar(self.scheduler.timesteps): + for step in self.progress_bar(self.scheduler.timesteps): # predict the noise residual + latents = self.scheduler.scale_model_input(latents, step) + t = self.scheduler.get_noise_condition(step) noise_prediction = self.unet(latents, t).sample # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample + latents = self.scheduler.step(noise_prediction, step, latents, **extra_kwargs).prev_sample # decode the image latents with the VAE image = self.vqvae.decode(latents).sample diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 411f27308aa7..7d05b190cc05 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -42,6 +42,14 @@ class StableDiffusionPipeline(DiffusionPipeline): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + vae: AutoencoderKL + text_encoder: CLIPTextModel + tokenizer: CLIPTokenizer + unet: UNet2DConditionModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: StableDiffusionSafetyChecker + feature_extractor: CLIPFeatureExtractor + def __init__( self, vae: AutoencoderKL, @@ -231,14 +239,11 @@ def __call__( if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") latents = latents.to(self.device) + latents = self.scheduler.scale_initial_noise(latents) # set timesteps self.scheduler.set_timesteps(num_inference_steps) - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -248,13 +253,11 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + for step in self.progress_bar(self.scheduler.timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, step) + t = self.scheduler.get_noise_condition(step) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -265,10 +268,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 7de7925a302b..f9037dd5f2f7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -241,13 +241,7 @@ def __call__( offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - if isinstance(self.scheduler, LMSDiscreteScheduler): - timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device - ) - else: - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + timesteps = torch.tensor([num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) @@ -291,14 +285,11 @@ def __call__( latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) - for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): - t_index = t_start + i + for step in self.progress_bar(self.scheduler.timesteps[t_start:]): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[t_index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, step) + t = self.scheduler.get_noise_condition(step) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -309,14 +300,9 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index)) - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) + latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(step)) latents = (init_latents_proper * mask) + (latents * (1 - mask)) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 0613ffd41d0e..0e47e30e15e1 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -25,7 +25,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler, SchedulerMixin @dataclass @@ -75,7 +75,7 @@ def alpha_bar(time_step): return np.array(betas, dtype=np.float32) -class DDIMScheduler(SchedulerMixin, ConfigMixin): +class DDIMScheduler(BaseScheduler, SchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. @@ -147,7 +147,8 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.schedule = np.arange(0, num_train_timesteps) + self.timesteps = self.schedule[::-1].copy() self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) @@ -162,6 +163,12 @@ def _get_variance(self, timestep, prev_timestep): return variance + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for a model. + """ + return self.schedule[step] + def set_timesteps(self, num_inference_steps: int, **kwargs): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -186,8 +193,10 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() - self.timesteps += offset + self.schedule = (np.arange(0, num_inference_steps) * step_ratio).round().copy() + self.schedule += offset + + self.timesteps = np.arange(0, num_inference_steps)[::-1].copy() self.set_format(tensor_format=self.tensor_format) def step( @@ -236,6 +245,8 @@ def step( # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" + timestep = self.schedule[timestep] + # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps @@ -291,6 +302,7 @@ def add_noise( ) -> Union[torch.FloatTensor, np.ndarray]: if self.tensor_format == "pt": timesteps = timesteps.to(self.alphas_cumprod.device) + timesteps = self.schedule[timesteps] sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 1dd6dbda1e19..59a3ef8f7ad6 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler, SchedulerMixin @dataclass @@ -43,7 +43,7 @@ class LMSDiscreteSchedulerOutput(BaseOutput): pred_original_sample: Optional[torch.FloatTensor] = None -class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): +class LMSDiscreteScheduler(BaseScheduler, SchedulerMixin, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: @@ -93,15 +93,36 @@ def __init__( self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + self.sigmas = self.sigmas[::-1].copy() # setable values self.num_inference_steps = None self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.schedule = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float) self.derivatives = [] self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) + def scale_initial_noise(self, noise: torch.FloatTensor): + """ + Scales the initial noise to the correct range for the scheduler. + """ + return noise * self.sigmas[0] + + def scale_model_input(self, sample: torch.FloatTensor, step: int): + """ + Scales the model input (`sample`) to the correct range for the scheduler. + """ + sigma = self.sigmas[self.num_inference_steps - step - 1] + return sample / ((sigma**2 + 1) ** 0.5) + + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for a model. + """ + return self.schedule[step] + def get_lms_coefficient(self, order, t, current_order): """ Compute a linear multistep coefficient. @@ -133,13 +154,11 @@ def set_timesteps(self, num_inference_steps: int): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps - self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + self.timesteps = np.arange(0, num_inference_steps)[::-1].copy() - low_idx = np.floor(self.timesteps).astype(int) - high_idx = np.ceil(self.timesteps).astype(int) - frac = np.mod(self.timesteps, 1.0) + self.schedule = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + sigmas = np.interp(self.schedule[::-1], np.arange(0, len(sigmas)), sigmas) self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.derivatives = [] @@ -172,6 +191,7 @@ def step( When returning a tuple, the first element is the sample tensor. """ + timestep = int(self.num_inference_steps - timestep - 1) sigma = self.sigmas[timestep] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 09e8a7e240c2..79c1e2d8c790 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -54,7 +54,7 @@ def alpha_bar(time_step): return np.array(betas, dtype=np.float32) -class PNDMScheduler(SchedulerMixin, ConfigMixin): +class PNDMScheduler(BaseScheduler, SchedulerMixin, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. @@ -137,11 +137,17 @@ def __init__( self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.prk_timesteps = None self.plms_timesteps = None - self.timesteps = None + self.schedule = None self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for a model. + """ + return self.schedule[step] + def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -185,7 +191,9 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor ::-1 ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy - self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.schedule = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.schedule = self.schedule[::-1].copy() # FIXME: create the schedule in ascending order to avoid this + self.timesteps = np.arange(0, len(self.schedule))[::-1].copy() self.ets = [] self.counter = 0 @@ -217,6 +225,7 @@ def step( returning a tuple, the first element is the sample tensor. """ + timestep = self.schedule[timestep] if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) else: @@ -387,6 +396,7 @@ def add_noise( ) -> torch.Tensor: if self.tensor_format == "pt": timesteps = timesteps.to(self.alphas_cumprod.device) + timesteps = self.schedule[timesteps] sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index f2bcd73acf32..71d25084722c 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc from dataclasses import dataclass from typing import Union @@ -37,6 +38,28 @@ class SchedulerOutput(BaseOutput): prev_sample: torch.FloatTensor +class BaseScheduler(abc.ABC): + + def scale_initial_noise(self, noise: torch.FloatTensor): + """ + Scales the initial noise to the correct range for the scheduler. + """ + return noise + + def scale_model_input(self, sample: torch.FloatTensor, step: int): + """ + Scales the model input (`sample`) to the correct range for the scheduler. + """ + return sample + + @abc.abstractmethod + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for the model (e.g. `timestep` or `sigma`). + """ + raise NotImplementedError("Scheduler must implement the `get_noise_condition` function.") + + class SchedulerMixin: """ Mixin containing common functions for the schedulers. From 556e6872b9fdee31eebcf082bd42eb41ea0fcfa6 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 26 Sep 2022 12:24:40 +0200 Subject: [PATCH 2/8] cover more schedulers, fix onnxpipeline --- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 5 ++-- src/diffusers/pipelines/pndm/pipeline_pndm.py | 6 +++-- .../pipeline_stable_diffusion_img2img.py | 23 ++++--------------- .../pipeline_stable_diffusion_onnx.py | 14 ++++------- src/diffusers/schedulers/scheduling_ddpm.py | 16 +++++++++---- .../schedulers/scheduling_lms_discrete.py | 2 ++ 6 files changed, 31 insertions(+), 35 deletions(-) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index b7f7093e379b..2c768ad5becd 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -88,12 +88,13 @@ def __call__( # set step values self.scheduler.set_timesteps(1000) - for t in self.progress_bar(self.scheduler.timesteps): + for step in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output + t = self.scheduler.get_noise_condition(step) model_output = self.unet(image, t).sample # 2. compute previous image: x_t -> t_t-1 - image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + image = self.scheduler.step(model_output, step, image, generator=generator).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index ae6c10e9e967..14f223410e33 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -93,12 +93,14 @@ def __call__( generator=generator, ) image = image.to(self.device) + image = self.scheduler.scale_initial_noise(image) self.scheduler.set_timesteps(num_inference_steps) - for t in self.progress_bar(self.scheduler.timesteps): + for step in self.progress_bar(self.scheduler.timesteps): + t = self.scheduler.get_noise_condition(step) model_output = self.unet(image, t).sample - image = self.scheduler.step(model_output, t, image).prev_sample + image = self.scheduler.step(model_output, step, image).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 46299bf3b3e7..752abfa3501e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -203,13 +203,9 @@ def __call__( offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - if isinstance(self.scheduler, LMSDiscreteScheduler): - timesteps = torch.tensor( + timesteps = torch.tensor( [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device ) - else: - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) @@ -254,17 +250,11 @@ def __call__( latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])): - t_index = t_start + i - + for step in self.progress_bar(self.scheduler.timesteps[t_start:]): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[t_index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, step) + t = self.scheduler.get_noise_condition(step) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -275,10 +265,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index ccba29ade5d3..fe3be5ee20e1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -98,6 +98,7 @@ def __call__( latents = np.random.randn(*latents_shape).astype(np.float32) elif latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = self.scheduler.scale_initial_noise(latents) # set timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -115,13 +116,11 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + for step in self.progress_bar(self.scheduler.timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, step) + t = self.scheduler.get_noise_condition(step) # predict the noise residual noise_pred = self.unet( @@ -135,10 +134,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 440b880385d4..85c94bb4bbf0 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler, SchedulerMixin @dataclass @@ -73,7 +73,7 @@ def alpha_bar(time_step): return np.array(betas, dtype=np.float32) -class DDPMScheduler(SchedulerMixin, ConfigMixin): +class DDPMScheduler(BaseScheduler, SchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. @@ -134,13 +134,20 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.schedule = np.arange(0, num_train_timesteps) + self.timesteps = self.schedule[::-1].copy() self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) self.variance_type = variance_type + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for a model. + """ + return self.schedule[step] + def set_timesteps(self, num_inference_steps: int): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -218,7 +225,7 @@ def step( returning a tuple, the first element is the sample tensor. """ - t = timestep + t = self.schedule[timestep] if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) @@ -272,6 +279,7 @@ def add_noise( ) -> Union[torch.FloatTensor, np.ndarray]: if self.tensor_format == "pt": timesteps = timesteps.to(self.alphas_cumprod.device) + timesteps = self.schedule[timesteps] sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 59a3ef8f7ad6..ec0d87efeb7e 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -225,6 +225,8 @@ def add_noise( ) -> Union[torch.FloatTensor, np.ndarray]: if self.tensor_format == "pt": timesteps = timesteps.to(self.sigmas.device) + # FIXME: accounting for the descending sigmas + timesteps = self.num_inference_steps - timesteps - 1 sigmas = self.match_shape(self.sigmas[timesteps], noise) noisy_samples = original_samples + noise * sigmas From ccc6afb7c05a73864fb3c80e56cbb986fdea7c7b Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 26 Sep 2022 12:52:10 +0200 Subject: [PATCH 3/8] style --- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 4 ++-- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 4 +++- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 +- src/diffusers/schedulers/scheduling_utils.py | 1 - 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 752abfa3501e..30d3b68f0f0e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -204,8 +204,8 @@ def __call__( init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device - ) + [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device + ) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index f9037dd5f2f7..b7ef494c8329 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -241,7 +241,9 @@ def __call__( offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - timesteps = torch.tensor([num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device) + timesteps = torch.tensor( + [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device + ) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index ec0d87efeb7e..d75e2bd47656 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -226,7 +226,7 @@ def add_noise( if self.tensor_format == "pt": timesteps = timesteps.to(self.sigmas.device) # FIXME: accounting for the descending sigmas - timesteps = self.num_inference_steps - timesteps - 1 + timesteps = self.num_inference_steps - timesteps sigmas = self.match_shape(self.sigmas[timesteps], noise) noisy_samples = original_samples + noise * sigmas diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 71d25084722c..010977f0ce7a 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -39,7 +39,6 @@ class SchedulerOutput(BaseOutput): class BaseScheduler(abc.ABC): - def scale_initial_noise(self, noise: torch.FloatTensor): """ Scales the initial noise to the correct range for the scheduler. From ba351f5a3c54fc6513517e7349f41963392d77ba Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 26 Sep 2022 14:52:18 +0200 Subject: [PATCH 4/8] fixed all tests --- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 4 +--- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 4 +--- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 30d3b68f0f0e..a82a892316d3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -203,9 +203,7 @@ def __call__( offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device - ) + timesteps = torch.tensor([init_timestep - 1] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b7ef494c8329..4a423d5dd225 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -241,9 +241,7 @@ def __call__( offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device - ) + timesteps = torch.tensor([init_timestep - 1] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index d75e2bd47656..ec0d87efeb7e 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -226,7 +226,7 @@ def add_noise( if self.tensor_format == "pt": timesteps = timesteps.to(self.sigmas.device) # FIXME: accounting for the descending sigmas - timesteps = self.num_inference_steps - timesteps + timesteps = self.num_inference_steps - timesteps - 1 sigmas = self.match_shape(self.sigmas[timesteps], noise) noisy_samples = original_samples + noise * sigmas From f4e717e1b7992be461e68f93baa3fc60d6b433f2 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 26 Sep 2022 15:06:51 +0200 Subject: [PATCH 5/8] style --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 4a423d5dd225..f7c30a1ddd2d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -6,7 +6,6 @@ import torch import PIL -from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict From f58846d06e50e94696eb6ed51ddcd5de6b913e39 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 26 Sep 2022 15:47:29 +0200 Subject: [PATCH 6/8] fix scheduler tests --- .../schedulers/scheduling_lms_discrete.py | 8 ++++--- tests/test_scheduler.py | 21 ++++++++++--------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index ec0d87efeb7e..854fd49dbf72 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -92,8 +92,9 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - self.sigmas = self.sigmas[::-1].copy() + sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + sigmas = sigmas[::-1].copy() + self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) # setable values self.num_inference_steps = None @@ -191,7 +192,8 @@ def step( When returning a tuple, the first element is the sample tensor. """ - timestep = int(self.num_inference_steps - timestep - 1) + # FIXME: accounting for the descending sigmas + timestep = int(len(self.timesteps) - timestep - 1) sigma = self.sigmas[timestep] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 7377797bebfa..8664f067bb8a 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -229,7 +229,7 @@ def recursive_check(tuple_object, dict_object): ) kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) + num_inference_steps = kwargs.pop("num_inference_steps", 50) for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config() @@ -369,9 +369,10 @@ def full_loop(self, **config): scheduler.set_timesteps(num_inference_steps) - for t in scheduler.timesteps: + for step in scheduler.timesteps: + t = scheduler.get_noise_condition(step) residual = model(sample, t) - sample = scheduler.step(residual, t, sample, eta).prev_sample + sample = scheduler.step(residual, step, sample, eta).prev_sample return sample @@ -387,7 +388,7 @@ def test_steps_offset(self): scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(5) - assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1])) + assert torch.equal(scheduler.schedule[scheduler.timesteps], torch.tensor([801, 601, 401, 201, 1])) def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): @@ -668,7 +669,7 @@ def test_steps_offset(self): scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(10) assert torch.equal( - scheduler.timesteps, + scheduler.schedule[scheduler.timesteps], torch.tensor( [901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1] ), @@ -988,14 +989,14 @@ def test_full_loop_no_noise(self): scheduler.set_timesteps(self.num_inference_steps) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.sigmas[0] - - for i, t in enumerate(scheduler.timesteps): - sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5) + sample = scheduler.scale_initial_noise(self.dummy_sample_deter) + for step in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, step) + t = scheduler.get_noise_condition(step) model_output = model(sample, t) - output = scheduler.step(model_output, i, sample) + output = scheduler.step(model_output, step, sample) sample = output.prev_sample result_sum = torch.sum(torch.abs(sample)) From 0d0395bdb9e84b9195f50d190a83807888c338dd Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 27 Sep 2022 18:49:41 +0200 Subject: [PATCH 7/8] Fix tests after merging --- src/diffusers/__init__.py | 1 + src/diffusers/pipeline_utils.py | 1 + .../pipeline_stable_diffusion_onnx.py | 2 +- src/diffusers/schedulers/__init__.py | 2 +- src/diffusers/schedulers/scheduling_ddim.py | 5 ++--- src/diffusers/schedulers/scheduling_ddpm.py | 4 ++-- src/diffusers/schedulers/scheduling_karras_ve.py | 10 ++++++++-- src/diffusers/schedulers/scheduling_lms_discrete.py | 13 ++++++------- src/diffusers/schedulers/scheduling_pndm.py | 4 ++-- src/diffusers/schedulers/scheduling_sde_ve.py | 10 ++++++++-- src/diffusers/schedulers/scheduling_sde_vp.py | 10 ++++++++-- src/diffusers/schedulers/scheduling_utils.py | 2 ++ 12 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index acdddaac4d26..802fa9a903ac 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,6 +31,7 @@ from .pipeline_utils import DiffusionPipeline from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline from .schedulers import ( + BaseScheduler, DDIMScheduler, DDPMScheduler, KarrasVeScheduler, diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index fb8801bc959a..a04ee652c5a7 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -43,6 +43,7 @@ LOADABLE_CLASSES = { "diffusers": { "ModelMixin": ["save_pretrained", "from_pretrained"], + "BaseScheduler": ["save_config", "from_config"], "SchedulerMixin": ["save_config", "from_config"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"], "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 6c76db4d0d90..6adc5e23e78d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -73,7 +73,7 @@ def __call__( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, - return_tensors="pt", + return_tensors="np", ) text_input_ids = text_inputs.input_ids diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 495f30d9fabd..2828753e72d9 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -23,7 +23,7 @@ from .scheduling_pndm import PNDMScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler - from .scheduling_utils import SchedulerMixin + from .scheduling_utils import BaseScheduler, SchedulerMixin else: from ..utils.dummy_pt_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 1c5cd91e967f..68134706f4b8 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -25,7 +25,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import BaseScheduler, SchedulerMixin +from .scheduling_utils import BaseScheduler @dataclass @@ -75,7 +75,7 @@ def alpha_bar(time_step): return torch.tensor(betas) -class DDIMScheduler(BaseScheduler, SchedulerMixin, ConfigMixin): +class DDIMScheduler(BaseScheduler, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. @@ -194,7 +194,6 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): self.schedule += offset self.timesteps = np.arange(0, num_inference_steps)[::-1].copy() - self.set_format(tensor_format=self.tensor_format) def step( self, diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 31ea55e6bf83..55c3d7075804 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import BaseScheduler, SchedulerMixin +from .scheduling_utils import BaseScheduler @dataclass @@ -73,7 +73,7 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -class DDPMScheduler(BaseScheduler, SchedulerMixin, ConfigMixin): +class DDPMScheduler(BaseScheduler, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 5826858faee4..6ee6692597e8 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler @dataclass @@ -45,7 +45,7 @@ class KarrasVeOutput(BaseOutput): pred_original_sample: Optional[torch.FloatTensor] = None -class KarrasVeScheduler(SchedulerMixin, ConfigMixin): +class KarrasVeScheduler(BaseScheduler, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. @@ -92,6 +92,12 @@ def __init__( self.timesteps: np.ndarray = None self.schedule: torch.FloatTensor = None # sigma(t_i) + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for the model. + """ + return self.schedule[step] + def set_timesteps(self, num_inference_steps: int): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 4855f42bf326..d706b22eab1d 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import BaseScheduler, SchedulerMixin +from .scheduling_utils import BaseScheduler @dataclass @@ -43,7 +43,7 @@ class LMSDiscreteSchedulerOutput(BaseOutput): pred_original_sample: Optional[torch.FloatTensor] = None -class LMSDiscreteScheduler(BaseScheduler, SchedulerMixin, ConfigMixin): +class LMSDiscreteScheduler(BaseScheduler, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: @@ -90,9 +90,9 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - sigmas = sigmas[::-1].copy() - self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None @@ -150,7 +150,7 @@ def set_timesteps(self, num_inference_steps: int): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps - timesteps = np.arange(0, num_inference_steps)[::-1].copy() + self.timesteps = np.arange(0, num_inference_steps)[::-1].copy() self.schedule = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) @@ -158,7 +158,6 @@ def set_timesteps(self, num_inference_steps: int): sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - self.timesteps = timesteps.astype(int) self.derivatives = [] def step( diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 53dd5763c7a6..eb0a36de42ee 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import BaseScheduler, SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -54,7 +54,7 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -class PNDMScheduler(BaseScheduler, SchedulerMixin, ConfigMixin): +class PNDMScheduler(BaseScheduler, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 7b06ae16c5e9..ce94b0ee9a8b 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput @dataclass @@ -43,7 +43,7 @@ class SdeVeOutput(BaseOutput): prev_sample_mean: torch.FloatTensor -class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): +class ScoreSdeVeScheduler(BaseScheduler, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. @@ -82,6 +82,12 @@ def __init__( self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for a model. + """ + return self.timesteps[step] + def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index 2f9821579c52..3a5bb9847d5f 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -21,10 +21,10 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler -class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): +class ScoreSdeVpScheduler(BaseScheduler, ConfigMixin): """ The variance preserving stochastic differential equation (SDE) scheduler. @@ -45,6 +45,12 @@ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling self.discrete_sigmas = None self.timesteps = None + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for a model. + """ + return self.discrete_sigmas[step] + def set_timesteps(self, num_inference_steps): self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 2f0675282798..8462a61db201 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -37,6 +37,8 @@ class SchedulerOutput(BaseOutput): class BaseScheduler(abc.ABC): + config_name = SCHEDULER_CONFIG_NAME + def scale_initial_noise(self, noise: torch.FloatTensor): """ Scales the initial noise to the correct range for the scheduler. From 9cfd2dcc6754362f5d3ea142148e381c0e2c51ce Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 27 Sep 2022 21:38:08 +0200 Subject: [PATCH 8/8] Fix numerical issues introduced with pytorch --- .../schedulers/scheduling_lms_discrete.py | 16 +++++++++------- tests/test_scheduler.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index bba1a33923cb..a09fed168f2e 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -106,7 +106,8 @@ def __init__( # setable values self.num_inference_steps = None self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - self.schedule = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float) + schedule = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float) + self.schedule = torch.from_numpy(schedule) self.derivatives = [] def scale_initial_noise(self, noise: torch.FloatTensor): @@ -161,11 +162,12 @@ def set_timesteps(self, num_inference_steps: int): self.num_inference_steps = num_inference_steps self.timesteps = np.arange(0, num_inference_steps)[::-1].copy() - self.schedule = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float) + schedule = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.interp(self.schedule[::-1], np.arange(0, len(sigmas)), sigmas) + sigmas = np.interp(schedule[::-1], np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) + self.schedule = torch.from_numpy(schedule) self.derivatives = [] @@ -224,10 +226,10 @@ def step( def add_noise( self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], - ) -> Union[torch.FloatTensor, np.ndarray]: + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: sigmas = self.sigmas.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # FIXME: accounting for the descending sigmas diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 639d9c21b802..483d8bb6803f 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -877,5 +877,5 @@ def test_full_loop_no_noise(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 1006.370) < 1e-2 + assert abs(result_sum.item() - 1006.388) < 1e-2 assert abs(result_mean.item() - 1.31) < 1e-3