diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index acdddaac4d26..802fa9a903ac 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,6 +31,7 @@ from .pipeline_utils import DiffusionPipeline from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline from .schedulers import ( + BaseScheduler, DDIMScheduler, DDPMScheduler, KarrasVeScheduler, diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 09a7baad560d..82ca5dbb6f56 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -17,7 +17,6 @@ "jaxlib": "jaxlib>=0.1.65,<=0.3.6", "modelcards": "modelcards>=0.1.4", "numpy": "numpy", - "onnxruntime": "onnxruntime", "onnxruntime-gpu": "onnxruntime-gpu", "pytest": "pytest", "pytest-timeout": "pytest-timeout", diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index fb8801bc959a..a04ee652c5a7 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -43,6 +43,7 @@ LOADABLE_CLASSES = { "diffusers": { "ModelMixin": ["save_pretrained", "from_pretrained"], + "BaseScheduler": ["save_config", "from_config"], "SchedulerMixin": ["save_config", "from_config"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"], "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 74607fe87a3d..58808e5db0cf 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -82,14 +82,15 @@ def __call__( # set step values self.scheduler.set_timesteps(num_inference_steps) - for t in self.progress_bar(self.scheduler.timesteps): + for step in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output + t = self.scheduler.get_noise_condition(step) model_output = self.unet(image, t).sample # 2. predict previous mean of image x_t-1 and add variance depending on eta # eta corresponds to η in paper and should be between [0, 1] # do x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, eta).prev_sample + image = self.scheduler.step(model_output, step, image, eta).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index aae29737aae3..4258e79681a2 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -75,12 +75,13 @@ def __call__( # set step values self.scheduler.set_timesteps(1000) - for t in self.progress_bar(self.scheduler.timesteps): + for step in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output + t = self.scheduler.get_noise_condition(step) model_output = self.unet(image, t).sample # 2. compute previous image: x_t -> t_t-1 - image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + image = self.scheduler.step(model_output, step, image, generator=generator).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 556e4211892b..63e27ef23707 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -117,6 +117,7 @@ def __call__( generator=generator, ) latents = latents.to(self.device) + latents = self.scheduler.scale_initial_noise(latents) self.scheduler.set_timesteps(num_inference_steps) @@ -127,7 +128,7 @@ def __call__( if accepts_eta: extra_kwargs["eta"] = eta - for t in self.progress_bar(self.scheduler.timesteps): + for step in self.progress_bar(self.scheduler.timesteps): if guidance_scale == 1.0: # guidance_scale of 1 means no guidance latents_input = latents @@ -140,6 +141,8 @@ def __call__( context = torch.cat([uncond_embeddings, text_embeddings]) # predict the noise residual + latents_input = self.scheduler.scale_model_input(latents_input, step) + t = self.scheduler.get_noise_condition(step) noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample # perform guidance if guidance_scale != 1.0: @@ -147,7 +150,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, step, latents, **extra_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index ef82abb7e6cb..377d98e5338b 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -63,6 +63,7 @@ def __call__( generator=generator, ) latents = latents.to(self.device) + latents = self.scheduler.scale_initial_noise(latents) self.scheduler.set_timesteps(num_inference_steps) @@ -73,11 +74,13 @@ def __call__( if accepts_eta: extra_kwargs["eta"] = eta - for t in self.progress_bar(self.scheduler.timesteps): + for step in self.progress_bar(self.scheduler.timesteps): # predict the noise residual + latents = self.scheduler.scale_model_input(latents, step) + t = self.scheduler.get_noise_condition(step) noise_prediction = self.unet(latents, t).sample # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample + latents = self.scheduler.step(noise_prediction, step, latents, **extra_kwargs).prev_sample # decode the image latents with the VAE image = self.vqvae.decode(latents).sample diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index f360da09ac94..29aa5435ec62 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -79,12 +79,14 @@ def __call__( generator=generator, ) image = image.to(self.device) + image = self.scheduler.scale_initial_noise(image) self.scheduler.set_timesteps(num_inference_steps) - for t in self.progress_bar(self.scheduler.timesteps): + for step in self.progress_bar(self.scheduler.timesteps): + t = self.scheduler.get_noise_condition(step) model_output = self.unet(image, t).sample - image = self.scheduler.step(model_output, t, image).prev_sample + image = self.scheduler.step(model_output, step, image).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 77f25ef1b9c5..62ea4247bff5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -46,6 +46,14 @@ class StableDiffusionPipeline(DiffusionPipeline): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + vae: AutoencoderKL + text_encoder: CLIPTextModel + tokenizer: CLIPTokenizer + unet: UNet2DConditionModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: StableDiffusionSafetyChecker + feature_extractor: CLIPFeatureExtractor + def __init__( self, vae: AutoencoderKL, @@ -230,14 +238,11 @@ def __call__( if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") latents = latents.to(self.device) + latents = self.scheduler.scale_initial_noise(latents) # set timesteps self.scheduler.set_timesteps(num_inference_steps) - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -247,13 +252,11 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + for step in self.progress_bar(self.scheduler.timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, step) + t = self.scheduler.get_noise_condition(step) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -264,10 +267,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f2ccee71c024..c9c8f4628901 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -206,13 +206,7 @@ def __call__( offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - if isinstance(self.scheduler, LMSDiscreteScheduler): - timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device - ) - else: - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + timesteps = torch.tensor([init_timestep - 1] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) @@ -265,17 +259,11 @@ def __call__( latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])): - t_index = t_start + i - + for step in self.progress_bar(self.scheduler.timesteps[t_start:]): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[t_index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, step) + t = self.scheduler.get_noise_condition(step) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -286,10 +274,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index a95f9152279a..6deb6a2a3285 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -6,7 +6,6 @@ import torch import PIL -from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -240,13 +239,7 @@ def __call__( offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - if isinstance(self.scheduler, LMSDiscreteScheduler): - timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device - ) - else: - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + timesteps = torch.tensor([init_timestep - 1] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) @@ -298,14 +291,11 @@ def __call__( latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) - for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): - t_index = t_start + i + for step in self.progress_bar(self.scheduler.timesteps[t_start:]): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[t_index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, step) + t = self.scheduler.get_noise_condition(step) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -316,14 +306,9 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index])) - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t])) + latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(step)) latents = (init_latents_proper * mask) + (latents * (1 - mask)) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 07e9c1d9acd6..6adc5e23e78d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -109,6 +109,7 @@ def __call__( latents = np.random.randn(*latents_shape).astype(np.float32) elif latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = self.scheduler.scale_initial_noise(latents) # set timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -126,13 +127,11 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + for step in self.progress_bar(self.scheduler.timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, step) + t = self.scheduler.get_noise_condition(step) # predict the noise residual noise_pred = self.unet( @@ -146,10 +145,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, step, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 495f30d9fabd..2828753e72d9 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -23,7 +23,7 @@ from .scheduling_pndm import PNDMScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler - from .scheduling_utils import SchedulerMixin + from .scheduling_utils import BaseScheduler, SchedulerMixin else: from ..utils.dummy_pt_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 0d9e285e054e..61c386ac7a06 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -25,7 +25,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler @dataclass @@ -75,7 +75,7 @@ def alpha_bar(time_step): return torch.tensor(betas) -class DDIMScheduler(SchedulerMixin, ConfigMixin): +class DDIMScheduler(BaseScheduler, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. @@ -155,7 +155,8 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1] + self.schedule = np.arange(0, num_train_timesteps) + self.timesteps = self.schedule[::-1] def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] @@ -167,6 +168,12 @@ def _get_variance(self, timestep, prev_timestep): return variance + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for a model. + """ + return self.schedule[step] + def set_timesteps(self, num_inference_steps: int, **kwargs): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -191,8 +198,10 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1] - self.timesteps += offset + self.schedule = (np.arange(0, num_inference_steps) * step_ratio).round().copy() + self.schedule += offset + + self.timesteps = np.arange(0, num_inference_steps)[::-1].copy() def step( self, @@ -240,6 +249,8 @@ def step( # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" + timestep = self.schedule[timestep] + # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps @@ -295,7 +306,7 @@ def add_noise( if timesteps.device != original_samples.device: timesteps = timesteps.to(original_samples.device) - + timesteps = self.schedule[timesteps] sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index cc17cee4c810..0e289e131dbb 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -24,7 +24,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler @dataclass @@ -74,7 +74,7 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -class DDPMScheduler(SchedulerMixin, ConfigMixin): +class DDPMScheduler(BaseScheduler, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. @@ -143,10 +143,17 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1] + self.schedule = np.arange(0, num_train_timesteps) + self.timesteps = self.schedule[::-1] self.variance_type = variance_type + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for a model. + """ + return self.schedule[step] + def set_timesteps(self, num_inference_steps: int): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -223,7 +230,7 @@ def step( returning a tuple, the first element is the sample tensor. """ - t = timestep + t = self.schedule[timestep] if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) @@ -281,8 +288,7 @@ def add_noise( self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) if timesteps.device != original_samples.device: - timesteps = timesteps.to(original_samples.device) - + timesteps = self.schedule[timesteps].to(original_samples.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index e6e5300e73e7..90246009fa5e 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler @dataclass @@ -46,7 +46,7 @@ class KarrasVeOutput(BaseOutput): pred_original_sample: Optional[torch.FloatTensor] = None -class KarrasVeScheduler(SchedulerMixin, ConfigMixin): +class KarrasVeScheduler(BaseScheduler, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. @@ -101,6 +101,12 @@ def __init__( self.timesteps: np.ndarray = None self.schedule: torch.FloatTensor = None # sigma(t_i) + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for the model. + """ + return self.schedule[step] + def set_timesteps(self, num_inference_steps: int): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 6d8db7682db5..a09fed168f2e 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler @dataclass @@ -44,7 +44,7 @@ class LMSDiscreteSchedulerOutput(BaseOutput): pred_original_sample: Optional[torch.FloatTensor] = None -class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): +class LMSDiscreteScheduler(BaseScheduler, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: @@ -99,13 +99,36 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1] # to be consistent has to be smaller than sigmas by 1 + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + schedule = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float) + self.schedule = torch.from_numpy(schedule) self.derivatives = [] + def scale_initial_noise(self, noise: torch.FloatTensor): + """ + Scales the initial noise to the correct range for the scheduler. + """ + return noise * self.sigmas[0] + + def scale_model_input(self, sample: torch.FloatTensor, step: int): + """ + Scales the model input (`sample`) to the correct range for the scheduler. + """ + sigma = self.sigmas[self.num_inference_steps - step - 1] + return sample / ((sigma**2 + 1) ** 0.5) + + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for a model. + """ + return self.schedule[step] + def get_lms_coefficient(self, order, t, current_order): """ Compute a linear multistep coefficient. @@ -137,17 +160,15 @@ def set_timesteps(self, num_inference_steps: int): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + self.timesteps = np.arange(0, num_inference_steps)[::-1].copy() - low_idx = np.floor(timesteps).astype(int) - high_idx = np.ceil(timesteps).astype(int) - frac = np.mod(timesteps, 1.0) + schedule = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + sigmas = np.interp(schedule[::-1], np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) + self.schedule = torch.from_numpy(schedule) - self.timesteps = timesteps.astype(int) self.derivatives = [] def step( @@ -176,6 +197,8 @@ def step( When returning a tuple, the first element is the sample tensor. """ + # FIXME: accounting for the descending sigmas + timestep = int(len(self.timesteps) - timestep - 1) sigma = self.sigmas[timestep] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise @@ -209,12 +232,14 @@ def add_noise( ) -> torch.FloatTensor: sigmas = self.sigmas.to(original_samples.device) timesteps = timesteps.to(original_samples.device) + # FIXME: accounting for the descending sigmas + timesteps = self.num_inference_steps - timesteps - 1 - sigma = sigmas[timesteps].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) + sigmas = sigmas[timesteps].flatten() + while len(sigmas.shape) < len(original_samples.shape): + sigmas = sigmas.unsqueeze(-1) - noisy_samples = original_samples + noise * sigma + noisy_samples = original_samples + noise * sigmas return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index d9e430c4a656..6c4e307f5f31 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -54,7 +54,7 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -class PNDMScheduler(SchedulerMixin, ConfigMixin): +class PNDMScheduler(BaseScheduler, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. @@ -145,7 +145,13 @@ def __init__( self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.prk_timesteps = None self.plms_timesteps = None - self.timesteps = None + self.schedule = None + + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for a model. + """ + return self.schedule[step] def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor: """ @@ -190,7 +196,9 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor ::-1 ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy - self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.schedule = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.schedule = self.schedule[::-1].copy() # FIXME: create the schedule in ascending order to avoid this + self.timesteps = np.arange(0, len(self.schedule))[::-1].copy() self.ets = [] self.counter = 0 @@ -221,6 +229,7 @@ def step( returning a tuple, the first element is the sample tensor. """ + timestep = self.schedule[timestep] if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) else: @@ -395,6 +404,7 @@ def add_noise( if timesteps.device != original_samples.device: timesteps = timesteps.to(original_samples.device) + timesteps = self.schedule[timesteps] sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index a549654c3b6f..6f2506195056 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import BaseScheduler, SchedulerOutput @dataclass @@ -43,7 +43,7 @@ class SdeVeOutput(BaseOutput): prev_sample_mean: torch.FloatTensor -class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): +class ScoreSdeVeScheduler(BaseScheduler, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. @@ -90,6 +90,12 @@ def __init__( self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for a model. + """ + return self.timesteps[step] + def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index daea743873f1..7cdeac19c50c 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -22,10 +22,10 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import BaseScheduler -class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): +class ScoreSdeVpScheduler(BaseScheduler, ConfigMixin): """ The variance preserving stochastic differential equation (SDE) scheduler. @@ -52,6 +52,12 @@ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling self.discrete_sigmas = None self.timesteps = None + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for a model. + """ + return self.discrete_sigmas[step] + def set_timesteps(self, num_inference_steps): self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 1cc1d94414a6..30ad19664e05 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import warnings from dataclasses import dataclass @@ -36,6 +37,38 @@ class SchedulerOutput(BaseOutput): prev_sample: torch.FloatTensor +class BaseScheduler(abc.ABC): + config_name = SCHEDULER_CONFIG_NAME + + def scale_initial_noise(self, noise: torch.FloatTensor): + """ + Scales the initial noise to the correct range for the scheduler. + """ + return noise + + def scale_model_input(self, sample: torch.FloatTensor, step: int): + """ + Scales the model input (`sample`) to the correct range for the scheduler. + """ + return sample + + @abc.abstractmethod + def get_noise_condition(self, step: int): + """ + Returns the input noise condition for the model (e.g. `timestep` or `sigma`). + """ + raise NotImplementedError("Scheduler must implement the `get_noise_condition` function.") + + def set_format(self, tensor_format="pt"): + warnings.warn( + "The method `set_format` is deprecated and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this function as the schedulers" + "are always in Pytorch", + DeprecationWarning, + ) + return self + + class SchedulerMixin: """ Mixin containing common functions for the schedulers. diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index cf3e607ea9d2..483d8bb6803f 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -201,7 +201,7 @@ def recursive_check(tuple_object, dict_object): ) kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) + num_inference_steps = kwargs.pop("num_inference_steps", 50) for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config() @@ -336,9 +336,10 @@ def full_loop(self, **config): scheduler.set_timesteps(num_inference_steps) - for t in scheduler.timesteps: + for step in scheduler.timesteps: + t = scheduler.get_noise_condition(step) residual = model(sample, t) - sample = scheduler.step(residual, t, sample, eta).prev_sample + sample = scheduler.step(residual, step, sample, eta).prev_sample return sample @@ -354,7 +355,7 @@ def test_steps_offset(self): scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(5) - assert np.equal(scheduler.timesteps, np.array([801, 601, 401, 201, 1])).all() + assert np.equal(scheduler.schedule[scheduler.timesteps], np.array([801, 601, 401, 201, 1])).all() def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): @@ -569,7 +570,7 @@ def test_steps_offset(self): scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(10) assert np.equal( - scheduler.timesteps, + scheduler.schedule[scheduler.timesteps], np.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), ).all() @@ -863,18 +864,18 @@ def test_full_loop_no_noise(self): scheduler.set_timesteps(self.num_inference_steps) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.sigmas[0] - - for i, t in enumerate(scheduler.timesteps): - sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5) + sample = scheduler.scale_initial_noise(self.dummy_sample_deter) + for step in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, step) + t = scheduler.get_noise_condition(step) model_output = model(sample, t) - output = scheduler.step(model_output, i, sample) + output = scheduler.step(model_output, step, sample) sample = output.prev_sample result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 1006.370) < 1e-2 + assert abs(result_sum.item() - 1006.388) < 1e-2 assert abs(result_mean.item() - 1.31) < 1e-3