Skip to content

Commit f106ab4

Browse files
skirstenpcuenca
andauthored
[Flax] Stateless schedulers, fixes and refactors (#1661)
* [Flax] Stateless schedulers, fixes and refactors * Remove scheduling_common_flax and some renames * Update src/diffusers/schedulers/scheduling_pndm_flax.py Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent d87cc15 commit f106ab4

File tree

12 files changed

+632
-550
lines changed

12 files changed

+632
-550
lines changed

examples/dreambooth/train_dreambooth_flax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def collate_fn(examples):
475475
noise_scheduler = FlaxDDPMScheduler(
476476
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
477477
)
478+
noise_scheduler_state = noise_scheduler.create_state()
478479

479480
# Initialize our training
480481
train_rngs = jax.random.split(rng, jax.local_device_count())
@@ -511,7 +512,7 @@ def compute_loss(params):
511512

512513
# Add noise to the latents according to the noise magnitude at each timestep
513514
# (this is the forward diffusion process)
514-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
515+
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
515516

516517
# Get the text embedding for conditioning
517518
if args.train_text_encoder:

examples/text_to_image/train_text_to_image_flax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ def collate_fn(examples):
417417
noise_scheduler = FlaxDDPMScheduler(
418418
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
419419
)
420+
noise_scheduler_state = noise_scheduler.create_state()
420421

421422
# Initialize our training
422423
rng = jax.random.PRNGKey(args.seed)
@@ -449,7 +450,7 @@ def compute_loss(params):
449450

450451
# Add noise to the latents according to the noise magnitude at each timestep
451452
# (this is the forward diffusion process)
452-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
453+
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
453454

454455
# Get the text embedding for conditioning
455456
encoder_hidden_states = text_encoder(

examples/textual_inversion/textual_inversion_flax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ def update_fn(updates, state, params=None):
505505
noise_scheduler = FlaxDDPMScheduler(
506506
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
507507
)
508+
noise_scheduler_state = noise_scheduler.create_state()
508509

509510
# Initialize our training
510511
train_rngs = jax.random.split(rng, jax.local_device_count())
@@ -531,7 +532,7 @@ def compute_loss(params):
531532
0,
532533
noise_scheduler.config.num_train_timesteps,
533534
)
534-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
535+
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
535536
encoder_hidden_states = state.apply_fn(
536537
batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
537538
)[0]

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ def loop_body(step, args):
261261
)
262262

263263
# scale the initial noise by the standard deviation required by the scheduler
264-
latents = latents * self.scheduler.init_noise_sigma
264+
latents = latents * params["scheduler"].init_noise_sigma
265+
265266
if DEBUG:
266267
# run with python for loop
267268
for i in range(num_inference_steps):

src/diffusers/schedulers/scheduling_ddim_flax.py

Lines changed: 68 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
1616
# and https://github.com/hojonathanho/diffusion
1717

18-
import math
1918
from dataclasses import dataclass
2019
from typing import Optional, Tuple, Union
2120

@@ -26,51 +25,37 @@
2625
from ..utils import deprecate
2726
from .scheduling_utils_flax import (
2827
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
28+
CommonSchedulerState,
2929
FlaxSchedulerMixin,
3030
FlaxSchedulerOutput,
31-
broadcast_to_shape_from_left,
31+
add_noise_common,
3232
)
3333

3434

35-
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
36-
"""
37-
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
38-
(1-beta) over time from t = [0,1].
39-
40-
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
41-
to that part of the diffusion process.
42-
43-
44-
Args:
45-
num_diffusion_timesteps (`int`): the number of betas to produce.
46-
max_beta (`float`): the maximum beta to use; use values lower than 1 to
47-
prevent singularities.
48-
49-
Returns:
50-
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
51-
"""
52-
53-
def alpha_bar(time_step):
54-
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
55-
56-
betas = []
57-
for i in range(num_diffusion_timesteps):
58-
t1 = i / num_diffusion_timesteps
59-
t2 = (i + 1) / num_diffusion_timesteps
60-
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
61-
return jnp.array(betas, dtype=jnp.float32)
62-
63-
6435
@flax.struct.dataclass
6536
class DDIMSchedulerState:
37+
common: CommonSchedulerState
38+
final_alpha_cumprod: jnp.ndarray
39+
6640
# setable values
41+
init_noise_sigma: jnp.ndarray
6742
timesteps: jnp.ndarray
68-
alphas_cumprod: jnp.ndarray
6943
num_inference_steps: Optional[int] = None
7044

7145
@classmethod
72-
def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):
73-
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod)
46+
def create(
47+
cls,
48+
common: CommonSchedulerState,
49+
final_alpha_cumprod: jnp.ndarray,
50+
init_noise_sigma: jnp.ndarray,
51+
timesteps: jnp.ndarray,
52+
):
53+
return cls(
54+
common=common,
55+
final_alpha_cumprod=final_alpha_cumprod,
56+
init_noise_sigma=init_noise_sigma,
57+
timesteps=timesteps,
58+
)
7459

7560

