From 5f1b5d5e09323d4441d3f8afc90ae95afb06093d Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Sep 2022 17:15:58 +0200 Subject: [PATCH] fix add noise --- src/diffusers/schedulers/scheduling_ddim.py | 7 ++++++- src/diffusers/schedulers/scheduling_ddpm.py | 8 ++++++-- src/diffusers/schedulers/scheduling_pndm.py | 2 -- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 6880700ecef0..ccff870609e8 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -282,7 +282,12 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: - timesteps = timesteps.to(self.alphas_cumprod.device) + if self.alphas_cumprod.device != original_samples.device: + self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) + + if timesteps.device != original_samples.device: + timesteps = timesteps.to(original_samples.device) + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 0383dea224c7..7f8988fdfd43 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -268,7 +268,11 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: - timesteps = timesteps.to(self.alphas_cumprod.device) + if self.alphas_cumprod.device != original_samples.device: + self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) + + if timesteps.device != original_samples.device: + timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() @@ -276,7 +280,7 @@ def add_noise( sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod.flatten() + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 1935a6ef93f2..ade223e2fbc0 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -387,8 +387,6 @@ def add_noise( if timesteps.device != original_samples.device: timesteps = timesteps.to(original_samples.device) - timesteps = timesteps.to(self.alphas_cumprod.device) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape):