Skip to content

Commit 51fbd29

Browse files
committed
Un-hardcode "cuda" as default device name
Allow configuring with `SGM_DEFAULT_DEVICE`
1 parent 45c443b commit 51fbd29

File tree

5 files changed

+41
-17
lines changed

5 files changed

+41
-17
lines changed

sgm/models/diffusion.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from contextlib import contextmanager
1+
from contextlib import contextmanager, nullcontext
22
from typing import Any, Dict, List, Tuple, Union
33

44
import pytorch_lightning as pl
@@ -13,6 +13,7 @@
1313
from ..util import (
1414
default,
1515
disabled_train,
16+
get_default_device_name,
1617
get_obj_from_str,
1718
instantiate_from_config,
1819
log_txt_as_img,
@@ -114,16 +115,22 @@ def get_input(self, batch):
114115
# image tensors should be scaled to -1 ... 1 and in bchw format
115116
return batch[self.input_key]
116117

118+
def _first_stage_autocast_context(self):
119+
device = get_default_device_name()
120+
if device not in ("cpu", "cuda"):
121+
return nullcontext()
122+
return torch.autocast(device, enabled=not self.disable_first_stage_autocast)
123+
117124
@torch.no_grad()
118125
def decode_first_stage(self, z):
119126
z = 1.0 / self.scale_factor * z
120-
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
127+
with self._first_stage_autocast_context():
121128
out = self.first_stage_model.decode(z)
122129
return out
123130

124131
@torch.no_grad()
125132
def encode_first_stage(self, x):
126-
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
133+
with self._first_stage_autocast_context():
127134
z = self.first_stage_model.encode(x)
128135
z = self.scale_factor * z
129136
return z

sgm/modules/diffusionmodules/openaimodel.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
timestep_embedding,
2020
zero_module,
2121
)
22-
from ...util import default, exists
22+
from ...util import default, exists, get_default_device_name
2323

2424

2525
# dummy replace
@@ -1241,6 +1241,7 @@ def __init__(self, in_channels=3, model_channels=64):
12411241
]
12421242
)
12431243

1244+
device = get_default_device_name()
12441245
model = UNetModel(
12451246
use_checkpoint=True,
12461247
image_size=64,
@@ -1255,8 +1256,8 @@ def __init__(self, in_channels=3, model_channels=64):
12551256
use_linear_in_transformer=True,
12561257
transformer_depth=1,
12571258
legacy=False,
1258-
).cuda()
1259-
x = th.randn(11, 4, 64, 64).cuda()
1260-
t = th.randint(low=0, high=10, size=(11,), device="cuda")
1259+
).to(device)
1260+
x = th.randn(11, 4, 64, 64).to(device)
1261+
t = th.randint(low=0, high=10, size=(11,), device=device)
12611262
o = model(x, t)
12621263
print("done.")

sgm/modules/diffusionmodules/sampling.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
to_neg_log_sigma,
1717
to_sigma,
1818
)
19-
from ...util import append_dims, default, instantiate_from_config
19+
from ...util import append_dims, default, instantiate_from_config, get_default_device_name
2020

2121
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
2222

@@ -28,8 +28,10 @@ def __init__(
2828
num_steps: Union[int, None] = None,
2929
guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
3030
verbose: bool = False,
31-
device: str = "cuda",
31+
device: Union[str, None] = None,
3232
):
33+
if device is None:
34+
device = get_default_device_name()
3335
self.num_steps = num_steps
3436
self.discretization = instantiate_from_config(discretization_config)
3537
self.guider = instantiate_from_config(

sgm/modules/encoders/modules.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
default,
3030
disabled_train,
3131
expand_dims_like,
32+
get_default_device_name,
3233
instantiate_from_config,
3334
)
3435

@@ -236,7 +237,9 @@ def forward(self, c):
236237
c = c[:, None, :]
237238
return c
238239

