Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import inspect
import math
import os
from pathlib import Path
Expand Down Expand Up @@ -190,10 +191,10 @@ def parse_args():
)

parser.add_argument(
"--predict_mode",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to align naming all over the codebase

type=str,
default="eps",
help="What the model should predict. 'eps' to predict error, 'x0' to directly predict reconstruction",
"--predict_epsilon",
action="store_true",
default=True,
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
)

parser.add_argument("--ddpm_num_steps", type=int, default=1000)
Expand Down Expand Up @@ -252,7 +253,17 @@ def main(args):
"UpBlock2D",
),
)
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())

if accepts_predict_epsilon:
noise_scheduler = DDPMScheduler(
num_train_timesteps=args.ddpm_num_steps,
beta_schedule=args.ddpm_beta_schedule,
predict_epsilon=args.predict_epsilon,
)
else:
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)

optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
Expand Down Expand Up @@ -351,9 +362,9 @@ def transforms(examples):
# Predict the noise residual
model_output = model(noisy_images, timesteps).sample

if args.predict_mode == "eps":
if args.predict_epsilon:
loss = F.mse_loss(model_output, noise) # this could have different weights!
elif args.predict_mode == "x0":
else:
alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
)
Expand Down Expand Up @@ -401,7 +412,6 @@ def transforms(examples):
generator=generator,
batch_size=args.eval_batch_size,
output_type="numpy",
predict_epsilon=args.predict_mode == "eps",
).images

# denormalize the images and save to tensorboard
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,11 @@ def extract_init_dict(cls, config_dict, **kwargs):
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
init_dict = {}
for key in expected_keys:
# if config param is passed to kwarg and is present in config dict
# it should overwrite existing config dict key
if key in kwargs and key in config_dict:
config_dict[key] = kwargs.pop(key)

if key in kwargs:
# overwrite key
init_dict[key] = kwargs.pop(key)
Expand Down
13 changes: 12 additions & 1 deletion src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import torch

from ...configuration_utils import FrozenDict
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...utils import deprecate


class DDPMPipeline(DiffusionPipeline):
Expand All @@ -45,7 +47,6 @@ def __call__(
num_inference_steps: int = 1000,
output_type: Optional[str] = "pil",
return_dict: bool = True,
predict_epsilon: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Expand All @@ -69,6 +70,16 @@ def __call__(
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)

if predict_epsilon is not None:
new_config = dict(self.scheduler.config)
new_config["predict_epsilon"] = predict_epsilon
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca note that now you should change this into:

new_config["prediction_type"] = "predict_epsilon"

Here and everywhere else

self.scheduler._internal_dict = FrozenDict(new_config)

# Sample gaussian noise to begin loop
image = torch.randn(
Expand Down
23 changes: 17 additions & 6 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin


Expand Down Expand Up @@ -99,6 +99,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.

"""

Expand All @@ -121,6 +123,7 @@ def __init__(
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
predict_epsilon: bool = True,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
Expand Down Expand Up @@ -221,9 +224,9 @@ def step(
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
predict_epsilon=True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predict_epsilon is an inherent config parameter just like beta_start that should not change during inference

generator=None,
return_dict: bool = True,
**kwargs,
) -> Union[DDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
Expand All @@ -234,8 +237,6 @@ def step(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class

Expand All @@ -245,6 +246,16 @@ def step(
returning a tuple, the first element is the sample tensor.

"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
new_config = dict(self.config)
new_config["predict_epsilon"] = predict_epsilon
self._internal_dict = FrozenDict(new_config)

t = timestep

if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
Expand All @@ -260,7 +271,7 @@ def step(

# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon:
if self.config.predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
pred_original_sample = model_output
Expand Down
22 changes: 17 additions & 5 deletions src/diffusers/schedulers/scheduling_ddpm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import jax.numpy as jnp
from jax import random

from ..configuration_utils import ConfigMixin, register_to_config
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left


Expand Down Expand Up @@ -97,7 +98,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.

"""

Expand All @@ -115,6 +117,7 @@ def __init__(
trained_betas: Optional[jnp.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
predict_epsilon: bool = True,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
Expand Down Expand Up @@ -196,6 +199,7 @@ def step(
key: random.KeyArray,
predict_epsilon: bool = True,
return_dict: bool = True,
**kwargs,
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
Expand All @@ -208,15 +212,23 @@ def step(
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class

Returns:
[`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.

"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
new_config = dict(self.config)
new_config["predict_epsilon"] = predict_epsilon
self._internal_dict = FrozenDict(new_config)

t = timestep

if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
Expand All @@ -232,7 +244,7 @@ def step(

# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon:
if self.config.predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
pred_original_sample = model_output
Expand Down
3 changes: 1 addition & 2 deletions tests/fixtures/custom_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __call__(
self,
batch_size: int = 1,
generator: Optional[torch.Generator] = None,
eta: float = 0.0,
num_inference_steps: int = 50,
output_type: Optional[str] = "pil",
return_dict: bool = True,
Expand Down Expand Up @@ -89,7 +88,7 @@ def __call__(
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta).prev_sample
image = self.scheduler.step(model_output, t, image).prev_sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
Expand Down
71 changes: 69 additions & 2 deletions tests/pipelines/ddpm/test_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch

from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import require_torch, slow, torch_device

from ...test_pipelines_common import PipelineTesterMixin
Expand All @@ -28,8 +29,74 @@


class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# FIXME: add fast tests
pass
@property
def dummy_uncond_unet(self):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
return model

def test_inference(self):
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler()

ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)

# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)

generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images

generator = torch.manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]

image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]

assert image.shape == (1, 32, 32, 3)
expected_slice = np.array(
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance

def test_inference_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(predict_epsilon=False)

ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)

# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)

generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images

generator = torch.manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]

image_slice = image[0, -3:, -3:, -1]
image_eps_slice = image_eps[0, -3:, -3:, -1]

assert image.shape == (1, 32, 32, 3)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance


@slow
Expand Down
24 changes: 24 additions & 0 deletions tests/test_config.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import diffusers
from diffusers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
Expand Down Expand Up @@ -291,6 +292,29 @@ def test_load_pndm(self):
# no warning should be thrown
assert cap_logger.out == ""

def test_overwrite_config_on_load(self):
logger = logging.get_logger("diffusers.configuration_utils")

with CaptureLogger(logger) as cap_logger:
ddpm = DDPMScheduler.from_config(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
beta_end=8,
)

with CaptureLogger(logger) as cap_logger_2:
ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88)

assert ddpm.__class__ == DDPMScheduler
assert ddpm.config.predict_epsilon is False
assert ddpm.config.beta_end == 8
assert ddpm_2.config.beta_start == 88

# no warning should be thrown
assert cap_logger.out == ""
assert cap_logger_2.out == ""

def test_load_dpmsolver(self):
logger = logging.get_logger("diffusers.configuration_utils")

Expand Down
1 change: 1 addition & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def test_run_custom_pipeline(self):
images, output_str = pipeline(num_inference_steps=2, output_type="np")

assert images[0].shape == (1, 32, 32, 3)

# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
assert output_str == "This is a test"

Expand Down
Loading