diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 2dc85a93adc9..2d24ecac1d95 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -300,11 +300,9 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: - if self.alphas_cumprod.device != original_samples.device: - self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) - - if timesteps.device != original_samples.device: - timesteps = timesteps.to(original_samples.device) + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index e1db9079d149..77ed98137708 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -294,11 +294,9 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: - if self.alphas_cumprod.device != original_samples.device: - self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) - - if timesteps.device != original_samples.device: - timesteps = timesteps.to(original_samples.device) + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 802da468cda6..1f6187c727c9 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -257,9 +257,13 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - sigmas = self.sigmas.to(original_samples.device) - schedule_timesteps = self.timesteps.to(original_samples.device) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + self.timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) + + schedule_timesteps = self.timesteps + if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): deprecate( "timesteps as indices", @@ -273,7 +277,7 @@ def add_noise( else: step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sigma = sigmas[step_indices].flatten() + sigma = self.sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index f6a6d6153be5..b29712e1e736 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -400,11 +400,9 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.Tensor: - if self.alphas_cumprod.device != original_samples.device: - self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) - - if timesteps.device != original_samples.device: - timesteps = timesteps.to(original_samples.device) + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten()