Skip to content

Commit 0ee3350

Browse files
committed
[bug fix] dpm multistep solver duplicate timesteps
1 parent a87e88b commit 0ee3350

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,22 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
192192
device (`str` or `torch.device`, optional):
193193
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
194194
"""
195-
self.num_inference_steps = num_inference_steps
196195
timesteps = (
197196
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
198197
.round()[::-1][:-1]
199198
.copy()
200199
.astype(np.int64)
201200
)
201+
202+
# when num_inference_steps == num_train_timesteps, we can end up with
203+
# duplicates in timesteps.
204+
_, unique_indices = np.unique(timesteps, return_index=True)
205+
timesteps = timesteps[np.sort(unique_indices)]
206+
202207
self.timesteps = torch.from_numpy(timesteps).to(device)
208+
209+
self.num_inference_steps = len(timesteps)
210+
203211
self.model_outputs = [
204212
None,
205213
] * self.config.solver_order

0 commit comments

Comments
 (0)