Skip to content

Commit ab3fd67

Browse files
pcuencamishig25Mishig Davaadorjpatrickvonplatenpatil-suraj
authored
Flax pipeline pndm (#583)
* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline * todo comment * Fix imports * Fix imports * add dummies * Fix empty init * make pipeline work * up * Allow dtype to be overridden on model load. This may be a temporary solution until #567 is addressed. * Convert params to bfloat16 or fp16 after loading. This deals with the weights, not the model. * Use Flax schedulers (typing, docstring) * PNDM: replace control flow with jax functions. Otherwise jitting/parallelization don't work properly as they don't know how to deal with traced objects. I temporarily removed `step_prk`. * Pass latents shape to scheduler set_timesteps() PNDMScheduler uses it to reserve space, other schedulers will just ignore it. * Wrap model imports inside availability checks. * Optionally return state in from_config. Useful for Flax schedulers. * Do not convert model weights to dtype. * Re-enable PRK steps with functional implementation. Values returned still not verified for correctness. * Remove left over has_state var. * make style * Apply suggestion list -> tuple Co-authored-by: Suraj Patil <[email protected]> * Apply suggestion list -> tuple Co-authored-by: Suraj Patil <[email protected]> * Remove unused comments. * Use zeros instead of empty. Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent c070e5f commit ab3fd67

File tree

3 files changed

+137
-54
lines changed

3 files changed

+137
-54
lines changed

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,6 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput):
5656
images: Union[List[PIL.Image.Image], np.ndarray]
5757
nsfw_content_detected: List[bool]
5858

59+
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
5960
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
6061
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ def loop_body(step, args):
186186
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
187187
return latents, scheduler_state
188188

189-
scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps)
189+
scheduler_state = self.scheduler.set_timesteps(
190+
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
191+
)
190192

191193
if debug:
192194
# run with python for loop

src/diffusers/schedulers/scheduling_pndm_flax.py

Lines changed: 133 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Optional, Tuple, Union
2020

2121
import flax
22+
import jax
2223
import jax.numpy as jnp
2324

