Skip to content

Commit f25f1c1

Browse files
committed
add device to set_timesteps in LMSD scheduler
1 parent e3c38e8 commit f25f1c1

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,22 +131,24 @@ def lms_derivative(tau):
131131

132132
return integrated_coeff
133133

134-
def set_timesteps(self, num_inference_steps: int):
134+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
135135
"""
136136
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
137137
138138
Args:
139139
num_inference_steps (`int`):
140140
the number of diffusion steps used when generating samples with a pre-trained model.
141+
device (`str` or `torch.device`, optional):
142+
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
141143
"""
142144
self.num_inference_steps = num_inference_steps
143145

144146
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
145147
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
146148
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
147149
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
148-
self.sigmas = torch.from_numpy(sigmas)
149-
self.timesteps = torch.from_numpy(timesteps)
150+
self.sigmas = torch.from_numpy(sigmas).to(device=device)
151+
self.timesteps = torch.from_numpy(timesteps).to(device=device)
150152

151153
self.derivatives = []
152154

0 commit comments

Comments
 (0)