Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b9ca406
WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline
mishig25 Sep 19, 2022
30abc63
todo comment
mishig25 Sep 19, 2022
9b54559
Merge branch 'main' into flax_pipeline
Sep 19, 2022
4b2becb
Fix imports
mishig25 Sep 19, 2022
7f0e429
Fix imports
mishig25 Sep 19, 2022
d9e2ae1
add dummies
patrickvonplaten Sep 19, 2022
d51e881
Fix empty init
mishig25 Sep 19, 2022
741046d
Merge branch 'flax_pipeline' of https://github.com/huggingface/diffus…
mishig25 Sep 19, 2022
7aab68d
make pipeline work
patrickvonplaten Sep 19, 2022
7d3fff6
merge conflict
patrickvonplaten Sep 19, 2022
47d7739
up
patrickvonplaten Sep 19, 2022
4dfcf21
Allow dtype to be overridden on model load.
pcuenca Sep 20, 2022
d480534
Convert params to bfloat16 or fp16 after loading.
pcuenca Sep 20, 2022
0c2a868
Use Flax schedulers (typing, docstring)
pcuenca Sep 20, 2022
a71e6be
Merge branch 'flax_pipeline' into flax_pipeline_bf16
pcuenca Sep 20, 2022
aa3c010
PNDM: replace control flow with jax functions.
pcuenca Sep 19, 2022
d6dbb89
Pass latents shape to scheduler set_timesteps()
pcuenca Sep 20, 2022
69b1d7a
Wrap model imports inside availability checks.
pcuenca Sep 20, 2022
7091c1d
Merge branch 'flax_pipeline' into flax_pipeline_pndm
pcuenca Sep 20, 2022
23f7d73
Optionally return state in from_config.
pcuenca Sep 20, 2022
163df38
Do not convert model weights to dtype.
pcuenca Sep 20, 2022
039d1d6
Merge branch 'flax_pipeline_bf16' into flax_pipeline_pndm
pcuenca Sep 20, 2022
8bc06b0
Re-enable PRK steps with functional implementation.
pcuenca Sep 20, 2022
3752bbc
Merge remote-tracking branch 'origin/main' into flax_pipeline_pndm
pcuenca Sep 21, 2022
8a9ccf2
Remove left over has_state var.
pcuenca Sep 21, 2022
cf6cd7a
make style
pcuenca Sep 21, 2022
f974a41
Apply suggestion list -> tuple
pcuenca Sep 22, 2022
ce0a327
Apply suggestion list -> tuple
pcuenca Sep 22, 2022
7fcbc32
Remove unused comments.
pcuenca Sep 22, 2022
cd17c56
Use zeros instead of empty.
pcuenca Sep 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,6 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]

from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def loop_body(step, args):
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state

scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps)
scheduler_state = self.scheduler.set_timesteps(
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
)

if debug:
# run with python for loop
Expand Down
186 changes: 133 additions & 53 deletions src/diffusers/schedulers/scheduling_pndm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional, Tuple, Union

import flax
import jax
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
Expand Down Expand Up @@ -155,7 +156,7 @@ def __init__(
def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)

def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

Expand Down Expand Up @@ -196,8 +197,11 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) ->

return state.replace(
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
ets=jnp.array([]),
counter=0,
# Reserve space for the state variables
cur_model_output=jnp.zeros(shape),
cur_sample=jnp.zeros(shape),
ets=jnp.zeros((4,) + shape),
)

