1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import math
16-
1715# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
16+
17+ import math
1818from dataclasses import dataclass
1919from typing import Optional , Tuple , Union
2020
@@ -59,7 +59,6 @@ class PNDMSchedulerState:
5959 # setable values
6060 _timesteps : jnp .ndarray
6161 num_inference_steps : Optional [int ] = None
62- _offset : int = 0
6362 prk_timesteps : Optional [jnp .ndarray ] = None
6463 plms_timesteps : Optional [jnp .ndarray ] = None
6564 timesteps : Optional [jnp .ndarray ] = None
@@ -104,6 +103,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
104103 skip_prk_steps (`bool`):
105104 allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
106105 before plms steps; defaults to `False`.
106+ set_alpha_to_one (`bool`, default `False`):
107+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
108+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
109+ otherwise it uses the value of alpha at step 0.
110+ steps_offset (`int`, default `0`):
111+ an offset added to the inference steps. You can use a combination of `offset=1` and
112+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
113+ stable diffusion.
107114 """
108115
109116 @register_to_config
@@ -115,6 +122,8 @@ def __init__(
115122 beta_schedule : str = "linear" ,
116123 trained_betas : Optional [jnp .ndarray ] = None ,
117124 skip_prk_steps : bool = False ,
125+ set_alpha_to_one : bool = False ,
126+ steps_offset : int = 0 ,
118127 ):
119128 if trained_betas is not None :
120129 self .betas = jnp .asarray (trained_betas )
@@ -132,16 +141,16 @@ def __init__(
132141 self .alphas = 1.0 - self .betas
133142 self .alphas_cumprod = jnp .cumprod (self .alphas , axis = 0 )
134143
144+ self .final_alpha_cumprod = jnp .array (1.0 ) if set_alpha_to_one else self .alphas_cumprod [0 ]
145+
135146 # For now we only support F-PNDM, i.e. the runge-kutta method
136147 # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
137148 # mainly at formula (9), (12), (13) and the Algorithm 2.
138149 self .pndm_order = 4
139150
140151 self .state = PNDMSchedulerState .create (num_train_timesteps = num_train_timesteps )
141152
142- def set_timesteps (
143- self , state : PNDMSchedulerState , num_inference_steps : int , offset : int = 0
144- ) -> PNDMSchedulerState :
153+ def set_timesteps (self , state : PNDMSchedulerState , num_inference_steps : int ) -> PNDMSchedulerState :
145154 """
146155 Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
147156
@@ -150,16 +159,15 @@ def set_timesteps(
150159 the `FlaxPNDMScheduler` state data class instance.
151160 num_inference_steps (`int`):
152161 the number of diffusion steps used when generating samples with a pre-trained model.
153- offset (`int`):
154- optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
155162 """
163+ offset = self .config .steps_offset
164+
156165 step_ratio = self .config .num_train_timesteps // num_inference_steps
157166 # creates integer timesteps by multiplying by ratio
158167 # rounding to avoid issues when num_inference_step is power of 3
159- _timesteps = (jnp .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ]
160- _timesteps = _timesteps + offset
168+ _timesteps = (jnp .arange (0 , num_inference_steps ) * step_ratio ).round () + offset
161169
162- state = state .replace (num_inference_steps = num_inference_steps , _offset = offset , _timesteps = _timesteps )
170+ state = state .replace (num_inference_steps = num_inference_steps , _timesteps = _timesteps )
163171
164172 if self .config .skip_prk_steps :
165173 # for some models like stable diffusion the prk steps can/should be skipped to
@@ -254,7 +262,7 @@ def step_prk(
254262 )
255263
256264 diff_to_prev = 0 if state .counter % 2 else self .config .num_train_timesteps // state .num_inference_steps // 2
257- prev_timestep = max ( timestep - diff_to_prev , state . prk_timesteps [ - 1 ])
265+ prev_timestep = timestep - diff_to_prev
258266 timestep = state .prk_timesteps [state .counter // 4 * 4 ]
259267
260268 if state .counter % 4 == 0 :
@@ -274,7 +282,7 @@ def step_prk(
274282 # cur_sample should not be `None`
275283 cur_sample = state .cur_sample if state .cur_sample is not None else sample
276284
277- prev_sample = self ._get_prev_sample (cur_sample , timestep , prev_timestep , model_output , state = state )
285+ prev_sample = self ._get_prev_sample (cur_sample , timestep , prev_timestep , model_output )
278286 state = state .replace (counter = state .counter + 1 )
279287
280288 if not return_dict :
@@ -320,7 +328,7 @@ def step_plms(
320328 "for more information."
321329 )
322330
323- prev_timestep = max ( timestep - self .config .num_train_timesteps // state .num_inference_steps , 0 )
331+ prev_timestep = timestep - self .config .num_train_timesteps // state .num_inference_steps
324332
325333 if state .counter != 1 :
326334 state = state .replace (ets = state .ets .append (model_output ))
@@ -344,15 +352,15 @@ def step_plms(
344352 55 * state .ets [- 1 ] - 59 * state .ets [- 2 ] + 37 * state .ets [- 3 ] - 9 * state .ets [- 4 ]
345353 )
346354
347- prev_sample = self ._get_prev_sample (sample , timestep , prev_timestep , model_output , state = state )
355+ prev_sample = self ._get_prev_sample (sample , timestep , prev_timestep , model_output )
348356 state = state .replace (counter = state .counter + 1 )
349357
350358 if not return_dict :
351359 return (prev_sample , state )
352360
353361 return FlaxSchedulerOutput (prev_sample = prev_sample , state = state )
354362
355- def _get_prev_sample (self , sample , timestep , timestep_prev , model_output , state ):
363+ def _get_prev_sample (self , sample , timestep , prev_timestep , model_output ):
356364 # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
357365 # this function computes x_(t−δ) using the formula of (9)
358366 # Note that x_t needs to be added to both sides of the equation
@@ -365,8 +373,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state)
365373 # sample -> x_t
366374 # model_output -> e_θ(x_t, t)
367375 # prev_sample -> x_(t−δ)
368- alpha_prod_t = self .alphas_cumprod [timestep + 1 - state . _offset ]
369- alpha_prod_t_prev = self .alphas_cumprod [timestep_prev + 1 - state . _offset ]
376+ alpha_prod_t = self .alphas_cumprod [timestep ]
377+ alpha_prod_t_prev = self .alphas_cumprod [prev_timestep ] if prev_timestep >= 0 else self . final_alpha_cumprod
370378 beta_prod_t = 1 - alpha_prod_t
371379 beta_prod_t_prev = 1 - alpha_prod_t_prev
372380
@@ -395,9 +403,14 @@ def add_noise(
395403 timesteps : jnp .ndarray ,
396404 ) -> jnp .ndarray :
397405 sqrt_alpha_prod = self .alphas_cumprod [timesteps ] ** 0.5
398- sqrt_alpha_prod = self .match_shape (sqrt_alpha_prod , original_samples )
406+ sqrt_alpha_prod = sqrt_alpha_prod .flatten ()
407+ while len (sqrt_alpha_prod .shape ) < len (original_samples .shape ):
408+ sqrt_alpha_prod = sqrt_alpha_prod [..., None ]
409+
399410 sqrt_one_minus_alpha_prod = (1 - self .alphas_cumprod [timesteps ]) ** 0.5
400- sqrt_one_minus_alpha_prod = self .match_shape (sqrt_one_minus_alpha_prod , original_samples )
411+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod .flatten ()
412+ while len (sqrt_one_minus_alpha_prod .shape ) < len (original_samples .shape ):
413+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod [..., None ]
401414
402415 noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
403416 return noisy_samples
0 commit comments