Skip to content

Commit 00d5a51

Browse files
committed
self.scheduler.set_timesteps now uses device arg for schedulers that accept it
1 parent f25f1c1 commit 00d5a51

File tree

11 files changed

+12
-13
lines changed

11 files changed

+12
-13
lines changed

examples/community/clip_guided_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def __call__(
259259
if accepts_offset:
260260
extra_set_kwargs["offset"] = 1
261261

262-
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
262+
self.scheduler.set_timesteps(num_inference_steps, device=self.device, **extra_set_kwargs)
263263

264264
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
265265
if isinstance(self.scheduler, LMSDiscreteScheduler):

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __call__(
8080
image = image.to(self.device)
8181

8282
# set step values
83-
self.scheduler.set_timesteps(num_inference_steps)
83+
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
8484

8585
for t in self.progress_bar(self.scheduler.timesteps):
8686
# 1. predict noise model_output

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __call__(
7373
image = image.to(self.device)
7474

7575
# set step values
76-
self.scheduler.set_timesteps(1000)
76+
self.scheduler.set_timesteps(1000, device=self.device)
7777

7878
for t in self.progress_bar(self.scheduler.timesteps):
7979
# 1. predict noise model_output

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __call__(
118118
)
119119
latents = latents.to(self.device)
120120

121-
self.scheduler.set_timesteps(num_inference_steps)
121+
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
122122

123123
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
124124
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())

src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __call__(
6464
)
6565
latents = latents.to(self.device)
6666

67-
self.scheduler.set_timesteps(num_inference_steps)
67+
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
6868

6969
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
7070
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())

src/diffusers/pipelines/pndm/pipeline_pndm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616

17+
from os import device_encoding
1718
from typing import Optional, Tuple, Union
1819

1920
import torch
@@ -80,7 +81,7 @@ def __call__(
8081
)
8182
image = image.to(self.device)
8283

83-
self.scheduler.set_timesteps(num_inference_steps)
84+
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
8485
for t in self.progress_bar(self.scheduler.timesteps):
8586
model_output = self.unet(image, t).sample
8687

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,7 @@ def __call__(
233233
latents = latents.to(latents_device)
234234

235235
# set timesteps
236-
self.scheduler.set_timesteps(num_inference_steps)
237-
if isinstance(self.scheduler.timesteps, torch.Tensor):
238-
self.scheduler.timesteps = self.scheduler.timesteps.to(self.device)
236+
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
239237

240238
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
241239
if isinstance(self.scheduler, LMSDiscreteScheduler):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def __call__(
189189
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
190190

191191
# set timesteps
192-
self.scheduler.set_timesteps(num_inference_steps)
192+
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
193193

194194
if isinstance(init_image, PIL.Image.Image):
195195
init_image = preprocess(init_image)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def __call__(
209209
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
210210

211211
# set timesteps
212-
self.scheduler.set_timesteps(num_inference_steps)
212+
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
213213

214214
# preprocess image
215215
if not isinstance(init_image, torch.FloatTensor):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __call__(
111111
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
112112

113113
# set timesteps
114-
self.scheduler.set_timesteps(num_inference_steps)
114+
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
115115

116116
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
117117
if isinstance(self.scheduler, LMSDiscreteScheduler):

0 commit comments

Comments
 (0)