From 6a0c1dba8afca93640b845e2ff69d259e7861c42 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 9 Sep 2022 20:20:51 +0200 Subject: [PATCH 1/3] Fix LMS scheduler indexing in `add_noise` #358. --- src/diffusers/schedulers/scheduling_lms_discrete.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 1381587febf1..55970db951a8 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -121,7 +121,7 @@ def set_timesteps(self, num_inference_steps: int): frac = np.mod(self.timesteps, 1.0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - self.sigmas = np.concatenate([sigmas, [0.0]]) + self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.derivatives = [] @@ -184,6 +184,7 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: + timesteps = timesteps.to(self.sigmas.device) sigmas = self.match_shape(self.sigmas[timesteps], noise) noisy_samples = original_samples + noise * sigmas From c8e6b26c99db816a309a36063133c662826fb95b Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 9 Sep 2022 20:58:45 +0200 Subject: [PATCH 2/3] Fix DDIM and DDPM indexing with mps device. --- src/diffusers/schedulers/scheduling_ddim.py | 1 + src/diffusers/schedulers/scheduling_ddpm.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 894d63bf2df0..09a5a57cc457 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -249,6 +249,7 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 020739406e71..248cb07d99c7 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -250,6 +250,7 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 From bdaceed1e21fd55324abfd3a452f2ac959db60d9 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 14 Sep 2022 11:08:58 +0200 Subject: [PATCH 3/3] Verify format is PyTorch before using `.to()` --- src/diffusers/schedulers/scheduling_ddim.py | 3 ++- src/diffusers/schedulers/scheduling_ddpm.py | 3 ++- src/diffusers/schedulers/scheduling_lms_discrete.py | 3 ++- src/diffusers/schedulers/scheduling_pndm.py | 4 ++-- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 09a5a57cc457..9b123087a4eb 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -249,7 +249,8 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: - timesteps = timesteps.to(self.alphas_cumprod.device) + if self.tensor_format == "pt": + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 248cb07d99c7..b790995c153a 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -250,7 +250,8 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: - timesteps = timesteps.to(self.alphas_cumprod.device) + if self.tensor_format == "pt": + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 55970db951a8..c4e2712c5f18 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -184,7 +184,8 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: - timesteps = timesteps.to(self.sigmas.device) + if self.tensor_format == "pt": + timesteps = timesteps.to(self.sigmas.device) sigmas = self.match_shape(self.sigmas[timesteps], noise) noisy_samples = original_samples + noise * sigmas diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index b43d88bbab77..1e3cff2919b4 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -364,8 +364,8 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> torch.Tensor: - # mps requires indices to be in the same device, so we use cpu as is the default with cuda - timesteps = timesteps.to(self.alphas_cumprod.device) + if self.tensor_format == "pt": + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5