Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions examples/community/stable_diffusion_controlnet_reference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL.Image
import torch

Expand Down Expand Up @@ -97,7 +98,14 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
height: Optional[int] = None,
width: Optional[int] = None,
Expand Down Expand Up @@ -130,8 +138,8 @@ def __call__(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
Expand Down Expand Up @@ -223,15 +231,12 @@ def __call__(
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height, width = self._default_height_width(height, width, image)
assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
image,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
Expand Down Expand Up @@ -266,6 +271,9 @@ def __call__(
guess_mode = guess_mode or global_pool_conditions

# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt(
prompt,
device,
Expand All @@ -274,6 +282,7 @@ def __call__(
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)

# 4. Prepare image
Expand All @@ -289,6 +298,7 @@ def __call__(
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = image.shape[-2:]
elif isinstance(controlnet, MultiControlNetModel):
images = []

Expand All @@ -308,6 +318,7 @@ def __call__(
images.append(image_)

image = images
height, width = image[0].shape[-2:]
else:
assert False

Expand Down Expand Up @@ -720,14 +731,15 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
controlnet_latent_model_input = latents
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
controlnet_latent_model_input = latent_model_input
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds

down_block_res_samples, mid_block_res_sample = self.controlnet(
controlnet_latent_model_input,
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
Expand Down
15 changes: 15 additions & 0 deletions examples/community/stable_diffusion_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
from diffusers.utils import PIL_INTERPOLATION, logging, randn_tensor


Expand Down Expand Up @@ -179,6 +180,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
attention_auto_machine_weight: float = 1.0,
gn_auto_machine_weight: float = 1.0,
style_fidelity: float = 0.5,
Expand Down Expand Up @@ -248,6 +250,11 @@ def __call__(
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
attention_auto_machine_weight (`float`):
Weight of using reference query for self attention's context.
If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
Expand Down Expand Up @@ -295,6 +302,9 @@ def __call__(
do_classifier_free_guidance = guidance_scale > 1.0

# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt(
prompt,
device,
Expand All @@ -303,6 +313,7 @@ def __call__(
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)

# 4. Preprocess reference image
Expand Down Expand Up @@ -748,6 +759,10 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

Expand Down