Skip to content

Commit d7dcba4

Browse files
Unify offset configuration in DDIM and PNDM schedulers (#479)
* Unify offset configuration in DDIM and PNDM schedulers * Format Add missing variables * Fix pipeline test * Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Patrick von Platen <[email protected]> * Default set_alpha_to_one to false * Format * Add tests * Format * add deprecation warning Co-authored-by: Patrick von Platen <[email protected]>
1 parent 9e439d8 commit d7dcba4

File tree

7 files changed

+213
-77
lines changed

7 files changed

+213
-77
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
88

9+
from ...configuration_utils import FrozenDict
910
from ...models import AutoencoderKL, UNet2DConditionModel
1011
from ...pipeline_utils import DiffusionPipeline
1112
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
@@ -53,6 +54,21 @@ def __init__(
5354
):
5455
super().__init__()
5556
scheduler = scheduler.set_format("pt")
57+
58+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
59+
warnings.warn(
60+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
61+
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
62+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
63+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
64+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
65+
" file",
66+
DeprecationWarning,
67+
)
68+
new_config = dict(scheduler.config)
69+
new_config["steps_offset"] = 1
70+
scheduler._internal_dict = FrozenDict(new_config)
71+
5672
self.register_modules(
5773
vae=vae,
5874
text_encoder=text_encoder,
@@ -217,12 +233,7 @@ def __call__(
217233
latents = latents.to(self.device)
218234

219235
# set timesteps
220-
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
221-
extra_set_kwargs = {}
222-
if accepts_offset:
223-
extra_set_kwargs["offset"] = 1
224-
225-
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
236+
self.scheduler.set_timesteps(num_inference_steps)
226237

227238
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
228239
if isinstance(self.scheduler, LMSDiscreteScheduler):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import warnings
23
from typing import List, Optional, Union
34

45
import numpy as np
@@ -7,6 +8,7 @@
78
import PIL
89
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
910

11+
from ...configuration_utils import FrozenDict
1012
from ...models import AutoencoderKL, UNet2DConditionModel
1113
from ...pipeline_utils import DiffusionPipeline
1214
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
@@ -64,6 +66,21 @@ def __init__(
6466
):
6567
super().__init__()
6668
scheduler = scheduler.set_format("pt")
69+
70+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
71+
warnings.warn(
72+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
73+
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
74+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
75+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
76+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
77+
" file",
78+
DeprecationWarning,
79+
)
80+
new_config = dict(scheduler.config)
81+
new_config["steps_offset"] = 1
82+
scheduler._internal_dict = FrozenDict(new_config)
83+
6784
self.register_modules(
6885
vae=vae,
6986
text_encoder=text_encoder,
@@ -169,14 +186,7 @@ def __call__(
169186
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
170187

171188
# set timesteps
172-
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
173-
extra_set_kwargs = {}
174-
offset = 0
175-
if accepts_offset:
176-
offset = 1
177-
extra_set_kwargs["offset"] = 1
178-
179-
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
189+
self.scheduler.set_timesteps(num_inference_steps)
180190

181191
if isinstance(init_image, PIL.Image.Image):
182192
init_image = preprocess(init_image)
@@ -190,6 +200,7 @@ def __call__(
190200
init_latents = torch.cat([init_latents] * batch_size)
191201

192202
# get the original timestep using init_timestep
203+
offset = self.scheduler.config.get("steps_offset", 0)
193204
init_timestep = int(num_inference_steps * strength) + offset
194205
init_timestep = min(init_timestep, num_inference_steps)
195206
if isinstance(self.scheduler, LMSDiscreteScheduler):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import warnings
23
from typing import List, Optional, Union
34

45
import numpy as np
@@ -8,6 +9,7 @@
89
from tqdm.auto import tqdm
910
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
1011

12+
from ...configuration_utils import FrozenDict
1113
from ...models import AutoencoderKL, UNet2DConditionModel
1214
from ...pipeline_utils import DiffusionPipeline
1315
from ...schedulers import DDIMScheduler, PNDMScheduler
@@ -83,6 +85,21 @@ def __init__(
8385
super().__init__()
8486
scheduler = scheduler.set_format("pt")
8587
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
88+
89+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
90+
warnings.warn(
91+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
92+
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
93+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
94+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
95+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
96+
" file",
97+
DeprecationWarning,
98+
)
99+
new_config = dict(scheduler.config)
100+
new_config["steps_offset"] = 1
101+
scheduler._internal_dict = FrozenDict(new_config)
102+
86103
self.register_modules(
87104
vae=vae,
88105
text_encoder=text_encoder,
@@ -193,19 +210,12 @@ def __call__(
193210
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
194211

195212
# set timesteps
196-
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
197-
extra_set_kwargs = {}
198-
offset = 0
199-
if accepts_offset:
200-
offset = 1
201-
extra_set_kwargs["offset"] = 1
202-
203-
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
213+
self.scheduler.set_timesteps(num_inference_steps)
204214

205215
# preprocess image
206216
if not isinstance(init_image, torch.FloatTensor):
207217
init_image = preprocess_image(init_image)
208-
init_image.to(self.device)
218+
init_image = init_image.to(self.device)
209219

210220
# encode the init image into latents and scale the latents
211221
init_latent_dist = self.vae.encode(init_image).latent_dist
@@ -220,14 +230,15 @@ def __call__(
220230
# preprocess mask
221231
if not isinstance(mask_image, torch.FloatTensor):
222232
mask_image = preprocess_mask(mask_image)
223-
mask_image.to(self.device)
233+
mask_image = mask_image.to(self.device)
224234
mask = torch.cat([mask_image] * batch_size)
225235

226236
# check sizes
227237
if not mask.shape == init_latents.shape:
228238
raise ValueError("The mask and init_image should be the same size!")
229239

230240
# get the original timestep using init_timestep
241+
offset = self.scheduler.config.get("steps_offset", 0)
231242
init_timestep = int(num_inference_steps * strength) + offset
232243
init_timestep = min(init_timestep, num_inference_steps)
233244
timesteps = self.scheduler.timesteps[-init_timestep]

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,7 @@ def __call__(
100100
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
101101

102102
# set timesteps
103-
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
104-
extra_set_kwargs = {}
105-
if accepts_offset:
106-
extra_set_kwargs["offset"] = 1
107-
108-
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
103+
self.scheduler.set_timesteps(num_inference_steps)
109104

110105
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
111106
if isinstance(self.scheduler, LMSDiscreteScheduler):

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# and https://github.com/hojonathanho/diffusion
1717

1818
import math
19+
import warnings
1920
from typing import Optional, Tuple, Union
2021

2122
import numpy as np
@@ -78,7 +79,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
7879
clip_sample (`bool`, default `True`):
7980
option to clip predicted sample between -1 and 1 for numerical stability.
8081
set_alpha_to_one (`bool`, default `True`):
81-
if alpha for final step is 1 or the final alpha of the "non-previous" one.
82+
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
83+
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
84+
otherwise it uses the value of alpha at step 0.
85+
steps_offset (`int`, default `0`):
86+
an offset added to the inference steps. You can use a combination of `offset=1` and
87+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
88+
stable diffusion.
8289
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
8390
8491
"""
@@ -93,6 +100,7 @@ def __init__(
93100
trained_betas: Optional[np.ndarray] = None,
94101
clip_sample: bool = True,
95102
set_alpha_to_one: bool = True,
103+
steps_offset: int = 0,
96104
tensor_format: str = "pt",
97105
):
98106
if trained_betas is not None:
@@ -134,16 +142,26 @@ def _get_variance(self, timestep, prev_timestep):
134142

135143
return variance
136144

137-
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
145+
def set_timesteps(self, num_inference_steps: int, **kwargs):
138146
"""
139147
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
140148
141149
Args:
142150
num_inference_steps (`int`):
143151
the number of diffusion steps used when generating samples with a pre-trained model.
144-
offset (`int`):
145-
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
146152
"""
153+
154+
offset = self.config.steps_offset
155+
156+
if "offset" in kwargs:
157+
warnings.warn(
158+
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
159+
" Please pass `steps_offset` to `__init__` instead.",
160+
DeprecationWarning,
161+
)
162+
163+
offset = kwargs["offset"]
164+
147165
self.num_inference_steps = num_inference_steps
148166
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
149167
# creates integer timesteps by multiplying by ratio

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
1616

1717
import math
18+
import warnings
1819
from typing import Optional, Tuple, Union
1920

2021
import numpy as np
@@ -74,10 +75,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
7475
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
7576
trained_betas (`np.ndarray`, optional):
7677
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
77-
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
7878
skip_prk_steps (`bool`):
7979
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
8080
before plms steps; defaults to `False`.
81+
set_alpha_to_one (`bool`, default `False`):
82+
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
83+
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
84+
otherwise it uses the value of alpha at step 0.
85+
steps_offset (`int`, default `0`):
86+
an offset added to the inference steps. You can use a combination of `offset=1` and
87+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
88+
stable diffusion.
89+
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
8190
8291
"""
8392

@@ -89,8 +98,10 @@ def __init__(
8998
beta_end: float = 0.02,
9099
beta_schedule: str = "linear",
91100
trained_betas: Optional[np.ndarray] = None,
92-
tensor_format: str = "pt",
93101
skip_prk_steps: bool = False,
102+
set_alpha_to_one: bool = False,
103+
steps_offset: int = 0,
104+
tensor_format: str = "pt",
94105
):
95106
if trained_betas is not None:
96107
self.betas = np.asarray(trained_betas)
@@ -108,6 +119,8 @@ def __init__(
108119
self.alphas = 1.0 - self.betas
109120
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
110121

122+
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
123+
111124
# For now we only support F-PNDM, i.e. the runge-kutta method
112125
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
113126
# mainly at formula (9), (12), (13) and the Algorithm 2.
@@ -122,31 +135,38 @@ def __init__(
122135
# setable values
123136
self.num_inference_steps = None
124137
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
125-
self._offset = 0
126138
self.prk_timesteps = None
127139
self.plms_timesteps = None
128140
self.timesteps = None
129141

130142
self.tensor_format = tensor_format
131143
self.set_format(tensor_format=tensor_format)
132144

133-
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
145+
def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
134146
"""
135147
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
136148
137149
Args:
138150
num_inference_steps (`int`):
139151
the number of diffusion steps used when generating samples with a pre-trained model.
140-
offset (`int`):
141-
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
142152
"""
153+
154+
offset = self.config.steps_offset
155+
156+
if "offset" in kwargs:
157+
warnings.warn(
158+
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
159+
" Please pass `steps_offset` to `__init__` instead."
160+
)
161+
162+
offset = kwargs["offset"]
163+
143164
self.num_inference_steps = num_inference_steps
144165
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
145166
# creates integer timesteps by multiplying by ratio
146167
# casting to int to avoid issues when num_inference_step is power of 3
147-
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist()
148-
self._offset = offset
149-
self._timesteps = np.array([t + self._offset for t in self._timesteps])
168+
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
169+
self._timesteps += offset
150170

151171
if self.config.skip_prk_steps:
152172
# for some models like stable diffusion the prk steps can/should be skipped to
@@ -231,7 +251,7 @@ def step_prk(
231251
)
232252

233253
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
234-
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
254+
prev_timestep = timestep - diff_to_prev
235255
timestep = self.prk_timesteps[self.counter // 4 * 4]
236256

237257
if self.counter % 4 == 0:
@@ -293,7 +313,7 @@ def step_plms(
293313
"for more information."
294314
)
295315

296-
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
316+
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
297317

298318
if self.counter != 1:
299319
self.ets.append(model_output)
@@ -323,7 +343,7 @@ def step_plms(
323343

324344
return SchedulerOutput(prev_sample=prev_sample)
325345

326-
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
346+
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
327347
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
328348
# this function computes x_(t−δ) using the formula of (9)
329349
# Note that x_t needs to be added to both sides of the equation
@@ -336,8 +356,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
336356
# sample -> x_t
337357
# model_output -> e_θ(x_t, t)
338358
# prev_sample -> x_(t−δ)
339-
alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
340-
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
359+
alpha_prod_t = self.alphas_cumprod[timestep]
360+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
341361
beta_prod_t = 1 - alpha_prod_t
342362
beta_prod_t_prev = 1 - alpha_prod_t_prev
343363

0 commit comments

Comments
 (0)