2425
from ..configuration_utils import ConfigMixin, register_to_config
@@ -155,7 +156,7 @@ def __init__(
155156
def create_state(self):
156157
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
157158

158-
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
159+
def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState:
159160
"""
160161
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
161162
@@ -196,8 +197,11 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) ->
196197

197198
return state.replace(
198199
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
199-
ets=jnp.array([]),
200200
counter=0,
201+
# Reserve space for the state variables
202+
cur_model_output=jnp.zeros(shape),
203+
cur_sample=jnp.zeros(shape),
204+
ets=jnp.zeros((4,) + shape),
201205
)
202206

203207
def step(
@@ -227,22 +231,32 @@ def step(
227231
When returning a tuple, the first element is the sample tensor.
228232
229233
"""
230-
if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps:
231-
return self.step_prk(
232-
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
234+
if self.config.skip_prk_steps:
235+
prev_sample, state = self.step_plms(
236+
state=state, model_output=model_output, timestep=timestep, sample=sample
233237
)
234238
else:
235-
return self.step_plms(
236-
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
239+
prev_sample, state = jax.lax.switch(
240+
jnp.where(state.counter < len(state.prk_timesteps), 0, 1),
241+
(self.step_prk, self.step_plms),
242+
# Args to either branch
243+
state,
244+
model_output,
245+
timestep,
246+
sample,
237247
)
238248

249+
if not return_dict:
250+
return (prev_sample, state)
251+
252+
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
253+
239254
def step_prk(
240255
self,
241256
state: PNDMSchedulerState,
242257
model_output: jnp.ndarray,
243258
timestep: int,
244259
sample: jnp.ndarray,
245-
return_dict: bool = True,
246260
) -> Union[FlaxSchedulerOutput, Tuple]:
247261
"""
248262
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
@@ -266,42 +280,53 @@ def step_prk(
266280
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
267281
)
268282

269-
diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
283+
diff_to_prev = jnp.where(
284+
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
285+
)
270286
prev_timestep = timestep - diff_to_prev
271287
timestep = state.prk_timesteps[state.counter // 4 * 4]
272288

273-
if state.counter % 4 == 0:
274-
state = state.replace(
275-
cur_model_output=state.cur_model_output + 1 / 6 * model_output,
276-
ets=state.ets.append(model_output),
277-
cur_sample=sample,
289+
def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
290+
return (
291+
state.replace(
292+
cur_model_output=state.cur_model_output + 1 / 6 * model_output,
293+
ets=state.ets.at[ets_at].set(model_output),
294+
cur_sample=sample,
295+
),
296+
model_output,
278297
)
279-
elif (self.counter - 1) % 4 == 0:
280-
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
281-
elif (self.counter - 2) % 4 == 0:
282-
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
283-
elif (self.counter - 3) % 4 == 0:
284-
model_output = state.cur_model_output + 1 / 6 * model_output
285-
state = state.replace(cur_model_output=0)
286298

287-
# cur_sample should not be `None`
288-
cur_sample = state.cur_sample if state.cur_sample is not None else sample
299+
def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
300+
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
289301

302+
def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
303+
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
304+
305+
def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
306+
model_output = state.cur_model_output + 1 / 6 * model_output
307+
return state.replace(cur_model_output=jnp.zeros_like(state.cur_model_output)), model_output
308+
309+
state, model_output = jax.lax.switch(
310+
state.counter % 4,
311+
(remainder_0, remainder_1, remainder_2, remainder_3),
312+
# Args to either branch
313+
state,
314+
model_output,
315+
state.counter // 4,
316+
)
317+
318+
cur_sample = state.cur_sample
290319
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
291320
state = state.replace(counter=state.counter + 1)
292321

293-
if not return_dict:
294-
return (prev_sample, state)
295-
296-
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
322+
return (prev_sample, state)
297323

298324
def step_plms(
299325
self,
300326
state: PNDMSchedulerState,
301327
model_output: jnp.ndarray,
302328
timestep: int,
303329
sample: jnp.ndarray,
304-
return_dict: bool = True,
305330
) -> Union[FlaxSchedulerOutput, Tuple]:
306331
"""
307332
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
@@ -334,36 +359,91 @@ def step_plms(
334359
)
335360

336361
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
362+
prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)
363+
364+
# Reference:
365+
# if state.counter != 1:
366+
# state.ets.append(model_output)
367+
# else:
368+
# prev_timestep = timestep
369+
# timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
370+
371+
prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
372+
timestep = jnp.where(
373+
state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
374+
)
337375

338-
if state.counter != 1:
339-
state = state.replace(ets=state.ets.append(model_output))
340-
else:
341-
prev_timestep = timestep
342-
timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
343-
344-
if len(state.ets) == 1 and state.counter == 0:
345-
model_output = model_output
346-
state = state.replace(cur_sample=sample)
347-
elif len(state.ets) == 1 and state.counter == 1:
348-
model_output = (model_output + state.ets[-1]) / 2
349-
sample = state.cur_sample
350-
state = state.replace(cur_sample=None)
351-
elif len(state.ets) == 2:
352-
model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
353-
elif len(state.ets) == 3:
354-
model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
355-
else:
356-
model_output = (1 / 24) * (
357-
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
376+
# Reference:
377+
# if len(state.ets) == 1 and state.counter == 0:
378+
# model_output = model_output
379+
# state.cur_sample = sample
380+
# elif len(state.ets) == 1 and state.counter == 1:
381+
# model_output = (model_output + state.ets[-1]) / 2
382+
# sample = state.cur_sample
383+
# state.cur_sample = None
384+
# elif len(state.ets) == 2:
385+
# model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
386+
# elif len(state.ets) == 3:
387+
# model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
388+
# else:
389+
# model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])
390+
391+
def counter_0(state: PNDMSchedulerState):
392+
ets = state.ets.at[0].set(model_output)
393+
return state.replace(
394+
ets=ets,
395+
cur_sample=sample,
396+
cur_model_output=jnp.array(model_output, dtype=jnp.float32),
397+
)
398+
399+
def counter_1(state: PNDMSchedulerState):
400+
return state.replace(
401+
cur_model_output=(model_output + state.ets[0]) / 2,
358402
)
359403

404+
def counter_2(state: PNDMSchedulerState):
405+
ets = state.ets.at[1].set(model_output)
406+
return state.replace(
407+
ets=ets,
408+
cur_model_output=(3 * ets[1] - ets[0]) / 2,
409+
cur_sample=sample,
410+
)
411+
412+
def counter_3(state: PNDMSchedulerState):
413+
ets = state.ets.at[2].set(model_output)
414+
return state.replace(
415+
ets=ets,
416+
cur_model_output=(23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12,
417+
cur_sample=sample,
418+
)
419+
420+
def counter_other(state: PNDMSchedulerState):
421+
ets = state.ets.at[3].set(model_output)
422+
next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0])
423+
424+
ets = ets.at[0].set(ets[1])
425+
ets = ets.at[1].set(ets[2])
426+
ets = ets.at[2].set(ets[3])
427+
428+
return state.replace(
429+
ets=ets,
430+
cur_model_output=next_model_output,
431+
cur_sample=sample,
432+
)
433+
434+
counter = jnp.clip(state.counter, 0, 4)
435+
state = jax.lax.switch(
436+
counter,
437+
[counter_0, counter_1, counter_2, counter_3, counter_other],
438+
state,
439+
)
440+
441+
sample = state.cur_sample
442+
model_output = state.cur_model_output
360443
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
361444
state = state.replace(counter=state.counter + 1)
362445

363-
if not return_dict:
364-
return (prev_sample, state)
365-
366-
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
446+
return (prev_sample, state)
367447

368448
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
369449
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
@@ -379,7 +459,7 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
379459
# model_output -> e_θ(x_t, t)
380460
# prev_sample -> x_(t−δ)
381461
alpha_prod_t = self.alphas_cumprod[timestep]
382-
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
462+
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
383463
beta_prod_t = 1 - alpha_prod_t
384464
beta_prod_t_prev = 1 - alpha_prod_t_prev
385465

0 commit comments

Comments
 (0)