Skip to content

Commit 0343d8f

Browse files
authored
Do not use torch.float64 on the mps device (#942)
* Add failing test for #940. * Do not use torch.float64 in mps. * style * Temporarily skip add_noise for IPNDMScheduler. Until #990 is addressed.
1 parent 4b9f589 commit 0343d8f

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,8 @@ def add_noise(
252252
) -> torch.FloatTensor:
253253
# Make sure sigmas and timesteps have the same device and dtype as original_samples
254254
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
255-
self.timesteps = self.timesteps.to(original_samples.device)
255+
dtype = torch.float32 if original_samples.device.type == "mps" else timesteps.dtype
256+
self.timesteps = self.timesteps.to(original_samples.device, dtype=dtype)
256257
timesteps = timesteps.to(original_samples.device)
257258

258259
schedule_timesteps = self.timesteps

tests/test_scheduler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
PNDMScheduler,
2828
ScoreSdeVeScheduler,
2929
)
30+
from diffusers.utils import torch_device
3031

3132

3233
torch.backends.cuda.matmul.allow_tf32 = False
@@ -258,6 +259,23 @@ def test_scheduler_public_api(self):
258259
scaled_sample = scheduler.scale_model_input(sample, 0.0)
259260
self.assertEqual(sample.shape, scaled_sample.shape)
260261

262+
def test_add_noise_device(self):
263+
for scheduler_class in self.scheduler_classes:
264+
if scheduler_class == IPNDMScheduler:
265+
# Skip until #990 is addressed
266+
continue
267+
scheduler_config = self.get_scheduler_config()
268+
scheduler = scheduler_class(**scheduler_config)
269+
270+
sample = self.dummy_sample.to(torch_device)
271+
scaled_sample = scheduler.scale_model_input(sample, 0.0)
272+
self.assertEqual(sample.shape, scaled_sample.shape)
273+
274+
noise = torch.randn_like(scaled_sample).to(torch_device)
275+
t = torch.tensor([10]).to(torch_device)
276+
noised = scheduler.add_noise(scaled_sample, noise, t)
277+
self.assertEqual(noised.shape, scaled_sample.shape)
278+
261279

262280
class DDPMSchedulerTest(SchedulerCommonTest):
263281
scheduler_classes = (DDPMScheduler,)

0 commit comments

Comments
 (0)