Skip to content

Commit ba352ae

Browse files
yiyixuxuokotakusayakpaulyiyixuxupatrickvonplaten
authored
[feat] IP Adapters (author @okotaku ) (#5713)
* add ip-adapter --------- Co-authored-by: okotaku <[email protected]> Co-authored-by: sayakpaul <[email protected]> Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Steven Liu <[email protected]>
1 parent 6fac136 commit ba352ae

40 files changed

+1755
-63
lines changed

docs/source/en/using-diffusers/loading_adapters.md

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,331 @@ prompt = "a house by william eggleston, sunrays, beautiful, sunlight, sunrays, b
307307
image = pipeline(prompt=prompt).images[0]
308308
image
309309
```
310+
311+
## IP-Adapter
312+
313+
[IP-Adapter](https://ip-adapter.github.io/) is an effective and lightweight adapter that adds image prompting capabilities to a diffusion model. This adapter works by decoupling the cross-attention layers of the image and text features. All the other model components are frozen and only the embedded image features in the UNet are trained. As a result, IP-Adapter files are typically only ~100MBs.
314+
315+
IP-Adapter works with most of our pipelines, including Stable Diffusion, Stable Diffusion XL (SDXL), ControlNet, T2I-Adapter, AnimateDiff. And you can use any custom models finetuned from the same base models. It also works with LCM-Lora out of box.
316+
317+
318+
<Tip>
319+
320+
You can find official IP-Adapter checkpoints in [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter).
321+
322+
IP-Adapter was contributed by [okotaku](https://github.com/okotaku).
323+
324+
</Tip>
325+
326+
Let's first create a Stable Diffusion Pipeline.
327+
328+
```py
329+
from diffusers import AutoPipelineForText2Image
330+
import torch
331+
from diffusers.utils import load_image
332+
333+
334+
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
335+
```
336+
337+
Now load the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) weights with the [`~loaders.IPAdapterMixin.load_ip_adapter`] method.
338+
339+
```py
340+
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
341+
```
342+
343+
<Tip>
344+
IP-Adapter relies on an image encoder to generate the image features, if your IP-Adapter weights folder contains a "image_encoder" subfolder, the image encoder will be automatically loaded and registered to the pipeline. Otherwise you can so load a [`~transformers.CLIPVisionModelWithProjection`] model and pass it to a Stable Diffusion pipeline when you create it.
345+
346+
```py
347+
from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection
348+
import torch
349+
350+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
351+
"h94/IP-Adapter",
352+
subfolder="models/image_encoder",
353+
torch_dtype=torch.float16,
354+
).to("cuda")
355+
356+
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, torch_dtype=torch.float16).to("cuda")
357+
```
358+
</Tip>
359+
360+
IP-Adapter allows you to use both image and text to condition the image generation process. For example, let's use the bear image from the [Textual Inversion](#textual-inversion) section as the image prompt (`ip_adapter_image`) along with a text prompt to add "sunglasses". 😎
361+
362+
```py
363+
pipeline.set_ip_adapter_scale(0.6)
364+
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
365+
generator = torch.Generator(device="cpu").manual_seed(33)
366+
images = pipeline(
367+
    prompt='best quality, high quality, wearing sunglasses',
368+
    ip_adapter_image=image,
369+
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
370+
    num_inference_steps=50,
371+
    generator=generator,
372+
).images
373+
images[0]
374+
```
375+
376+
<div class="flex justify-center">
377+
    <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip-bear.png" />
378+
</div>
379+
380+
<Tip>
381+
382+
You can use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method to adjust the text prompt and image prompt condition ratio.  If you're only using the image prompt, you should set the scale to `1.0`. You can lower the scale to get more generation diversity, but it'll be less aligned with the prompt.
383+
`scale=0.5` can achieve good results in most cases when you use both text and image prompts.
384+
</Tip>
385+
386+
IP-Adapter also works great with Image-to-Image and Inpainting pipelines. See below examples of how you can use it with Image-to-Image and Inpaint.
387+
388+
<hfoptions id="tasks">
389+
<hfoption id="image-to-image">
390+
391+
```py
392+
from diffusers import AutoPipelineForImage2Image
393+
import torch
394+
from diffusers.utils import load_image
395+
396+
pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
397+
398+
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg")
399+
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png")
400+
401+
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
402+
generator = torch.Generator(device="cpu").manual_seed(33)
403+
images = pipeline(
404+
    prompt='best quality, high quality',
405+
    image = image,
406+
    ip_adapter_image=ip_image,
407+
    num_inference_steps=50,
408+
    generator=generator,
409+
    strength=0.6,
410+
).images
411+
images[0]
412+
```
413+
414+
</hfoption>
415+
<hfoption id="inpaint">
416+
417+
```py
418+
from diffusers import AutoPipelineForInpaint
419+
import torch
420+
from diffusers.utils import load_image
421+
422+
pipeline = AutoPipelineForInpaint.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float).to("cuda")
423+
424+
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png")
425+
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png")
426+
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png")
427+
428+
image = image.resize((512, 768))
429+
mask = mask.resize((512, 768))
430+
431+
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
432+
433+
generator = torch.Generator(device="cpu").manual_seed(33)
434+
images = pipeline(
435+
prompt='best quality, high quality',
436+
image = image,
437+
mask_image = mask,
438+
ip_adapter_image=ip_image,
439+
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
440+
num_inference_steps=50,
441+
generator=generator,
442+
strength=0.5,
443+
).images
444+
images[0]
445+
```
446+
</hfoption>
447+
</hfoptions>
448+
449+
450+
IP-Adapters can also be used with [SDXL](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
451+
452+
```python
453+
from diffusers import AutoPipelineForText2Image
454+
from diffusers.utils import load_image
455+
import torch
456+
457+
pipeline = AutoPipelineForText2Image.from_pretrained(
458+
"stabilityai/stable-diffusion-xl-base-1.0",
459+
torch_dtype=torch.float16
460+
).to("cuda")
461+
462+
image = load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg")
463+
464+
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
465+
466+
generator = torch.Generator(device="cpu").manual_seed(33)
467+
image = pipeline(
468+
prompt="best quality, high quality",
469+
ip_adapter_image=image,
470+
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
471+
num_inference_steps=25,
472+
generator=generator,
473+
).images[0]
474+
image.save("sdxl_t2i.png")
475+
```
476+
477+
<div class="flex flex-row gap-4">
478+
<div class="flex-1">
479+
<img class="rounded-xl" src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg"/>
480+
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
481+
</div>
482+
<div class="flex-1">
483+
<img class="rounded-xl" src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/sdxl_t2i.png"/>
484+
<figcaption class="mt-2 text-center text-sm text-gray-500">adapted image</figcaption>
485+
</div>
486+
</div>
487+
488+
489+
### LCM-Lora
490+
491+
You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights.
492+
493+
```py
494+
from diffusers import DiffusionPipeline, LCMScheduler
495+
import torch
496+
from diffusers.utils import load_image
497+
498+
model_id = "sd-dreambooth-library/herge-style"
499+
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
500+
501+
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
502+
503+
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
504+
pipe.load_lora_weights(lcm_lora_id)
505+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
506+
pipe.enable_model_cpu_offload()
507+
508+
prompt = "best quality, high quality"
509+
image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")
510+
images = pipe(
511+
prompt=prompt,
512+
ip_adapter_image=image,
513+
num_inference_steps=4,
514+
guidance_scale=1,
515+
).images[0]
516+
```
517+
518+
### Other pipelines
519+
520+
IP-Adapter is compatible with any pipeline that (1) uses a text prompt and (2) uses Stable Diffusion or Stable Diffusion XL checkpoint. To use IP-Adapter with a different pipeline, all you need to do is to run `load_ip_adapter()` method after you create the pipeline, and then pass your image to the pipeline as `ip_adapter_image`
521+
522+
<Tip>
523+
524+
🤗 Diffusers currently only supports using IP-Adapter with some of the most popular pipelines, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require integrating IP-adapters with a pipeline that does not support it yet!
525+
526+
</Tip>
527+
528+
You can find below examples on how to use IP-Adapter with ControlNet and AnimateDiff.
529+
530+
<hfoptions id="model">
531+
<hfoption id="ControlNet">
532+
533+
```
534+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
535+
import torch
536+
from diffusers.utils import load_image
537+
538+
controlnet_model_path = "lllyasviel/control_v11f1p_sd15_depth"
539+
controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16)
540+
541+
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
542+
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16)
543+
pipeline.to("cuda")
544+
545+
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png")
546+
depth_map = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/depth.png")
547+
548+
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
549+
550+
generator = torch.Generator(device="cpu").manual_seed(33)
551+
images = pipeline(
552+
prompt='best quality, high quality',
553+
image=depth_map,
554+
ip_adapter_image=image,
555+
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
556+
num_inference_steps=50,
557+
generator=generator,
558+
).images
559+
images[0]
560+
```
561+
<div class="flex flex-row gap-4">
562+
<div class="flex-1">
563+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png"/>
564+
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
565+
</div>
566+
<div class="flex-1">
567+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png"/>
568+
<figcaption class="mt-2 text-center text-sm text-gray-500">adapted image</figcaption>
569+
</div>
570+
</div>
571+
572+
</hfoption>
573+
<hfoption id="AnimateDiff">
574+
575+
```py
576+
# animate diff + ip adapter
577+
import torch
578+
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
579+
from diffusers.utils import export_to_gif, load_image
580+
581+
# Load the motion adapter
582+
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
583+
# load SD 1.5 based finetuned model
584+
model_id = "Lykon/DreamShaper"
585+
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
586+
587+
# scheduler
588+
scheduler = DDIMScheduler(
589+
clip_sample=False,
590+
beta_start=0.00085,
591+
beta_end=0.012,
592+
beta_schedule="linear",
593+
timestep_spacing="trailing",
594+
steps_offset=1
595+
)
596+
pipe.scheduler = scheduler
597+
598+
# enable memory savings
599+
pipe.enable_vae_slicing()
600+
pipe.enable_model_cpu_offload()
601+
602+
# load ip_adapter
603+
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
604+
605+
# load motion adapters
606+
pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
607+
pipe.load_lora_weights("guoyww/animatediff-motion-lora-tilt-up", adapter_name="tilt-up")
608+
pipe.load_lora_weights("guoyww/animatediff-motion-lora-pan-left", adapter_name="pan-left")
609+
610+
seed = 42
611+
image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")
612+
images = [image] * 3
613+
prompts = ["best quality, high quality"] * 3
614+
negative_prompt = "bad quality, worst quality"
615+
adapter_weights = [[0.75, 0.0, 0.0], [0.0, 0.0, 0.75], [0.0, 0.75, 0.75]]
616+
617+
# generate
618+
output_frames = []
619+
for prompt, image, adapter_weight in zip(prompts, images, adapter_weights):
620+
pipe.set_adapters(["zoom-out", "tilt-up", "pan-left"], adapter_weights=adapter_weight)
621+
output = pipe(
622+
prompt= prompt,
623+
num_frames=16,
624+
guidance_scale=7.5,
625+
num_inference_steps=30,
626+
ip_adapter_image = image,
627+
generator=torch.Generator("cpu").manual_seed(seed),
628+
)
629+
frames = output.frames[0]
630+
output_frames.extend(frames)
631+
632+
export_to_gif(output_frames, "test_out_animation.gif")
633+
```
634+
635+
</hfoption>
636+
</hfoptions>
637+

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def text_encoder_attn_modules(text_encoder):
6262
_import_structure["single_file"].extend(["FromSingleFileMixin"])
6363
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
6464
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
65+
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
6566

6667

6768
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -72,6 +73,7 @@ def text_encoder_attn_modules(text_encoder):
7273
from .utils import AttnProcsLayers
7374

7475
if is_transformers_available():
76+
from .ip_adapter import IPAdapterMixin
7577
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
7678
from .single_file import FromSingleFileMixin
7779
from .textual_inversion import TextualInversionLoaderMixin

0 commit comments

Comments
 (0)