Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import (
BaseScheduler,
DDIMScheduler,
DDPMScheduler,
KarrasVeScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
)
from .training_utils import EMAModel
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from .configuration_utils import ConfigMixin
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, BaseScheduler
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging


Expand Down Expand Up @@ -436,7 +436,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
params[name] = loaded_params
elif issubclass(class_obj, SchedulerMixin):
elif issubclass(class_obj, BaseScheduler):
loaded_sub_model, scheduler_state = load_method(loadable_folder)
params[name] = scheduler_state
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""

vae: AutoencoderKL
text_encoder: CLIPTextModel
tokenizer: CLIPTokenizer
unet: UNet2DConditionModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: StableDiffusionSafetyChecker
feature_extractor: CLIPFeatureExtractor

def __init__(
self,
vae: AutoencoderKL,
Expand Down Expand Up @@ -231,30 +239,15 @@ def __call__(
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
latents = self.scheduler.scale_initial_noise(latents)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta

for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
for t in self.progress_bar(self.scheduler.timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input, t = self.scheduler.scale_model_inputs(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand All @@ -264,11 +257,7 @@ def __call__(
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .scheduling_pndm import PNDMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import BaseScheduler
else:
from ..utils.dummy_pt_objects import * # noqa F403

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import BaseScheduler, SchedulerOutput


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
Expand Down Expand Up @@ -55,7 +55,7 @@ def alpha_bar(time_step):
return np.array(betas, dtype=np.float32)


class DDIMScheduler(SchedulerMixin, ConfigMixin):
class DDIMScheduler(BaseScheduler, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_ddim_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import BaseScheduler, SchedulerOutput


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
Expand Down Expand Up @@ -72,7 +72,7 @@ class FlaxSchedulerOutput(SchedulerOutput):
state: DDIMSchedulerState


class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
class FlaxDDIMScheduler(BaseScheduler, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
Expand Down
68 changes: 5 additions & 63 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,10 @@
import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import BaseScheduler, SchedulerOutput


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].

Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.


Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.

Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""

def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32)


class DDPMScheduler(SchedulerMixin, ConfigMixin):
class DDPMScheduler(BaseScheduler):
"""
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling.
Expand Down Expand Up @@ -83,43 +53,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):

"""

@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
tensor_format: str = "pt",
**kwargs,
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()

self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)

super().__init__(**kwargs)
self.variance_type = variance_type
self.clip_sample = clip_sample

def set_timesteps(self, num_inference_steps: int):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_ddpm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax import random

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import BaseScheduler, SchedulerOutput


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
Expand Down Expand Up @@ -71,7 +71,7 @@ class FlaxSchedulerOutput(SchedulerOutput):
state: DDPMSchedulerState


class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
class FlaxDDPMScheduler(BaseScheduler, ConfigMixin):
"""
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling.
Expand Down
35 changes: 22 additions & 13 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import BaseScheduler


@dataclass
Expand All @@ -41,7 +41,7 @@ class KarrasVeOutput(BaseOutput):
derivative: torch.FloatTensor


class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
class KarrasVeScheduler(BaseScheduler, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference.
Expand Down Expand Up @@ -74,16 +74,15 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):

"""

@register_to_config
def __init__(
self,
sigma_min: float = 0.02,
sigma_max: float = 100,
s_noise: float = 1.007,
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
tensor_format: str = "pt",
def __old_init__(
self,
sigma_min: float = 0.02,
sigma_max: float = 100,
s_noise: float = 1.007,
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
tensor_format: str = "pt",
):
# setable values
self.num_inference_steps = None
Expand All @@ -93,7 +92,17 @@ def __init__(
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)

def set_timesteps(self, num_inference_steps: int):
def __init__(
self,
s_noise: float = 1.007,
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
**kwargs,
):
super().__init__(s_noise=s_noise, s_churn=s_churn, s_min=s_min, s_max=s_max, **kwargs)

def set_schedule(self, num_inference_steps: int):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_karras_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import BaseScheduler


@flax.struct.dataclass
Expand Down Expand Up @@ -56,7 +56,7 @@ class FlaxKarrasVeOutput(BaseOutput):
state: KarrasVeSchedulerState


class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
class FlaxKarrasVeScheduler(BaseScheduler, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference.
Expand Down
Loading