Skip to content

Commit f1d4289

Browse files
[Flax] Add test (#824)
1 parent 323a9e1 commit f1d4289

File tree

3 files changed

+74
-8
lines changed

3 files changed

+74
-8
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(
5252
dtype: jnp.dtype = jnp.float32,
5353
):
5454
super().__init__()
55-
scheduler = scheduler.set_format("np")
5655
self.dtype = dtype
5756

5857
self.register_modules(

src/diffusers/utils/testing_utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,27 @@
77
from pathlib import Path
88
from typing import Union
99

10-
import torch
11-
1210
import PIL.Image
1311
import PIL.ImageOps
1412
import requests
1513
from packaging import version
1614

17-
from .import_utils import is_flax_available
15+
from .import_utils import is_flax_available, is_torch_available
1816

1917

2018
global_rng = random.Random()
21-
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
22-
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
2319

24-
if is_torch_higher_equal_than_1_12:
25-
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
20+
21+
if is_torch_available():
22+
import torch
23+
24+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
25+
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse(
26+
"1.12"
27+
)
28+
29+
if is_torch_higher_equal_than_1_12:
30+
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
2631

2732

2833
def get_tests_dir(append_path=None):

tests/test_pipelines_flax.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# coding=utf-8
2+
# Copyright 2022 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import numpy as np
19+
20+
from diffusers.utils import is_flax_available
21+
from diffusers.utils.testing_utils import require_flax, slow
22+
23+
24+
if is_flax_available():
25+
import jax
26+
from diffusers import FlaxStableDiffusionPipeline
27+
from flax.jax_utils import replicate
28+
from flax.training.common_utils import shard
29+
from jax import pmap
30+
31+
32+
@require_flax
33+
@slow
34+
class FlaxPipelineTests(unittest.TestCase):
35+
def test_dummy_all_tpus(self):
36+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
37+
"hf-internal-testing/tiny-stable-diffusion-pipe"
38+
)
39+
40+
prompt = (
41+
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
42+
" field, close up, split lighting, cinematic"
43+
)
44+
45+
prng_seed = jax.random.PRNGKey(0)
46+
num_inference_steps = 4
47+
48+
num_samples = jax.device_count()
49+
prompt = num_samples * [prompt]
50+
prompt_ids = pipeline.prepare_inputs(prompt)
51+
52+
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
53+
54+
# shard inputs and rng
55+
params = replicate(params)
56+
prng_seed = jax.random.split(prng_seed, 8)
57+
prompt_ids = shard(prompt_ids)
58+
59+
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
60+
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
61+
62+
assert len(images_pil) == 8

0 commit comments

Comments
 (0)