Skip to content

Commit 7b030a7

Browse files
authored
handle device for randn in euler step (#1124)
* handle device for randn in euler step * convert device to str
1 parent 42bb459 commit 7b030a7

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,16 @@ def step(
217217
prev_sample = sample + derivative * dt
218218

219219
device = model_output.device if torch.is_tensor(model_output) else "cpu"
220-
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
220+
if str(device) == "mps":
221+
# randn does not work reproducibly on mps
222+
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
223+
device
224+
)
225+
else:
226+
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
227+
device
228+
)
229+
221230
prev_sample = prev_sample + noise * sigma_up
222231

223232
if not return_dict:

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,16 @@ def step(
214214
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
215215

216216
device = model_output.device if torch.is_tensor(model_output) else "cpu"
217-
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
217+
if str(device) == "mps":
218+
# randn does not work reproducibly on mps
219+
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
220+
device
221+
)
222+
else:
223+
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
224+
device
225+
)
226+
218227
eps = noise * s_noise
219228
sigma_hat = sigma * (gamma + 1)
220229

0 commit comments

Comments
 (0)