Skip to content

Commit b1182bc

Browse files
authored
[Flax] fix Flax scheduler (#564)
* remove match_shape * ported fixes from #479 to flax * remove unused argument * typo * remove warnings
1 parent 0424615 commit b1182bc

File tree

5 files changed

+74
-37
lines changed

5 files changed

+74
-37
lines changed

src/diffusers/schedulers/scheduling_ddim_flax.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,13 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
9696
clip_sample (`bool`, default `True`):
9797
option to clip predicted sample between -1 and 1 for numerical stability.
9898
set_alpha_to_one (`bool`, default `True`):
99-
if alpha for final step is 1 or the final alpha of the "non-previous" one.
99+
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
100+
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
101+
otherwise it uses the value of alpha at step 0.
102+
steps_offset (`int`, default `0`):
103+
an offset added to the inference steps. You can use a combination of `offset=1` and
104+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
105+
stable diffusion.
100106
"""
101107

102108
@register_to_config
@@ -109,6 +115,7 @@ def __init__(
109115
trained_betas: Optional[jnp.ndarray] = None,
110116
clip_sample: bool = True,
111117
set_alpha_to_one: bool = True,
118+
steps_offset: int = 0,
112119
):
113120
if trained_betas is not None:
114121
self.betas = jnp.asarray(trained_betas)
@@ -144,9 +151,7 @@ def _get_variance(self, timestep, prev_timestep):
144151

145152
return variance
146153

147-
def set_timesteps(
148-
self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0
149-
) -> DDIMSchedulerState:
154+
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState:
150155
"""
151156
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
152157
@@ -155,9 +160,9 @@ def set_timesteps(
155160
the `FlaxDDIMScheduler` state data class instance.
156161
num_inference_steps (`int`):
157162
the number of diffusion steps used when generating samples with a pre-trained model.
158-
offset (`int`):
159-
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
160163
"""
164+
offset = self.config.steps_offset
165+
161166
step_ratio = self.config.num_train_timesteps // num_inference_steps
162167
# creates integer timesteps by multiplying by ratio
163168
# casting to int to avoid issues when num_inference_step is power of 3
@@ -263,9 +268,14 @@ def add_noise(
263268
timesteps: jnp.ndarray,
264269
) -> jnp.ndarray:
265270
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
266-
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
267-
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
268-
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
271+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
272+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
273+
sqrt_alpha_prod = sqrt_alpha_prod[:, None]
274+
275+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
276+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
277+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
278+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]
269279

270280
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
271281
return noisy_samples

src/diffusers/schedulers/scheduling_ddpm_flax.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,14 @@ def add_noise(
266266
timesteps: jnp.ndarray,
267267
) -> jnp.ndarray:
268268
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
269-
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
269+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
270+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
271+
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
272+
270273
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
271-
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
274+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
275+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
276+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
272277

273278
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
274279
return noisy_samples

src/diffusers/schedulers/scheduling_lms_discrete_flax.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,11 @@ def add_noise(
198198
noise: jnp.ndarray,
199199
timesteps: jnp.ndarray,
200200
) -> jnp.ndarray:
201-
sigmas = self.match_shape(state.sigmas[timesteps], noise)
202-
noisy_samples = original_samples + noise * sigmas
201+
sigma = state.sigmas[timesteps].flatten()
202+
while len(sigma.shape) < len(noise.shape):
203+
sigma = sigma[..., None]
204+
205+
noisy_samples = original_samples + noise * sigma
203206

204207
return noisy_samples
205208

src/diffusers/schedulers/scheduling_pndm_flax.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import math
16-
1715
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
16+
17+
import math
1818
from dataclasses import dataclass
1919
from typing import Optional, Tuple, Union
2020

@@ -59,7 +59,6 @@ class PNDMSchedulerState:
5959
# setable values
6060
_timesteps: jnp.ndarray
6161
num_inference_steps: Optional[int] = None
62-
_offset: int = 0
6362
prk_timesteps: Optional[jnp.ndarray] = None
6463
plms_timesteps: Optional[jnp.ndarray] = None
6564
timesteps: Optional[jnp.ndarray] = None
@@ -104,6 +103,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
104103
skip_prk_steps (`bool`):
105104
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
106105
before plms steps; defaults to `False`.
106+
set_alpha_to_one (`bool`, default `False`):
107+
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
108+
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
109+
otherwise it uses the value of alpha at step 0.
110+
steps_offset (`int`, default `0`):
111+
an offset added to the inference steps. You can use a combination of `offset=1` and
112+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
113+
stable diffusion.
107114
"""
108115

109116
@register_to_config
@@ -115,6 +122,8 @@ def __init__(
115122
beta_schedule: str = "linear",
116123
trained_betas: Optional[jnp.ndarray] = None,
117124
skip_prk_steps: bool = False,
125+
set_alpha_to_one: bool = False,
126+
steps_offset: int = 0,
118127
):
119128
if trained_betas is not None:
120129
self.betas = jnp.asarray(trained_betas)
@@ -132,16 +141,16 @@ def __init__(
132141
self.alphas = 1.0 - self.betas
133142
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
134143

144+
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
145+
135146
# For now we only support F-PNDM, i.e. the runge-kutta method
136147
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
137148
# mainly at formula (9), (12), (13) and the Algorithm 2.
138149
self.pndm_order = 4
139150

140151
self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)
141152

142-
def set_timesteps(
143-
self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0
144-
) -> PNDMSchedulerState:
153+
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
145154
"""
146155
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
147156
@@ -150,16 +159,15 @@ def set_timesteps(
150159
the `FlaxPNDMScheduler` state data class instance.
151160
num_inference_steps (`int`):
152161
the number of diffusion steps used when generating samples with a pre-trained model.
153-
offset (`int`):
154-
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
155162
"""
163+
offset = self.config.steps_offset
164+
156165
step_ratio = self.config.num_train_timesteps // num_inference_steps
157166
# creates integer timesteps by multiplying by ratio
158167
# rounding to avoid issues when num_inference_step is power of 3
159-
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
160-
_timesteps = _timesteps + offset
168+
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset
161169

162-
state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps)
170+
state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps)
163171

