Skip to content

Commit f00d896

Browse files
authored
DDPM changes to support v diffusion (#1121)
* v diffusion support for ddpm * quality and style * variable name consistency * missing base case * pass prediction type along in the pipeline * put prediction type in scheduler config * style
1 parent ac6be90 commit f00d896

File tree

1 file changed

+62
-26
lines changed

1 file changed

+62
-26
lines changed

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,27 @@
1616

1717
import math
1818
from dataclasses import dataclass
19-
from typing import Optional, Tuple, Union
19+
from typing import Literal, Optional, Tuple, Union
2020

2121
import numpy as np
2222
import torch
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
25-
from ..utils import BaseOutput, deprecate
25+
from ..utils import BaseOutput
2626
from .scheduling_utils import SchedulerMixin
2727

2828

29+
def expand_to_shape(input, timesteps, shape, device):
30+
"""
31+
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
32+
nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once.
33+
"""
34+
out = torch.gather(input.to(device), 0, timesteps.to(device))
35+
reshape = [shape[0]] + [1] * (len(shape) - 1)
36+
out = out.reshape(*reshape)
37+
return out
38+
39+
2940
@dataclass
3041
class DDPMSchedulerOutput(BaseOutput):
3142
"""
@@ -102,6 +113,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
102113
103114
"""
104115

116+
_compatible_classes = [
117+
"DDIMScheduler",
118+
"PNDMScheduler",
119+
"LMSDiscreteScheduler",
120+
"EulerDiscreteScheduler",
121+
"EulerAncestralDiscreteScheduler",
122+
]
123+
105124
@register_to_config
106125
def __init__(
107126
self,
@@ -112,15 +131,8 @@ def __init__(
112131
trained_betas: Optional[np.ndarray] = None,
113132
variance_type: str = "fixed_small",
114133
clip_sample: bool = True,
115-
**kwargs,
134+
prediction_type: Literal["epsilon", "sample", "v"] = "epsilon",
116135
):
117-
deprecate(
118-
"tensor_format",
119-
"0.6.0",
120-
"If you're running your code in PyTorch, you can safely remove this argument.",
121-
take_from=kwargs,
122-
)
123-
124136
if trained_betas is not None:
125137
self.betas = torch.from_numpy(trained_betas)
126138
elif beta_schedule == "linear":
@@ -142,8 +154,8 @@ def __init__(
142154

143155
self.alphas = 1.0 - self.betas
144156
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
145-
self.sigmas = 1 - self.alphas**2
146-
self.one = torch.tensor(1.0)
157+
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
158+
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
147159

148160
# standard deviation of the initial noise distribution
149161
self.init_noise_sigma = 1.0
@@ -153,6 +165,7 @@ def __init__(
153165
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
154166

155167
self.variance_type = variance_type
168+
self.prediction_type = prediction_type
156169

157170
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
158171
"""
@@ -185,7 +198,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
185198

186199
def _get_variance(self, timestep, predicted_variance=None, variance_type=None):
187200
alpha_prod_t = self.alphas_cumprod[timestep]
188-
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
201+
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0)
189202

190203
# For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
191204
# and sample from it to get previous sample
@@ -213,6 +226,8 @@ def _get_variance(self, timestep, predicted_variance=None, variance_type=None):
213226
max_log = self.betas[timestep]
214227
frac = (predicted_variance + 1) / 2
215228
variance = frac * max_log + (1 - frac) * min_log
229+
elif variance_type == "v_diffusion":
230+
variance = torch.log(self.betas[timestep] * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t))
216231

217232
return variance
218233

@@ -221,7 +236,7 @@ def step(
221236
model_output: torch.FloatTensor,
222237
timestep: int,
223238
sample: torch.FloatTensor,
224-
prediction_type: str = "epsilon",
239+
# prediction_type: Literal["epsilon", "sample", "v"] = "epsilon",
225240
generator=None,
226241
return_dict: bool = True,
227242
) -> Union[DDPMSchedulerOutput, Tuple]:
@@ -234,9 +249,9 @@ def step(
234249
timestep (`int`): current discrete timestep in the diffusion chain.
235250
sample (`torch.FloatTensor`):
236251
current instance of sample being created by diffusion process.
237-
prediction_type (`str`):
252+
prediction_type (`Literal["epsilon", "sample", "v"]`, optional):
238253
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
239-
process), `sample` (directly predicting the noisy sample), or `v` (see section 2.4
254+
process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4
240255
https://imagen.research.google/video/paper.pdf)
241256
generator: random number generator.
242257
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
@@ -247,30 +262,36 @@ def step(
247262
returning a tuple, the first element is the sample tensor.
248263
249264
"""
265+
if self.variance_type == "v_diffusion":
266+
assert self.prediction_type == "v", "Need to use v prediction with v_diffusion"
250267
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
251268
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
252269
else:
253270
predicted_variance = None
254271

255272
# 1. compute alphas, betas
256273
alpha_prod_t = self.alphas_cumprod[timestep]
257-
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
274+
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0)
258275
beta_prod_t = 1 - alpha_prod_t
259276
beta_prod_t_prev = 1 - alpha_prod_t_prev
260277

261278
# 2. compute predicted original sample from predicted noise also called
262279
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
263-
if prediction_type == "epsilon":
280+
if self.prediction_type == "v":
281+
# x_recon in p_mean_variance
282+
pred_original_sample = (
283+
sample * self.sqrt_alphas_cumprod[timestep]
284+
- model_output * self.sqrt_one_minus_alphas_cumprod[timestep]
285+
)
286+
elif self.prediction_type == "epsilon":
264287
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
265-
elif prediction_type == "sample":
288+
289+
elif self.prediction_type == "sample":
266290
pred_original_sample = model_output
267-
elif prediction_type == "v":
268-
# v_t = alpha_t * epsilon - sigma_t * x
269-
# need to merge the PRs for sigma to be available in DDPM
270-
pred = sample * self.alphas[timestep] - model_output * self.sigmas[timestep]
271-
eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep]
272291
else:
273-
raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`")
292+
raise ValueError(
293+
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`"
294+
)
274295

275296
# 3. Clip "predicted x_0"
276297
if self.config.clip_sample:
@@ -291,7 +312,12 @@ def step(
291312
noise = torch.randn(
292313
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
293314
).to(model_output.device)
294-
variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise
315+
if self.variance_type == "fixed_small_log":
316+
variance = self._get_variance(timestep, predicted_variance=predicted_variance) * noise
317+
elif self.variance_type == "v_diffusion":
318+
variance = torch.exp(0.5 * self._get_variance(timestep, predicted_variance)) * noise
319+
else:
320+
variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise
295321

296322
pred_prev_sample = pred_prev_sample + variance
297323

@@ -306,6 +332,11 @@ def add_noise(
306332
noise: torch.FloatTensor,
307333
timesteps: torch.IntTensor,
308334
) -> torch.FloatTensor:
335+
if self.variance_type == "v_diffusion":
336+
alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device)
337+
z_t = alpha * original_samples + sigma * noise
338+
return z_t
339+
309340
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
310341
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
311342
timesteps = timesteps.to(original_samples.device)
@@ -325,3 +356,8 @@ def add_noise(
325356

326357
def __len__(self):
327358
return self.config.num_train_timesteps
359+
360+
def get_alpha_sigma(self, sample, timesteps, device):
361+
alpha = expand_to_shape(self.sqrt_alphas_cumprod, timesteps, sample.shape, device)
362+
sigma = expand_to_shape(self.sqrt_one_minus_alphas_cumprod, timesteps, sample.shape, device)
363+
return alpha, sigma

0 commit comments

Comments
 (0)