diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index acdddaac4d26..9caff8e38314 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,11 +31,11 @@ from .pipeline_utils import DiffusionPipeline from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline from .schedulers import ( + BaseScheduler, DDIMScheduler, DDPMScheduler, KarrasVeScheduler, PNDMScheduler, - SchedulerMixin, ScoreSdeVeScheduler, ) from .training_utils import EMAModel diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 6cfd7ae32112..2702cedc7907 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -30,7 +30,7 @@ from .configuration_utils import ConfigMixin from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin -from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin +from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, BaseScheduler from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging @@ -436,7 +436,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) params[name] = loaded_params - elif issubclass(class_obj, SchedulerMixin): + elif issubclass(class_obj, BaseScheduler): loaded_sub_model, scheduler_state = load_method(loadable_folder) params[name] = scheduler_state else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 216a76a55997..17d3c6efed7d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -42,6 +42,14 @@ class StableDiffusionPipeline(DiffusionPipeline): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + vae: AutoencoderKL + text_encoder: CLIPTextModel + tokenizer: CLIPTokenizer + unet: UNet2DConditionModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: StableDiffusionSafetyChecker + feature_extractor: CLIPFeatureExtractor + def __init__( self, vae: AutoencoderKL, @@ -231,30 +239,15 @@ def __call__( if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") latents = latents.to(self.device) + latents = self.scheduler.scale_initial_noise(latents) # set timesteps self.scheduler.set_timesteps(num_inference_steps) - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + for t in self.progress_bar(self.scheduler.timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input, t = self.scheduler.scale_model_inputs(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -264,11 +257,7 @@ def __call__( noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 495f30d9fabd..dab373b23cc5 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -23,7 +23,7 @@ from .scheduling_pndm import PNDMScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler - from .scheduling_utils import SchedulerMixin + from .scheduling_utils import BaseScheduler else: from ..utils.dummy_pt_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index a5369b1603c6..a0a8533dffae 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -23,7 +23,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -55,7 +55,7 @@ def alpha_bar(time_step): return np.array(betas, dtype=np.float32) -class DDIMScheduler(SchedulerMixin, ConfigMixin): +class DDIMScheduler(BaseScheduler, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index d81d66607147..c920257cd0d5 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -23,7 +23,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -72,7 +72,7 @@ class FlaxSchedulerOutput(SchedulerOutput): state: DDIMSchedulerState -class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): +class FlaxDDIMScheduler(BaseScheduler, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d008b84da6e7..653c85e85f85 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -20,40 +20,10 @@ import numpy as np import torch -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas, dtype=np.float32) - - -class DDPMScheduler(SchedulerMixin, ConfigMixin): +class DDPMScheduler(BaseScheduler): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. @@ -83,43 +53,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): """ - @register_to_config def __init__( self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, - tensor_format: str = "pt", + **kwargs, ): - if trained_betas is not None: - self.betas = np.asarray(trained_betas) - elif beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - self.one = np.array(1.0) - - # setable values - self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - + super().__init__(**kwargs) self.variance_type = variance_type + self.clip_sample = clip_sample def set_timesteps(self, num_inference_steps: int): """ diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 7c7b8d29ab52..e120aa331f05 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -23,7 +23,7 @@ from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -71,7 +71,7 @@ class FlaxSchedulerOutput(SchedulerOutput): state: DDPMSchedulerState -class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): +class FlaxDDPMScheduler(BaseScheduler, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index caf7625fb683..2ea8866f7f68 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler @dataclass @@ -41,7 +41,7 @@ class KarrasVeOutput(BaseOutput): derivative: torch.FloatTensor -class KarrasVeScheduler(SchedulerMixin, ConfigMixin): +class KarrasVeScheduler(BaseScheduler, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. @@ -74,16 +74,15 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): """ - @register_to_config - def __init__( - self, - sigma_min: float = 0.02, - sigma_max: float = 100, - s_noise: float = 1.007, - s_churn: float = 80, - s_min: float = 0.05, - s_max: float = 50, - tensor_format: str = "pt", + def __old_init__( + self, + sigma_min: float = 0.02, + sigma_max: float = 100, + s_noise: float = 1.007, + s_churn: float = 80, + s_min: float = 0.05, + s_max: float = 50, + tensor_format: str = "pt", ): # setable values self.num_inference_steps = None @@ -93,7 +92,17 @@ def __init__( self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) - def set_timesteps(self, num_inference_steps: int): + def __init__( + self, + s_noise: float = 1.007, + s_churn: float = 80, + s_min: float = 0.05, + s_max: float = 50, + **kwargs, + ): + super().__init__(s_noise=s_noise, s_churn=s_churn, s_min=s_min, s_max=s_max, **kwargs) + + def set_schedule(self, num_inference_steps: int): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index 51b16c96fa06..8c29d9dd8730 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler @flax.struct.dataclass @@ -56,7 +56,7 @@ class FlaxKarrasVeOutput(BaseOutput): state: KarrasVeSchedulerState -class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): +class FlaxKarrasVeScheduler(BaseScheduler, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 5857ae70a856..60929b6bf1df 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -20,10 +20,11 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput +from ..models import UNet2DModel -class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): +class LMSDiscreteScheduler(BaseScheduler, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: @@ -50,38 +51,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): """ @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", - ): - if trained_betas is not None: - self.betas = np.asarray(trained_betas) - if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.schedule = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 # setable values - self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.derivatives = [] - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - def get_lms_coefficient(self, order, t, current_order): """ Compute a linear multistep coefficient. @@ -104,7 +81,7 @@ def lms_derivative(tau): return integrated_coeff - def set_timesteps(self, num_inference_steps: int): + def old_set_timesteps(self, num_inference_steps: int): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -126,6 +103,28 @@ def set_timesteps(self, num_inference_steps: int): self.set_format(tensor_format=self.tensor_format) + def set_timesteps(self, num_inference_steps: int): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + + self.timesteps + + self.derivatives = [] + + self.set_format(tensor_format=self.tensor_format) + + def scale_model_inputs(self, sample, noise_cond, sigma): + return sample / ((sigma**2 + 1) ** 0.5), noise_cond + + def scale_initial_noise(self, noise): + return noise * self.sigmas[0] + def step( self, model_output: Union[torch.FloatTensor, np.ndarray], @@ -133,6 +132,7 @@ def step( sample: Union[torch.FloatTensor, np.ndarray], order: int = 4, return_dict: bool = True, + **kwargs, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 7f4c076b54d1..6487347d810d 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -20,7 +20,7 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput @flax.struct.dataclass @@ -41,7 +41,7 @@ class FlaxSchedulerOutput(SchedulerOutput): state: LMSDiscreteSchedulerState -class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): +class FlaxLMSDiscreteScheduler(BaseScheduler, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 09e8a7e240c2..99b10df871d3 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -54,7 +54,7 @@ def alpha_bar(time_step): return np.array(betas, dtype=np.float32) -class PNDMScheduler(SchedulerMixin, ConfigMixin): +class PNDMScheduler(BaseScheduler, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 8344505620c4..b60c0da7ebfa 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -22,7 +22,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: @@ -79,7 +79,7 @@ class FlaxSchedulerOutput(SchedulerOutput): state: PNDMSchedulerState -class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): +class FlaxPNDMScheduler(BaseScheduler, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 4af8f4fdad7d..02cf70ef591b 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput @dataclass @@ -43,7 +43,7 @@ class SdeVeOutput(BaseOutput): prev_sample_mean: torch.FloatTensor -class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): +class ScoreSdeVeScheduler(BaseScheduler, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index 08fbe14732da..87adfc0ffe03 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -22,7 +22,7 @@ from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput @flax.struct.dataclass @@ -56,7 +56,7 @@ class FlaxSdeVeOutput(SchedulerOutput): prev_sample_mean: Optional[jnp.ndarray] = None -class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): +class FlaxScoreSdeVeScheduler(BaseScheduler, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index f19a5ad76f81..ffc36e914ccf 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -20,10 +20,10 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler -class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): +class ScoreSdeVpScheduler(BaseScheduler, ConfigMixin): """ The variance preserving stochastic differential equation (SDE) scheduler. diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index f2bcd73acf32..af57d8d629f1 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -11,18 +11,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import numpy as np import torch +from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput SCHEDULER_CONFIG_NAME = "scheduler_config.json" +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float32) + + @dataclass class SchedulerOutput(BaseOutput): """ @@ -37,7 +68,7 @@ class SchedulerOutput(BaseOutput): prev_sample: torch.FloatTensor -class SchedulerMixin: +class BaseScheduler(ConfigMixin): """ Mixin containing common functions for the schedulers. """ @@ -45,6 +76,66 @@ class SchedulerMixin: config_name = SCHEDULER_CONFIG_NAME ignore_for_config = ["tensor_format"] + @register_to_config + def __init__( + self, + beta_start: Optional[float] = None, + beta_end: Optional[float] = None, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + num_train_timesteps: int = 1000, + beta_schedule: Optional[str] = None, + trained_betas: Optional[np.ndarray] = None, + tensor_format: str = "pt", + **kwargs, + ): + if beta_start is not None and beta_end is not None: + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + elif beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.one = np.array(1.0) + self.schedule_type = "beta" + elif sigma_min is not None and sigma_max is not None: + self.schedule = None + else: + raise ValueError("Either beta_start and beta_end or sigma_min and sigma_max must be provided.") + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(num_train_timesteps, 0, dtype=int) + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_schedule(self, num_inference_steps: int): + raise NotImplementedError("set_schedule is not implemented for this scheduler.") + + def t_to_sigma(self, t: int): + raise NotImplementedError("t_to_sigma is not implemented for this scheduler.") + + def sigma_to_t(self, t: int): + raise NotImplementedError("sigma_to_t is not implemented for this scheduler.") + + def scale_model_inputs(self, sample, noise_cond, sigma=None): + return sample, noise_cond + + def scale_initial_noise(self, noise): + return noise + def set_format(self, tensor_format="pt"): self.tensor_format = tensor_format if tensor_format == "pt":