Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions src/diffusers/schedulers/scheduling_ddim_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,13 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`):
if alpha for final step is 1 or the final alpha of the "non-previous" one.
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""

@register_to_config
Expand All @@ -109,6 +115,7 @@ def __init__(
trained_betas: Optional[jnp.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
Expand Down Expand Up @@ -144,9 +151,7 @@ def _get_variance(self, timestep, prev_timestep):

return variance

def set_timesteps(
self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0
) -> DDIMSchedulerState:
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

Expand All @@ -155,9 +160,9 @@ def set_timesteps(
the `FlaxDDIMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
offset = self.config.steps_offset

step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
Expand Down Expand Up @@ -263,9 +268,14 @@ def add_noise(
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[:, None]

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/schedulers/scheduling_ddpm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,14 @@ def add_noise(
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
Expand Down
7 changes: 5 additions & 2 deletions src/diffusers/schedulers/scheduling_lms_discrete_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,11 @@ def add_noise(
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sigmas = self.match_shape(state.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas
sigma = state.sigmas[timesteps].flatten()
while len(sigma.shape) < len(noise.shape):
sigma = sigma[..., None]

noisy_samples = original_samples + noise * sigma

return noisy_samples

Expand Down
53 changes: 33 additions & 20 deletions src/diffusers/schedulers/scheduling_pndm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -59,7 +59,6 @@ class PNDMSchedulerState:
# setable values
_timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
_offset: int = 0
prk_timesteps: Optional[jnp.ndarray] = None
plms_timesteps: Optional[jnp.ndarray] = None
timesteps: Optional[jnp.ndarray] = None
Expand Down Expand Up @@ -104,6 +103,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
skip_prk_steps (`bool`):
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
before plms steps; defaults to `False`.
set_alpha_to_one (`bool`, default `False`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""

@register_to_config
Expand All @@ -115,6 +122,8 @@ def __init__(
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
Expand All @@ -132,16 +141,16 @@ def __init__(
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)

self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]

# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4

self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)

def set_timesteps(
self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0
) -> PNDMSchedulerState:
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

Expand All @@ -150,16 +159,15 @@ def set_timesteps(
the `FlaxPNDMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
offset = self.config.steps_offset

step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# rounding to avoid issues when num_inference_step is power of 3
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
_timesteps = _timesteps + offset
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset

state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps)
state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps)

if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to
Expand Down Expand Up @@ -254,7 +262,7 @@ def step_prk(
)

diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1])
prev_timestep = timestep - diff_to_prev
timestep = state.prk_timesteps[state.counter // 4 * 4]

if state.counter % 4 == 0:
Expand All @@ -274,7 +282,7 @@ def step_prk(
# cur_sample should not be `None`
cur_sample = state.cur_sample if state.cur_sample is not None else sample

prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state)
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)

if not return_dict:
Expand Down Expand Up @@ -320,7 +328,7 @@ def step_plms(
"for more information."
)

prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps

if state.counter != 1:
state = state.replace(ets=state.ets.append(model_output))
Expand All @@ -344,15 +352,15 @@ def step_plms(
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
)

prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state)
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)

if not return_dict:
return (prev_sample, state)

return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)

def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state):
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
Expand All @@ -365,8 +373,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state)
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset]
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

Expand Down Expand Up @@ -395,9 +403,14 @@ def add_noise(
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
Expand Down
14 changes: 10 additions & 4 deletions src/diffusers/schedulers/scheduling_sde_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,17 @@ def step_pred(

# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
diffusion = diffusion.flatten()
while len(diffusion.shape) < len(sample.shape):
diffusion = diffusion[:, None]
drift = drift - diffusion**2 * model_output

# equation 6: sample noise for the diffusion term of
key = random.split(key, num=1)
noise = random.normal(key=key, shape=sample.shape)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g

if not return_dict:
return (prev_sample, prev_sample_mean, state)
Expand Down Expand Up @@ -248,8 +251,11 @@ def step_correct(
step_size = step_size * jnp.ones(sample.shape[0])

# compute corrected sample: model_output term and noise term
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
step_size = step_size.flatten()
while len(step_size.shape) < len(sample.shape):
step_size = step_size[:, None]
prev_sample_mean = sample + step_size * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise

if not return_dict:
return (prev_sample, state)
Expand Down