diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 894d63bf2df0..6d95ae92e1d4 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -145,9 +145,10 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0): offset (`int`): TODO """ self.num_inference_steps = num_inference_steps - self.timesteps = np.arange( - 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps - )[::-1].copy() + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() self.timesteps += offset self.set_format(tensor_format=self.tensor_format) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index b43d88bbab77..54de882d6d9c 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -141,9 +141,10 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.Floa offset (`int`): TODO """ self.num_inference_steps = num_inference_steps - self._timesteps = list( - range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) - ) + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist() self._offset = offset self._timesteps = np.array([t + self._offset for t in self._timesteps]) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 3c2e786fc1f4..ee4ee5649b57 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -378,7 +378,7 @@ def test_time_indices(self): def test_inference_steps(self): for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): - self.check_over_forward(num_inference_steps=num_inference_steps) + self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) def test_eta(self): for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]): @@ -621,6 +621,23 @@ def test_inference_steps(self): for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + def test_pow_of_3_inference_steps(self): + # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 + num_inference_steps = 27 + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(num_inference_steps) + + # before power of 3 fix, would error on first step, so we only need to do two + for i, t in enumerate(scheduler.prk_timesteps[:2]): + sample = scheduler.step_prk(residual, t, sample).prev_sample + def test_inference_plms_no_past_residuals(self): with self.assertRaises(ValueError): scheduler_class = self.scheduler_classes[0]