Skip to content

Commit 2bdde4d

Browse files
patil-surajpatrickvonplaten
authored andcommitted
[schedulers] hanlde dtype in add_noise (#767)
* handle dtype in vae and image2image pipeline * handle dtype in add noise * don't modify vae and pipeline * remove the if
1 parent 91ddd2a commit 2bdde4d

File tree

4 files changed

+16
-18
lines changed

4 files changed

+16
-18
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,11 +300,9 @@ def add_noise(
300300
noise: torch.FloatTensor,
301301
timesteps: torch.IntTensor,
302302
) -> torch.FloatTensor:
303-
if self.alphas_cumprod.device != original_samples.device:
304-
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
305-
306-
if timesteps.device != original_samples.device:
307-
timesteps = timesteps.to(original_samples.device)
303+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
304+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
305+
timesteps = timesteps.to(original_samples.device)
308306

309307
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
310308
sqrt_alpha_prod = sqrt_alpha_prod.flatten()

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,9 @@ def add_noise(
294294
noise: torch.FloatTensor,
295295
timesteps: torch.IntTensor,
296296
) -> torch.FloatTensor:
297-
if self.alphas_cumprod.device != original_samples.device:
298-
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
299-
300-
if timesteps.device != original_samples.device:
301-
timesteps = timesteps.to(original_samples.device)
297+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
298+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
299+
timesteps = timesteps.to(original_samples.device)
302300

303301
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
304302
sqrt_alpha_prod = sqrt_alpha_prod.flatten()

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,13 @@ def add_noise(
257257
noise: torch.FloatTensor,
258258
timesteps: torch.FloatTensor,
259259
) -> torch.FloatTensor:
260-
sigmas = self.sigmas.to(original_samples.device)
261-
schedule_timesteps = self.timesteps.to(original_samples.device)
260+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
261+
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
262+
self.timesteps = self.timesteps.to(original_samples.device)
262263
timesteps = timesteps.to(original_samples.device)
264+
265+
schedule_timesteps = self.timesteps
266+
263267
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
264268
deprecate(
265269
"timesteps as indices",
@@ -273,7 +277,7 @@ def add_noise(
273277
else:
274278
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
275279

276-
sigma = sigmas[step_indices].flatten()
280+
sigma = self.sigmas[step_indices].flatten()
277281
while len(sigma.shape) < len(original_samples.shape):
278282
sigma = sigma.unsqueeze(-1)
279283

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -400,11 +400,9 @@ def add_noise(
400400
noise: torch.FloatTensor,
401401
timesteps: torch.IntTensor,
402402
) -> torch.Tensor:
403-
if self.alphas_cumprod.device != original_samples.device:
404-
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
405-
406-
if timesteps.device != original_samples.device:
407-
timesteps = timesteps.to(original_samples.device)
403+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
404+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
405+
timesteps = timesteps.to(original_samples.device)
408406

409407
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
410408
sqrt_alpha_prod = sqrt_alpha_prod.flatten()

0 commit comments

Comments
 (0)