diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index 364cc592..6f2b3bc6 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -104,7 +104,7 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) self.clip_sample = clip_sample self.steps_offset = steps_offset @@ -117,6 +117,13 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. device: target device to put the data. """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + self.num_inference_steps = num_inference_steps step_ratio = self.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 9f5ca107..bd4fac29 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -102,11 +102,18 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. device: target device to put the data. """ - num_inference_steps = min(self.num_train_timesteps, num_inference_steps) + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + self.num_inference_steps = num_inference_steps - timesteps = np.arange(0, self.num_train_timesteps, self.num_train_timesteps // self.num_inference_steps)[ - ::-1 - ].copy() + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index 69218f6b..109019d0 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -121,11 +121,18 @@ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, to num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. device: target device to put the data. """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + self.num_inference_steps = num_inference_steps step_ratio = self.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64) self._timesteps += self.steps_offset if self.skip_prk_steps: diff --git a/tests/test_scheduler_ddim.py b/tests/test_scheduler_ddim.py index 802d85a7..206ae363 100644 --- a/tests/test_scheduler_ddim.py +++ b/tests/test_scheduler_ddim.py @@ -57,6 +57,13 @@ def test_set_timesteps(self): self.assertEqual(scheduler.num_inference_steps, 100) self.assertEqual(len(scheduler.timesteps), 100) + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = DDIMScheduler( + num_train_timesteps=1000, + ) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py index d0f3778c..cb635e31 100644 --- a/tests/test_scheduler_ddpm.py +++ b/tests/test_scheduler_ddpm.py @@ -72,6 +72,13 @@ def test_set_timesteps(self): self.assertEqual(scheduler.num_inference_steps, 100) self.assertEqual(len(scheduler.timesteps), 100) + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = DDPMScheduler( + num_train_timesteps=1000, + ) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py index 6f30302e..543f692e 100644 --- a/tests/test_scheduler_pndm.py +++ b/tests/test_scheduler_pndm.py @@ -67,6 +67,13 @@ def test_set_timesteps_prk(self): self.assertEqual(scheduler.num_inference_steps, 109) self.assertEqual(len(scheduler.timesteps), 109) + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = PNDMScheduler( + num_train_timesteps=1000, + ) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + if __name__ == "__main__": unittest.main()