1515# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
1616# and https://github.com/hojonathanho/diffusion
1717
18- import math
1918from dataclasses import dataclass
2019from typing import Optional , Tuple , Union
2120
2625from ..utils import deprecate
2726from .scheduling_utils_flax import (
2827 _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS ,
28+ CommonSchedulerState ,
2929 FlaxSchedulerMixin ,
3030 FlaxSchedulerOutput ,
31- broadcast_to_shape_from_left ,
31+ add_noise_common ,
3232)
3333
3434
35- def betas_for_alpha_bar (num_diffusion_timesteps , max_beta = 0.999 ) -> jnp .ndarray :
36- """
37- Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
38- (1-beta) over time from t = [0,1].
39-
40- Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
41- to that part of the diffusion process.
42-
43-
44- Args:
45- num_diffusion_timesteps (`int`): the number of betas to produce.
46- max_beta (`float`): the maximum beta to use; use values lower than 1 to
47- prevent singularities.
48-
49- Returns:
50- betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
51- """
52-
53- def alpha_bar (time_step ):
54- return math .cos ((time_step + 0.008 ) / 1.008 * math .pi / 2 ) ** 2
55-
56- betas = []
57- for i in range (num_diffusion_timesteps ):
58- t1 = i / num_diffusion_timesteps
59- t2 = (i + 1 ) / num_diffusion_timesteps
60- betas .append (min (1 - alpha_bar (t2 ) / alpha_bar (t1 ), max_beta ))
61- return jnp .array (betas , dtype = jnp .float32 )
62-
63-
6435@flax .struct .dataclass
6536class DDIMSchedulerState :
37+ common : CommonSchedulerState
38+ final_alpha_cumprod : jnp .ndarray
39+
6640 # setable values
41+ init_noise_sigma : jnp .ndarray
6742 timesteps : jnp .ndarray
68- alphas_cumprod : jnp .ndarray
6943 num_inference_steps : Optional [int ] = None
7044
7145 @classmethod
72- def create (cls , num_train_timesteps : int , alphas_cumprod : jnp .ndarray ):
73- return cls (timesteps = jnp .arange (0 , num_train_timesteps )[::- 1 ], alphas_cumprod = alphas_cumprod )
46+ def create (
47+ cls ,
48+ common : CommonSchedulerState ,
49+ final_alpha_cumprod : jnp .ndarray ,
50+ init_noise_sigma : jnp .ndarray ,
51+ timesteps : jnp .ndarray ,
52+ ):
53+ return cls (
54+ common = common ,
55+ final_alpha_cumprod = final_alpha_cumprod ,
56+ init_noise_sigma = init_noise_sigma ,
57+ timesteps = timesteps ,
58+ )
7459
7560
7661@dataclass
@@ -112,12 +97,15 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
11297 prediction_type (`str`, default `epsilon`):
11398 indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
11499 `v-prediction` is not supported for this scheduler.
115-
100+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
101+ the `dtype` used for params and computation.
116102 """
117103
118104 _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS .copy ()
119105 _deprecated_kwargs = ["predict_epsilon" ]
120106
107+ dtype : jnp .dtype
108+
121109 @property
122110 def has_state (self ):
123111 return True
@@ -129,43 +117,46 @@ def __init__(
129117 beta_start : float = 0.0001 ,
130118 beta_end : float = 0.02 ,
131119 beta_schedule : str = "linear" ,
120+ trained_betas : Optional [jnp .ndarray ] = None ,
132121 set_alpha_to_one : bool = True ,
133122 steps_offset : int = 0 ,
134123 prediction_type : str = "epsilon" ,
124+ dtype : jnp .dtype = jnp .float32 ,
135125 ** kwargs ,
136126 ):
137127 message = (
138128 "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
139- " FlaxDDIMScheduler .from_pretrained(<model_id>, prediction_type='epsilon')`."
129+ f" { self . __class__ . __name__ } .from_pretrained(<model_id>, prediction_type='epsilon')`."
140130 )
141131 predict_epsilon = deprecate ("predict_epsilon" , "0.13.0" , message , take_from = kwargs )
142132 if predict_epsilon is not None :
143133 self .register_to_config (prediction_type = "epsilon" if predict_epsilon else "sample" )
144134
145- if beta_schedule == "linear" :
146- self .betas = jnp .linspace (beta_start , beta_end , num_train_timesteps , dtype = jnp .float32 )
147- elif beta_schedule == "scaled_linear" :
148- # this schedule is very specific to the latent diffusion model.
149- self .betas = jnp .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = jnp .float32 ) ** 2
150- elif beta_schedule == "squaredcos_cap_v2" :
151- # Glide cosine schedule
152- self .betas = betas_for_alpha_bar (num_train_timesteps )
153- else :
154- raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
155-
156- self .alphas = 1.0 - self .betas
135+ self .dtype = dtype
157136
158- # HACK for now - clean up later (PVP)
159- self ._alphas_cumprod = jnp .cumprod (self .alphas , axis = 0 )
137+ def create_state (self , common : Optional [CommonSchedulerState ] = None ) -> DDIMSchedulerState :
138+ if common is None :
139+ common = CommonSchedulerState .create (self )
160140
161141 # At every step in ddim, we are looking into the previous alphas_cumprod
162142 # For the final step, there is no previous alphas_cumprod because we are already at 0
163143 # `set_alpha_to_one` decides whether we set this parameter simply to one or
164144 # whether we use the final alpha of the "non-previous" one.
165- self .final_alpha_cumprod = jnp .array (1.0 ) if set_alpha_to_one else float (self ._alphas_cumprod [0 ])
145+ final_alpha_cumprod = (
146+ jnp .array (1.0 , dtype = self .dtype ) if self .config .set_alpha_to_one else common .alphas_cumprod [0 ]
147+ )
166148
167149 # standard deviation of the initial noise distribution
168- self .init_noise_sigma = 1.0
150+ init_noise_sigma = jnp .array (1.0 , dtype = self .dtype )
151+
152+ timesteps = jnp .arange (0 , self .config .num_train_timesteps ).round ()[::- 1 ]
153+
154+ return DDIMSchedulerState .create (
155+ common = common ,
156+ final_alpha_cumprod = final_alpha_cumprod ,
157+ init_noise_sigma = init_noise_sigma ,
158+ timesteps = timesteps ,
159+ )
169160
170161 def scale_model_input (
171162 self , state : DDIMSchedulerState , sample : jnp .ndarray , timestep : Optional [int ] = None
@@ -181,21 +172,6 @@ def scale_model_input(
181172 """
182173 return sample
183174
184- def create_state (self ):
185- return DDIMSchedulerState .create (
186- num_train_timesteps = self .config .num_train_timesteps , alphas_cumprod = self ._alphas_cumprod
187- )
188-
189- def _get_variance (self , timestep , prev_timestep , alphas_cumprod ):
190- alpha_prod_t = alphas_cumprod [timestep ]
191- alpha_prod_t_prev = jnp .where (prev_timestep >= 0 , alphas_cumprod [prev_timestep ], self .final_alpha_cumprod )
192- beta_prod_t = 1 - alpha_prod_t
193- beta_prod_t_prev = 1 - alpha_prod_t_prev
194-
195- variance = (beta_prod_t_prev / beta_prod_t ) * (1 - alpha_prod_t / alpha_prod_t_prev )
196-
197- return variance
198-
199175 def set_timesteps (
200176 self , state : DDIMSchedulerState , num_inference_steps : int , shape : Tuple = ()
201177 ) -> DDIMSchedulerState :
@@ -208,22 +184,35 @@ def set_timesteps(
208184 num_inference_steps (`int`):
209185 the number of diffusion steps used when generating samples with a pre-trained model.
210186 """
211- offset = self .config .steps_offset
212-
213187 step_ratio = self .config .num_train_timesteps // num_inference_steps
214188 # creates integer timesteps by multiplying by ratio
215- # casting to int to avoid issues when num_inference_step is power of 3
216- timesteps = (jnp .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ]
217- timesteps = timesteps + offset
189+ # rounding to avoid issues when num_inference_step is power of 3
190+ timesteps = (jnp .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ] + self .config .steps_offset
191+
192+ return state .replace (
193+ num_inference_steps = num_inference_steps ,
194+ timesteps = timesteps ,
195+ )
196+
197+ def _get_variance (self , state : DDIMSchedulerState , timestep , prev_timestep ):
198+ alpha_prod_t = state .common .alphas_cumprod [timestep ]
199+ alpha_prod_t_prev = jnp .where (
200+ prev_timestep >= 0 , state .common .alphas_cumprod [prev_timestep ], state .final_alpha_cumprod
201+ )
202+ beta_prod_t = 1 - alpha_prod_t
203+ beta_prod_t_prev = 1 - alpha_prod_t_prev
204+
205+ variance = (beta_prod_t_prev / beta_prod_t ) * (1 - alpha_prod_t / alpha_prod_t_prev )
218206
219- return state . replace ( num_inference_steps = num_inference_steps , timesteps = timesteps )
207+ return variance
220208
221209 def step (
222210 self ,
223211 state : DDIMSchedulerState ,
224212 model_output : jnp .ndarray ,
225213 timestep : int ,
226214 sample : jnp .ndarray ,
215+ eta : float = 0.0 ,
227216 return_dict : bool = True ,
228217 ) -> Union [FlaxDDIMSchedulerOutput , Tuple ]:
229218 """
@@ -259,17 +248,15 @@ def step(
259248 # - pred_sample_direction -> "direction pointing to x_t"
260249 # - pred_prev_sample -> "x_t-1"
261250
262- # TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
263- eta = 0.0
264-
265251 # 1. get previous step value (=t-1)
266252 prev_timestep = timestep - self .config .num_train_timesteps // state .num_inference_steps
267253
268- alphas_cumprod = state .alphas_cumprod
254+ alphas_cumprod = state .common .alphas_cumprod
255+ final_alpha_cumprod = state .final_alpha_cumprod
269256
270257 # 2. compute alphas, betas
271258 alpha_prod_t = alphas_cumprod [timestep ]
272- alpha_prod_t_prev = jnp .where (prev_timestep >= 0 , alphas_cumprod [prev_timestep ], self . final_alpha_cumprod )
259+ alpha_prod_t_prev = jnp .where (prev_timestep >= 0 , alphas_cumprod [prev_timestep ], final_alpha_cumprod )
273260
274261 beta_prod_t = 1 - alpha_prod_t
275262
@@ -291,7 +278,7 @@ def step(
291278
292279 # 4. compute variance: "sigma_t(η)" -> see formula (16)
293280 # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
294- variance = self ._get_variance (timestep , prev_timestep , alphas_cumprod )
281+ variance = self ._get_variance (state , timestep , prev_timestep )
295282 std_dev_t = eta * variance ** (0.5 )
296283
297284 # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
@@ -307,20 +294,12 @@ def step(
307294
308295 def add_noise (
309296 self ,
297+ state : DDIMSchedulerState ,
310298 original_samples : jnp .ndarray ,
311299 noise : jnp .ndarray ,
312300 timesteps : jnp .ndarray ,
313301 ) -> jnp .ndarray :
314- sqrt_alpha_prod = self .alphas_cumprod [timesteps ] ** 0.5
315- sqrt_alpha_prod = sqrt_alpha_prod .flatten ()
316- sqrt_alpha_prod = broadcast_to_shape_from_left (sqrt_alpha_prod , original_samples .shape )
317-
318- sqrt_one_minus_alpha_prod = (1 - self .alphas_cumprod [timesteps ]) ** 0.0
319- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod .flatten ()
320- sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left (sqrt_one_minus_alpha_prod , original_samples .shape )
321-
322- noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
323- return noisy_samples
302+ return add_noise_common (state .common , original_samples , noise , timesteps )
324303
325304 def __len__ (self ):
326305 return self .config .num_train_timesteps
0 commit comments