Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api/schedulers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
To this end, the design of schedulers is such that:

- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists).
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).


## API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,8 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps)

# Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand
if torch.is_tensor(self.scheduler.timesteps):
timesteps_tensor = self.scheduler.timesteps.to(self.device)
else:
timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)

# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,10 @@ def __call__(
latents = init_latents

t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]

# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps = self.scheduler.timesteps[t_start:].to(self.device)

for i, t in enumerate(self.progress_bar(timesteps)):
t_index = t_start + i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,10 @@ def __call__(
latents = init_latents

t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]

# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps = self.scheduler.timesteps[t_start:].to(self.device)

for i, t in tqdm(enumerate(timesteps)):
t_index = t_start + i
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

- Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps.
- Schedulers can be used interchangeable between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are available in numpy, but can easily be transformed into PyTorch.
- Schedulers are available in PyTorch and Jax.

## API

Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())

def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
Expand All @@ -166,7 +166,7 @@ def _get_variance(self, timestep, prev_timestep):

return variance

def set_timesteps(self, num_inference_steps: int, **kwargs):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand All @@ -183,7 +183,8 @@ def set_timesteps(self, num_inference_steps: int, **kwargs):
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1]
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += offset

def step(
Expand Down
9 changes: 5 additions & 4 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def __init__(

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())

self.variance_type = variance_type

def set_timesteps(self, num_inference_steps: int):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

Expand All @@ -156,9 +156,10 @@ def set_timesteps(self, num_inference_steps: int):
"""
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(
timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1]
)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)

def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
Expand Down
9 changes: 5 additions & 4 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def __init__(

# setable values
self.num_inference_steps: int = None
self.timesteps: np.ndarray = None
self.timesteps: np.IntTensor = None
self.schedule: torch.FloatTensor = None # sigma(t_i)

def set_timesteps(self, num_inference_steps: int):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand All @@ -110,15 +110,16 @@ def set_timesteps(self, num_inference_steps: int):
"""
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)
schedule = [
(
self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in self.timesteps
]
self.schedule = torch.tensor(schedule, dtype=torch.float32)
self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)

def add_noise_to_input(
self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(
self.plms_timesteps = None
self.timesteps = None

def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down Expand Up @@ -184,7 +184,8 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor
::-1
].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy

self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)

self.ets = []
self.counter = 0
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/schedulers/scheduling_sde_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def __init__(

self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)

def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
def set_timesteps(
self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.

Expand All @@ -101,7 +103,7 @@ def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
"""
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps

self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)

def set_sigmas(
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
Expand Down
7 changes: 3 additions & 4 deletions src/diffusers/schedulers/scheduling_sde_vp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit

import math
from typing import Union

import torch

Expand Down Expand Up @@ -52,8 +51,8 @@ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling
self.discrete_sigmas = None
self.timesteps = None

def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device)

def step_pred(self, score, x, t, generator=None):
if self.timesteps is None:
Expand Down
10 changes: 6 additions & 4 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def test_steps_offset(self):
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(5)
assert np.equal(scheduler.timesteps, np.array([801, 601, 401, 201, 1])).all()
assert torch.equal(scheduler.timesteps, torch.LongTensor([801, 601, 401, 201, 1]))

def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
Expand Down Expand Up @@ -568,10 +568,12 @@ def test_steps_offset(self):
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(10)
assert np.equal(
assert torch.equal(
scheduler.timesteps,
np.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
).all()
torch.LongTensor(
[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
),
)

def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
Expand Down