Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)

# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]

self.timesteps = torch.from_numpy(timesteps).to(device)
self.model_outputs = [
None,
] * self.config.solver_order

self.num_inference_steps = len(timesteps)

self.model_outputs = [None] * self.config.solver_order
self.lower_order_nums = 0

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
Expand Down
14 changes: 10 additions & 4 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)

# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]

self.timesteps = torch.from_numpy(timesteps).to(device)
self.model_outputs = [
None,
] * self.config.solver_order

self.num_inference_steps = len(timesteps)

self.model_outputs = [None] * self.config.solver_order
self.lower_order_nums = 0

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
Expand Down
15 changes: 10 additions & 5 deletions src/diffusers/schedulers/scheduling_unipc_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,21 +194,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)

# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]

self.timesteps = torch.from_numpy(timesteps).to(device)
self.model_outputs = [
None,
] * self.config.solver_order

self.num_inference_steps = len(timesteps)
self.model_outputs = [None] * self.config.solver_order
self.lower_order_nums = 0
self.last_sample = None
if self.solver_p:
self.solver_p.set_timesteps(num_inference_steps, device=device)
self.solver_p.set_timesteps(self.num_inference_steps, device=device)

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
Expand Down
9 changes: 9 additions & 0 deletions tests/schedulers/test_scheduler_deis.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)

def test_unique_timesteps(self, **config):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

if hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(1000)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps

def test_thresholding(self):
self.check_over_configs(thresholding=False)
for order in [1, 2, 3]:
Expand Down
9 changes: 9 additions & 0 deletions tests/schedulers/test_scheduler_dpm_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,15 @@ def test_switch(self):

assert abs(result_mean.item() - 0.3301) < 1e-3

def test_unique_timesteps(self, **config):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

if hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(1000)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps

def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
Expand Down
9 changes: 9 additions & 0 deletions tests/schedulers/test_scheduler_dpm_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,15 @@ def test_full_loop_with_v_prediction(self):

assert abs(result_mean.item() - 0.1453) < 1e-3

def test_unique_timesteps(self, **config):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

if hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(1000)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps

def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
Expand Down
9 changes: 9 additions & 0 deletions tests/schedulers/test_scheduler_unipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)

def test_unique_timesteps(self, **config):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

if hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(1000)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps

def test_thresholding(self):
self.check_over_configs(thresholding=False)
for order in [1, 2, 3]:
Expand Down