From 90e3ae20bf8cd5c9b5e1c76a63b8795d78df4077 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Sat, 10 Sep 2022 16:56:19 -0600 Subject: [PATCH 1/5] initial attempt at solving --- src/diffusers/schedulers/scheduling_ddim.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 894d63bf2df0..b1f330c1fc94 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -145,9 +145,10 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0): offset (`int`): TODO """ self.num_inference_steps = num_inference_steps - self.timesteps = np.arange( - 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps - )[::-1].copy() + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multipling by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).astype(int)[::-1].copy() self.timesteps += offset self.set_format(tensor_format=self.tensor_format) From f9bbce342cfea1218e5244a5665efd86f5dd66ca Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Sun, 11 Sep 2022 16:17:50 -0600 Subject: [PATCH 2/5] fix pndm power of 3 inference_step --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- src/diffusers/schedulers/scheduling_pndm.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index b1f330c1fc94..92fb2e9a3fc7 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -146,7 +146,7 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0): """ self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multipling by ratio + # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).astype(int)[::-1].copy() self.timesteps += offset diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index b43d88bbab77..e3a8d6596a3e 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -141,9 +141,10 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.Floa offset (`int`): TODO """ self.num_inference_steps = num_inference_steps - self._timesteps = list( - range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) - ) + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).astype(int).tolist() self._offset = offset self._timesteps = np.array([t + self._offset for t in self._timesteps]) From a7359ac0cb8bb8e808091aa5c2b12017c8d5f0d5 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 12 Sep 2022 09:23:04 -0600 Subject: [PATCH 3/5] add power of 3 test --- tests/test_scheduler.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 3c2e786fc1f4..28431d4e6c8a 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -378,7 +378,22 @@ def test_time_indices(self): def test_inference_steps(self): for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): - self.check_over_forward(num_inference_steps=num_inference_steps) + self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + + def test_pow_of_3_inference_steps(self): + num_inference_steps = 27 + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.step(residual, i, sample).prev_sample def test_eta(self): for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]): @@ -621,6 +636,21 @@ def test_inference_steps(self): for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + def test_pow_of_3_inference_steps(self): + num_inference_steps = 27 + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.prk_timesteps): + sample = scheduler.step_prk(residual, i, sample).prev_sample + def test_inference_plms_no_past_residuals(self): with self.assertRaises(ValueError): scheduler_class = self.scheduler_classes[0] From 7b40737f5969b65d68a222d0ab48b2cd7a1e80d3 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 12 Sep 2022 09:44:17 -0600 Subject: [PATCH 4/5] fix index in pndm test, remove ddim test --- tests/test_scheduler.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 28431d4e6c8a..1a8857fb1f9e 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -380,21 +380,6 @@ def test_inference_steps(self): for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) - def test_pow_of_3_inference_steps(self): - num_inference_steps = 27 - - for scheduler_class in self.scheduler_classes: - sample = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - - scheduler.set_timesteps(num_inference_steps) - - for i, t in enumerate(scheduler.timesteps): - sample = scheduler.step(residual, i, sample).prev_sample - def test_eta(self): for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]): self.check_over_forward(time_step=t, eta=eta) @@ -649,7 +634,7 @@ def test_pow_of_3_inference_steps(self): scheduler.set_timesteps(num_inference_steps) for i, t in enumerate(scheduler.prk_timesteps): - sample = scheduler.step_prk(residual, i, sample).prev_sample + sample = scheduler.step_prk(residual, t, sample).prev_sample def test_inference_plms_no_past_residuals(self): with self.assertRaises(ValueError): From ee3f490cde37a3fc4ee25aa18175400fc0c09a71 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 13 Sep 2022 09:41:31 -0600 Subject: [PATCH 5/5] add comments, change to round() --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- src/diffusers/schedulers/scheduling_pndm.py | 2 +- tests/test_scheduler.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 92fb2e9a3fc7..6d95ae92e1d4 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -148,7 +148,7 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0): step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).astype(int)[::-1].copy() + self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() self.timesteps += offset self.set_format(tensor_format=self.tensor_format) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index e3a8d6596a3e..54de882d6d9c 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -144,7 +144,7 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.Floa step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).astype(int).tolist() + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist() self._offset = offset self._timesteps = np.array([t + self._offset for t in self._timesteps]) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 1a8857fb1f9e..ee4ee5649b57 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -622,6 +622,7 @@ def test_inference_steps(self): self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) def test_pow_of_3_inference_steps(self): + # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 num_inference_steps = 27 for scheduler_class in self.scheduler_classes: @@ -633,7 +634,8 @@ def test_pow_of_3_inference_steps(self): scheduler.set_timesteps(num_inference_steps) - for i, t in enumerate(scheduler.prk_timesteps): + # before power of 3 fix, would error on first step, so we only need to do two + for i, t in enumerate(scheduler.prk_timesteps[:2]): sample = scheduler.step_prk(residual, t, sample).prev_sample def test_inference_plms_no_past_residuals(self):