Skip to content

Commit 9a45d7f

Browse files
Add guidance start/stop (#3770)
* Add guidance start/stop * Add guidance start/stop to inpaint class * Black formatting * Add support for guidance for multicontrolnet * Add inclusive end * Improve design * correct imports * Finish * Finish all * Correct more * make style --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 61916fe commit 9a45d7f

File tree

6 files changed

+305
-7
lines changed

6 files changed

+305
-7
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,8 @@ def check_inputs(
491491
prompt_embeds=None,
492492
negative_prompt_embeds=None,
493493
controlnet_conditioning_scale=1.0,
494+
control_guidance_start=0.0,
495+
control_guidance_end=1.0,
494496
):
495497
if (callback_steps is None) or (
496498
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
@@ -593,6 +595,27 @@ def check_inputs(
593595
else:
594596
assert False
595597

598+
if len(control_guidance_start) != len(control_guidance_end):
599+
raise ValueError(
600+
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
601+
)
602+
603+
if isinstance(self.controlnet, MultiControlNetModel):
604+
if len(control_guidance_start) != len(self.controlnet.nets):
605+
raise ValueError(
606+
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
607+
)
608+
609+
for start, end in zip(control_guidance_start, control_guidance_end):
610+
if start >= end:
611+
raise ValueError(
612+
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
613+
)
614+
if start < 0.0:
615+
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
616+
if end > 1.0:
617+
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
618+
596619
def check_image(self, image, prompt, prompt_embeds):
597620
image_is_pil = isinstance(image, PIL.Image.Image)
598621
image_is_tensor = isinstance(image, torch.Tensor)
@@ -709,6 +732,8 @@ def __call__(
709732
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
710733
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
711734
guess_mode: bool = False,
735+
control_guidance_start: Union[float, List[float]] = 0.0,
736+
control_guidance_end: Union[float, List[float]] = 1.0,
712737
):
713738
r"""
714739
Function invoked when calling the pipeline for generation.
@@ -784,6 +809,10 @@ def __call__(
784809
guess_mode (`bool`, *optional*, defaults to `False`):
785810
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
786811
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
812+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
813+
The percentage of total steps at which the controlnet starts applying.
814+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
815+
The percentage of total steps at which the controlnet stops applying.
787816
788817
Examples:
789818
@@ -794,6 +823,18 @@ def __call__(
794823
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
795824
(nsfw) content, according to the `safety_checker`.
796825
"""
826+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
827+
828+
# align format for control guidance
829+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
830+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
831+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
832+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
833+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
834+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
835+
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
836+
control_guidance_end
837+
]
797838

798839
# 1. Check inputs. Raise error if not correct
799840
self.check_inputs(
@@ -804,6 +845,8 @@ def __call__(
804845
prompt_embeds,
805846
negative_prompt_embeds,
806847
controlnet_conditioning_scale,
848+
control_guidance_start,
849+
control_guidance_end,
807850
)
808851

809852
# 2. Define call parameters
@@ -820,8 +863,6 @@ def __call__(
820863
# corresponds to doing no classifier free guidance.
821864
do_classifier_free_guidance = guidance_scale > 1.0
822865

823-
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
824-
825866
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
826867
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
827868

@@ -904,6 +945,15 @@ def __call__(
904945
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
905946
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
906947

948+
# 7.1 Create tensor stating which controlnets to keep
949+
controlnet_keep = []
950+
for i in range(num_inference_steps):
951+
keeps = [
952+
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
953+
for s, e in zip(control_guidance_start, control_guidance_end)
954+
]
955+
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
956+
907957
# 8. Denoising loop
908958
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
909959
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -922,12 +972,17 @@ def __call__(
922972
control_model_input = latent_model_input
923973
controlnet_prompt_embeds = prompt_embeds
924974

975+
if isinstance(controlnet_keep[i], list):
976+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
977+
else:
978+
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
979+
925980
down_block_res_samples, mid_block_res_sample = self.controlnet(
926981
control_model_input,
927982
t,
928983
encoder_hidden_states=controlnet_prompt_embeds,
929984
controlnet_cond=image,
930-
conditioning_scale=controlnet_conditioning_scale,
985+
conditioning_scale=cond_scale,
931986
guess_mode=guess_mode,
932987
return_dict=False,
933988
)

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,8 @@ def check_inputs(
517517
prompt_embeds=None,
518518
negative_prompt_embeds=None,
519519
controlnet_conditioning_scale=1.0,
520+
control_guidance_start=0.0,
521+
control_guidance_end=1.0,
520522
):
521523
if (callback_steps is None) or (
522524
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
@@ -619,6 +621,27 @@ def check_inputs(
619621
else:
620622
assert False
621623

624+
if len(control_guidance_start) != len(control_guidance_end):
625+
raise ValueError(
626+
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
627+
)
628+
629+
if isinstance(self.controlnet, MultiControlNetModel):
630+
if len(control_guidance_start) != len(self.controlnet.nets):
631+
raise ValueError(
632+
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
633+
)
634+
635+
for start, end in zip(control_guidance_start, control_guidance_end):
636+
if start >= end:
637+
raise ValueError(
638+
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
639+
)
640+
if start < 0.0:
641+
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
642+
if end > 1.0:
643+
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
644+
622645
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
623646
def check_image(self, image, prompt, prompt_embeds):
624647
image_is_pil = isinstance(image, PIL.Image.Image)
@@ -796,6 +819,8 @@ def __call__(
796819
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
797820
controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
798821
guess_mode: bool = False,
822+
control_guidance_start: Union[float, List[float]] = 0.0,
823+
control_guidance_end: Union[float, List[float]] = 1.0,
799824
):
800825
r"""
801826
Function invoked when calling the pipeline for generation.
@@ -876,6 +901,10 @@ def __call__(
876901
guess_mode (`bool`, *optional*, defaults to `False`):
877902
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
878903
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
904+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
905+
The percentage of total steps at which the controlnet starts applying.
906+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
907+
The percentage of total steps at which the controlnet stops applying.
879908
880909
Examples:
881910
@@ -886,6 +915,19 @@ def __call__(
886915
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
887916
(nsfw) content, according to the `safety_checker`.
888917
"""
918+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
919+
920+
# align format for control guidance
921+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
922+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
923+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
924+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
925+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
926+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
927+
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
928+
control_guidance_end
929+
]
930+
889931
# 1. Check inputs. Raise error if not correct
890932
self.check_inputs(
891933
prompt,
@@ -895,6 +937,8 @@ def __call__(
895937
prompt_embeds,
896938
negative_prompt_embeds,
897939
controlnet_conditioning_scale,
940+
control_guidance_start,
941+
control_guidance_end,
898942
)
899943

900944
# 2. Define call parameters
@@ -994,6 +1038,15 @@ def __call__(
9941038
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
9951039
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
9961040

1041+
# 7.1 Create tensor stating which controlnets to keep
1042+
controlnet_keep = []
1043+
for i in range(num_inference_steps):
1044+
keeps = [
1045+
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
1046+
for s, e in zip(control_guidance_start, control_guidance_end)
1047+
]
1048+
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
1049+
9971050
# 8. Denoising loop
9981051
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
9991052
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -1012,12 +1065,17 @@ def __call__(
10121065
control_model_input = latent_model_input
10131066
controlnet_prompt_embeds = prompt_embeds
10141067

1068+
if isinstance(controlnet_keep[i], list):
1069+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1070+
else:
1071+
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
1072+
10151073
down_block_res_samples, mid_block_res_sample = self.controlnet(
10161074
control_model_input,
10171075
t,
10181076
encoder_hidden_states=controlnet_prompt_embeds,
10191077
controlnet_cond=control_image,
1020-
conditioning_scale=controlnet_conditioning_scale,
1078+
conditioning_scale=cond_scale,
10211079
guess_mode=guess_mode,
10221080
return_dict=False,
10231081
)

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,8 @@ def check_inputs(
646646
prompt_embeds=None,
647647
negative_prompt_embeds=None,
648648
controlnet_conditioning_scale=1.0,
649+
control_guidance_start=0.0,
650+
control_guidance_end=1.0,
649651
):
650652
if height % 8 != 0 or width % 8 != 0:
651653
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -751,6 +753,27 @@ def check_inputs(
751753
else:
752754
assert False
753755

756+
if len(control_guidance_start) != len(control_guidance_end):
757+
raise ValueError(
758+
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
759+
)
760+
761+
if isinstance(self.controlnet, MultiControlNetModel):
762+
if len(control_guidance_start) != len(self.controlnet.nets):
763+
raise ValueError(
764+
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
765+
)
766+
767+
for start, end in zip(control_guidance_start, control_guidance_end):
768+
if start >= end:
769+
raise ValueError(
770+
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
771+
)
772+
if start < 0.0:
773+
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
774+
if end > 1.0:
775+
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
776+
754777
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
755778
def check_image(self, image, prompt, prompt_embeds):
756779
image_is_pil = isinstance(image, PIL.Image.Image)
@@ -990,6 +1013,8 @@ def __call__(
9901013
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
9911014
controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
9921015
guess_mode: bool = False,
1016+
control_guidance_start: Union[float, List[float]] = 0.0,
1017+
control_guidance_end: Union[float, List[float]] = 1.0,
9931018
):
9941019
r"""
9951020
Function invoked when calling the pipeline for generation.
@@ -1073,6 +1098,10 @@ def __call__(
10731098
guess_mode (`bool`, *optional*, defaults to `False`):
10741099
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
10751100
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
1101+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1102+
The percentage of total steps at which the controlnet starts applying.
1103+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1104+
The percentage of total steps at which the controlnet stops applying.
10761105
10771106
Examples:
10781107
@@ -1083,9 +1112,22 @@ def __call__(
10831112
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
10841113
(nsfw) content, according to the `safety_checker`.
10851114
"""
1115+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1116+
10861117
# 0. Default height and width to unet
10871118
height, width = self._default_height_width(height, width, image)
10881119

1120+
# align format for control guidance
1121+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1122+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1123+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1124+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1125+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1126+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1127+
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
1128+
control_guidance_end
1129+
]
1130+
10891131
# 1. Check inputs. Raise error if not correct
10901132
self.check_inputs(
10911133
prompt,
@@ -1097,6 +1139,8 @@ def __call__(
10971139
prompt_embeds,
10981140
negative_prompt_embeds,
10991141
controlnet_conditioning_scale,
1142+
control_guidance_start,
1143+
control_guidance_end,
11001144
)
11011145

11021146
# 2. Define call parameters
@@ -1113,8 +1157,6 @@ def __call__(
11131157
# corresponds to doing no classifier free guidance.
11141158
do_classifier_free_guidance = guidance_scale > 1.0
11151159

1116-
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1117-
11181160
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
11191161
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
11201162

@@ -1231,6 +1273,15 @@ def __call__(
12311273
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
12321274
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
12331275

1276+
# 7.1 Create tensor stating which controlnets to keep
1277+
controlnet_keep = []
1278+
for i in range(num_inference_steps):
1279+
keeps = [
1280+
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
1281+
for s, e in zip(control_guidance_start, control_guidance_end)
1282+
]
1283+
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
1284+
12341285
# 8. Denoising loop
12351286
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
12361287
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -1249,12 +1300,17 @@ def __call__(
12491300
control_model_input = latent_model_input
12501301
controlnet_prompt_embeds = prompt_embeds
12511302

1303+
if isinstance(controlnet_keep[i], list):
1304+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1305+
else:
1306+
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
1307+
12521308
down_block_res_samples, mid_block_res_sample = self.controlnet(
12531309
control_model_input,
12541310
t,
12551311
encoder_hidden_states=controlnet_prompt_embeds,
12561312
controlnet_cond=control_image,
1257-
conditioning_scale=controlnet_conditioning_scale,
1313+
conditioning_scale=cond_scale,
12581314
guess_mode=guess_mode,
12591315
return_dict=False,
12601316
)

0 commit comments

Comments
 (0)