From e37c3a8fb49d84249d5362d90b2af3adc45e17e3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 6 Nov 2022 13:51:14 +0100 Subject: [PATCH 1/6] [Scheduler] Move predict epsilon to init --- .../train_unconditional.py | 26 ++++--- src/diffusers/configuration_utils.py | 5 ++ src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 13 +++- src/diffusers/schedulers/scheduling_ddpm.py | 23 ++++-- .../schedulers/scheduling_ddpm_flax.py | 22 ++++-- tests/pipelines/ddpm/test_ddpm.py | 71 ++++++++++++++++++- tests/test_config.py | 35 ++++++++- tests/test_scheduler.py | 30 +++++++- 8 files changed, 201 insertions(+), 24 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 3f9ffb11ef45..0eadecbd3095 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -1,4 +1,5 @@ import argparse +import inspect import math import os from pathlib import Path @@ -190,10 +191,10 @@ def parse_args(): ) parser.add_argument( - "--predict_mode", - type=str, - default="eps", - help="What the model should predict. 'eps' to predict error, 'x0' to directly predict reconstruction", + "--predict_epsilon", + action="store_true", + default=True, + help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.", ) parser.add_argument("--ddpm_num_steps", type=int, default=1000) @@ -252,7 +253,17 @@ def main(args): "UpBlock2D", ), ) - noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) + accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) + + if accepts_predict_epsilon: + noise_scheduler = DDPMScheduler( + num_train_timesteps=args.ddpm_num_steps, + beta_schedule=args.ddpm_beta_schedule, + predict_epsilon=args.predict_epsilon, + ) + else: + noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) + optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, @@ -351,9 +362,9 @@ def transforms(examples): # Predict the noise residual model_output = model(noisy_images, timesteps).sample - if args.predict_mode == "eps": + if args.predict_epsilon: loss = F.mse_loss(model_output, noise) # this could have different weights! - elif args.predict_mode == "x0": + else: alpha_t = _extract_into_tensor( noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1) ) @@ -401,7 +412,6 @@ def transforms(examples): generator=generator, batch_size=args.eval_batch_size, output_type="numpy", - predict_epsilon=args.predict_mode == "eps", ).images # denormalize the images and save to tensorboard diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index d830857a302c..eb959edb09fa 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -334,6 +334,11 @@ def extract_init_dict(cls, config_dict, **kwargs): # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments init_dict = {} for key in expected_keys: + # if config param is passed to kwarg and is present in config dict + # it should overwrite existing config dict key + if key in kwargs and key in config_dict: + config_dict[key] = kwargs.pop(key) + if key in kwargs: # overwrite key init_dict[key] = kwargs.pop(key) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 811614ecbdde..37d12f2f5df6 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -18,7 +18,9 @@ import torch +from ...configuration_utils import FrozenDict from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...utils import deprecate class DDPMPipeline(DiffusionPipeline): @@ -45,7 +47,6 @@ def __call__( num_inference_steps: int = 1000, output_type: Optional[str] = "pil", return_dict: bool = True, - predict_epsilon: bool = True, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: r""" @@ -69,6 +70,16 @@ def __call__( `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ + message = ( + "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" + " DDPMScheduler.from_config(, predict_epsilon=True)`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + + if predict_epsilon is not None: + new_config = dict(self.scheduler.config) + new_config["predict_epsilon"] = predict_epsilon + self.scheduler._internal_dict = FrozenDict(new_config) # Sample gaussian noise to begin loop image = torch.randn( diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 114a86b4320e..2ae9add9f7d9 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -21,8 +21,8 @@ import numpy as np import torch -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config +from ..utils import BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -99,6 +99,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. + predict_epsilon (`bool`): + optional flag to use when model predicts the samples directly instead of the noise, epsilon. """ @@ -120,6 +122,7 @@ def __init__( trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, + predict_epsilon: bool = True, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -220,9 +223,9 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, - predict_epsilon=True, generator=None, return_dict: bool = True, + **kwargs, ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion @@ -233,8 +236,6 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. - predict_epsilon (`bool`): - optional flag to use when model predicts the samples directly instead of the noise, epsilon. generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class @@ -244,6 +245,16 @@ def step( returning a tuple, the first element is the sample tensor. """ + message = ( + "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" + " DDPMScheduler.from_config(, predict_epsilon=True)`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: + new_config = dict(self.config) + new_config["predict_epsilon"] = predict_epsilon + self._internal_dict = FrozenDict(new_config) + t = timestep if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: @@ -259,7 +270,7 @@ def step( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if predict_epsilon: + if self.config.predict_epsilon: pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) else: pred_original_sample = model_output diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 7220a0145471..720f4ca19462 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -22,7 +22,8 @@ import jax.numpy as jnp from jax import random -from ..configuration_utils import ConfigMixin, register_to_config +from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config +from ..utils import deprecate from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left @@ -97,7 +98,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + predict_epsilon (`bool`): + optional flag to use when model predicts the samples directly instead of the noise, epsilon. """ @@ -115,6 +117,7 @@ def __init__( trained_betas: Optional[jnp.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, + predict_epsilon: bool = True, ): if trained_betas is not None: self.betas = jnp.asarray(trained_betas) @@ -196,6 +199,7 @@ def step( key: random.KeyArray, predict_epsilon: bool = True, return_dict: bool = True, + **kwargs, ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion @@ -208,8 +212,6 @@ def step( sample (`jnp.ndarray`): current instance of sample being created by diffusion process. key (`random.KeyArray`): a PRNG key. - predict_epsilon (`bool`): - optional flag to use when model predicts the samples directly instead of the noise, epsilon. return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class Returns: @@ -217,6 +219,16 @@ def step( `tuple`. When returning a tuple, the first element is the sample tensor. """ + message = ( + "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" + " DDPMScheduler.from_config(, predict_epsilon=True)`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: + new_config = dict(self.config) + new_config["predict_epsilon"] = predict_epsilon + self._internal_dict = FrozenDict(new_config) + t = timestep if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: @@ -232,7 +244,7 @@ def step( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if predict_epsilon: + if self.config.predict_epsilon: pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) else: pred_original_sample = model_output diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index c58e2db38f21..a09f77d12467 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -19,6 +19,7 @@ import torch from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel +from diffusers.utils import deprecate from diffusers.utils.testing_utils import require_torch, slow, torch_device from ...test_pipelines_common import PipelineTesterMixin @@ -28,8 +29,74 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): - # FIXME: add fast tests - pass + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + def test_inference(self): + unet = self.dummy_uncond_unet + scheduler = DDPMScheduler() + + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=1) + + generator = torch.manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + + generator = torch.manual_seed(0) + image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array( + [5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02] + ) + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance + + def test_inference_predict_epsilon(self): + deprecate("remove this test", "0.10.0", "remove") + unet = self.dummy_uncond_unet + scheduler = DDPMScheduler(predict_epsilon=False) + + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=1) + + generator = torch.manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + + generator = torch.manual_seed(0) + image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] + + image_slice = image[0, -3:, -3:, -1] + image_eps_slice = image_eps[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance @slow diff --git a/tests/test_config.py b/tests/test_config.py index 7a9f270af364..a0f4bbfd5acf 100755 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -19,7 +19,14 @@ import unittest import diffusers -from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, PNDMScheduler, logging +from diffusers import ( + DDIMScheduler, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + PNDMScheduler, + logging, +) from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils.testing_utils import CaptureLogger @@ -283,3 +290,29 @@ def test_load_pndm(self): assert pndm.__class__ == PNDMScheduler # no warning should be thrown assert cap_logger.out == "" + + def test_overwrite_config_on_load(self): + logger = logging.get_logger("diffusers.configuration_utils") + + with CaptureLogger(logger) as cap_logger: + ddpm = DDPMScheduler.from_config( + "hf-internal-testing/tiny-stable-diffusion-torch", + subfolder="scheduler", + predict_epsilon=False, + beta_end=8, + ) + + with CaptureLogger(logger) as cap_logger_2: + ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88) + + import ipdb + + ipdb.set_trace() + assert ddpm.__class__ == DDPMScheduler + assert ddpm.config.predict_epsilon is False + assert ddpm.config.beta_end == 8 + assert ddpm_2.config.beta_start == 88 + + # no warning should be thrown + assert cap_logger.out == "" + assert cap_logger_2.out == "" diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 29186aaac99b..ce1e84ad56e5 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -32,7 +32,7 @@ ScoreSdeVeScheduler, VQDiffusionScheduler, ) -from diffusers.utils import torch_device +from diffusers.utils import deprecate, torch_device torch.backends.cuda.matmul.allow_tf32 = False @@ -392,6 +392,34 @@ def test_clip_sample(self): for clip_sample in [True, False]: self.check_over_configs(clip_sample=clip_sample) + def test_predict_epsilon(self): + for predict_epsilon in [True, False]: + self.check_over_configs(predict_epsilon=predict_epsilon) + + def test_deprecated_epsilon(self): + deprecate("remove this test", "0.10.0", "remove") + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + + sample = self.dummy_sample_deter + residual = 0.1 * self.dummy_sample_deter + time_step = 4 + + scheduler = scheduler_class(**scheduler_config) + scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config) + + kwargs = {} + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) + output = scheduler.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample + + kwargs = {} + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) + output_eps = scheduler_eps.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample + + assert (output - output_eps).abs().sum() < 1e-5 + def test_time_indices(self): for t in [0, 500, 999]: self.check_over_forward(time_step=t) From ffdbd340a25e047a2775304479fe8ee3a11a540e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 6 Nov 2022 13:54:55 +0100 Subject: [PATCH 2/6] up --- tests/test_config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index a0f4bbfd5acf..9e976eecb3c7 100755 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -305,9 +305,6 @@ def test_overwrite_config_on_load(self): with CaptureLogger(logger) as cap_logger_2: ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88) - import ipdb - - ipdb.set_trace() assert ddpm.__class__ == DDPMScheduler assert ddpm.config.predict_epsilon is False assert ddpm.config.beta_end == 8 From 2dc4837753a1221cf25b9f67013182cca570fd35 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 6 Nov 2022 14:09:37 +0100 Subject: [PATCH 3/6] uP --- tests/fixtures/custom_pipeline/pipeline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/fixtures/custom_pipeline/pipeline.py b/tests/fixtures/custom_pipeline/pipeline.py index 10a22edaa490..e7429d0a1945 100644 --- a/tests/fixtures/custom_pipeline/pipeline.py +++ b/tests/fixtures/custom_pipeline/pipeline.py @@ -42,7 +42,6 @@ def __call__( self, batch_size: int = 1, generator: Optional[torch.Generator] = None, - eta: float = 0.0, num_inference_steps: int = 50, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -89,7 +88,7 @@ def __call__( # 2. predict previous mean of image x_t-1 and add variance depending on eta # eta corresponds to η in paper and should be between [0, 1] # do x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, eta).prev_sample + image = self.scheduler.step(model_output, t, image).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() From c8acb3125344e5f2fee48e404da567a4b9f4e069 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 6 Nov 2022 14:20:41 +0100 Subject: [PATCH 4/6] uP --- tests/test_pipelines.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index b8316075fa93..fd5a9b92ebbc 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -107,6 +107,7 @@ def test_run_custom_pipeline(self): images, output_str = pipeline(num_inference_steps=2, output_type="np") assert images[0].shape == (1, 32, 32, 3) + # compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102 assert output_str == "This is a test" From 98e5f3cb17cbb04d025050d56bcdc43bea0d4294 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 8 Nov 2022 17:18:38 +0100 Subject: [PATCH 5/6] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- src/diffusers/schedulers/scheduling_ddpm.py | 2 +- src/diffusers/schedulers/scheduling_ddpm_flax.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 2ae9add9f7d9..1250bb5cb8c3 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -100,7 +100,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. predict_epsilon (`bool`): - optional flag to use when model predicts the samples directly instead of the noise, epsilon. + optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise. """ diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 720f4ca19462..f1b04a04176e 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -99,7 +99,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. predict_epsilon (`bool`): - optional flag to use when model predicts the samples directly instead of the noise, epsilon. + optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise. """ From 319b042cb607b7b70b5cecfcd5b5bdfe2d3967a2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 8 Nov 2022 17:24:15 +0100 Subject: [PATCH 6/6] up --- tests/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_config.py b/tests/test_config.py index 674a6d87ece8..8ae8e1d9e173 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -325,4 +325,4 @@ def test_load_dpmsolver(self): assert dpm.__class__ == DPMSolverMultistepScheduler # no warning should be thrown - assert cap_logger.out == "" \ No newline at end of file + assert cap_logger.out == ""