Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 158 additions & 89 deletions src/diffusers/schedulers/scheduling_pndm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional, Tuple, Union

import flax
import jax
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
Expand Down Expand Up @@ -150,7 +151,12 @@ def __init__(

self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)

def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
def set_timesteps(
self,
state: PNDMSchedulerState,
shape: Tuple,
num_inference_steps: int
) -> PNDMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

Expand Down Expand Up @@ -191,8 +197,11 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) ->

return state.replace(
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
ets=jnp.array([]),
counter=0,
# Will be zeros, not really empty
cur_model_output = jnp.empty(shape),
cur_sample = jnp.empty(shape),
ets = jnp.empty((4,) + shape),
)

def step(
Expand Down Expand Up @@ -222,73 +231,77 @@ def step(
When returning a tuple, the first element is the sample tensor.

"""
if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
)
else:
return self.step_plms(
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
)

def step_prk(
self,
state: PNDMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.

Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.

"""
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
prev_timestep = timestep - diff_to_prev
timestep = state.prk_timesteps[state.counter // 4 * 4]

if state.counter % 4 == 0:
state = state.replace(
cur_model_output=state.cur_model_output + 1 / 6 * model_output,
ets=state.ets.append(model_output),
cur_sample=sample,
)
elif (self.counter - 1) % 4 == 0:
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
elif (self.counter - 2) % 4 == 0:
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
elif (self.counter - 3) % 4 == 0:
model_output = state.cur_model_output + 1 / 6 * model_output
state = state.replace(cur_model_output=0)

# cur_sample should not be `None`
cur_sample = state.cur_sample if state.cur_sample is not None else sample

prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)

if not return_dict:
return (prev_sample, state)
return self.step_plms(
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
)

return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
# if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps:
# return self.step_prk(
# state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
# )
# else:
# return self.step_plms(
# state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
# )

# def step_prk(
# self,
# state: PNDMSchedulerState,
# model_output: jnp.ndarray,
# timestep: int,
# sample: jnp.ndarray,
# return_dict: bool = True,
# ) -> Union[FlaxSchedulerOutput, Tuple]:
# """
# Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
# solution to the differential equation.

# Args:
# state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
# model_output (`jnp.ndarray`): direct output from learned diffusion model.
# timestep (`int`): current discrete timestep in the diffusion chain.
# sample (`jnp.ndarray`):
# current instance of sample being created by diffusion process.
# return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

# Returns:
# [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
# When returning a tuple, the first element is the sample tensor.

# """
# if state.num_inference_steps is None:
# raise ValueError(
# "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
# )

# diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
# prev_timestep = timestep - diff_to_prev
# timestep = state.prk_timesteps[state.counter // 4 * 4]

# if state.counter % 4 == 0:
# state = state.replace(
# cur_model_output=state.cur_model_output + 1 / 6 * model_output,
# ets=state.ets.append(model_output),
# cur_sample=sample,
# )
# elif (self.counter - 1) % 4 == 0:
# state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
# elif (self.counter - 2) % 4 == 0:
# state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
# elif (self.counter - 3) % 4 == 0:
# model_output = state.cur_model_output + 1 / 6 * model_output
# state = state.replace(cur_model_output=0)

# # cur_sample should not be `None`
# cur_sample = state.cur_sample if state.cur_sample is not None else sample

# prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
# state = state.replace(counter=state.counter + 1)

# if not return_dict:
# return (prev_sample, state)

# return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)

def step_plms(
self,
Expand Down Expand Up @@ -329,29 +342,85 @@ def step_plms(
)

prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)

# Reference:
# if state.counter != 1:
# state.ets.append(model_output)
# else:
# prev_timestep = timestep
# timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps

prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
timestep = jnp.where(state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep)

# Reference:
# if len(state.ets) == 1 and state.counter == 0:
# model_output = model_output
# state.cur_sample = sample
# elif len(state.ets) == 1 and state.counter == 1:
# model_output = (model_output + state.ets[-1]) / 2
# sample = state.cur_sample
# state.cur_sample = None
# elif len(state.ets) == 2:
# model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
# elif len(state.ets) == 3:
# model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
# else:
# model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])

def counter_0(state: PNDMSchedulerState):
ets = state.ets.at[0].set(model_output)
return state.replace(
ets = ets,
cur_sample = sample,
cur_model_output = jnp.array(model_output, dtype=jnp.float32),
)

if state.counter != 1:
state = state.replace(ets=state.ets.append(model_output))
else:
prev_timestep = timestep
timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps

if len(state.ets) == 1 and state.counter == 0:
model_output = model_output
state = state.replace(cur_sample=sample)
elif len(state.ets) == 1 and state.counter == 1:
model_output = (model_output + state.ets[-1]) / 2
sample = state.cur_sample
state = state.replace(cur_sample=None)
elif len(state.ets) == 2:
model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
elif len(state.ets) == 3:
model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
else:
model_output = (1 / 24) * (
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
def counter_1(state: PNDMSchedulerState):
return state.replace(
cur_model_output = (model_output + state.ets[0]) / 2,
)

def counter_2(state: PNDMSchedulerState):
ets = state.ets.at[1].set(model_output)
return state.replace(
ets = ets,
cur_model_output = (3 * ets[1] - ets[0]) / 2,
cur_sample = sample,
)

def counter_3(state: PNDMSchedulerState):
ets = state.ets.at[2].set(model_output)
return state.replace(
ets = ets,
cur_model_output = (23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12,
cur_sample = sample,
)

def counter_other(state: PNDMSchedulerState):
ets = state.ets.at[3].set(model_output)
next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0])

ets = ets.at[0].set(ets[1])
ets = ets.at[1].set(ets[2])
ets = ets.at[2].set(ets[3])

return state.replace(
ets = ets,
cur_model_output = next_model_output,
cur_sample = sample,
)

counter = jnp.clip(state.counter, 0, 4)
state = jax.lax.switch(
counter,
[counter_0, counter_1, counter_2, counter_3, counter_other],
state,
)

sample = state.cur_sample
model_output = state.cur_model_output
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)

Expand All @@ -374,7 +443,7 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

Expand Down