File tree Expand file tree Collapse file tree 2 files changed +20
-1
lines changed Expand file tree Collapse file tree 2 files changed +20
-1
lines changed Original file line number Diff line number Diff line change @@ -252,7 +252,8 @@ def add_noise(
252252 ) -> torch .FloatTensor :
253253 # Make sure sigmas and timesteps have the same device and dtype as original_samples
254254 self .sigmas = self .sigmas .to (device = original_samples .device , dtype = original_samples .dtype )
255- self .timesteps = self .timesteps .to (original_samples .device )
255+ dtype = torch .float32 if original_samples .device .type == "mps" else timesteps .dtype
256+ self .timesteps = self .timesteps .to (original_samples .device , dtype = dtype )
256257 timesteps = timesteps .to (original_samples .device )
257258
258259 schedule_timesteps = self .timesteps
Original file line number Diff line number Diff line change 2727 PNDMScheduler ,
2828 ScoreSdeVeScheduler ,
2929)
30+ from diffusers .utils import torch_device
3031
3132
3233torch .backends .cuda .matmul .allow_tf32 = False
@@ -258,6 +259,23 @@ def test_scheduler_public_api(self):
258259 scaled_sample = scheduler .scale_model_input (sample , 0.0 )
259260 self .assertEqual (sample .shape , scaled_sample .shape )
260261
262+ def test_add_noise_device (self ):
263+ for scheduler_class in self .scheduler_classes :
264+ if scheduler_class == IPNDMScheduler :
265+ # Skip until #990 is addressed
266+ continue
267+ scheduler_config = self .get_scheduler_config ()
268+ scheduler = scheduler_class (** scheduler_config )
269+
270+ sample = self .dummy_sample .to (torch_device )
271+ scaled_sample = scheduler .scale_model_input (sample , 0.0 )
272+ self .assertEqual (sample .shape , scaled_sample .shape )
273+
274+ noise = torch .randn_like (scaled_sample ).to (torch_device )
275+ t = torch .tensor ([10 ]).to (torch_device )
276+ noised = scheduler .add_noise (scaled_sample , noise , t )
277+ self .assertEqual (noised .shape , scaled_sample .shape )
278+
261279
262280class DDPMSchedulerTest (SchedulerCommonTest ):
263281 scheduler_classes = (DDPMScheduler ,)
You can’t perform that action at this time.
0 commit comments