Skip to content

Commit 249d9bc

Browse files
[Scheduler] Move predict epsilon to init (#1155)
* [Scheduler] Move predict epsilon to init * up * uP * uP * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * up Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 5786b0e commit 249d9bc

File tree

10 files changed

+193
-25
lines changed

10 files changed

+193
-25
lines changed

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import inspect
23
import math
34
import os
45
from pathlib import Path
@@ -190,10 +191,10 @@ def parse_args():
190191
)
191192

192193
parser.add_argument(
193-
"--predict_mode",
194-
type=str,
195-
default="eps",
196-
help="What the model should predict. 'eps' to predict error, 'x0' to directly predict reconstruction",
194+
"--predict_epsilon",
195+
action="store_true",
196+
default=True,
197+
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
197198
)
198199

199200
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
@@ -252,7 +253,17 @@ def main(args):
252253
"UpBlock2D",
253254
),
254255
)
255-
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
256+
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
257+
258+
if accepts_predict_epsilon:
259+
noise_scheduler = DDPMScheduler(
260+
num_train_timesteps=args.ddpm_num_steps,
261+
beta_schedule=args.ddpm_beta_schedule,
262+
predict_epsilon=args.predict_epsilon,
263+
)
264+
else:
265+
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
266+
256267
optimizer = torch.optim.AdamW(
257268
model.parameters(),
258269
lr=args.learning_rate,
@@ -351,9 +362,9 @@ def transforms(examples):
351362
# Predict the noise residual
352363
model_output = model(noisy_images, timesteps).sample
353364

354-
if args.predict_mode == "eps":
365+
if args.predict_epsilon:
355366
loss = F.mse_loss(model_output, noise) # this could have different weights!
356-
elif args.predict_mode == "x0":
367+
else:
357368
alpha_t = _extract_into_tensor(
358369
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
359370
)
@@ -401,7 +412,6 @@ def transforms(examples):
401412
generator=generator,
402413
batch_size=args.eval_batch_size,
403414
output_type="numpy",
404-
predict_epsilon=args.predict_mode == "eps",
405415
).images
406416

407417
# denormalize the images and save to tensorboard

src/diffusers/configuration_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,11 @@ def extract_init_dict(cls, config_dict, **kwargs):
334334
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
335335
init_dict = {}
336336
for key in expected_keys:
337+
# if config param is passed to kwarg and is present in config dict
338+
# it should overwrite existing config dict key
339+
if key in kwargs and key in config_dict:
340+
config_dict[key] = kwargs.pop(key)
341+
337342
if key in kwargs:
338343
# overwrite key
339344
init_dict[key] = kwargs.pop(key)

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
import torch
2020

21+
from ...configuration_utils import FrozenDict
2122
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
23+
from ...utils import deprecate
2224

2325

2426
class DDPMPipeline(DiffusionPipeline):
@@ -45,7 +47,6 @@ def __call__(
4547
num_inference_steps: int = 1000,
4648
output_type: Optional[str] = "pil",
4749
return_dict: bool = True,
48-
predict_epsilon: bool = True,
4950
**kwargs,
5051
) -> Union[ImagePipelineOutput, Tuple]:
5152
r"""
@@ -69,6 +70,16 @@ def __call__(
6970
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
7071
generated images.
7172
"""
73+
message = (
74+
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
75+
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
76+
)
77+
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
78+
79+
if predict_epsilon is not None:
80+
new_config = dict(self.scheduler.config)
81+
new_config["predict_epsilon"] = predict_epsilon
82+
self.scheduler._internal_dict = FrozenDict(new_config)
7283

7384
# Sample gaussian noise to begin loop
7485
image = torch.randn(

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import numpy as np
2222
import torch
2323

24-
from ..configuration_utils import ConfigMixin, register_to_config
25-
from ..utils import BaseOutput
24+
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
25+
from ..utils import BaseOutput, deprecate
2626
from .scheduling_utils import SchedulerMixin
2727

2828

@@ -99,6 +99,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
9999
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
100100
clip_sample (`bool`, default `True`):
101101
option to clip predicted sample between -1 and 1 for numerical stability.
102+
predict_epsilon (`bool`):
103+
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
102104
103105
"""
104106

@@ -121,6 +123,7 @@ def __init__(
121123
trained_betas: Optional[np.ndarray] = None,
122124
variance_type: str = "fixed_small",
123125
clip_sample: bool = True,
126+
predict_epsilon: bool = True,
124127
):
125128
if trained_betas is not None:
126129
self.betas = torch.from_numpy(trained_betas)
@@ -221,9 +224,9 @@ def step(
221224
model_output: torch.FloatTensor,
222225
timestep: int,
223226
sample: torch.FloatTensor,
224-
predict_epsilon=True,
225227
generator=None,
226228
return_dict: bool = True,
229+
**kwargs,
227230
) -> Union[DDPMSchedulerOutput, Tuple]:
228231
"""
229232
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
@@ -234,8 +237,6 @@ def step(
234237
timestep (`int`): current discrete timestep in the diffusion chain.
235238
sample (`torch.FloatTensor`):
236239
current instance of sample being created by diffusion process.
237-
predict_epsilon (`bool`):
238-
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
239240
generator: random number generator.
240241
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
241242
@@ -245,6 +246,16 @@ def step(
245246
returning a tuple, the first element is the sample tensor.
246247
247248
"""
249+
message = (
250+
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
251+
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
252+
)
253+
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
254+
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
255+
new_config = dict(self.config)
256+
new_config["predict_epsilon"] = predict_epsilon
257+
self._internal_dict = FrozenDict(new_config)
258+
248259
t = timestep
249260

250261
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
@@ -260,7 +271,7 @@ def step(
260271

261272
# 2. compute predicted original sample from predicted noise also called
262273
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
263-
if predict_epsilon:
274+
if self.config.predict_epsilon:
264275
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
265276
else:
266277
pred_original_sample = model_output

src/diffusers/schedulers/scheduling_ddpm_flax.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import jax.numpy as jnp
2323
from jax import random
2424

25-
from ..configuration_utils import ConfigMixin, register_to_config
25+
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
26+
from ..utils import deprecate
2627
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
2728

2829

@@ -97,7 +98,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
9798
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
9899
clip_sample (`bool`, default `True`):
99100
option to clip predicted sample between -1 and 1 for numerical stability.
100-
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
101+
predict_epsilon (`bool`):
102+
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
101103
102104
"""
103105

@@ -115,6 +117,7 @@ def __init__(
115117
trained_betas: Optional[jnp.ndarray] = None,
116118
variance_type: str = "fixed_small",
117119
clip_sample: bool = True,
120+
predict_epsilon: bool = True,
118121
):
119122
if trained_betas is not None:
120123
self.betas = jnp.asarray(trained_betas)
@@ -196,6 +199,7 @@ def step(
196199
key: random.KeyArray,
197200
predict_epsilon: bool = True,
198201
return_dict: bool = True,
202+
**kwargs,
199203
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
200204
"""
201205
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
@@ -208,15 +212,23 @@ def step(
208212
sample (`jnp.ndarray`):
209213
current instance of sample being created by diffusion process.
210214
key (`random.KeyArray`): a PRNG key.
211-
predict_epsilon (`bool`):
212-
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
213215
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
214216
215217
Returns:
216218
[`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a
217219
`tuple`. When returning a tuple, the first element is the sample tensor.
218220
219221
"""
222+
message = (
223+
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
224+
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
225+
)
226+
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
227+
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
228+
new_config = dict(self.config)
229+
new_config["predict_epsilon"] = predict_epsilon
230+
self._internal_dict = FrozenDict(new_config)
231+
220232
t = timestep
221233

222234
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
@@ -232,7 +244,7 @@ def step(
232244

233245
# 2. compute predicted original sample from predicted noise also called
234246
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
235-
if predict_epsilon:
247+
if self.config.predict_epsilon:
236248
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
237249
else:
238250
pred_original_sample = model_output

tests/fixtures/custom_pipeline/pipeline.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def __call__(
4242
self,
4343
batch_size: int = 1,
4444
generator: Optional[torch.Generator] = None,
45-
eta: float = 0.0,
4645
num_inference_steps: int = 50,
4746
output_type: Optional[str] = "pil",
4847
return_dict: bool = True,
@@ -89,7 +88,7 @@ def __call__(
8988
# 2. predict previous mean of image x_t-1 and add variance depending on eta
9089
# eta corresponds to η in paper and should be between [0, 1]
9190
# do x_t -> x_t-1
92-
image = self.scheduler.step(model_output, t, image, eta).prev_sample
91+
image = self.scheduler.step(model_output, t, image).prev_sample
9392

9493
image = (image / 2 + 0.5).clamp(0, 1)
9594
image = image.cpu().permute(0, 2, 3, 1).numpy()

tests/pipelines/ddpm/test_ddpm.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020

2121
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
22+
from diffusers.utils import deprecate
2223
from diffusers.utils.testing_utils import require_torch, slow, torch_device
2324

2425
from ...test_pipelines_common import PipelineTesterMixin
@@ -28,8 +29,74 @@
2829

2930

3031
class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
31-
# FIXME: add fast tests
32-
pass
32+
@property
33+
def dummy_uncond_unet(self):
34+
torch.manual_seed(0)
35+
model = UNet2DModel(
36+
block_out_channels=(32, 64),
37+
layers_per_block=2,
38+
sample_size=32,
39+
in_channels=3,
40+
out_channels=3,
41+
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
42+
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
43+
)
44+
return model
45+
46+
def test_inference(self):
47+
unet = self.dummy_uncond_unet
48+
scheduler = DDPMScheduler()
49+
50+
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
51+
ddpm.to(torch_device)
52+
ddpm.set_progress_bar_config(disable=None)
53+
54+
# Warmup pass when using mps (see #372)
55+
if torch_device == "mps":
56+
_ = ddpm(num_inference_steps=1)
57+
58+
generator = torch.manual_seed(0)
59+
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
60+
61+
generator = torch.manual_seed(0)
62+
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
63+
64+
image_slice = image[0, -3:, -3:, -1]
65+
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
66+
67+
assert image.shape == (1, 32, 32, 3)
68+
expected_slice = np.array(
69+
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
70+
)
71+
tolerance = 1e-2 if torch_device != "mps" else 3e-2
72+
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
73+
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
74+
75+
def test_inference_predict_epsilon(self):
76+
deprecate("remove this test", "0.10.0", "remove")
77+
unet = self.dummy_uncond_unet
78+
scheduler = DDPMScheduler(predict_epsilon=False)
79+
80+
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
81+
ddpm.to(torch_device)
82+
ddpm.set_progress_bar_config(disable=None)
83+
84+
# Warmup pass when using mps (see #372)
85+
if torch_device == "mps":
86+
_ = ddpm(num_inference_steps=1)
87+
88+
generator = torch.manual_seed(0)
89+
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
90+
91+
generator = torch.manual_seed(0)
92+
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
93+
94+
image_slice = image[0, -3:, -3:, -1]
95+
image_eps_slice = image_eps[0, -3:, -3:, -1]
96+
97+
assert image.shape == (1, 32, 32, 3)
98+
tolerance = 1e-2 if torch_device != "mps" else 3e-2
99+
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
33100

34101

35102
@slow

tests/test_config.py

100755100644
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import diffusers
2222
from diffusers import (
2323
DDIMScheduler,
24+
DDPMScheduler,
2425
DPMSolverMultistepScheduler,
2526
EulerAncestralDiscreteScheduler,
2627
EulerDiscreteScheduler,
@@ -291,6 +292,29 @@ def test_load_pndm(self):
291292
# no warning should be thrown
292293
assert cap_logger.out == ""
293294

295+
def test_overwrite_config_on_load(self):
296+
logger = logging.get_logger("diffusers.configuration_utils")
297+
298+
with CaptureLogger(logger) as cap_logger:
299+
ddpm = DDPMScheduler.from_config(
300+
"hf-internal-testing/tiny-stable-diffusion-torch",
301+
subfolder="scheduler",
302+
predict_epsilon=False,
303+
beta_end=8,
304+
)
305+
306+
with CaptureLogger(logger) as cap_logger_2:
307+
ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88)
308+
309+
assert ddpm.__class__ == DDPMScheduler
310+
assert ddpm.config.predict_epsilon is False
311+
assert ddpm.config.beta_end == 8
312+
assert ddpm_2.config.beta_start == 88
313+
314+
# no warning should be thrown
315+
assert cap_logger.out == ""
316+
assert cap_logger_2.out == ""
317+
294318
def test_load_dpmsolver(self):
295319
logger = logging.get_logger("diffusers.configuration_utils")
296320

tests/test_pipelines.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def test_run_custom_pipeline(self):
107107
images, output_str = pipeline(num_inference_steps=2, output_type="np")
108108

109109
assert images[0].shape == (1, 32, 32, 3)
110+
110111
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
111112
assert output_str == "This is a test"
112113

0 commit comments

Comments
 (0)