From 58d2c673b66b8be5c71bdd6dc6ecadca40d08abe Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Oct 2022 18:57:12 +0000 Subject: [PATCH] up --- src/diffusers/configuration_utils.py | 4 ++++ .../pipeline_stable_diffusion.py | 18 +++++++----------- src/diffusers/schedulers/__init__.py | 2 +- .../schedulers/scheduling_lms_discrete.py | 10 ++++++---- src/diffusers/schedulers/scheduling_utils.py | 9 +++++++++ 5 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 59c93157891b..0253e72005cb 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -55,6 +55,10 @@ class ConfigMixin: def register_to_config(self, **kwargs): if self.config_name is None: raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") + + if hasattr(self, "type") and self.type is None: + raise NotImplementedError(f"Make sure that {self.__class__} has defined a scheduler type") + kwargs["_class_name"] = self.__class__.__name__ kwargs["_diffusers_version"] = __version__ diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 45939e410672..0798fa627f65 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -9,7 +9,7 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, SchedulerType from ...utils import logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -259,8 +259,8 @@ def __call__( timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] + if self.scheduler.type == SchedulerType.CONTINUOUS: + latents = latents * self.scheduler.init_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -274,10 +274,9 @@ def __call__( for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + if self.scheduler.type == SchedulerType.CONTINUOUS and self.model.config.trained_scheduler_type == SchedulerType.DISCRETE: + latent_model_input = self.scheduler.scale_(latent_model_input) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -288,10 +287,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index a906c39eb24c..12889d3c7b4f 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -23,7 +23,7 @@ from .scheduling_pndm import PNDMScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler - from .scheduling_utils import SchedulerMixin + from .scheduling_utils import SchedulerMixin, SchedulerType else: from ..utils.dummy_pt_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 8fd8c2b844a8..bcee646956e4 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import SchedulerMixin, SchedulerType @dataclass @@ -66,6 +66,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. """ + type = SchedulerType.CONTINUOUS @register_to_config def __init__( @@ -178,7 +179,8 @@ def step( When returning a tuple, the first element is the sample tensor. """ - sigma = self.sigmas[timestep] + index = (self.config.num_train_timesteps - timestep) // (self.config.num_train_timesteps // self.num_inference_steps) + sigma = self.sigmas[index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise pred_original_sample = sample - sigma * model_output @@ -190,8 +192,8 @@ def step( self.derivatives.pop(0) # 3. Compute linear multistep coefficients - order = min(timestep + 1, order) - lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)] + order = min(index + 1, order) + lms_coeffs = [self.get_lms_coefficient(order, index, curr_order) for curr_order in range(order)] # 4. Compute previous sample based on the derivatives path prev_sample = sample + sum( diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 1cc1d94414a6..543a5b4622f0 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -13,6 +13,8 @@ # limitations under the License. import warnings from dataclasses import dataclass +from typing import Optional +from enum import Enum import torch @@ -22,6 +24,12 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json" +class SchedulerType(Enum): + + CONTINUOUS = "continous" + DISCRETE = "discrete" + + @dataclass class SchedulerOutput(BaseOutput): """ @@ -42,6 +50,7 @@ class SchedulerMixin: """ config_name = SCHEDULER_CONFIG_NAME + type: Optional[SchedulerType] = None def set_format(self, tensor_format="pt"): warnings.warn(