1414
1515
1616import inspect
17- from typing import Any , Callable , Dict , List , Optional , Union
17+ import os
18+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1819
1920import numpy as np
2021import PIL .Image
2122import torch
23+ from torch import nn
2224from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
2325
2426from ...models import AutoencoderKL , ControlNetModel , UNet2DConditionModel
27+ from ...models .controlnet import ControlNetOutput
28+ from ...models .modeling_utils import ModelMixin
2529from ...schedulers import KarrasDiffusionSchedulers
2630from ...utils import (
2731 PIL_INTERPOLATION ,
8589"""
8690
8791
92+ class MultiControlNetModel (ModelMixin ):
93+ r"""
94+ Multiple `ControlNetModel` wrapper class for Multi-ControlNet
95+
96+ This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
97+ compatible with `ControlNetModel`.
98+
99+ Args:
100+ controlnets (`List[ControlNetModel]`):
101+ Provides additional conditioning to the unet during the denoising process. You must set multiple
102+ `ControlNetModel` as a list.
103+ """
104+
105+ def __init__ (self , controlnets : Union [List [ControlNetModel ], Tuple [ControlNetModel ]]):
106+ super ().__init__ ()
107+ self .nets = nn .ModuleList (controlnets )
108+
109+ def forward (
110+ self ,
111+ sample : torch .FloatTensor ,
112+ timestep : Union [torch .Tensor , float , int ],
113+ encoder_hidden_states : torch .Tensor ,
114+ controlnet_cond : List [torch .tensor ],
115+ conditioning_scale : List [float ],
116+ class_labels : Optional [torch .Tensor ] = None ,
117+ timestep_cond : Optional [torch .Tensor ] = None ,
118+ attention_mask : Optional [torch .Tensor ] = None ,
119+ cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
120+ return_dict : bool = True ,
121+ ) -> Union [ControlNetOutput , Tuple ]:
122+ for i , (image , scale , controlnet ) in enumerate (zip (controlnet_cond , conditioning_scale , self .nets )):
123+ down_samples , mid_sample = controlnet (
124+ sample ,
125+ timestep ,
126+ encoder_hidden_states ,
127+ image ,
128+ scale ,
129+ class_labels ,
130+ timestep_cond ,
131+ attention_mask ,
132+ cross_attention_kwargs ,
133+ return_dict ,
134+ )
135+
136+ # merge samples
137+ if i == 0 :
138+ down_block_res_samples , mid_block_res_sample = down_samples , mid_sample
139+ else :
140+ down_block_res_samples = [
141+ samples_prev + samples_curr
142+ for samples_prev , samples_curr in zip (down_block_res_samples , down_samples )
143+ ]
144+ mid_block_res_sample += mid_sample
145+
146+ return down_block_res_samples , mid_block_res_sample
147+
148+
88149class StableDiffusionControlNetPipeline (DiffusionPipeline ):
89150 r"""
90151 Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
@@ -103,8 +164,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
103164 Tokenizer of class
104165 [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
105166 unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
106- controlnet ([`ControlNetModel`]):
107- Provides additional conditioning to the unet during the denoising process
167+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
168+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
169+ as a list, the outputs from each ControlNet are added together to create one combined additional
170+ conditioning.
108171 scheduler ([`SchedulerMixin`]):
109172 A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
110173 [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
@@ -122,7 +185,7 @@ def __init__(
122185 text_encoder : CLIPTextModel ,
123186 tokenizer : CLIPTokenizer ,
124187 unet : UNet2DConditionModel ,
125- controlnet : ControlNetModel ,
188+ controlnet : Union [ ControlNetModel , List [ ControlNetModel ], Tuple [ ControlNetModel ], MultiControlNetModel ] ,
126189 scheduler : KarrasDiffusionSchedulers ,
127190 safety_checker : StableDiffusionSafetyChecker ,
128191 feature_extractor : CLIPFeatureExtractor ,
@@ -146,6 +209,9 @@ def __init__(
146209 " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
147210 )
148211
212+ if isinstance (controlnet , (list , tuple )):
213+ controlnet = MultiControlNetModel (controlnet )
214+
149215 self .register_modules (
150216 vae = vae ,
151217 text_encoder = text_encoder ,
@@ -432,6 +498,7 @@ def check_inputs(
432498 negative_prompt = None ,
433499 prompt_embeds = None ,
434500 negative_prompt_embeds = None ,
501+ controlnet_conditioning_scale = 1.0 ,
435502 ):
436503 if height % 8 != 0 or width % 8 != 0 :
437504 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -470,6 +537,41 @@ def check_inputs(
470537 f" { negative_prompt_embeds .shape } ."
471538 )
472539
540+ # Check `image`
541+
542+ if isinstance (self .controlnet , ControlNetModel ):
543+ self .check_image (image , prompt , prompt_embeds )
544+ elif isinstance (self .controlnet , MultiControlNetModel ):
545+ if not isinstance (image , list ):
546+ raise TypeError ("For multiple controlnets: `image` must be type `list`" )
547+
548+ if len (image ) != len (self .controlnet .nets ):
549+ raise ValueError (
550+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
551+ )
552+
553+ for image_ in image :
554+ self .check_image (image_ , prompt , prompt_embeds )
555+ else :
556+ assert False
557+
558+ # Check `controlnet_conditioning_scale`
559+
560+ if isinstance (self .controlnet , ControlNetModel ):
561+ if not isinstance (controlnet_conditioning_scale , float ):
562+ raise TypeError ("For single controlnet: `controlnet_conditioning_scale` must be type `float`." )
563+ elif isinstance (self .controlnet , MultiControlNetModel ):
564+ if isinstance (controlnet_conditioning_scale , list ) and len (controlnet_conditioning_scale ) != len (
565+ self .controlnet .nets
566+ ):
567+ raise ValueError (
568+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
569+ " the same length as the number of controlnets"
570+ )
571+ else :
572+ assert False
573+
574+ def check_image (self , image , prompt , prompt_embeds ):
473575 image_is_pil = isinstance (image , PIL .Image .Image )
474576 image_is_tensor = isinstance (image , torch .Tensor )
475577 image_is_pil_list = isinstance (image , list ) and isinstance (image [0 ], PIL .Image .Image )
@@ -501,7 +603,9 @@ def check_inputs(
501603 f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: { image_batch_size } , prompt batch size: { prompt_batch_size } "
502604 )
503605
504- def prepare_image (self , image , width , height , batch_size , num_images_per_prompt , device , dtype ):
606+ def prepare_image (
607+ self , image , width , height , batch_size , num_images_per_prompt , device , dtype , do_classifier_free_guidance
608+ ):
505609 if not isinstance (image , torch .Tensor ):
506610 if isinstance (image , PIL .Image .Image ):
507611 image = [image ]
@@ -529,6 +633,9 @@ def prepare_image(self, image, width, height, batch_size, num_images_per_prompt,
529633
530634 image = image .to (device = device , dtype = dtype )
531635
636+ if do_classifier_free_guidance :
637+ image = torch .cat ([image ] * 2 )
638+
532639 return image
533640
534641 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
@@ -550,7 +657,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
550657 return latents
551658
552659 def _default_height_width (self , height , width , image ):
553- if isinstance (image , list ):
660+ # NOTE: It is possible that a list of images have different
661+ # dimensions for each image, so just checking the first image
662+ # is not _exactly_ correct, but it is simple.
663+ while isinstance (image , list ):
554664 image = image [0 ]
555665
556666 if height is None :
@@ -571,6 +681,18 @@ def _default_height_width(self, height, width, image):
571681
572682 return height , width
573683
684+ # override DiffusionPipeline
685+ def save_pretrained (
686+ self ,
687+ save_directory : Union [str , os .PathLike ],
688+ safe_serialization : bool = False ,
689+ variant : Optional [str ] = None ,
690+ ):
691+ if isinstance (self .controlnet , ControlNetModel ):
692+ super ().save_pretrained (save_directory , safe_serialization , variant )
693+ else :
694+ raise NotImplementedError ("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet." )
695+
574696 @torch .no_grad ()
575697 @replace_example_docstring (EXAMPLE_DOC_STRING )
576698 def __call__ (
@@ -593,7 +715,7 @@ def __call__(
593715 callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
594716 callback_steps : int = 1 ,
595717 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
596- controlnet_conditioning_scale : float = 1.0 ,
718+ controlnet_conditioning_scale : Union [ float , List [ float ]] = 1.0 ,
597719 ):
598720 r"""
599721 Function invoked when calling the pipeline for generation.
@@ -602,10 +724,14 @@ def __call__(
602724 prompt (`str` or `List[str]`, *optional*):
603725 The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
604726 instead.
605- image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
727+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
728+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
606729 The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
607- the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
608- also be accepted as an image. The control image is automatically resized to fit the output image.
730+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
731+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
732+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
733+ specified in init, images must be passed as a list such that each element of the list can be correctly
734+ batched for input to a single controlnet.
609735 height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
610736 The height in pixels of the generated image.
611737 width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -658,10 +784,10 @@ def __call__(
658784 A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
659785 `self.processor` in
660786 [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
661- controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
787+ controlnet_conditioning_scale (`float` or `List[float]` , *optional*, defaults to 1.0):
662788 The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
663- to the residual in the original unet.
664-
789+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
790+ corresponding scale as a list.
665791 Examples:
666792
667793 Returns:
@@ -676,7 +802,15 @@ def __call__(
676802
677803 # 1. Check inputs. Raise error if not correct
678804 self .check_inputs (
679- prompt , image , height , width , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds
805+ prompt ,
806+ image ,
807+ height ,
808+ width ,
809+ callback_steps ,
810+ negative_prompt ,
811+ prompt_embeds ,
812+ negative_prompt_embeds ,
813+ controlnet_conditioning_scale ,
680814 )
681815
682816 # 2. Define call parameters
@@ -693,6 +827,9 @@ def __call__(
693827 # corresponds to doing no classifier free guidance.
694828 do_classifier_free_guidance = guidance_scale > 1.0
695829
830+ if isinstance (self .controlnet , MultiControlNetModel ) and isinstance (controlnet_conditioning_scale , float ):
831+ controlnet_conditioning_scale = [controlnet_conditioning_scale ] * len (self .controlnet .nets )
832+
696833 # 3. Encode input prompt
697834 prompt_embeds = self ._encode_prompt (
698835 prompt ,
@@ -705,18 +842,37 @@ def __call__(
705842 )
706843
707844 # 4. Prepare image
708- image = self .prepare_image (
709- image ,
710- width ,
711- height ,
712- batch_size * num_images_per_prompt ,
713- num_images_per_prompt ,
714- device ,
715- self .controlnet .dtype ,
716- )
845+ if isinstance (self .controlnet , ControlNetModel ):
846+ image = self .prepare_image (
847+ image = image ,
848+ width = width ,
849+ height = height ,
850+ batch_size = batch_size * num_images_per_prompt ,
851+ num_images_per_prompt = num_images_per_prompt ,
852+ device = device ,
853+ dtype = self .controlnet .dtype ,
854+ do_classifier_free_guidance = do_classifier_free_guidance ,
855+ )
856+ elif isinstance (self .controlnet , MultiControlNetModel ):
857+ images = []
858+
859+ for image_ in image :
860+ image_ = self .prepare_image (
861+ image = image_ ,
862+ width = width ,
863+ height = height ,
864+ batch_size = batch_size * num_images_per_prompt ,
865+ num_images_per_prompt = num_images_per_prompt ,
866+ device = device ,
867+ dtype = self .controlnet .dtype ,
868+ do_classifier_free_guidance = do_classifier_free_guidance ,
869+ )
717870
718- if do_classifier_free_guidance :
719- image = torch .cat ([image ] * 2 )
871+ images .append (image_ )
872+
873+ image = images
874+ else :
875+ assert False
720876
721877 # 5. Prepare timesteps
722878 self .scheduler .set_timesteps (num_inference_steps , device = device )
@@ -746,20 +902,16 @@ def __call__(
746902 latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
747903 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
748904
905+ # controlnet(s) inference
749906 down_block_res_samples , mid_block_res_sample = self .controlnet (
750907 latent_model_input ,
751908 t ,
752909 encoder_hidden_states = prompt_embeds ,
753910 controlnet_cond = image ,
911+ conditioning_scale = controlnet_conditioning_scale ,
754912 return_dict = False ,
755913 )
756914
757- down_block_res_samples = [
758- down_block_res_sample * controlnet_conditioning_scale
759- for down_block_res_sample in down_block_res_samples
760- ]
761- mid_block_res_sample *= controlnet_conditioning_scale
762-
763915 # predict the noise residual
764916 noise_pred = self .unet (
765917 latent_model_input ,
0 commit comments