11# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
22from typing import Any , Callable , Dict , List , Optional , Tuple , Union
33
4+ import numpy as np
45import PIL .Image
56import torch
67
@@ -97,7 +98,14 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do
9798 def __call__ (
9899 self ,
99100 prompt : Union [str , List [str ]] = None ,
100- image : Union [torch .FloatTensor , PIL .Image .Image , List [torch .FloatTensor ], List [PIL .Image .Image ]] = None ,
101+ image : Union [
102+ torch .FloatTensor ,
103+ PIL .Image .Image ,
104+ np .ndarray ,
105+ List [torch .FloatTensor ],
106+ List [PIL .Image .Image ],
107+ List [np .ndarray ],
108+ ] = None ,
101109 ref_image : Union [torch .FloatTensor , PIL .Image .Image ] = None ,
102110 height : Optional [int ] = None ,
103111 width : Optional [int ] = None ,
@@ -130,8 +138,8 @@ def __call__(
130138 prompt (`str` or `List[str]`, *optional*):
131139 The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
132140 instead.
133- image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
134- `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
141+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, ` List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
142+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
135143 The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
136144 the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
137145 also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
@@ -223,15 +231,12 @@ def __call__(
223231 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
224232 (nsfw) content, according to the `safety_checker`.
225233 """
226- # 0. Default height and width to unet
227- height , width = self ._default_height_width (height , width , image )
234+ assert reference_attn or reference_adain , "`reference_attn` or `reference_adain` must be True."
228235
229236 # 1. Check inputs. Raise error if not correct
230237 self .check_inputs (
231238 prompt ,
232239 image ,
233- height ,
234- width ,
235240 callback_steps ,
236241 negative_prompt ,
237242 prompt_embeds ,
@@ -266,6 +271,9 @@ def __call__(
266271 guess_mode = guess_mode or global_pool_conditions
267272
268273 # 3. Encode input prompt
274+ text_encoder_lora_scale = (
275+ cross_attention_kwargs .get ("scale" , None ) if cross_attention_kwargs is not None else None
276+ )
269277 prompt_embeds = self ._encode_prompt (
270278 prompt ,
271279 device ,
@@ -274,6 +282,7 @@ def __call__(
274282 negative_prompt ,
275283 prompt_embeds = prompt_embeds ,
276284 negative_prompt_embeds = negative_prompt_embeds ,
285+ lora_scale = text_encoder_lora_scale ,
277286 )
278287
279288 # 4. Prepare image
@@ -289,6 +298,7 @@ def __call__(
289298 do_classifier_free_guidance = do_classifier_free_guidance ,
290299 guess_mode = guess_mode ,
291300 )
301+ height , width = image .shape [- 2 :]
292302 elif isinstance (controlnet , MultiControlNetModel ):
293303 images = []
294304
@@ -308,6 +318,7 @@ def __call__(
308318 images .append (image_ )
309319
310320 image = images
321+ height , width = image [0 ].shape [- 2 :]
311322 else :
312323 assert False
313324
@@ -720,14 +731,15 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=
720731 # controlnet(s) inference
721732 if guess_mode and do_classifier_free_guidance :
722733 # Infer ControlNet only for the conditional batch.
723- controlnet_latent_model_input = latents
734+ control_model_input = latents
735+ control_model_input = self .scheduler .scale_model_input (control_model_input , t )
724736 controlnet_prompt_embeds = prompt_embeds .chunk (2 )[1 ]
725737 else :
726- controlnet_latent_model_input = latent_model_input
738+ control_model_input = latent_model_input
727739 controlnet_prompt_embeds = prompt_embeds
728740
729741 down_block_res_samples , mid_block_res_sample = self .controlnet (
730- controlnet_latent_model_input ,
742+ control_model_input ,
731743 t ,
732744 encoder_hidden_states = controlnet_prompt_embeds ,
733745 controlnet_cond = image ,
0 commit comments