Skip to content

Commit 12358b9

Browse files
MC-Esayakpaul
andauthored
add models for T2I-Adapter-XL (#4696)
* T2I-Adapter-XL * update * update * add pipeline * modify pipeline * modify pipeline * modify pipeline * modify pipeline * modify pipeline * modify modeling_text_unet * fix styling. * fix: copies. * adapter settings * new test case * new test case * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * revert prints. * new test case * remove print * org test case * add test_pipeline * styling. * fix copies. * modify test parameter * style. * add adapter-xl doc * double quotes in docs * Fix potential type mismatch * style. --------- Co-authored-by: sayakpaul <[email protected]>
1 parent 5eeedd9 commit 12358b9

File tree

10 files changed

+1252
-3
lines changed

10 files changed

+1252
-3
lines changed

docs/source/en/api/pipelines/stable_diffusion/adapter.md

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ This model was contributed by the community contributor [HimariO](https://github
2929
| Pipeline | Tasks | Demo
3030
|---|---|:---:|
3131
| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
32+
| [StableDiffusionXLAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_xl_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning on StableDiffusion-XL* | -
3233

33-
## Usage example
34+
## Usage example with the base model of StableDiffusion-1.4/1.5
3435

35-
In the following we give a simple example of how to use a *T2IAdapter* checkpoint with Diffusers for inference.
36+
In the following we give a simple example of how to use a *T2IAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
3637
All adapters use the same pipeline.
3738

3839
1. Images are first converted into the appropriate *control image* format.
@@ -93,6 +94,62 @@ out_image = pipe(
9394

9495
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_output.png)
9596

97+
## Usage example with the base model of StableDiffusion-XL
98+
99+
In the following we give a simple example of how to use a *T2IAdapter* checkpoint with Diffusers for inference based on StableDiffusion-XL.
100+
All adapters use the same pipeline.
101+
102+
1. Images are first downloaded into the appropriate *control image* format.
103+
2. The *control image* and *prompt* are passed to the [`StableDiffusionXLAdapterPipeline`].
104+
105+
Let's have a look at a simple example using the [Sketch Adapter](https://huggingface.co/Adapter/t2iadapter/tree/main/sketch_sdxl_1.0).
106+
107+
```python
108+
from diffusers.utils import load_image
109+
110+
sketch_image = load_image("https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch.png").convert("L")
111+
```
112+
113+
![img](https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch.png)
114+
115+
Then, create the adapter pipeline
116+
117+
```py
118+
import torch
119+
from diffusers import (
120+
T2IAdapter,
121+
StableDiffusionXLAdapterPipeline,
122+
DDPMScheduler
123+
)
124+
from diffusers.models.unet_2d_condition import UNet2DConditionModel
125+
126+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
127+
adapter = T2IAdapter.from_pretrained("Adapter/t2iadapter", subfolder="sketch_sdxl_1.0",torch_dtype=torch.float16, adapter_type="full_adapter_xl")
128+
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
129+
130+
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
131+
model_id, adapter=adapter, safety_checker=None, torch_dtype=torch.float16, variant="fp16", scheduler=scheduler
132+
)
133+
134+
pipe.to("cuda")
135+
```
136+
137+
Finally, pass the prompt and control image to the pipeline
138+
139+
```py
140+
# fix the random seed, so you will get the same result as the example
141+
generator = torch.Generator().manual_seed(42)
142+
143+
sketch_image_out = pipe(
144+
prompt="a photo of a dog in real world, high quality",
145+
negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
146+
image=sketch_image,
147+
generator=generator,
148+
guidance_scale=7.5
149+
).images[0]
150+
```
151+
152+
![img](https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch_output.png)
96153

97154
## Available checkpoints
98155

@@ -113,6 +170,9 @@ Non-diffusers checkpoints can be found under [TencentARC/T2I-Adapter](https://hu
113170
|[TencentARC/t2iadapter_depth_sd15v2](https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2)||
114171
|[TencentARC/t2iadapter_sketch_sd15v2](https://huggingface.co/TencentARC/t2iadapter_sketch_sd15v2)||
115172
|[TencentARC/t2iadapter_zoedepth_sd15v1](https://huggingface.co/TencentARC/t2iadapter_zoedepth_sd15v1)||
173+
|[Adapter/t2iadapter, subfolder='sketch_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/sketch_sdxl_1.0)||
174+
|[Adapter/t2iadapter, subfolder='canny_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/canny_sdxl_1.0)||
175+
|[Adapter/t2iadapter, subfolder='openpose_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/openpose_sdxl_1.0)||
116176

117177
## Combining multiple adapters
118178

@@ -185,3 +245,14 @@ However, T2I-Adapter performs slightly worse than ControlNet.
185245
- disable_vae_slicing
186246
- enable_xformers_memory_efficient_attention
187247
- disable_xformers_memory_efficient_attention
248+
249+
## StableDiffusionXLAdapterPipeline
250+
[[autodoc]] StableDiffusionXLAdapterPipeline
251+
- all
252+
- __call__
253+
- enable_attention_slicing
254+
- disable_attention_slicing
255+
- enable_vae_slicing
256+
- disable_vae_slicing
257+
- enable_xformers_memory_efficient_attention
258+
- disable_xformers_memory_efficient_attention

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@
191191
StableDiffusionPix2PixZeroPipeline,
192192
StableDiffusionSAGPipeline,
193193
StableDiffusionUpscalePipeline,
194+
StableDiffusionXLAdapterPipeline,
194195
StableDiffusionXLControlNetImg2ImgPipeline,
195196
StableDiffusionXLControlNetPipeline,
196197
StableDiffusionXLImg2ImgPipeline,

src/diffusers/models/adapter.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def __init__(
128128

129129
if adapter_type == "full_adapter":
130130
self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
131+
elif adapter_type == "full_adapter_xl":
132+
self.adapter = FullAdapterXL(in_channels, channels, num_res_blocks, downscale_factor)
131133
elif adapter_type == "light_adapter":
132134
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
133135
else:
@@ -184,6 +186,48 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
184186
return features
185187

186188

189+
class FullAdapterXL(nn.Module):
190+
def __init__(
191+
self,
192+
in_channels: int = 3,
193+
channels: List[int] = [320, 640, 1280, 1280],
194+
num_res_blocks: int = 2,
195+
downscale_factor: int = 16,
196+
):
197+
super().__init__()
198+
199+
in_channels = in_channels * downscale_factor**2
200+
201+
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
202+
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
203+
204+
self.body = []
205+
# blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32]
206+
for i in range(len(channels)):
207+
if i == 1:
208+
self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks))
209+
elif i == 2:
210+
self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True))
211+
else:
212+
self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
213+
214+
self.body = nn.ModuleList(self.body)
215+
# XL has one fewer downsampling
216+
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 2)
217+
218+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
219+
x = self.unshuffle(x)
220+
x = self.conv_in(x)
221+
222+
features = []
223+
224+
for block in self.body:
225+
x = block(x)
226+
features.append(x)
227+
228+
return features
229+
230+
187231
class AdapterBlock(nn.Module):
188232
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
189233
super().__init__()

src/diffusers/models/unet_2d_condition.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,13 @@ def forward(
965965
cross_attention_kwargs=cross_attention_kwargs,
966966
encoder_attention_mask=encoder_attention_mask,
967967
)
968+
# To support T2I-Adapter-XL
969+
if (
970+
is_adapter
971+
and len(down_block_additional_residuals) > 0
972+
and sample.shape == down_block_additional_residuals[0].shape
973+
):
974+
sample += down_block_additional_residuals.pop(0)
968975

969976
if is_controlnet:
970977
sample = sample + mid_block_additional_residual

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@
118118
StableDiffusionXLInstructPix2PixPipeline,
119119
StableDiffusionXLPipeline,
120120
)
121-
from .t2i_adapter import StableDiffusionAdapterPipeline
121+
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
122122
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
123123
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
124124
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder

src/diffusers/pipelines/t2i_adapter/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
1313
else:
1414
from .pipeline_stable_diffusion_adapter import StableDiffusionAdapterPipeline
15+
from .pipeline_stable_diffusion_xl_adapter import StableDiffusionXLAdapterPipeline

0 commit comments

Comments
 (0)