2323
2424from diffusers .configuration_utils import ConfigMixin , register_to_config
2525from diffusers .schedulers .scheduling_utils import SchedulerMixin
26- from diffusers .utils import BaseOutput
26+ from diffusers .utils import BaseOutput , deprecate
2727
2828
2929@dataclass
@@ -96,15 +96,17 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
9696 trained_betas (`np.ndarray`, optional):
9797 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
9898 clip_sample (`bool`, default `True`):
99- option to clip predicted sample between -1 and 1 for numerical stability.
100- set_alpha_to_one (`bool`, default `True`):
99+ option to clip predicted sample for numerical stability.
100+ clip_sample_range (`float`, default `1.0`):
101+ the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
102+ set_alpha_to_zero (`bool`, default `True`):
101103 each diffusion step uses the value of alphas product at that step and at the previous one. For the final
102- step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1 `,
103- otherwise it uses the value of alpha at step 0 .
104+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `0 `,
105+ otherwise it uses the value of alpha at step `num_train_timesteps - 1` .
104106 steps_offset (`int`, default `0`):
105107 an offset added to the inference steps. You can use a combination of `offset=1` and
106- `set_alpha_to_one =False`, to make the last step use step 0 for the previous alpha product, as done in
107- stable diffusion .
108+ `set_alpha_to_zero =False`, to make the last step use step `num_train_timesteps - 1` for the previous alpha
109+ product .
108110 prediction_type (`str`, default `epsilon`, optional):
109111 prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
110112 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
@@ -122,10 +124,18 @@ def __init__(
122124 beta_schedule : str = "linear" ,
123125 trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
124126 clip_sample : bool = True ,
125- set_alpha_to_one : bool = True ,
127+ set_alpha_to_zero : bool = True ,
126128 steps_offset : int = 0 ,
127129 prediction_type : str = "epsilon" ,
130+ clip_sample_range : float = 1.0 ,
131+ ** kwargs ,
128132 ):
133+ if kwargs .get ("set_alpha_to_one" , None ) is not None :
134+ deprecation_message = (
135+ "The `set_alpha_to_one` argument is deprecated. Please use `set_alpha_to_zero` instead."
136+ )
137+ deprecate ("set_alpha_to_one" , "1.0.0" , deprecation_message , standard_warn = False )
138+ set_alpha_to_zero = kwargs ["set_alpha_to_one" ]
129139 if trained_betas is not None :
130140 self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
131141 elif beta_schedule == "linear" :
@@ -144,11 +154,12 @@ def __init__(
144154 self .alphas = 1.0 - self .betas
145155 self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
146156
147- # At every step in ddim, we are looking into the previous alphas_cumprod
148- # For the final step, there is no previous alphas_cumprod because we are already at 0
149- # `set_alpha_to_one` decides whether we set this parameter simply to one or
150- # whether we use the final alpha of the "non-previous" one.
151- self .final_alpha_cumprod = torch .tensor (1.0 ) if set_alpha_to_one else self .alphas_cumprod [0 ]
157+ # At every step in inverted ddim, we are looking into the next alphas_cumprod
158+ # For the final step, there is no next alphas_cumprod, and the index is out of bounds
159+ # `set_alpha_to_zero` decides whether we set this parameter simply to zero
160+ # in this case, self.step() just output the predicted noise
161+ # or whether we use the final alpha of the "non-previous" one.
162+ self .final_alpha_cumprod = torch .tensor (0.0 ) if set_alpha_to_zero else self .alphas_cumprod [- 1 ]
152163
153164 # standard deviation of the initial noise distribution
154165 self .init_noise_sigma = 1.0
@@ -157,6 +168,7 @@ def __init__(
157168 self .num_inference_steps = None
158169 self .timesteps = torch .from_numpy (np .arange (0 , num_train_timesteps ).copy ().astype (np .int64 ))
159170
171+ # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
160172 def scale_model_input (self , sample : torch .FloatTensor , timestep : Optional [int ] = None ) -> torch .FloatTensor :
161173 """
162174 Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
@@ -205,23 +217,52 @@ def step(
205217 variance_noise : Optional [torch .FloatTensor ] = None ,
206218 return_dict : bool = True ,
207219 ) -> Union [DDIMSchedulerOutput , Tuple ]:
208- e_t = model_output
209-
210- x = sample
220+ # 1. get previous step value (=t+1)
211221 prev_timestep = timestep + self .config .num_train_timesteps // self .num_inference_steps
212222
213- a_t = self .alphas_cumprod [timestep - 1 ]
214- a_prev = self .alphas_cumprod [prev_timestep - 1 ] if prev_timestep >= 0 else self .final_alpha_cumprod
223+ # 2. compute alphas, betas
224+ # change original implementation to exactly match noise levels for analogous forward process
225+ alpha_prod_t = self .alphas_cumprod [timestep ]
226+ alpha_prod_t_prev = (
227+ self .alphas_cumprod [prev_timestep ]
228+ if prev_timestep < self .config .num_train_timesteps
229+ else self .final_alpha_cumprod
230+ )
231+
232+ beta_prod_t = 1 - alpha_prod_t
233+
234+ # 3. compute predicted original sample from predicted noise also called
235+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
236+ if self .config .prediction_type == "epsilon" :
237+ pred_original_sample = (sample - beta_prod_t ** (0.5 ) * model_output ) / alpha_prod_t ** (0.5 )
238+ pred_epsilon = model_output
239+ elif self .config .prediction_type == "sample" :
240+ pred_original_sample = model_output
241+ pred_epsilon = (sample - alpha_prod_t ** (0.5 ) * pred_original_sample ) / beta_prod_t ** (0.5 )
242+ elif self .config .prediction_type == "v_prediction" :
243+ pred_original_sample = (alpha_prod_t ** 0.5 ) * sample - (beta_prod_t ** 0.5 ) * model_output
244+ pred_epsilon = (alpha_prod_t ** 0.5 ) * model_output + (beta_prod_t ** 0.5 ) * sample
245+ else :
246+ raise ValueError (
247+ f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample`, or"
248+ " `v_prediction`"
249+ )
215250
216- pred_x0 = (x - (1 - a_t ) ** 0.5 * e_t ) / a_t .sqrt ()
251+ # 4. Clip or threshold "predicted x_0"
252+ if self .config .clip_sample :
253+ pred_original_sample = pred_original_sample .clamp (
254+ - self .config .clip_sample_range , self .config .clip_sample_range
255+ )
217256
218- dir_xt = (1.0 - a_prev ).sqrt () * e_t
257+ # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
258+ pred_sample_direction = (1 - alpha_prod_t_prev ) ** (0.5 ) * pred_epsilon
219259
220- prev_sample = a_prev .sqrt () * pred_x0 + dir_xt
260+ # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
261+ prev_sample = alpha_prod_t_prev ** (0.5 ) * pred_original_sample + pred_sample_direction
221262
222263 if not return_dict :
223- return (prev_sample , pred_x0 )
224- return DDIMSchedulerOutput (prev_sample = prev_sample , pred_original_sample = pred_x0 )
264+ return (prev_sample , pred_original_sample )
265+ return DDIMSchedulerOutput (prev_sample = prev_sample , pred_original_sample = pred_original_sample )
225266
226267 def __len__ (self ):
227268 return self .config .num_train_timesteps
0 commit comments