From 0aec14e4b1060d1e16de5856e1e44c23de560111 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Sep 2022 13:27:36 +0200 Subject: [PATCH 1/5] remove match_shape --- src/diffusers/schedulers/scheduling_ddim_flax.py | 11 ++++++++--- src/diffusers/schedulers/scheduling_ddpm_flax.py | 9 +++++++-- .../schedulers/scheduling_lms_discrete_flax.py | 7 +++++-- src/diffusers/schedulers/scheduling_pndm_flax.py | 9 +++++++-- src/diffusers/schedulers/scheduling_sde_ve_flax.py | 14 ++++++++++---- 5 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index dd3c2ac85d2c..f72efc957b77 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -263,9 +263,14 @@ def add_noise( timesteps: jnp.ndarray, ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod[:, None] + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None] noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index f686a2a32234..9096663016c2 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -266,9 +266,14 @@ def add_noise( timesteps: jnp.ndarray, ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod[..., None] + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None] noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 1431bdacf54c..7f4c076b54d1 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -198,8 +198,11 @@ def add_noise( noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: - sigmas = self.match_shape(state.sigmas[timesteps], noise) - noisy_samples = original_samples + noise * sigmas + sigma = state.sigmas[timesteps].flatten() + while len(sigma.shape) < len(noise.shape): + sigma = sigma[..., None] + + noisy_samples = original_samples + noise * sigma return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 8444d6680401..635696002d56 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -395,9 +395,14 @@ def add_noise( timesteps: jnp.ndarray, ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod[..., None] + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None] noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index e5860706aa2e..08fbe14732da 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -192,14 +192,17 @@ def step_pred( # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods - drift = drift - diffusion[:, None, None, None] ** 2 * model_output + diffusion = diffusion.flatten() + while len(diffusion.shape) < len(sample.shape): + diffusion = diffusion[:, None] + drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of key = random.split(key, num=1) noise = random.normal(key=key, shape=sample.shape) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? - prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g + prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g if not return_dict: return (prev_sample, prev_sample_mean, state) @@ -248,8 +251,11 @@ def step_correct( step_size = step_size * jnp.ones(sample.shape[0]) # compute corrected sample: model_output term and noise term - prev_sample_mean = sample + step_size[:, None, None, None] * model_output - prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise + step_size = step_size.flatten() + while len(step_size.shape) < len(sample.shape): + step_size = step_size[:, None] + prev_sample_mean = sample + step_size * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise if not return_dict: return (prev_sample, state) From 670da29a3d21de41767e124032a92d1bf9f5a638 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Sep 2022 13:44:24 +0200 Subject: [PATCH 2/5] ported fixes from #479 to flax --- .../schedulers/scheduling_ddim_flax.py | 26 +++++++--- .../schedulers/scheduling_pndm_flax.py | 49 +++++++++++++------ 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index f72efc957b77..828f5ae4f40b 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -16,6 +16,7 @@ # and https://github.com/hojonathanho/diffusion import math +import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -96,7 +97,13 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. set_alpha_to_one (`bool`, default `True`): - if alpha for final step is 1 or the final alpha of the "non-previous" one. + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ @register_to_config @@ -109,6 +116,7 @@ def __init__( trained_betas: Optional[jnp.ndarray] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, + steps_offset: int = 0, ): if trained_betas is not None: self.betas = jnp.asarray(trained_betas) @@ -144,9 +152,7 @@ def _get_variance(self, timestep, prev_timestep): return variance - def set_timesteps( - self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0 - ) -> DDIMSchedulerState: + def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, **kwargs) -> DDIMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -155,9 +161,17 @@ def set_timesteps( the `FlaxDDIMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): - optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. """ + offset = self.config.steps_offset + if "offset" in kwargs: + warnings.warn( + "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." + " Please pass `steps_offset` to `__init__` instead.", + DeprecationWarning, + ) + + offset = kwargs["offset"] + step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 635696002d56..6cb882bd86db 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -59,7 +60,6 @@ class PNDMSchedulerState: # setable values _timesteps: jnp.ndarray num_inference_steps: Optional[int] = None - _offset: int = 0 prk_timesteps: Optional[jnp.ndarray] = None plms_timesteps: Optional[jnp.ndarray] = None timesteps: Optional[jnp.ndarray] = None @@ -104,6 +104,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. + set_alpha_to_one (`bool`, default `False`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ @register_to_config @@ -115,6 +123,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + steps_offset: int = 0, ): if trained_betas is not None: self.betas = jnp.asarray(trained_betas) @@ -132,6 +142,8 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. @@ -139,9 +151,7 @@ def __init__( self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) - def set_timesteps( - self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0 - ) -> PNDMSchedulerState: + def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, **kwargs) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -150,16 +160,23 @@ def set_timesteps( the `FlaxPNDMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): - optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. """ + offset = self.config.steps_offset + + if "offset" in kwargs: + warnings.warn( + "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." + " Please pass `steps_offset` to `__init__` instead." + ) + + offset = kwargs["offset"] + step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 - _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] - _timesteps = _timesteps + offset + _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset - state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps) + state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps) if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to @@ -254,7 +271,7 @@ def step_prk( ) diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 - prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1]) + prev_timestep = timestep - diff_to_prev timestep = state.prk_timesteps[state.counter // 4 * 4] if state.counter % 4 == 0: @@ -320,7 +337,7 @@ def step_plms( "for more information." ) - prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0) + prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps if state.counter != 1: state = state.replace(ets=state.ets.append(model_output)) @@ -352,7 +369,7 @@ def step_plms( return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) - def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state): + def _get_prev_sample(self, sample, timestep, prev_timestep, model_output, state): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation @@ -365,8 +382,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state) # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) - alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset] - alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev From 5779bb376ff7b693bb5a7d5dec66bae89408bb8e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Sep 2022 13:45:33 +0200 Subject: [PATCH 3/5] remove unused argument --- src/diffusers/schedulers/scheduling_pndm_flax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 6cb882bd86db..343b2ea1e1b9 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -361,7 +361,7 @@ def step_plms( 55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4] ) - prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state) + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) if not return_dict: @@ -369,7 +369,7 @@ def step_plms( return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) - def _get_prev_sample(self, sample, timestep, prev_timestep, model_output, state): + def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation From 6846095447332531e96ee45c3d8b0393758cbfe1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Sep 2022 13:46:08 +0200 Subject: [PATCH 4/5] typo --- src/diffusers/schedulers/scheduling_pndm_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 343b2ea1e1b9..3f8bffbc1448 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -291,7 +291,7 @@ def step_prk( # cur_sample should not be `None` cur_sample = state.cur_sample if state.cur_sample is not None else sample - prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state) + prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) if not return_dict: From 9fe940439f5430e8f32c471cdee569ad76fa396e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Sep 2022 14:33:24 +0200 Subject: [PATCH 5/5] remove warnings --- src/diffusers/schedulers/scheduling_ddim_flax.py | 11 +---------- src/diffusers/schedulers/scheduling_pndm_flax.py | 11 +---------- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 828f5ae4f40b..015b79b2780d 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -16,7 +16,6 @@ # and https://github.com/hojonathanho/diffusion import math -import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -152,7 +151,7 @@ def _get_variance(self, timestep, prev_timestep): return variance - def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, **kwargs) -> DDIMSchedulerState: + def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -163,14 +162,6 @@ def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, **k the number of diffusion steps used when generating samples with a pre-trained model. """ offset = self.config.steps_offset - if "offset" in kwargs: - warnings.warn( - "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." - " Please pass `steps_offset` to `__init__` instead.", - DeprecationWarning, - ) - - offset = kwargs["offset"] step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 3f8bffbc1448..efc3858ca75a 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -15,7 +15,6 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -151,7 +150,7 @@ def __init__( self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) - def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, **kwargs) -> PNDMSchedulerState: + def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -163,14 +162,6 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, **k """ offset = self.config.steps_offset - if "offset" in kwargs: - warnings.warn( - "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." - " Please pass `steps_offset` to `__init__` instead." - ) - - offset = kwargs["offset"] - step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3