From f4d0d83261fa765d98ff60c98038096882a53681 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 9 Jun 2023 11:02:43 +0900 Subject: [PATCH 1/2] update reference pipeline --- .../stable_diffusion_controlnet_reference.py | 32 +++++++++++++------ .../community/stable_diffusion_reference.py | 15 +++++++++ 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py index ca06136d7829..2548fd58d365 100644 --- a/examples/community/stable_diffusion_controlnet_reference.py +++ b/examples/community/stable_diffusion_controlnet_reference.py @@ -3,6 +3,7 @@ import PIL.Image import torch +import numpy as np from diffusers import StableDiffusionControlNetPipeline from diffusers.models import ControlNetModel @@ -97,7 +98,14 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None, height: Optional[int] = None, width: Optional[int] = None, @@ -130,8 +138,8 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, - `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If @@ -223,15 +231,12 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - # 0. Default height and width to unet - height, width = self._default_height_width(height, width, image) + assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True." # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, - height, - width, callback_steps, negative_prompt, prompt_embeds, @@ -266,6 +271,9 @@ def __call__( guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -274,6 +282,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Prepare image @@ -289,6 +298,7 @@ def __call__( do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) + height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] @@ -308,6 +318,7 @@ def __call__( images.append(image_) image = images + height, width = image[0].shape[-2:] else: assert False @@ -720,14 +731,15 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb= # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. - controlnet_latent_model_input = latents + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: - controlnet_latent_model_input = latent_model_input + control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds down_block_res_samples, mid_block_res_sample = self.controlnet( - controlnet_latent_model_input, + control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py index dbfb768f8b4f..1829adc036ce 100644 --- a/examples/community/stable_diffusion_reference.py +++ b/examples/community/stable_diffusion_reference.py @@ -9,6 +9,7 @@ from diffusers.models.attention import BasicTransformerBlock from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg from diffusers.utils import PIL_INTERPOLATION, logging, randn_tensor @@ -179,6 +180,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, attention_auto_machine_weight: float = 1.0, gn_auto_machine_weight: float = 1.0, style_fidelity: float = 0.5, @@ -248,6 +250,11 @@ def __call__( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. attention_auto_machine_weight (`float`): Weight of using reference query for self attention's context. If attention_auto_machine_weight=1.0, use reference query for all self attention's context. @@ -295,6 +302,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -303,6 +313,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Preprocess reference image @@ -747,6 +758,10 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb= if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] From 5363cb2b95a5322d81dc4af99473db1d34465829 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 9 Jun 2023 11:03:30 +0900 Subject: [PATCH 2/2] update reference pipeline --- examples/community/stable_diffusion_controlnet_reference.py | 2 +- examples/community/stable_diffusion_reference.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py index 2548fd58d365..f52da6f5a193 100644 --- a/examples/community/stable_diffusion_controlnet_reference.py +++ b/examples/community/stable_diffusion_controlnet_reference.py @@ -1,9 +1,9 @@ # Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280 from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import PIL.Image import torch -import numpy as np from diffusers import StableDiffusionControlNetPipeline from diffusers.models import ControlNetModel diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py index 1829adc036ce..364d5d80d721 100644 --- a/examples/community/stable_diffusion_reference.py +++ b/examples/community/stable_diffusion_reference.py @@ -758,7 +758,7 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb= if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - + if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)