Skip to content

Conversation

@yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Aug 13, 2023

working super well now!

import torch
import matplotlib
import matplotlib.cm
import numpy as np

from PIL import Image

from diffusers.utils import load_image
from diffusers import StableDiffusionXLControlNetImg2ImgPipeline, ControlNetModel, AutoencoderKL



torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True)  # Triggers fresh download of MiDaS repo
model_zoe_n = torch.hub.load("isl-org/ZoeDepth", "ZoeD_NK", pretrained=True).eval()
model_zoe_n = model_zoe_n.to("cuda")


def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
    if isinstance(value, torch.Tensor):
        value = value.detach().cpu().numpy()

    value = value.squeeze()
    if invalid_mask is None:
        invalid_mask = value == invalid_val
    mask = np.logical_not(invalid_mask)

    # normalize
    vmin = np.percentile(value[mask],2) if vmin is None else vmin
    vmax = np.percentile(value[mask],85) if vmax is None else vmax
    if vmin != vmax:
        value = (value - vmin) / (vmax - vmin)  # vmin..vmax
    else:
        # Avoid 0-division
        value = value * 0.

    # squeeze last dim if it exists
    # grey out the invalid values

    value[invalid_mask] = np.nan
    cmapper = matplotlib.cm.get_cmap(cmap)
    if value_transform:
        value = value_transform(value)
        # value = value / value.max()
    value = cmapper(value, bytes=True)  # (nxmx4)

    # img = value[:, :, :]
    img = value[...]
    img[invalid_mask] = background_color

    # gamma correction
    img = img / 255
    img = np.power(img, 2.2)
    img = img * 255
    img = img.astype(np.uint8)
    img = Image.fromarray(img)
    return img


def get_zoe_depth_map(image):
    with torch.autocast("cuda", enabled=True):
        depth = model_zoe_n.infer_pil(image)
    depth = colorize(depth, cmap="gray_r")
    return depth



controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-zoe-depth-sdxl-1.0",
    use_safetensors=True,
    torch_dtype=torch.float16,
).to("cuda")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    variant="fp16",
    use_safetensors=True,
    torch_dtype=torch.float16,
).to("cuda")
pipe.enable_model_cpu_offload()


prompt = "A robot, 4k photo"
negative_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"

image = load_image(
             "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/kandinsky/cat.png"
        ).resize((1024, 1024))

controlnet_conditioning_scale = 0.55

depth_image = get_zoe_depth_map(image).resize((1024, 1024))

generator = torch.Generator("cuda").manual_seed(0)
images = pipe(
    prompt, image = image, control_image=depth_image, strength=0.99, num_inference_steps=50, controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator
).images

images[0].save(f"sdxl_robot_cat.png")

robot_cat_sdxl_img2img

as a comparison, this is the output we would get from regular controlnet 😂😂😂

robot_cat_not_img2img

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 13, 2023

The documentation is not available anymore as the PR was closed or merged.

@yamkz
Copy link

yamkz commented Aug 14, 2023

😍
#4589

@yiyixuxu yiyixuxu marked this pull request as draft August 16, 2023 18:06
@bghira
Copy link
Contributor

bghira commented Aug 16, 2023

please try:

seed = 123456
generator_1 = torch.manual_seed(seed)
generator_2 = torch.manual_seed(seed ^ 0xFFFFFFF)

or similar, to avoid using the same seed for both pipelines. this might improve the quality. Edit: I just noticed you're not passing the 1.5 img into SDXL. My misunderstanding.

the other thing is that Img2Img guidance scale needs to be lower. it defaults to 5.0. maybe this is too low for use with ControlNet.

when you say it doesn't work well, you mean in any context using base SDXL 1.0 as img2img model?

@yiyixuxu
Copy link
Collaborator Author

@bghira

when you say it doesn't work well, you mean in any context using base SDXL 1.0 as img2img model?

yes, i refer to the context where we use the stabilityai/stable-diffusion-xl-base-1.0 checkpoint in a img2img model. for example, I'm unable to get our classic img2img example working with the below scripts. I tried different guidance_scale values, too, didn't seem to help

device = "cuda"

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((768, 512))

prompt = "A fantasy landscape, trending on artstation"

model_id_or_path = "stabilityai/stable-diffusion-xl-base-1.0"

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    model_id_or_path, 
    vae=vae,
    torch_dtype=torch.float16)
pipe = pipe.to(device)

generator = torch.manual_seed(0)
images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images
images[0].save("fantasy_landscape_sdxl.png")

@bghira
Copy link
Contributor

bghira commented Aug 16, 2023

"doesn't work" means it kicks out an error, or simply results in unexpected image?

@bghira
Copy link
Contributor

bghira commented Aug 16, 2023

i think in your example you just provided the only issue i can see is the input image resolution is too low.

@yiyixuxu
Copy link
Collaborator Author

i think in your example you just provided the only issue i can see is the input image resolution is too low.

so pass a original_size ortarget_size should help?

@bghira
Copy link
Contributor

bghira commented Aug 16, 2023

i don't believe that will help, it just seems that SDXL's ability to generate low resolution images is compromised due to its fine-tuning on 1024px. it is as if the transformer layers are heavily conditioned to represent details that simply won't exist in lower resolution.

@yiyixuxu
Copy link
Collaborator Author

@bghira

i don't believe that will help, it just seems that SDXL's ability to generate low resolution images is compromised due to its fine-tuning on 1024px. it is as if the transformer layers are heavily conditioned to represent details that simply won't exist in lower resolution.

