Skip to content

Commit 1d04e1b

Browse files
authored
Continuation of #942: additional float64 failure (#996)
* Add failing test for #940. * Do not use torch.float64 in mps. * style * Temporarily skip add_noise for IPNDMScheduler. Until #990 is addressed. * Fix additional float64 error in mps. * Improve add_noise test * Slight edit – I think it's clearer this way.
1 parent a23ad87 commit 1d04e1b

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,13 @@ def add_noise(
252252
) -> torch.FloatTensor:
253253
# Make sure sigmas and timesteps have the same device and dtype as original_samples
254254
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
255-
dtype = torch.float32 if original_samples.device.type == "mps" else timesteps.dtype
256-
self.timesteps = self.timesteps.to(original_samples.device, dtype=dtype)
257-
timesteps = timesteps.to(original_samples.device)
255+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
256+
# mps does not support float64
257+
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
258+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
259+
else:
260+
self.timesteps = self.timesteps.to(original_samples.device)
261+
timesteps = timesteps.to(original_samples.device)
258262

259263
schedule_timesteps = self.timesteps
260264

tests/test_scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,13 +266,14 @@ def test_add_noise_device(self):
266266
continue
267267
scheduler_config = self.get_scheduler_config()
268268
scheduler = scheduler_class(**scheduler_config)
269+
scheduler.set_timesteps(100)
269270

270271
sample = self.dummy_sample.to(torch_device)
271272
scaled_sample = scheduler.scale_model_input(sample, 0.0)
272273
self.assertEqual(sample.shape, scaled_sample.shape)
273274

274275
noise = torch.randn_like(scaled_sample).to(torch_device)
275-
t = torch.tensor([10]).to(torch_device)
276+
t = scheduler.timesteps[5][None]
276277
noised = scheduler.add_noise(scaled_sample, noise, t)
277278
self.assertEqual(noised.shape, scaled_sample.shape)
278279

0 commit comments

Comments
 (0)