def step(
Expand Down Expand Up @@ -227,22 +231,32 @@ def step(
When returning a tuple, the first element is the sample tensor.

"""
if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
if self.config.skip_prk_steps:
prev_sample, state = self.step_plms(
state=state, model_output=model_output, timestep=timestep, sample=sample
)
else:
return self.step_plms(
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
prev_sample, state = jax.lax.switch(
jnp.where(state.counter < len(state.prk_timesteps), 0, 1),
(self.step_prk, self.step_plms),
# Args to either branch
state,
model_output,
timestep,
sample,
)

if not return_dict:
return (prev_sample, state)

return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)

def step_prk(
self,
state: PNDMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
Expand All @@ -266,42 +280,53 @@ def step_prk(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
diff_to_prev = jnp.where(
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
)
prev_timestep = timestep - diff_to_prev
timestep = state.prk_timesteps[state.counter // 4 * 4]

if state.counter % 4 == 0:
state = state.replace(
cur_model_output=state.cur_model_output + 1 / 6 * model_output,
ets=state.ets.append(model_output),
cur_sample=sample,
def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
return (
state.replace(
cur_model_output=state.cur_model_output + 1 / 6 * model_output,
ets=state.ets.at[ets_at].set(model_output),
cur_sample=sample,
),
model_output,
)
elif (self.counter - 1) % 4 == 0:
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
elif (self.counter - 2) % 4 == 0:
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
elif (self.counter - 3) % 4 == 0:
model_output = state.cur_model_output + 1 / 6 * model_output
state = state.replace(cur_model_output=0)

# cur_sample should not be `None`
cur_sample = state.cur_sample if state.cur_sample is not None else sample
def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output

def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output

def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
model_output = state.cur_model_output + 1 / 6 * model_output
return state.replace(cur_model_output=jnp.zeros_like(state.cur_model_output)), model_output

state, model_output = jax.lax.switch(
state.counter % 4,
(remainder_0, remainder_1, remainder_2, remainder_3),
# Args to either branch
state,
model_output,
state.counter // 4,
)

cur_sample = state.cur_sample
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)

if not return_dict:
return (prev_sample, state)

return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
return (prev_sample, state)

def step_plms(
self,
state: PNDMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
Expand Down Expand Up @@ -334,36 +359,91 @@ def step_plms(
)

prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)

# Reference:
# if state.counter != 1:
# state.ets.append(model_output)
# else:
# prev_timestep = timestep
# timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps

prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
timestep = jnp.where(
state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
)

if state.counter != 1:
state = state.replace(ets=state.ets.append(model_output))
else:
prev_timestep = timestep
timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps

if len(state.ets) == 1 and state.counter == 0:
model_output = model_output
state = state.replace(cur_sample=sample)
elif len(state.ets) == 1 and state.counter == 1:
model_output = (model_output + state.ets[-1]) / 2
sample = state.cur_sample
state = state.replace(cur_sample=None)
elif len(state.ets) == 2:
model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
elif len(state.ets) == 3:
model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
else:
model_output = (1 / 24) * (
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
# Reference:
# if len(state.ets) == 1 and state.counter == 0:
# model_output = model_output
# state.cur_sample = sample
# elif len(state.ets) == 1 and state.counter == 1:
# model_output = (model_output + state.ets[-1]) / 2
# sample = state.cur_sample
# state.cur_sample = None
# elif len(state.ets) == 2:
# model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
# elif len(state.ets) == 3:
# model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
# else:
# model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])

def counter_0(state: PNDMSchedulerState):
ets = state.ets.at[0].set(model_output)
return state.replace(
ets=ets,
cur_sample=sample,
cur_model_output=jnp.array(model_output, dtype=jnp.float32),
)

def counter_1(state: PNDMSchedulerState):
return state.replace(
cur_model_output=(model_output + state.ets[0]) / 2,
)

def counter_2(state: PNDMSchedulerState):
ets = state.ets.at[1].set(model_output)
return state.replace(
ets=ets,
cur_model_output=(3 * ets[1] - ets[0]) / 2,
cur_sample=sample,
)

def counter_3(state: PNDMSchedulerState):
ets = state.ets.at[2].set(model_output)
return state.replace(
ets=ets,
cur_model_output=(23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12,
cur_sample=sample,
)

def counter_other(state: PNDMSchedulerState):
ets = state.ets.at[3].set(model_output)
next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0])

ets = ets.at[0].set(ets[1])
ets = ets.at[1].set(ets[2])
ets = ets.at[2].set(ets[3])

return state.replace(
ets=ets,
cur_model_output=next_model_output,
cur_sample=sample,
)

counter = jnp.clip(state.counter, 0, 4)
state = jax.lax.switch(
counter,
[counter_0, counter_1, counter_2, counter_3, counter_other],
state,
)

sample = state.cur_sample
model_output = state.cur_model_output
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)

if not return_dict:
return (prev_sample, state)

return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
return (prev_sample, state)

def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
Expand All @@ -379,7 +459,7 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

Expand Down