7661
@dataclass
@@ -112,12 +97,15 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
11297
prediction_type (`str`, default `epsilon`):
11398
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
11499
`v-prediction` is not supported for this scheduler.
115-
100+
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
101+
the `dtype` used for params and computation.
116102
"""
117103

118104
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
119105
_deprecated_kwargs = ["predict_epsilon"]
120106

107+
dtype: jnp.dtype
108+
121109
@property
122110
def has_state(self):
123111
return True
@@ -129,43 +117,46 @@ def __init__(
129117
beta_start: float = 0.0001,
130118
beta_end: float = 0.02,
131119
beta_schedule: str = "linear",
120+
trained_betas: Optional[jnp.ndarray] = None,
132121
set_alpha_to_one: bool = True,
133122
steps_offset: int = 0,
134123
prediction_type: str = "epsilon",
124+
dtype: jnp.dtype = jnp.float32,
135125
**kwargs,
136126
):
137127
message = (
138128
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
139-
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
129+
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
140130
)
141131
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
142132
if predict_epsilon is not None:
143133
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
144134

145-
if beta_schedule == "linear":
146-
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
147-
elif beta_schedule == "scaled_linear":
148-
# this schedule is very specific to the latent diffusion model.
149-
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
150-
elif beta_schedule == "squaredcos_cap_v2":
151-
# Glide cosine schedule
152-
self.betas = betas_for_alpha_bar(num_train_timesteps)
153-
else:
154-
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
155-
156-
self.alphas = 1.0 - self.betas
135+
self.dtype = dtype
157136

158-
# HACK for now - clean up later (PVP)
159-
self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
137+
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
138+
if common is None:
139+
common = CommonSchedulerState.create(self)
160140

161141
# At every step in ddim, we are looking into the previous alphas_cumprod
162142
# For the final step, there is no previous alphas_cumprod because we are already at 0
163143
# `set_alpha_to_one` decides whether we set this parameter simply to one or
164144
# whether we use the final alpha of the "non-previous" one.
165-
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
145+
final_alpha_cumprod = (
146+
jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0]
147+
)
166148

167149
# standard deviation of the initial noise distribution
168-
self.init_noise_sigma = 1.0
150+
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
151+
152+
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
153+
154+
return DDIMSchedulerState.create(
155+
common=common,
156+
final_alpha_cumprod=final_alpha_cumprod,
157+
init_noise_sigma=init_noise_sigma,
158+
timesteps=timesteps,
159+
)
169160

170161
def scale_model_input(
171162
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
@@ -181,21 +172,6 @@ def scale_model_input(
181172
"""
182173
return sample
183174

184-
def create_state(self):
185-
return DDIMSchedulerState.create(
186-
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
187-
)
188-
189-
def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
190-
alpha_prod_t = alphas_cumprod[timestep]
191-
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
192-
beta_prod_t = 1 - alpha_prod_t
193-
beta_prod_t_prev = 1 - alpha_prod_t_prev
194-
195-
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
196-
197-
return variance
198-
199175
def set_timesteps(
200176
self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
201177
) -> DDIMSchedulerState:
@@ -208,22 +184,35 @@ def set_timesteps(
208184
num_inference_steps (`int`):
209185
the number of diffusion steps used when generating samples with a pre-trained model.
210186
"""
211-
offset = self.config.steps_offset
212-
213187
step_ratio = self.config.num_train_timesteps // num_inference_steps
214188
# creates integer timesteps by multiplying by ratio
215-
# casting to int to avoid issues when num_inference_step is power of 3
216-
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
217-
timesteps = timesteps + offset
189+
# rounding to avoid issues when num_inference_step is power of 3
190+
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset
191+
192+
return state.replace(
193+
num_inference_steps=num_inference_steps,
194+
timesteps=timesteps,
195+
)
196+
197+
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
198+
alpha_prod_t = state.common.alphas_cumprod[timestep]
199+
alpha_prod_t_prev = jnp.where(
200+
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
201+
)
202+
beta_prod_t = 1 - alpha_prod_t
203+
beta_prod_t_prev = 1 - alpha_prod_t_prev
204+
205+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
218206

219-
return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
207+
return variance
220208

221209
def step(
222210
self,
223211
state: DDIMSchedulerState,
224212
model_output: jnp.ndarray,
225213
timestep: int,
226214
sample: jnp.ndarray,
215+
eta: float = 0.0,
227216
return_dict: bool = True,
228217
) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
229218
"""
@@ -259,17 +248,15 @@ def step(
259248
# - pred_sample_direction -> "direction pointing to x_t"
260249
# - pred_prev_sample -> "x_t-1"
261250

262-
# TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
263-
eta = 0.0
264-
265251
# 1. get previous step value (=t-1)
266252
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
267253

268-
alphas_cumprod = state.alphas_cumprod
254+
alphas_cumprod = state.common.alphas_cumprod
255+
final_alpha_cumprod = state.final_alpha_cumprod
269256

270257
# 2. compute alphas, betas
271258
alpha_prod_t = alphas_cumprod[timestep]
272-
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
259+
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod)
273260

274261
beta_prod_t = 1 - alpha_prod_t
275262

@@ -291,7 +278,7 @@ def step(
291278

292279
# 4. compute variance: "sigma_t(η)" -> see formula (16)
293280
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
294-
variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
281+
variance = self._get_variance(state, timestep, prev_timestep)
295282
std_dev_t = eta * variance ** (0.5)
296283

297284
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
@@ -307,20 +294,12 @@ def step(
307294

308295
def add_noise(
309296
self,
297+
state: DDIMSchedulerState,
310298
original_samples: jnp.ndarray,
311299
noise: jnp.ndarray,
312300
timesteps: jnp.ndarray,
313301
) -> jnp.ndarray:
314-
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
315-
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
316-
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
317-
318-
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
319-
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
320-
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
321-
322-
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
323-
return noisy_samples
302+
return add_noise_common(state.common, original_samples, noise, timesteps)
324303

325304
def __len__(self):
326305
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)