1919from typing import Optional , Tuple , Union
2020
2121import flax
22+ import jax
2223import jax .numpy as jnp
2324
2425from ..configuration_utils import ConfigMixin , register_to_config
@@ -155,7 +156,7 @@ def __init__(
155156 def create_state (self ):
156157 return PNDMSchedulerState .create (num_train_timesteps = self .config .num_train_timesteps )
157158
158- def set_timesteps (self , state : PNDMSchedulerState , num_inference_steps : int ) -> PNDMSchedulerState :
159+ def set_timesteps (self , state : PNDMSchedulerState , shape : Tuple , num_inference_steps : int ) -> PNDMSchedulerState :
159160 """
160161 Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
161162
@@ -196,8 +197,11 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) ->
196197
197198 return state .replace (
198199 timesteps = jnp .concatenate ([state .prk_timesteps , state .plms_timesteps ]).astype (jnp .int64 ),
199- ets = jnp .array ([]),
200200 counter = 0 ,
201+ # Reserve space for the state variables
202+ cur_model_output = jnp .zeros (shape ),
203+ cur_sample = jnp .zeros (shape ),
204+ ets = jnp .zeros ((4 ,) + shape ),
201205 )
202206
203207 def step (
@@ -227,22 +231,32 @@ def step(
227231 When returning a tuple, the first element is the sample tensor.
228232
229233 """
230- if state . counter < len ( state . prk_timesteps ) and not self .config .skip_prk_steps :
231- return self .step_prk (
232- state = state , model_output = model_output , timestep = timestep , sample = sample , return_dict = return_dict
234+ if self .config .skip_prk_steps :
235+ prev_sample , state = self .step_plms (
236+ state = state , model_output = model_output , timestep = timestep , sample = sample
233237 )
234238 else :
235- return self .step_plms (
236- state = state , model_output = model_output , timestep = timestep , sample = sample , return_dict = return_dict
239+ prev_sample , state = jax .lax .switch (
240+ jnp .where (state .counter < len (state .prk_timesteps ), 0 , 1 ),
241+ (self .step_prk , self .step_plms ),
242+ # Args to either branch
243+ state ,
244+ model_output ,
245+ timestep ,
246+ sample ,
237247 )
238248
249+ if not return_dict :
250+ return (prev_sample , state )
251+
252+ return FlaxSchedulerOutput (prev_sample = prev_sample , state = state )
253+
239254 def step_prk (
240255 self ,
241256 state : PNDMSchedulerState ,
242257 model_output : jnp .ndarray ,
243258 timestep : int ,
244259 sample : jnp .ndarray ,
245- return_dict : bool = True ,
246260 ) -> Union [FlaxSchedulerOutput , Tuple ]:
247261 """
248262 Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
@@ -266,42 +280,53 @@ def step_prk(
266280 "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
267281 )
268282
269- diff_to_prev = 0 if state .counter % 2 else self .config .num_train_timesteps // state .num_inference_steps // 2
283+ diff_to_prev = jnp .where (
284+ state .counter % 2 , 0 , self .config .num_train_timesteps // state .num_inference_steps // 2
285+ )
270286 prev_timestep = timestep - diff_to_prev
271287 timestep = state .prk_timesteps [state .counter // 4 * 4 ]
272288
273- if state .counter % 4 == 0 :
274- state = state .replace (
275- cur_model_output = state .cur_model_output + 1 / 6 * model_output ,
276- ets = state .ets .append (model_output ),
277- cur_sample = sample ,
289+ def remainder_0 (state : PNDMSchedulerState , model_output : jnp .ndarray , ets_at : int ):
290+ return (
291+ state .replace (
292+ cur_model_output = state .cur_model_output + 1 / 6 * model_output ,
293+ ets = state .ets .at [ets_at ].set (model_output ),
294+ cur_sample = sample ,
295+ ),
296+ model_output ,
278297 )
279- elif (self .counter - 1 ) % 4 == 0 :
280- state = state .replace (cur_model_output = state .cur_model_output + 1 / 3 * model_output )
281- elif (self .counter - 2 ) % 4 == 0 :
282- state = state .replace (cur_model_output = state .cur_model_output + 1 / 3 * model_output )
283- elif (self .counter - 3 ) % 4 == 0 :
284- model_output = state .cur_model_output + 1 / 6 * model_output
285- state = state .replace (cur_model_output = 0 )
286298
287- # cur_sample should not be `None`
288- cur_sample = state . cur_sample if state .cur_sample is not None else sample
299+ def remainder_1 ( state : PNDMSchedulerState , model_output : jnp . ndarray , ets_at : int ):
300+ return state .replace ( cur_model_output = state . cur_model_output + 1 / 3 * model_output ), model_output
289301
302+ def remainder_2 (state : PNDMSchedulerState , model_output : jnp .ndarray , ets_at : int ):
303+ return state .replace (cur_model_output = state .cur_model_output + 1 / 3 * model_output ), model_output
304+
305+ def remainder_3 (state : PNDMSchedulerState , model_output : jnp .ndarray , ets_at : int ):
306+ model_output = state .cur_model_output + 1 / 6 * model_output
307+ return state .replace (cur_model_output = jnp .zeros_like (state .cur_model_output )), model_output
308+
309+ state , model_output = jax .lax .switch (
310+ state .counter % 4 ,
311+ (remainder_0 , remainder_1 , remainder_2 , remainder_3 ),
312+ # Args to either branch
313+ state ,
314+ model_output ,
315+ state .counter // 4 ,
316+ )
317+
318+ cur_sample = state .cur_sample
290319 prev_sample = self ._get_prev_sample (cur_sample , timestep , prev_timestep , model_output )
291320 state = state .replace (counter = state .counter + 1 )
292321
293- if not return_dict :
294- return (prev_sample , state )
295-
296- return FlaxSchedulerOutput (prev_sample = prev_sample , state = state )
322+ return (prev_sample , state )
297323
298324 def step_plms (
299325 self ,
300326 state : PNDMSchedulerState ,
301327 model_output : jnp .ndarray ,
302328 timestep : int ,
303329 sample : jnp .ndarray ,
304- return_dict : bool = True ,
305330 ) -> Union [FlaxSchedulerOutput , Tuple ]:
306331 """
307332 Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
@@ -334,36 +359,91 @@ def step_plms(
334359 )
335360
336361 prev_timestep = timestep - self .config .num_train_timesteps // state .num_inference_steps
362+ prev_timestep = jnp .where (prev_timestep > 0 , prev_timestep , 0 )
363+
364+ # Reference:
365+ # if state.counter != 1:
366+ # state.ets.append(model_output)
367+ # else:
368+ # prev_timestep = timestep
369+ # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
370+
371+ prev_timestep = jnp .where (state .counter == 1 , timestep , prev_timestep )
372+ timestep = jnp .where (
373+ state .counter == 1 , timestep + self .config .num_train_timesteps // state .num_inference_steps , timestep
374+ )
337375
338- if state .counter != 1 :
339- state = state .replace (ets = state .ets .append (model_output ))
340- else :
341- prev_timestep = timestep
342- timestep = timestep + self .config .num_train_timesteps // state .num_inference_steps
343-
344- if len (state .ets ) == 1 and state .counter == 0 :
345- model_output = model_output
346- state = state .replace (cur_sample = sample )
347- elif len (state .ets ) == 1 and state .counter == 1 :
348- model_output = (model_output + state .ets [- 1 ]) / 2
349- sample = state .cur_sample
350- state = state .replace (cur_sample = None )
351- elif len (state .ets ) == 2 :
352- model_output = (3 * state .ets [- 1 ] - state .ets [- 2 ]) / 2
353- elif len (state .ets ) == 3 :
354- model_output = (23 * state .ets [- 1 ] - 16 * state .ets [- 2 ] + 5 * state .ets [- 3 ]) / 12
355- else :
356- model_output = (1 / 24 ) * (
357- 55 * state .ets [- 1 ] - 59 * state .ets [- 2 ] + 37 * state .ets [- 3 ] - 9 * state .ets [- 4 ]
376+ # Reference:
377+ # if len(state.ets) == 1 and state.counter == 0:
378+ # model_output = model_output
379+ # state.cur_sample = sample
380+ # elif len(state.ets) == 1 and state.counter == 1:
381+ # model_output = (model_output + state.ets[-1]) / 2
382+ # sample = state.cur_sample
383+ # state.cur_sample = None
384+ # elif len(state.ets) == 2:
385+ # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
386+ # elif len(state.ets) == 3:
387+ # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
388+ # else:
389+ # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])
390+
391+ def counter_0 (state : PNDMSchedulerState ):
392+ ets = state .ets .at [0 ].set (model_output )
393+ return state .replace (
394+ ets = ets ,
395+ cur_sample = sample ,
396+ cur_model_output = jnp .array (model_output , dtype = jnp .float32 ),
397+ )
398+
399+ def counter_1 (state : PNDMSchedulerState ):
400+ return state .replace (
401+ cur_model_output = (model_output + state .ets [0 ]) / 2 ,
358402 )
359403
404+ def counter_2 (state : PNDMSchedulerState ):
405+ ets = state .ets .at [1 ].set (model_output )
406+ return state .replace (
407+ ets = ets ,
408+ cur_model_output = (3 * ets [1 ] - ets [0 ]) / 2 ,
409+ cur_sample = sample ,
410+ )
411+
412+ def counter_3 (state : PNDMSchedulerState ):
413+ ets = state .ets .at [2 ].set (model_output )
414+ return state .replace (
415+ ets = ets ,
416+ cur_model_output = (23 * ets [2 ] - 16 * ets [1 ] + 5 * ets [0 ]) / 12 ,
417+ cur_sample = sample ,
418+ )
419+
420+ def counter_other (state : PNDMSchedulerState ):
421+ ets = state .ets .at [3 ].set (model_output )
422+ next_model_output = (1 / 24 ) * (55 * ets [3 ] - 59 * ets [2 ] + 37 * ets [1 ] - 9 * ets [0 ])
423+
424+ ets = ets .at [0 ].set (ets [1 ])
425+ ets = ets .at [1 ].set (ets [2 ])
426+ ets = ets .at [2 ].set (ets [3 ])
427+
428+ return state .replace (
429+ ets = ets ,
430+ cur_model_output = next_model_output ,
431+ cur_sample = sample ,
432+ )
433+
434+ counter = jnp .clip (state .counter , 0 , 4 )
435+ state = jax .lax .switch (
436+ counter ,
437+ [counter_0 , counter_1 , counter_2 , counter_3 , counter_other ],
438+ state ,
439+ )
440+
441+ sample = state .cur_sample
442+ model_output = state .cur_model_output
360443 prev_sample = self ._get_prev_sample (sample , timestep , prev_timestep , model_output )
361444 state = state .replace (counter = state .counter + 1 )
362445
363- if not return_dict :
364- return (prev_sample , state )
365-
366- return FlaxSchedulerOutput (prev_sample = prev_sample , state = state )
446+ return (prev_sample , state )
367447
368448 def _get_prev_sample (self , sample , timestep , prev_timestep , model_output ):
369449 # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
@@ -379,7 +459,7 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
379459 # model_output -> e_θ(x_t, t)
380460 # prev_sample -> x_(t−δ)
381461 alpha_prod_t = self .alphas_cumprod [timestep ]
382- alpha_prod_t_prev = self . alphas_cumprod [ prev_timestep ] if prev_timestep >= 0 else self .final_alpha_cumprod
462+ alpha_prod_t_prev = jnp . where ( prev_timestep >= 0 , self . alphas_cumprod [ prev_timestep ], self .final_alpha_cumprod )
383463 beta_prod_t = 1 - alpha_prod_t
384464 beta_prod_t_prev = 1 - alpha_prod_t_prev
385465
0 commit comments