Skip to content

Commit 1ae15fa

Browse files
[Enhance] Update reference (#3723)
* update reference pipeline * update reference pipeline --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 027a365 commit 1ae15fa

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

examples/community/stable_diffusion_controlnet_reference.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
22
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
33

4+
import numpy as np
45
import PIL.Image
56
import torch
67

@@ -97,7 +98,14 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do
9798
def __call__(
9899
self,
99100
prompt: Union[str, List[str]] = None,
100-
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
101+
image: Union[
102+
torch.FloatTensor,
103+
PIL.Image.Image,
104+
np.ndarray,
105+
List[torch.FloatTensor],
106+
List[PIL.Image.Image],
107+
List[np.ndarray],
108+
] = None,
101109
ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
102110
height: Optional[int] = None,
103111
width: Optional[int] = None,
@@ -130,8 +138,8 @@ def __call__(
130138
prompt (`str` or `List[str]`, *optional*):
131139
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
132140
instead.
133-
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
134-
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
141+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
142+
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
135143
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
136144
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
137145
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
@@ -223,15 +231,12 @@ def __call__(
223231
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
224232
(nsfw) content, according to the `safety_checker`.
225233
"""
226-
# 0. Default height and width to unet
227-
height, width = self._default_height_width(height, width, image)
234+
assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
228235

229236
# 1. Check inputs. Raise error if not correct
230237
self.check_inputs(
231238
prompt,
232239
image,
233-
height,
234-
width,
235240
callback_steps,
236241
negative_prompt,
237242
prompt_embeds,
@@ -266,6 +271,9 @@ def __call__(
266271
guess_mode = guess_mode or global_pool_conditions
267272

268273
# 3. Encode input prompt
274+
text_encoder_lora_scale = (
275+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
276+
)
269277
prompt_embeds = self._encode_prompt(
270278
prompt,
271279
device,
@@ -274,6 +282,7 @@ def __call__(
274282
negative_prompt,
275283
prompt_embeds=prompt_embeds,
276284
negative_prompt_embeds=negative_prompt_embeds,
285+
lora_scale=text_encoder_lora_scale,
277286
)
278287

279288
# 4. Prepare image
@@ -289,6 +298,7 @@ def __call__(
289298
do_classifier_free_guidance=do_classifier_free_guidance,
290299
guess_mode=guess_mode,
291300
)
301+
height, width = image.shape[-2:]
292302
elif isinstance(controlnet, MultiControlNetModel):
293303
images = []
294304

@@ -308,6 +318,7 @@ def __call__(
308318
images.append(image_)
309319

310320
image = images
321+
height, width = image[0].shape[-2:]
311322
else:
312323
assert False
313324

@@ -720,14 +731,15 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=
720731
# controlnet(s) inference
721732
if guess_mode and do_classifier_free_guidance:
722733
# Infer ControlNet only for the conditional batch.
723-
controlnet_latent_model_input = latents
734+
control_model_input = latents
735+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
724736
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
725737
else:
726-
controlnet_latent_model_input = latent_model_input
738+
control_model_input = latent_model_input
727739
controlnet_prompt_embeds = prompt_embeds
728740

729741
down_block_res_samples, mid_block_res_sample = self.controlnet(
730-
controlnet_latent_model_input,
742+
control_model_input,
731743
t,
732744
encoder_hidden_states=controlnet_prompt_embeds,
733745
controlnet_cond=image,

examples/community/stable_diffusion_reference.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from diffusers.models.attention import BasicTransformerBlock
1010
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
1111
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
12+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
1213
from diffusers.utils import PIL_INTERPOLATION, logging, randn_tensor
1314

1415

@@ -179,6 +180,7 @@ def __call__(
179180
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
180181
callback_steps: int = 1,
181182
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
183+
guidance_rescale: float = 0.0,
182184
attention_auto_machine_weight: float = 1.0,
183185
gn_auto_machine_weight: float = 1.0,
184186
style_fidelity: float = 0.5,
@@ -248,6 +250,11 @@ def __call__(
248250
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
249251
`self.processor` in
250252
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
253+
guidance_rescale (`float`, *optional*, defaults to 0.7):
254+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
255+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
256+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
257+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
251258
attention_auto_machine_weight (`float`):
252259
Weight of using reference query for self attention's context.
253260
If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
@@ -295,6 +302,9 @@ def __call__(
295302
do_classifier_free_guidance = guidance_scale > 1.0
296303

297304
# 3. Encode input prompt
305+
text_encoder_lora_scale = (
306+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
307+
)
298308
prompt_embeds = self._encode_prompt(
299309
prompt,
300310
device,
@@ -303,6 +313,7 @@ def __call__(
303313
negative_prompt,
304314
prompt_embeds=prompt_embeds,
305315
negative_prompt_embeds=negative_prompt_embeds,
316+
lora_scale=text_encoder_lora_scale,
306317
)
307318

308319
# 4. Preprocess reference image
@@ -748,6 +759,10 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=
748759
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
749760
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
750761

762+
if do_classifier_free_guidance and guidance_rescale > 0.0:
763+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
764+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
765+
751766
# compute the previous noisy sample x_t -> x_t-1
752767
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
753768

0 commit comments

Comments
 (0)