164172
if self.config.skip_prk_steps:
165173
# for some models like stable diffusion the prk steps can/should be skipped to
@@ -254,7 +262,7 @@ def step_prk(
254262
)
255263

256264
diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
257-
prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1])
265+
prev_timestep = timestep - diff_to_prev
258266
timestep = state.prk_timesteps[state.counter // 4 * 4]
259267

260268
if state.counter % 4 == 0:
@@ -274,7 +282,7 @@ def step_prk(
274282
# cur_sample should not be `None`
275283
cur_sample = state.cur_sample if state.cur_sample is not None else sample
276284

277-
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state)
285+
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
278286
state = state.replace(counter=state.counter + 1)
279287

280288
if not return_dict:
@@ -320,7 +328,7 @@ def step_plms(
320328
"for more information."
321329
)
322330

323-
prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0)
331+
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
324332

325333
if state.counter != 1:
326334
state = state.replace(ets=state.ets.append(model_output))
@@ -344,15 +352,15 @@ def step_plms(
344352
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
345353
)
346354

347-
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state)
355+
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
348356
state = state.replace(counter=state.counter + 1)
349357

350358
if not return_dict:
351359
return (prev_sample, state)
352360

353361
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
354362

355-
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state):
363+
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
356364
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
357365
# this function computes x_(t−δ) using the formula of (9)
358366
# Note that x_t needs to be added to both sides of the equation
@@ -365,8 +373,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state)
365373
# sample -> x_t
366374
# model_output -> e_θ(x_t, t)
367375
# prev_sample -> x_(t−δ)
368-
alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset]
369-
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset]
376+
alpha_prod_t = self.alphas_cumprod[timestep]
377+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
370378
beta_prod_t = 1 - alpha_prod_t
371379
beta_prod_t_prev = 1 - alpha_prod_t_prev
372380

@@ -395,9 +403,14 @@ def add_noise(
395403
timesteps: jnp.ndarray,
396404
) -> jnp.ndarray:
397405
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
398-
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
406+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
407+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
408+
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
409+
399410
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
400-
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
411+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
412+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
413+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
401414

402415
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
403416
return noisy_samples

src/diffusers/schedulers/scheduling_sde_ve_flax.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,17 @@ def step_pred(
192192

193193
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
194194
# also equation 47 shows the analog from SDE models to ancestral sampling methods
195-
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
195+
diffusion = diffusion.flatten()
196+
while len(diffusion.shape) < len(sample.shape):
197+
diffusion = diffusion[:, None]
198+
drift = drift - diffusion**2 * model_output
196199

197200
# equation 6: sample noise for the diffusion term of
198201
key = random.split(key, num=1)
199202
noise = random.normal(key=key, shape=sample.shape)
200203
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
201204
# TODO is the variable diffusion the correct scaling term for the noise?
202-
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
205+
prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
203206

204207
if not return_dict:
205208
return (prev_sample, prev_sample_mean, state)
@@ -248,8 +251,11 @@ def step_correct(
248251
step_size = step_size * jnp.ones(sample.shape[0])
249252

250253
# compute corrected sample: model_output term and noise term
251-
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
252-
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
254+
step_size = step_size.flatten()
255+
while len(step_size.shape) < len(sample.shape):
256+
step_size = step_size[:, None]
257+
prev_sample_mean = sample + step_size * model_output
258+
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
253259

254260
if not return_dict:
255261
return (prev_sample, state)

0 commit comments

Comments
 (0)