Skip to content

Commit 4c68504

Browse files
committed
add ddim
1 parent 3eb2593 commit 4c68504

File tree

2 files changed

+41
-24
lines changed

2 files changed

+41
-24
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def __init__(
145145

146146
self.alphas = 1.0 - self.betas
147147
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
148+
self.sigmas = 1 - self.alphas**2
148149

149150
# At every step in ddim, we are looking into the previous alphas_cumprod
150151
# For the final step, there is no previous alphas_cumprod because we are already at 0
@@ -209,6 +210,7 @@ def step(
209210
model_output: torch.FloatTensor,
210211
timestep: int,
211212
sample: torch.FloatTensor,
213+
prediction_type: str = "epsilon",
212214
eta: float = 0.0,
213215
use_clipped_model_output: bool = False,
214216
generator=None,
@@ -223,6 +225,10 @@ def step(
223225
timestep (`int`): current discrete timestep in the diffusion chain.
224226
sample (`torch.FloatTensor`):
225227
current instance of sample being created by diffusion process.
228+
prediction_type (`str`):
229+
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
230+
process), `sample` (directly predicting the noisy sample), or `v` (see section 2.4
231+
https://imagen.research.google/video/paper.pdf)
226232
eta (`float`): weight of noise for added noise in diffusion step.
227233
use_clipped_model_output (`bool`): TODO
228234
generator: random number generator.
@@ -243,14 +249,14 @@ def step(
243249
# Ideally, read DDIM paper in-detail understanding
244250

245251
# Notation (<variable name> -> <name in paper>
246-
# - pred_noise_t -> e_theta(x_t, t)
247-
# - pred_original_sample -> f_theta(x_t, t) or x_0
252+
# - pred_noise_t -> e_theta(x_t, timestep)
253+
# - pred_original_sample -> f_theta(x_t, timestep) or x_0
248254
# - std_dev_t -> sigma_t
249255
# - eta -> η
250256
# - pred_sample_direction -> "direction pointing to x_t"
251257
# - pred_prev_sample -> "x_t-1"
252258

253-
# 1. get previous step value (=t-1)
259+
# 1. get previous step value (=timestep-1)
254260
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
255261

256262
# 2. compute alphas, betas
@@ -261,7 +267,20 @@ def step(
261267

262268
# 3. compute predicted original sample from predicted noise also called
263269
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
264-
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
270+
if prediction_type == "epsilon":
271+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
272+
eps = torch.tensor(1)
273+
elif prediction_type == "sample":
274+
pred_original_sample = model_output
275+
eps = torch.tensor(1)
276+
elif prediction_type == "v":
277+
# v_t = alpha_t * epsilon - sigma_t * x
278+
# need to merge the PRs for sigma to be available in DDPM
279+
pred_original_sample = sample * self.alphas[timestep] - model_output * self.sigmas[timestep]
280+
eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep]
281+
raise NotImplementedError(f"v prediction not yet implemented for DDPM")
282+
else:
283+
raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`")
265284

266285
# 4. Clip "predicted x_0"
267286
if self.config.clip_sample:
@@ -280,7 +299,7 @@ def step(
280299
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
281300

282301
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
283-
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
302+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction
284303

285304
if eta > 0:
286305
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
183183
)[::-1].copy()
184184
self.timesteps = torch.from_numpy(timesteps).to(device)
185185

186-
def _get_variance(self, t, predicted_variance=None, variance_type=None):
187-
alpha_prod_t = self.alphas_cumprod[t]
188-
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
186+
def _get_variance(self, timestep, predicted_variance=None, variance_type=None):
187+
alpha_prod_t = self.alphas_cumprod[timestep]
188+
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
189189

190-
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
190+
# For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
191191
# and sample from it to get previous sample
192-
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
193-
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
192+
# x_{timestep-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
193+
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep]
194194

195195
if variance_type is None:
196196
variance_type = self.config.variance_type
@@ -202,15 +202,15 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
202202
elif variance_type == "fixed_small_log":
203203
variance = torch.log(torch.clamp(variance, min=1e-20))
204204
elif variance_type == "fixed_large":
205-
variance = self.betas[t]
205+
variance = self.betas[timestep]
206206
elif variance_type == "fixed_large_log":
207207
# Glide max_log
208-
variance = torch.log(self.betas[t])
208+
variance = torch.log(self.betas[timestep])
209209
elif variance_type == "learned":
210210
return predicted_variance
211211
elif variance_type == "learned_range":
212212
min_log = variance
213-
max_log = self.betas[t]
213+
max_log = self.betas[timestep]
214214
frac = (predicted_variance + 1) / 2
215215
variance = frac * max_log + (1 - frac) * min_log
216216

@@ -247,16 +247,14 @@ def step(
247247
returning a tuple, the first element is the sample tensor.
248248
249249
"""
250-
t = timestep
251-
252250
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
253251
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
254252
else:
255253
predicted_variance = None
256254

257255
# 1. compute alphas, betas
258-
alpha_prod_t = self.alphas_cumprod[t]
259-
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
256+
alpha_prod_t = self.alphas_cumprod[timestep]
257+
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
260258
beta_prod_t = 1 - alpha_prod_t
261259
beta_prod_t_prev = 1 - alpha_prod_t_prev
262260

@@ -269,8 +267,8 @@ def step(
269267
elif prediction_type == "v":
270268
# v_t = alpha_t * epsilon - sigma_t * x
271269
# need to merge the PRs for sigma to be available in DDPM
272-
pred_original_sample = sample * self.alphas[t] - model_output * self.sigmas[t]
273-
eps = model_output * self.alphas[t] - sample * self.sigmas[t]
270+
pred = sample * self.alphas[timestep] - model_output * self.sigmas[timestep]
271+
eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep]
274272
raise NotImplementedError(f"v prediction not yet implemented for DDPM")
275273
else:
276274
raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`")
@@ -281,20 +279,20 @@ def step(
281279

282280
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
283281
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
284-
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
285-
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
282+
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t
283+
current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t
286284

287285
# 5. Compute predicted previous sample µ_t
288286
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
289287
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
290288

291289
# 6. Add noise
292290
variance = 0
293-
if t > 0:
291+
if timestep > 0:
294292
noise = torch.randn(
295293
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
296294
).to(model_output.device)
297-
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
295+
variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise
298296

299297
pred_prev_sample = pred_prev_sample + variance
300298

0 commit comments

Comments
 (0)