diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 39f8f17df5d3..070ca20cfa20 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -181,17 +181,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps timesteps = ( np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device) - self.model_outputs = [ - None, - ] * self.config.solver_order + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [None] * self.config.solver_order self.lower_order_nums = 0 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 474d9b0d7339..7f4a377d8e82 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -192,17 +192,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps timesteps = ( np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device) - self.model_outputs = [ - None, - ] * self.config.solver_order + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [None] * self.config.solver_order self.lower_order_nums = 0 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index e4f38d0f5dad..16f4b1031129 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -194,21 +194,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps timesteps = ( np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + self.timesteps = torch.from_numpy(timesteps).to(device) - self.model_outputs = [ - None, - ] * self.config.solver_order + + self.num_inference_steps = len(timesteps) + self.model_outputs = [None] * self.config.solver_order self.lower_order_nums = 0 self.last_sample = None if self.solver_p: - self.solver_p.set_timesteps(num_inference_steps, device=device) + self.solver_p.set_timesteps(self.num_inference_steps, device=device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: diff --git a/tests/schedulers/test_scheduler_deis.py b/tests/schedulers/test_scheduler_deis.py index 8b14601bc982..e587a5b9824b 100644 --- a/tests/schedulers/test_scheduler_deis.py +++ b/tests/schedulers/test_scheduler_deis.py @@ -162,6 +162,15 @@ def test_timesteps(self): for timesteps in [25, 50, 100, 999, 1000]: self.check_over_configs(num_train_timesteps=timesteps) + def test_unique_timesteps(self, **config): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + if hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(1000) + assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps + def test_thresholding(self): self.check_over_configs(thresholding=False) for order in [1, 2, 3]: diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 295bbe882746..85518a3c9214 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -228,6 +228,15 @@ def test_switch(self): assert abs(result_mean.item() - 0.3301) < 1e-3 + def test_unique_timesteps(self, **config): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + if hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(1000) + assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps + def test_fp16_support(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 9dff04e7c998..6e67f36e1881 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -195,6 +195,15 @@ def test_full_loop_with_v_prediction(self): assert abs(result_mean.item() - 0.1453) < 1e-3 + def test_unique_timesteps(self, **config): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + if hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(1000) + assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps + def test_fp16_support(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 6154c8e2d625..072b179e09f5 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -160,6 +160,15 @@ def test_timesteps(self): for timesteps in [25, 50, 100, 999, 1000]: self.check_over_configs(num_train_timesteps=timesteps) + def test_unique_timesteps(self, **config): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + if hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(1000) + assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps + def test_thresholding(self): self.check_over_configs(thresholding=False) for order in [1, 2, 3]: