@@ -183,14 +183,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
183183 )[::- 1 ].copy ()
184184 self .timesteps = torch .from_numpy (timesteps ).to (device )
185185
186- def _get_variance (self , t , predicted_variance = None , variance_type = None ):
187- alpha_prod_t = self .alphas_cumprod [t ]
188- alpha_prod_t_prev = self .alphas_cumprod [t - 1 ] if t > 0 else self .one
186+ def _get_variance (self , timestep , predicted_variance = None , variance_type = None ):
187+ alpha_prod_t = self .alphas_cumprod [timestep ]
188+ alpha_prod_t_prev = self .alphas_cumprod [timestep - 1 ] if timestep > 0 else self .one
189189
190- # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
190+ # For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
191191 # and sample from it to get previous sample
192- # x_{t -1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
193- variance = (1 - alpha_prod_t_prev ) / (1 - alpha_prod_t ) * self .betas [t ]
192+ # x_{timestep -1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
193+ variance = (1 - alpha_prod_t_prev ) / (1 - alpha_prod_t ) * self .betas [timestep ]
194194
195195 if variance_type is None :
196196 variance_type = self .config .variance_type
@@ -202,15 +202,15 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
202202 elif variance_type == "fixed_small_log" :
203203 variance = torch .log (torch .clamp (variance , min = 1e-20 ))
204204 elif variance_type == "fixed_large" :
205- variance = self .betas [t ]
205+ variance = self .betas [timestep ]
206206 elif variance_type == "fixed_large_log" :
207207 # Glide max_log
208- variance = torch .log (self .betas [t ])
208+ variance = torch .log (self .betas [timestep ])
209209 elif variance_type == "learned" :
210210 return predicted_variance
211211 elif variance_type == "learned_range" :
212212 min_log = variance
213- max_log = self .betas [t ]
213+ max_log = self .betas [timestep ]
214214 frac = (predicted_variance + 1 ) / 2
215215 variance = frac * max_log + (1 - frac ) * min_log
216216
@@ -247,16 +247,14 @@ def step(
247247 returning a tuple, the first element is the sample tensor.
248248
249249 """
250- t = timestep
251-
252250 if model_output .shape [1 ] == sample .shape [1 ] * 2 and self .variance_type in ["learned" , "learned_range" ]:
253251 model_output , predicted_variance = torch .split (model_output , sample .shape [1 ], dim = 1 )
254252 else :
255253 predicted_variance = None
256254
257255 # 1. compute alphas, betas
258- alpha_prod_t = self .alphas_cumprod [t ]
259- alpha_prod_t_prev = self .alphas_cumprod [t - 1 ] if t > 0 else self .one
256+ alpha_prod_t = self .alphas_cumprod [timestep ]
257+ alpha_prod_t_prev = self .alphas_cumprod [timestep - 1 ] if timestep > 0 else self .one
260258 beta_prod_t = 1 - alpha_prod_t
261259 beta_prod_t_prev = 1 - alpha_prod_t_prev
262260
@@ -269,8 +267,8 @@ def step(
269267 elif prediction_type == "v" :
270268 # v_t = alpha_t * epsilon - sigma_t * x
271269 # need to merge the PRs for sigma to be available in DDPM
272- pred_original_sample = sample * self .alphas [t ] - model_output * self .sigmas [t ]
273- eps = model_output * self .alphas [t ] - sample * self .sigmas [t ]
270+ pred = sample * self .alphas [timestep ] - model_output * self .sigmas [timestep ]
271+ eps = model_output * self .alphas [timestep ] - sample * self .sigmas [timestep ]
274272 raise NotImplementedError (f"v prediction not yet implemented for DDPM" )
275273 else :
276274 raise ValueError (f"prediction_type given as { prediction_type } must be one of `epsilon`, `sample`, or `v`" )
@@ -281,20 +279,20 @@ def step(
281279
282280 # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
283281 # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
284- pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5 ) * self .betas [t ]) / beta_prod_t
285- current_sample_coeff = self .alphas [t ] ** (0.5 ) * beta_prod_t_prev / beta_prod_t
282+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5 ) * self .betas [timestep ]) / beta_prod_t
283+ current_sample_coeff = self .alphas [timestep ] ** (0.5 ) * beta_prod_t_prev / beta_prod_t
286284
287285 # 5. Compute predicted previous sample µ_t
288286 # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
289287 pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
290288
291289 # 6. Add noise
292290 variance = 0
293- if t > 0 :
291+ if timestep > 0 :
294292 noise = torch .randn (
295293 model_output .size (), dtype = model_output .dtype , layout = model_output .layout , generator = generator
296294 ).to (model_output .device )
297- variance = (self ._get_variance (t , predicted_variance = predicted_variance ) ** 0.5 ) * noise
295+ variance = (self ._get_variance (timestep , predicted_variance = predicted_variance ) ** 0.5 ) * noise
298296
299297 pred_prev_sample = pred_prev_sample + variance
300298
0 commit comments