Skip to content

Commit 1e197c7

Browse files
authored
Do not use torch.float64 on the mps device (huggingface#942)
* Add failing test for huggingface#940. * Do not use torch.float64 in mps. * style * Temporarily skip add_noise for IPNDMScheduler. Until huggingface#990 is addressed.
1 parent 1b6b68c commit 1e197c7

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

schedulers/scheduling_lms_discrete.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,8 @@ def add_noise(
252252
) -> torch.FloatTensor:
253253
# Make sure sigmas and timesteps have the same device and dtype as original_samples
254254
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
255-
self.timesteps = self.timesteps.to(original_samples.device)
255+
dtype = torch.float32 if original_samples.device.type == "mps" else timesteps.dtype
256+
self.timesteps = self.timesteps.to(original_samples.device, dtype=dtype)
256257
timesteps = timesteps.to(original_samples.device)
257258

258259
schedule_timesteps = self.timesteps

0 commit comments

Comments
 (0)