Skip to content

Commit cd179a6

Browse files
zledasanton-l
authored andcommitted
Stable Diffusion image-to-image and inpaint using onnx. (huggingface#552)
* * Stabe Diffusion img2img using onnx. * * Stabe Diffusion inpaint using onnx. * Export vae_encoder, upgrade img2img, add test * updated inpainting pipeline + test * style Co-authored-by: anton-l <[email protected]>
1 parent 1eddce8 commit cd179a6

File tree

8 files changed

+837
-2
lines changed

8 files changed

+837
-2
lines changed

scripts/convert_stable_diffusion_checkpoint_to_onnx.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
9393
},
9494
opset=opset,
9595
)
96+
del pipeline.text_encoder
9697

9798
# UNET
9899
unet_path = output_path / "unet" / "model.onnx"
@@ -125,6 +126,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
125126
location="weights.pb",
126127
convert_attribute=False,
127128
)
129+
del pipeline.unet
128130

129131
# VAE ENCODER
130132
vae_encoder = pipeline.vae
@@ -157,6 +159,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
157159
},
158160
opset=opset,
159161
)
162+
del pipeline.vae
160163

161164
# SAFETY CHECKER
162165
safety_checker = pipeline.safety_checker
@@ -173,8 +176,10 @@ def convert_models(model_path: str, output_path: str, opset: int):
173176
},
174177
opset=opset,
175178
)
179+
del pipeline.safety_checker
176180

177181
onnx_pipeline = StableDiffusionOnnxPipeline(
182+
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
178183
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
179184
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
180185
tokenizer=pipeline.tokenizer,
@@ -187,6 +192,8 @@ def convert_models(model_path: str, output_path: str, opset: int):
187192
onnx_pipeline.save_pretrained(output_path)
188193
print("ONNX pipeline saved to", output_path)
189194

195+
del pipeline
196+
del onnx_pipeline
190197
_ = StableDiffusionOnnxPipeline.from_pretrained(output_path, provider="CPUExecutionProvider")
191198
print("ONNX pipeline is loadable")
192199

src/diffusers/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@
5858
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
5959

6060
if is_torch_available() and is_transformers_available() and is_onnx_available():
61-
from .pipelines import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
61+
from .pipelines import (
62+
OnnxStableDiffusionImg2ImgPipeline,
63+
OnnxStableDiffusionInpaintPipeline,
64+
OnnxStableDiffusionPipeline,
65+
StableDiffusionOnnxPipeline,
66+
)
6267
else:
6368
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
6469

src/diffusers/pipelines/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
)
2121

2222
if is_transformers_available() and is_onnx_available():
23-
from .stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
23+
from .stable_diffusion import (
24+
OnnxStableDiffusionImg2ImgPipeline,
25+
OnnxStableDiffusionInpaintPipeline,
26+
OnnxStableDiffusionPipeline,
27+
StableDiffusionOnnxPipeline,
28+
)
2429

2530
if is_transformers_available() and is_flax_available():
2631
from .stable_diffusion import FlaxStableDiffusionPipeline

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class StableDiffusionPipelineOutput(BaseOutput):
3535

3636
if is_transformers_available() and is_onnx_available():
3737
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
38+
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
39+
from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline
3840

3941
if is_transformers_available() and is_flax_available():
4042
import flax

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
2626

2727
def __init__(
2828
self,
29+
vae_encoder: OnnxRuntimeModel,
2930
vae_decoder: OnnxRuntimeModel,
3031
text_encoder: OnnxRuntimeModel,
3132
tokenizer: CLIPTokenizer,
@@ -36,6 +37,7 @@ def __init__(
3637
):
3738
super().__init__()
3839
self.register_modules(
40+
vae_encoder=vae_encoder,
3941
vae_decoder=vae_decoder,
4042
text_encoder=text_encoder,
4143
tokenizer=tokenizer,

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

Lines changed: 361 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 385 additions & 0 deletions
Large diffs are not rendered by default.

tests/test_pipelines.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
LDMPipeline,
3838
LDMTextToImagePipeline,
3939
LMSDiscreteScheduler,
40+
OnnxStableDiffusionImg2ImgPipeline,
41+
OnnxStableDiffusionInpaintPipeline,
4042
OnnxStableDiffusionPipeline,
4143
PNDMPipeline,
4244
PNDMScheduler,
@@ -2025,6 +2027,72 @@ def test_stable_diffusion_onnx(self):
20252027
expected_slice = np.array([0.3602, 0.3688, 0.3652, 0.3895, 0.3782, 0.3747, 0.3927, 0.4241, 0.4327])
20262028
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
20272029

2030+
@slow
2031+
def test_stable_diffusion_img2img_onnx(self):
2032+
init_image = load_image(
2033+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
2034+
"/img2img/sketch-mountains-input.jpg"
2035+
)
2036+
init_image = init_image.resize((768, 512))
2037+
2038+
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
2039+
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
2040+
)
2041+
pipe.set_progress_bar_config(disable=None)
2042+
2043+
prompt = "A fantasy landscape, trending on artstation"
2044+
2045+
np.random.seed(0)
2046+
output = pipe(
2047+
prompt=prompt,
2048+
init_image=init_image,
2049+
strength=0.75,
2050+
guidance_scale=7.5,
2051+
num_inference_steps=8,
2052+
output_type="np",
2053+
)
2054+
images = output.images
2055+
image_slice = images[0, 255:258, 383:386, -1]
2056+
2057+
assert images.shape == (1, 512, 768, 3)
2058+
expected_slice = np.array([[0.4806, 0.5125, 0.5453, 0.4846, 0.4984, 0.4955, 0.4830, 0.4962, 0.4969]])
2059+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
2060+
2061+
@slow
2062+
def test_stable_diffusion_inpaint_onnx(self):
2063+
init_image = load_image(
2064+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
2065+
"/in_paint/overture-creations-5sI6fQgYIuo.png"
2066+
)
2067+
mask_image = load_image(
2068+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
2069+
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
2070+
)
2071+
2072+
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
2073+
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
2074+
)
2075+
pipe.set_progress_bar_config(disable=None)
2076+
2077+
prompt = "A red cat sitting on a park bench"
2078+
2079+
np.random.seed(0)
2080+
output = pipe(
2081+
prompt=prompt,
2082+
init_image=init_image,
2083+
mask_image=mask_image,
2084+
strength=0.75,
2085+
guidance_scale=7.5,
2086+
num_inference_steps=8,
2087+
output_type="np",
2088+
)
2089+
images = output.images
2090+
image_slice = images[0, 255:258, 255:258, -1]
2091+
2092+
assert images.shape == (1, 512, 512, 3)
2093+
expected_slice = np.array([0.3524, 0.3289, 0.3464, 0.3872, 0.4129, 0.3566, 0.3709, 0.4128, 0.3734])
2094+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
2095+
20282096
@slow
20292097
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
20302098
def test_stable_diffusion_text2img_intermediate_state(self):

0 commit comments

Comments
 (0)