Skip to content

Commit cb0bf0b

Browse files
authored
fix(DDIM scheduler): use correct dtype for noise (#742)
Otherwise, it crashes when eta > 0 with float16.
1 parent e0fece2 commit cb0bf0b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,9 @@ def step(
283283
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
284284

285285
if eta > 0:
286+
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
286287
device = model_output.device if torch.is_tensor(model_output) else "cpu"
287-
noise = torch.randn(model_output.shape, generator=generator).to(device)
288+
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
288289
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
289290

290291
prev_sample = prev_sample + variance

0 commit comments

Comments
 (0)