Skip to content

Commit b671cb0

Browse files
authored
Remove deprecated torch_device kwarg (#623)
* Remove deprecated `torch_device` kwarg. * Remove unused imports.
1 parent bb0c5d1 commit b671cb0

File tree

8 files changed

+1
-102
lines changed

8 files changed

+1
-102
lines changed

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616

17-
import warnings
1817
from typing import Optional, Tuple, Union
1918

2019
import torch
@@ -74,20 +73,6 @@ def __call__(
7473
generated images.
7574
"""
7675

77-
if "torch_device" in kwargs:
78-
device = kwargs.pop("torch_device")
79-
warnings.warn(
80-
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
81-
" Consider using `pipe.to(torch_device)` instead."
82-
)
83-
84-
# Set device as before (to be removed in 0.3.0)
85-
if device is None:
86-
device = "cuda" if torch.cuda.is_available() else "cpu"
87-
self.to(device)
88-
89-
# eta corresponds to η in paper and should be between [0, 1]
90-
9176
# Sample gaussian noise to begin loop
9277
image = torch.randn(
9378
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
@@ -103,6 +88,7 @@ def __call__(
10388
model_output = self.unet(image, t).sample
10489

10590
# 2. predict previous mean of image x_t-1 and add variance depending on eta
91+
# eta corresponds to η in paper and should be between [0, 1]
10692
# do x_t -> x_t-1
10793
image = self.scheduler.step(model_output, t, image, eta).prev_sample
10894

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616

17-
import warnings
1817
from typing import Optional, Tuple, Union
1918

2019
import torch
@@ -66,17 +65,6 @@ def __call__(
6665
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
6766
generated images.
6867
"""
69-
if "torch_device" in kwargs:
70-
device = kwargs.pop("torch_device")
71-
warnings.warn(
72-
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
73-
" Consider using `pipe.to(torch_device)` instead."
74-
)
75-
76-
# Set device as before (to be removed in 0.3.0)
77-
if device is None:
78-
device = "cuda" if torch.cuda.is_available() else "cpu"
79-
self.to(device)
8068

8169
# Sample gaussian noise to begin loop
8270
image = torch.randn(

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import inspect
2-
import warnings
32
from typing import List, Optional, Tuple, Union
43

54
import torch
@@ -94,17 +93,6 @@ def __call__(
9493
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
9594
generated images.
9695
"""
97-
if "torch_device" in kwargs:
98-
device = kwargs.pop("torch_device")
99-
warnings.warn(
100-
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
101-
" Consider using `pipe.to(torch_device)` instead."
102-
)
103-
104-
# Set device as before (to be removed in 0.3.0)
105-
if device is None:
106-
device = "cuda" if torch.cuda.is_available() else "cpu"
107-
self.to(device)
10896

10997
if isinstance(prompt, str):
11098
batch_size = 1

src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import inspect
2-
import warnings
32
from typing import Optional, Tuple, Union
43

54
import torch
@@ -60,18 +59,6 @@ def __call__(
6059
generated images.
6160
"""
6261

63-
if "torch_device" in kwargs:
64-
device = kwargs.pop("torch_device")
65-
warnings.warn(
66-
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
67-
" Consider using `pipe.to(torch_device)` instead."
68-
)
69-
70-
# Set device as before (to be removed in 0.3.0)
71-
if device is None:
72-
device = "cuda" if torch.cuda.is_available() else "cpu"
73-
self.to(device)
74-
7562
latents = torch.randn(
7663
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
7764
generator=generator,

src/diffusers/pipelines/pndm/pipeline_pndm.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616

17-
import warnings
1817
from typing import Optional, Tuple, Union
1918

2019
import torch
@@ -75,18 +74,6 @@ def __call__(
7574
# For more information on the sampling method you can take a look at Algorithm 2 of
7675
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
7776

78-
if "torch_device" in kwargs:
79-
device = kwargs.pop("torch_device")
80-
warnings.warn(
81-
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
82-
" Consider using `pipe.to(torch_device)` instead."
83-
)
84-
85-
# Set device as before (to be removed in 0.3.0)
86-
if device is None:
87-
device = "cuda" if torch.cuda.is_available() else "cpu"
88-
self.to(device)
89-
9077
# Sample gaussian noise to begin loop
9178
image = torch.randn(
9279
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),

src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
2-
import warnings
32
from typing import Optional, Tuple, Union
43

54
import torch
@@ -53,18 +52,6 @@ def __call__(
5352
generated images.
5453
"""
5554

56-
if "torch_device" in kwargs:
57-
device = kwargs.pop("torch_device")
58-
warnings.warn(
59-
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
60-
" Consider using `pipe.to(torch_device)` instead."
61-
)
62-
63-
# Set device as before (to be removed in 0.3.0)
64-
if device is None:
65-
device = "cuda" if torch.cuda.is_available() else "cpu"
66-
self.to(device)
67-
6855
img_size = self.unet.config.sample_size
6956
shape = (batch_size, 3, img_size, img_size)
7057

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,6 @@ def __call__(
169169
(nsfw) content, according to the `safety_checker`.
170170
"""
171171

172-
if "torch_device" in kwargs:
173-
device = kwargs.pop("torch_device")
174-
warnings.warn(
175-
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
176-
" Consider using `pipe.to(torch_device)` instead."
177-
)
178-
179-
# Set device as before (to be removed in 0.3.0)
180-
if device is None:
181-
device = "cuda" if torch.cuda.is_available() else "cpu"
182-
self.to(device)
183-
184172
if isinstance(prompt, str):
185173
batch_size = 1
186174
elif isinstance(prompt, list):

src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
2-
import warnings
32
from typing import Optional, Tuple, Union
43

54
import torch
@@ -64,17 +63,6 @@ def __call__(
6463
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
6564
generated images.
6665
"""
67-
if "torch_device" in kwargs:
68-
device = kwargs.pop("torch_device")
69-
warnings.warn(
70-
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
71-
" Consider using `pipe.to(torch_device)` instead."
72-
)
73-
74-
# Set device as before (to be removed in 0.3.0)
75-
if device is None:
76-
device = "cuda" if torch.cuda.is_available() else "cpu"
77-
self.to(device)
7866

7967
img_size = self.unet.config.sample_size
8068
shape = (batch_size, 3, img_size, img_size)

0 commit comments

Comments
 (0)