Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion generative/networks/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(

# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64))

self.clip_sample = clip_sample
self.steps_offset = steps_offset
Expand All @@ -117,6 +117,13 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
device: target device to put the data.
"""
if num_inference_steps > self.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.num_train_timesteps} timesteps."
)

self.num_inference_steps = num_inference_steps
step_ratio = self.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
Expand Down
15 changes: 11 additions & 4 deletions generative/networks/schedulers/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,18 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
device: target device to put the data.
"""
num_inference_steps = min(self.num_train_timesteps, num_inference_steps)
if num_inference_steps > self.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.num_train_timesteps} timesteps."
)

self.num_inference_steps = num_inference_steps
timesteps = np.arange(0, self.num_train_timesteps, self.num_train_timesteps // self.num_inference_steps)[
::-1
].copy()
step_ratio = self.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)

def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
Expand Down
9 changes: 8 additions & 1 deletion generative/networks/schedulers/pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,18 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
device: target device to put the data.
"""
if num_inference_steps > self.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.num_train_timesteps} timesteps."
)

self.num_inference_steps = num_inference_steps
step_ratio = self.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64)
self._timesteps += self.steps_offset

if self.skip_prk_steps:
Expand Down
7 changes: 7 additions & 0 deletions tests/test_scheduler_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ def test_set_timesteps(self):
self.assertEqual(scheduler.num_inference_steps, 100)
self.assertEqual(len(scheduler.timesteps), 100)

def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
scheduler = DDIMScheduler(
num_train_timesteps=1000,
)
with self.assertRaises(ValueError):
scheduler.set_timesteps(num_inference_steps=2000)


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions tests/test_scheduler_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def test_set_timesteps(self):
self.assertEqual(scheduler.num_inference_steps, 100)
self.assertEqual(len(scheduler.timesteps), 100)

def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
scheduler = DDPMScheduler(
num_train_timesteps=1000,
)
with self.assertRaises(ValueError):
scheduler.set_timesteps(num_inference_steps=2000)


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions tests/test_scheduler_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def test_set_timesteps_prk(self):
self.assertEqual(scheduler.num_inference_steps, 109)
self.assertEqual(len(scheduler.timesteps), 109)

def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
scheduler = PNDMScheduler(
num_train_timesteps=1000,
)
with self.assertRaises(ValueError):
scheduler.set_timesteps(num_inference_steps=2000)


if __name__ == "__main__":
unittest.main()