ok, rechecked the paper, they did mention this. will try it on a 1024 x 1024 image

@yiyixuxu
Copy link
Collaborator Author

@bghira

So I used this one image generated from SDXL text2imge. It's kind of worked a little bit but way worse than the SD1.5 output

original image(generated from SDXL)
yiyi_test_1_init_image
SDXL img2img
yiyi_test_1_out_fantasy_landscape_sdxl
SD1.5 img2img
yiyi_test_1_out_fantasy_landscape_sd15

@yiyixuxu
Copy link
Collaborator Author

I also tried to resize the other image to 1024x1024 - didn't make any difference. so not sure if resolutions play a part here

@bghira
Copy link
Contributor

bghira commented Aug 16, 2023

Sorry for my chosen example image, but I wanted to see if I could reproduce the issue, or if it would behave like 'hires fix':

The input, a classic meme:
image

The output:
image

I looped it back through img2img about 60+ times:
image

My scheduler configuration:

{
  "_class_name": "DDIMScheduler",
  "_diffusers_version": "0.19.0.dev0",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "interpolation_type": "linear",
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "sample_max_value": 1.0,
  "set_alpha_to_one": false,
  "skip_prk_steps": true,
  "steps_offset": 1,
  "timestep_spacing": "trailing",
  "trained_betas": null,
  "use_karras_sigmas": false
  }

notably, we used timestep_spacing='trailing' via DDIMScheduler at strength=0.5 and guidance_scale=7.5

you can see that the detail is still obliterated from the background of the image, but fine details are still present.

@yiyixuxu
Copy link
Collaborator Author

related :#4724

@yiyixuxu yiyixuxu marked this pull request as ready for review August 24, 2023 08:25
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Can we add some tests here as well? Then it should be good to go :-)

>>> image = pipe(
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
... ).images[0]
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example needs to be edited.

unet: UNet2DConditionModel,
controlnet: ControlNetModel,
scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this coming from?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW let's make sure to add a docstring for requires_aestetcs_score here as well

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is yet to be addressed.

Comment on lines 161 to 162
if isinstance(controlnet, (list, tuple)):
raise ValueError("MultiControlNet is not yet supported.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When did we add support aesthetic scoring?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is copied from the XL Img2Img pipeline, which is designed for use with the refiner.

i suppose you could tune a ControlNet with aesthetic values, and it would:

  • lead to lower prompt adherence, in favour of aesthetic scored data distribution
  • require conditional dropout so that we don't overfit to relying on these scores
  • need additional dataset management

that said, i would actually like to see support for refiner tuning.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think aesthetic scores won't actually take effect here since the config.requires_aesthetics_score is set to False. I'm slightly in favor of keeping these arguments because:

  1. one can potentially train a controlnet with refiner in the future (maybe?)
  2. It is consistent with sdxl img2img pipeline, where aesthetic scores are also just place holders when you use it with base model.

however, I don't have a strong opinion on this, will be happy to remove it if you guys think it's better that way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with your judgement here @yiyixuxu - I also think it could make sense to train a controlnet refiner (by limiting the added noise to < 20%)

Comment on lines +832 to +834
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested their effect?

self.controlnet.to("cpu")
torch.cuda.empty_cache()

# make sure the VAE is in float32 mode, as it overflows in float16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update this part of the code to conform with the changes here: #4796 ?

@patrickvonplaten
Copy link
Contributor

Cool I think we're mostly good to go here no? Ok to merge for me once all tests are green. @sayakpaul if you could give a final review that'd be great!

latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need to also correct the behaviour for guess_mode here per your other PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i forgot! thanks! will add

self._test_inference_batch_single_identical(expected_max_diff=2e-3)

# TODO(Patrick, Sayak) - skip for now as this requires more refiner tests
def test_save_load_optional_components(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the problem here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure actually! I thought maybe you left the note 😂
it was copied from sdxl-img2img

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah that note was me! The problem is that for img2img/refiner the "text_encoder" is optional so we should be able to load the pipeline without it, but that's not possible when using the non-refiner architecture (which is currently done in tests). We should probably add a new test class here at some point

Copy link

@polavishnu4444 polavishnu4444 Nov 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is the plan for the refiner support of controlnet Img2Img?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking very nice! Thanks for this one!

@sayakpaul
Copy link
Member

Let's also fix the tests before merging :-)

@yiyixuxu yiyixuxu merged commit 5eeedd9 into main Aug 28, 2023
@yiyixuxu yiyixuxu deleted the controlnet-img2img branch August 28, 2023 18:16
@markrmiller
Copy link

Ignore if I'm offbase for some reason, but I'm not sure from_single_file support is there.

@vionwinnie
Copy link

Hello! I am actually trying to use this for my prototype. Thank you so much for creating this pipeline :)

I have 2 questions about this pipeline:

  1. When I use previous controlnet img2img pipeline, there's a logic to input image, mask, control_image, and then there's logic to created masked_image using image and mask.

In this pipeline, where is the equilvalent logic? I was reading the code and it seems like vae.image_processor only converts image into tensor but I don't see the masking logic in place.

  1. For img2img pipeline, unet in_channel is 9: 4(image) + 1(mask) + 4 (masked_image). Are we doing that in this pipeline?

I am still learning how SD and controlnet models work so thank you so much for your help in advance. @yiyixuxu

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants