From 6a4c62fcaa458fc1870455d9dffa28e6b12f0805 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 4 Oct 2022 21:37:43 +0200 Subject: [PATCH 1/4] pytorch timesteps --- docs/source/api/schedulers.mdx | 2 +- .../stable_diffusion/pipeline_stable_diffusion.py | 2 +- .../pipeline_stable_diffusion_img2img.py | 7 +++++-- .../pipeline_stable_diffusion_inpaint.py | 7 +++++-- src/diffusers/schedulers/README.md | 2 +- src/diffusers/schedulers/scheduling_ddim.py | 7 ++++--- src/diffusers/schedulers/scheduling_ddpm.py | 9 +++++---- src/diffusers/schedulers/scheduling_karras_ve.py | 9 +++++---- src/diffusers/schedulers/scheduling_pndm.py | 5 +++-- src/diffusers/schedulers/scheduling_sde_ve.py | 6 ++++-- src/diffusers/schedulers/scheduling_sde_vp.py | 6 +++--- tests/test_scheduler.py | 10 ++++++---- 12 files changed, 43 insertions(+), 29 deletions(-) diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index b5af14d4bf4a..12a6b5c587bc 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher To this end, the design of schedulers is such that: - Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality. -- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists). +- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists). ## API diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index d190acb1fa1c..34678978ebad 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -251,7 +251,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps) # Some schedulers like PNDM have timesteps as arrays - # It's more optimzed to move all timesteps to correct device beforehand + # It's more optimized to move all timesteps to correct device beforehand if torch.is_tensor(self.scheduler.timesteps): timesteps_tensor = self.scheduler.timesteps.to(self.device) else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index c8f02b5896d6..06660fb2f847 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -283,8 +283,11 @@ def __call__( t_start = max(num_inference_steps - init_timestep + offset, 0) # Some schedulers like PNDM have timesteps as arrays - # It's more optimzed to move all timesteps to correct device beforehand - timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) + # It's more optimized to move all timesteps to correct device beforehand + if torch.is_tensor(self.scheduler.timesteps): + timesteps_tensor = self.scheduler.timesteps[t_start:].to(self.device) + else: + timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:].copy(), device=self.device) for i, t in enumerate(self.progress_bar(timesteps_tensor)): t_index = t_start + i diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 21490d975730..e212f9024583 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -317,8 +317,11 @@ def __call__( t_start = max(num_inference_steps - init_timestep + offset, 0) # Some schedulers like PNDM have timesteps as arrays - # It's more optimzed to move all timesteps to correct device beforehand - timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) + # It's more optimized to move all timesteps to correct device beforehand + if torch.is_tensor(self.scheduler.timesteps): + timesteps_tensor = self.scheduler.timesteps[t_start:].to(self.device) + else: + timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:].copy(), device=self.device) for i, t in tqdm(enumerate(timesteps_tensor)): t_index = t_start + i diff --git a/src/diffusers/schedulers/README.md b/src/diffusers/schedulers/README.md index b6b711ebbf3f..6a01c503a909 100644 --- a/src/diffusers/schedulers/README.md +++ b/src/diffusers/schedulers/README.md @@ -2,7 +2,7 @@ - Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps. - Schedulers can be used interchangeable between diffusion models in inference to find the preferred trade-off between speed and generation quality. -- Schedulers are available in numpy, but can easily be transformed into PyTorch. +- Schedulers are available in PyTorch and Jax. ## API diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index a728ab29d7bb..44c7b268cb68 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -154,7 +154,7 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1] + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] @@ -166,7 +166,7 @@ def _get_variance(self, timestep, prev_timestep): return variance - def set_timesteps(self, num_inference_steps: int, **kwargs): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -183,7 +183,8 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1] + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() + self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps += offset def step( diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 4d4e986a76ea..e5a7abfc3797 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -142,11 +142,11 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1] + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type - def set_timesteps(self, num_inference_steps: int): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -156,9 +156,10 @@ def set_timesteps(self, num_inference_steps: int): """ num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) self.num_inference_steps = num_inference_steps - self.timesteps = np.arange( + timesteps = np.arange( 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps - )[::-1] + )[::-1].copy() + self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): alpha_prod_t = self.alphas_cumprod[t] diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 63e1400262d8..f8a7d9fe995e 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -97,10 +97,10 @@ def __init__( # setable values self.num_inference_steps: int = None - self.timesteps: np.ndarray = None + self.timesteps: np.IntTensor = None self.schedule: torch.FloatTensor = None # sigma(t_i) - def set_timesteps(self, num_inference_steps: int): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -110,7 +110,8 @@ def set_timesteps(self, num_inference_steps: int): """ self.num_inference_steps = num_inference_steps - self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps).to(device) schedule = [ ( self.config.sigma_max**2 @@ -118,7 +119,7 @@ def set_timesteps(self, num_inference_steps: int): ) for i in self.timesteps ] - self.schedule = torch.tensor(schedule, dtype=torch.float32) + self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device) def add_noise_to_input( self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 3974335a2f1b..86e9b35ccd8d 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -147,7 +147,7 @@ def __init__( self.plms_timesteps = None self.timesteps = None - def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor: + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -184,7 +184,8 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor ::-1 ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy - self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) self.ets = [] self.counter = 0 diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 12ed1a1b656e..9dda30e360de 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -89,7 +89,9 @@ def __init__( self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) - def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): + def set_timesteps( + self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None + ): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -101,7 +103,7 @@ def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): """ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps - self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps) + self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device) def set_sigmas( self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index 7cf1da44272a..6f4470bcb8d5 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -15,7 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit - +from typing import Union import math import torch @@ -52,8 +52,8 @@ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling self.discrete_sigmas = None self.timesteps = None - def set_timesteps(self, num_inference_steps): - self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None): + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device) def step_pred(self, score, x, t, generator=None): if self.timesteps is None: diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index bee36c39acdb..2eb9030155bf 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -354,7 +354,7 @@ def test_steps_offset(self): scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(5) - assert np.equal(scheduler.timesteps, np.array([801, 601, 401, 201, 1])).all() + assert torch.equal(scheduler.timesteps, torch.IntTensor([801, 601, 401, 201, 1])) def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): @@ -568,10 +568,12 @@ def test_steps_offset(self): scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(10) - assert np.equal( + assert torch.equal( scheduler.timesteps, - np.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), - ).all() + torch.IntTensor( + [901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1] + ), + ) def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): From 093e14e716af6b10a8ec3f2d87d498f9ac68205c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 4 Oct 2022 21:47:44 +0200 Subject: [PATCH 2/4] style --- src/diffusers/schedulers/scheduling_sde_vp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index 6f4470bcb8d5..1130d3d99d59 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -14,9 +14,8 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch -# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit -from typing import Union import math +from typing import Union import torch From 588499dbfcf7c5451afc4d248fc686449f5bdc15 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 5 Oct 2022 07:57:24 +0200 Subject: [PATCH 3/4] get rid of if-else --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 5 +---- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 5 +---- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 5 +---- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b3ac9d41f65e..614367aea77e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -279,10 +279,7 @@ def __call__( # Some schedulers like PNDM have timesteps as arrays # It's more optimized to move all timesteps to correct device beforehand - if torch.is_tensor(self.scheduler.timesteps): - timesteps_tensor = self.scheduler.timesteps.to(self.device) - else: - timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device) + timesteps_tensor = self.scheduler.timesteps.to(self.device) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 201716f7b671..473fbd9ad099 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -307,10 +307,7 @@ def __call__( # Some schedulers like PNDM have timesteps as arrays # It's more optimized to move all timesteps to correct device beforehand - if torch.is_tensor(self.scheduler.timesteps): - timesteps_tensor = self.scheduler.timesteps[t_start:].to(self.device) - else: - timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:].copy(), device=self.device) + timesteps_tensor = self.scheduler.timesteps[t_start:].to(self.device) for i, t in enumerate(self.progress_bar(timesteps_tensor)): t_index = t_start + i diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 1e707c8f7754..5e03ca332a1d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -345,10 +345,7 @@ def __call__( # Some schedulers like PNDM have timesteps as arrays # It's more optimized to move all timesteps to correct device beforehand - if torch.is_tensor(self.scheduler.timesteps): - timesteps_tensor = self.scheduler.timesteps[t_start:].to(self.device) - else: - timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:].copy(), device=self.device) + timesteps_tensor = self.scheduler.timesteps[t_start:].to(self.device) for i, t in tqdm(enumerate(timesteps_tensor)): t_index = t_start + i From 0302090cc363735473a952295af5286d8cdbd8d4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 5 Oct 2022 08:43:52 +0200 Subject: [PATCH 4/4] fix test --- tests/test_scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 2eb9030155bf..4e968aef70c4 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -354,7 +354,7 @@ def test_steps_offset(self): scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(5) - assert torch.equal(scheduler.timesteps, torch.IntTensor([801, 601, 401, 201, 1])) + assert torch.equal(scheduler.timesteps, torch.LongTensor([801, 601, 401, 201, 1])) def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): @@ -570,7 +570,7 @@ def test_steps_offset(self): scheduler.set_timesteps(10) assert torch.equal( scheduler.timesteps, - torch.IntTensor( + torch.LongTensor( [901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1] ), )