1616
1717import math
1818from dataclasses import dataclass
19- from typing import Optional , Tuple , Union
19+ from typing import Literal , Optional , Tuple , Union
2020
2121import numpy as np
2222import torch
2323
2424from ..configuration_utils import ConfigMixin , register_to_config
25- from ..utils import BaseOutput , deprecate
25+ from ..utils import BaseOutput
2626from .scheduling_utils import SchedulerMixin
2727
2828
29+ def expand_to_shape (input , timesteps , shape , device ):
30+ """
31+ Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
32+ nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once.
33+ """
34+ out = torch .gather (input .to (device ), 0 , timesteps .to (device ))
35+ reshape = [shape [0 ]] + [1 ] * (len (shape ) - 1 )
36+ out = out .reshape (* reshape )
37+ return out
38+
39+
2940@dataclass
3041class DDPMSchedulerOutput (BaseOutput ):
3142 """
@@ -102,6 +113,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
102113
103114 """
104115
116+ _compatible_classes = [
117+ "DDIMScheduler" ,
118+ "PNDMScheduler" ,
119+ "LMSDiscreteScheduler" ,
120+ "EulerDiscreteScheduler" ,
121+ "EulerAncestralDiscreteScheduler" ,
122+ ]
123+
105124 @register_to_config
106125 def __init__ (
107126 self ,
@@ -112,15 +131,8 @@ def __init__(
112131 trained_betas : Optional [np .ndarray ] = None ,
113132 variance_type : str = "fixed_small" ,
114133 clip_sample : bool = True ,
115- ** kwargs ,
134+ prediction_type : Literal [ "epsilon" , "sample" , "v" ] = "epsilon" ,
116135 ):
117- deprecate (
118- "tensor_format" ,
119- "0.6.0" ,
120- "If you're running your code in PyTorch, you can safely remove this argument." ,
121- take_from = kwargs ,
122- )
123-
124136 if trained_betas is not None :
125137 self .betas = torch .from_numpy (trained_betas )
126138 elif beta_schedule == "linear" :
@@ -142,8 +154,8 @@ def __init__(
142154
143155 self .alphas = 1.0 - self .betas
144156 self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
145- self .sigmas = 1 - self .alphas ** 2
146- self .one = torch .tensor ( 1.0 )
157+ self .sqrt_alphas_cumprod = torch . sqrt ( self .alphas_cumprod )
158+ self .sqrt_one_minus_alphas_cumprod = torch .sqrt ( 1 - self . alphas_cumprod )
147159
148160 # standard deviation of the initial noise distribution
149161 self .init_noise_sigma = 1.0
@@ -153,6 +165,7 @@ def __init__(
153165 self .timesteps = torch .from_numpy (np .arange (0 , num_train_timesteps )[::- 1 ].copy ())
154166
155167 self .variance_type = variance_type
168+ self .prediction_type = prediction_type
156169
157170 def scale_model_input (self , sample : torch .FloatTensor , timestep : Optional [int ] = None ) -> torch .FloatTensor :
158171 """
@@ -185,7 +198,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
185198
186199 def _get_variance (self , timestep , predicted_variance = None , variance_type = None ):
187200 alpha_prod_t = self .alphas_cumprod [timestep ]
188- alpha_prod_t_prev = self .alphas_cumprod [timestep - 1 ] if timestep > 0 else self . one
201+ alpha_prod_t_prev = self .alphas_cumprod [timestep - 1 ] if timestep > 0 else torch . tensor ( 1.0 )
189202
190203 # For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
191204 # and sample from it to get previous sample
@@ -213,6 +226,8 @@ def _get_variance(self, timestep, predicted_variance=None, variance_type=None):
213226 max_log = self .betas [timestep ]
214227 frac = (predicted_variance + 1 ) / 2
215228 variance = frac * max_log + (1 - frac ) * min_log
229+ elif variance_type == "v_diffusion" :
230+ variance = torch .log (self .betas [timestep ] * (1 - alpha_prod_t_prev ) / (1 - alpha_prod_t ))
216231
217232 return variance
218233
@@ -221,7 +236,7 @@ def step(
221236 model_output : torch .FloatTensor ,
222237 timestep : int ,
223238 sample : torch .FloatTensor ,
224- prediction_type : str = "epsilon" ,
239+ # prediction_type: Literal["epsilon", "sample", "v"] = "epsilon",
225240 generator = None ,
226241 return_dict : bool = True ,
227242 ) -> Union [DDPMSchedulerOutput , Tuple ]:
@@ -234,9 +249,9 @@ def step(
234249 timestep (`int`): current discrete timestep in the diffusion chain.
235250 sample (`torch.FloatTensor`):
236251 current instance of sample being created by diffusion process.
237- prediction_type (`str` ):
252+ prediction_type (`Literal["epsilon", "sample", "v"]`, optional ):
238253 prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
239- process), `sample` (directly predicting the noisy sample), or `v` (see section 2.4
254+ process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4
240255 https://imagen.research.google/video/paper.pdf)
241256 generator: random number generator.
242257 return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
@@ -247,30 +262,36 @@ def step(
247262 returning a tuple, the first element is the sample tensor.
248263
249264 """
265+ if self .variance_type == "v_diffusion" :
266+ assert self .prediction_type == "v" , "Need to use v prediction with v_diffusion"
250267 if model_output .shape [1 ] == sample .shape [1 ] * 2 and self .variance_type in ["learned" , "learned_range" ]:
251268 model_output , predicted_variance = torch .split (model_output , sample .shape [1 ], dim = 1 )
252269 else :
253270 predicted_variance = None
254271
255272 # 1. compute alphas, betas
256273 alpha_prod_t = self .alphas_cumprod [timestep ]
257- alpha_prod_t_prev = self .alphas_cumprod [timestep - 1 ] if timestep > 0 else self . one
274+ alpha_prod_t_prev = self .alphas_cumprod [timestep - 1 ] if timestep > 0 else torch . tensor ( 1.0 )
258275 beta_prod_t = 1 - alpha_prod_t
259276 beta_prod_t_prev = 1 - alpha_prod_t_prev
260277
261278 # 2. compute predicted original sample from predicted noise also called
262279 # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
263- if prediction_type == "epsilon" :
280+ if self .prediction_type == "v" :
281+ # x_recon in p_mean_variance
282+ pred_original_sample = (
283+ sample * self .sqrt_alphas_cumprod [timestep ]
284+ - model_output * self .sqrt_one_minus_alphas_cumprod [timestep ]
285+ )
286+ elif self .prediction_type == "epsilon" :
264287 pred_original_sample = (sample - beta_prod_t ** (0.5 ) * model_output ) / alpha_prod_t ** (0.5 )
265- elif prediction_type == "sample" :
288+
289+ elif self .prediction_type == "sample" :
266290 pred_original_sample = model_output
267- elif prediction_type == "v" :
268- # v_t = alpha_t * epsilon - sigma_t * x
269- # need to merge the PRs for sigma to be available in DDPM
270- pred = sample * self .alphas [timestep ] - model_output * self .sigmas [timestep ]
271- eps = model_output * self .alphas [timestep ] - sample * self .sigmas [timestep ]
272291 else :
273- raise ValueError (f"prediction_type given as { prediction_type } must be one of `epsilon`, `sample`, or `v`" )
292+ raise ValueError (
293+ f"prediction_type given as { self .prediction_type } must be one of `epsilon`, `sample`, or `v`"
294+ )
274295
275296 # 3. Clip "predicted x_0"
276297 if self .config .clip_sample :
@@ -291,7 +312,12 @@ def step(
291312 noise = torch .randn (
292313 model_output .size (), dtype = model_output .dtype , layout = model_output .layout , generator = generator
293314 ).to (model_output .device )
294- variance = (self ._get_variance (timestep , predicted_variance = predicted_variance ) ** 0.5 ) * noise
315+ if self .variance_type == "fixed_small_log" :
316+ variance = self ._get_variance (timestep , predicted_variance = predicted_variance ) * noise
317+ elif self .variance_type == "v_diffusion" :
318+ variance = torch .exp (0.5 * self ._get_variance (timestep , predicted_variance )) * noise
319+ else :
320+ variance = (self ._get_variance (timestep , predicted_variance = predicted_variance ) ** 0.5 ) * noise
295321
296322 pred_prev_sample = pred_prev_sample + variance
297323
@@ -306,6 +332,11 @@ def add_noise(
306332 noise : torch .FloatTensor ,
307333 timesteps : torch .IntTensor ,
308334 ) -> torch .FloatTensor :
335+ if self .variance_type == "v_diffusion" :
336+ alpha , sigma = self .get_alpha_sigma (original_samples , timesteps , original_samples .device )
337+ z_t = alpha * original_samples + sigma * noise
338+ return z_t
339+
309340 # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
310341 self .alphas_cumprod = self .alphas_cumprod .to (device = original_samples .device , dtype = original_samples .dtype )
311342 timesteps = timesteps .to (original_samples .device )
@@ -325,3 +356,8 @@ def add_noise(
325356
326357 def __len__ (self ):
327358 return self .config .num_train_timesteps
359+
360+ def get_alpha_sigma (self , sample , timesteps , device ):
361+ alpha = expand_to_shape (self .sqrt_alphas_cumprod , timesteps , sample .shape , device )
362+ sigma = expand_to_shape (self .sqrt_one_minus_alphas_cumprod , timesteps , sample .shape , device )
363+ return alpha , sigma
0 commit comments