File tree Expand file tree Collapse file tree 2 files changed +12
-8
lines changed
pipelines/stable_diffusion Expand file tree Collapse file tree 2 files changed +12
-8
lines changed Original file line number Diff line number Diff line change @@ -52,7 +52,6 @@ def __init__(
5252 dtype : jnp .dtype = jnp .float32 ,
5353 ):
5454 super ().__init__ ()
55- scheduler = scheduler .set_format ("np" )
5655 self .dtype = dtype
5756
5857 self .register_modules (
Original file line number Diff line number Diff line change 77from pathlib import Path
88from typing import Union
99
10- import torch
11-
1210import PIL .Image
1311import PIL .ImageOps
1412import requests
1513from packaging import version
1614
17- from .import_utils import is_flax_available
15+ from .import_utils import is_flax_available , is_torch_available
1816
1917
2018global_rng = random .Random ()
21- torch_device = "cuda" if torch .cuda .is_available () else "cpu"
22- is_torch_higher_equal_than_1_12 = version .parse (version .parse (torch .__version__ ).base_version ) >= version .parse ("1.12" )
2319
24- if is_torch_higher_equal_than_1_12 :
25- torch_device = "mps" if torch .backends .mps .is_available () else torch_device
20+
21+ if is_torch_available ():
22+ import torch
23+
24+ torch_device = "cuda" if torch .cuda .is_available () else "cpu"
25+ is_torch_higher_equal_than_1_12 = version .parse (version .parse (torch .__version__ ).base_version ) >= version .parse (
26+ "1.12"
27+ )
28+
29+ if is_torch_higher_equal_than_1_12 :
30+ torch_device = "mps" if torch .backends .mps .is_available () else torch_device
2631
2732
2833def get_tests_dir (append_path = None ):
You can’t perform that action at this time.
0 commit comments