Skip to content

Commit 3304538

Browse files
authored
[DDIM, DDPM] fix add_noise (#648)
fix add noise
1 parent e5eed52 commit 3304538

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,12 @@ def add_noise(
282282
noise: torch.FloatTensor,
283283
timesteps: torch.IntTensor,
284284
) -> torch.FloatTensor:
285-
timesteps = timesteps.to(self.alphas_cumprod.device)
285+
if self.alphas_cumprod.device != original_samples.device:
286+
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
287+
288+
if timesteps.device != original_samples.device:
289+
timesteps = timesteps.to(original_samples.device)
290+
286291
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
287292
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
288293
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,15 +268,19 @@ def add_noise(
268268
noise: torch.FloatTensor,
269269
timesteps: torch.IntTensor,
270270
) -> torch.FloatTensor:
271-
timesteps = timesteps.to(self.alphas_cumprod.device)
271+
if self.alphas_cumprod.device != original_samples.device:
272+
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
273+
274+
if timesteps.device != original_samples.device:
275+
timesteps = timesteps.to(original_samples.device)
272276

273277
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
274278
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
275279
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
276280
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
277281

278282
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
279-
sqrt_one_minus_alpha_prod.flatten()
283+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
280284
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
281285
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
282286

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,6 @@ def add_noise(
387387
if timesteps.device != original_samples.device:
388388
timesteps = timesteps.to(original_samples.device)
389389

390-
timesteps = timesteps.to(self.alphas_cumprod.device)
391-
392390
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
393391
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
394392
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):

0 commit comments

Comments
 (0)