239-
def get_unconditional_conditioning(self, bs, device="cuda"):
240+
def get_unconditional_conditioning(self, bs, device=None):
241+
if device is None:
242+
device = get_default_device_name()
240243
uc_class = (
241244
self.n_classes - 1
242245
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
@@ -261,9 +264,10 @@ class FrozenT5Embedder(AbstractEmbModel):
261264
"""Uses the T5 transformer encoder for text"""
262265

263266
def __init__(
264-
self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
267+
self, version="google/t5-v1_1-xxl", device=None, max_length=77, freeze=True
265268
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
266269
super().__init__()
270+
device = device or get_default_device_name()
267271
self.tokenizer = T5Tokenizer.from_pretrained(version)
268272
self.transformer = T5EncoderModel.from_pretrained(version)
269273
self.device = device
@@ -304,9 +308,10 @@ class FrozenByT5Embedder(AbstractEmbModel):
304308
"""
305309

306310
def __init__(
307-
self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
311+
self, version="google/byt5-base", device=None, max_length=77, freeze=True
308312
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
309313
super().__init__()
314+
device = device or get_default_device_name()
310315
self.tokenizer = ByT5Tokenizer.from_pretrained(version)
311316
self.transformer = T5EncoderModel.from_pretrained(version)
312317
self.device = device
@@ -348,14 +353,15 @@ class FrozenCLIPEmbedder(AbstractEmbModel):
348353
def __init__(
349354
self,
350355
version="openai/clip-vit-large-patch14",
351-
device="cuda",
356+
device=None,
352357
max_length=77,
353358
freeze=True,
354359
layer="last",
355360
layer_idx=None,
356361
always_return_pooled=False,
357362
): # clip-vit-base-patch32
358363
super().__init__()
364+
device = device or get_default_device_name()
359365
assert layer in self.LAYERS
360366
self.tokenizer = CLIPTokenizer.from_pretrained(version)
361367
self.transformer = CLIPTextModel.from_pretrained(version)
@@ -416,14 +422,15 @@ def __init__(
416422
self,
417423
arch="ViT-H-14",
418424
version="laion2b_s32b_b79k",
419-
device="cuda",
425+
device=None,
420426
max_length=77,
421427
freeze=True,
422428
layer="last",
423429
always_return_pooled=False,
424430
legacy=True,
425431
):
426432
super().__init__()
433+
device = device or get_default_device_name()
427434
assert layer in self.LAYERS
428435
model, _, _ = open_clip.create_model_and_transforms(
429436
arch,
@@ -518,12 +525,13 @@ def __init__(
518525
self,
519526
arch="ViT-H-14",
520527
version="laion2b_s32b_b79k",
521-
device="cuda",
528+
device=None,
522529
max_length=77,
523530
freeze=True,
524531
layer="last",
525532
):
526533
super().__init__()
534+
device = device or get_default_device_name()
527535
assert layer in self.LAYERS
528536
model, _, _ = open_clip.create_model_and_transforms(
529537
arch, device=torch.device("cpu"), pretrained=version
@@ -588,7 +596,7 @@ def __init__(
588596
self,
589597
arch="ViT-H-14",
590598
version="laion2b_s32b_b79k",
591-
device="cuda",
599+
device=None,
592600
max_length=77,
593601
freeze=True,
594602
antialias=True,
@@ -599,6 +607,7 @@ def __init__(
599607
output_tokens=False,
600608
):
601609
super().__init__()
610+
device = device or get_default_device_name()
602611
model, _, _ = open_clip.create_model_and_transforms(
603612
arch,
604613
device=torch.device("cpu"),
@@ -744,11 +753,12 @@ def __init__(
744753
self,
745754
clip_version="openai/clip-vit-large-patch14",
746755
t5_version="google/t5-v1_1-xl",
747-
device="cuda",
756+
device=None,
748757
clip_max_length=77,
749758
t5_max_length=77,
750759
):
751760
super().__init__()
761+
device = device or get_default_device_name()
752762
self.clip_encoder = FrozenCLIPEmbedder(
753763
clip_version, device, max_length=clip_max_length
754764
)

sgm/util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from safetensors.torch import load_file as load_safetensors
1212

1313

14+
def get_default_device_name() -> str:
15+
return os.environ.get("SGM_DEFAULT_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
16+
17+
1418
def disabled_train(self, mode=True):
1519
"""Overwrite model.train with this function to make sure train/eval mode
1620
does not change anymore."""

0 commit comments

Comments
 (0)