Skip to content

Commit ad8c9a6

Browse files
[Flax] Add test (huggingface#824)
1 parent 696fb7d commit ad8c9a6

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(
5252
dtype: jnp.dtype = jnp.float32,
5353
):
5454
super().__init__()
55-
scheduler = scheduler.set_format("np")
5655
self.dtype = dtype
5756

5857
self.register_modules(

utils/testing_utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,27 @@
77
from pathlib import Path
88
from typing import Union
99

10-
import torch
11-
1210
import PIL.Image
1311
import PIL.ImageOps
1412
import requests
1513
from packaging import version
1614

17-
from .import_utils import is_flax_available
15+
from .import_utils import is_flax_available, is_torch_available
1816

1917

2018
global_rng = random.Random()
21-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
22-
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
2319

24-
if is_torch_higher_equal_than_1_12:
25-
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
20+
21+
if is_torch_available():
22+
import torch
23+
24+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
25+
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse(
26+
"1.12"
27+
)
28+
29+
if is_torch_higher_equal_than_1_12:
30+
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
2631

2732

2833
def get_tests_dir(append_path=None):

0 commit comments

Comments
 (0)