From 87732e279a4f462f9b83b399d8e7635a59a98def Mon Sep 17 00:00:00 2001 From: daspartho Date: Mon, 26 Dec 2022 03:43:10 +0530 Subject: [PATCH 1/7] initial --- examples/community/magic_mix.py | 142 ++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 examples/community/magic_mix.py diff --git a/examples/community/magic_mix.py b/examples/community/magic_mix.py new file mode 100644 index 000000000000..f2f6a769a582 --- /dev/null +++ b/examples/community/magic_mix.py @@ -0,0 +1,142 @@ +import torch + +from diffusers import DiffusionPipeline +from PIL import Image +from torchvision import transforms as tfms +from tqdm.auto import tqdm + + +class MagicMixPipeline(DiffusionPipeline): + def __init__( + self, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + ): + super().__init__() + + self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + # convert PIL image to latents + def encode(self, img): + with torch.no_grad(): + latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1) + latent = 0.18215 * latent.latent_dist.sample() + return latent + + # convert latents to PIL image + def decode(self, latent): + latent = (1 / 0.18215) * latent + with torch.no_grad(): + img = self.vae.decode(latent).sample + img = (img / 2 + 0.5).clamp(0, 1) + img = img.detach().cpu().permute(0, 2, 3, 1).numpy() + img = (img * 255).round().astype("uint8") + return Image.fromarray(img[0]) + + # convert prompt into text embeddings, also unconditional embeddings + def prep_text(self, prompt): + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0] + + uncond_input = self.tokenizer( + "", + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + return torch.cat([uncond_embedding, text_embedding]) + + def __call__( + self, + img, + prompt, + kmin=0.3, + kmax=0.6, + v=0.5, + seed=42, + steps=50, + guidance_scale=7.5, + ): + tmin = steps - int(kmin * steps) + tmax = steps - int(kmax * steps) + + text_embeddings = self.prep_text(prompt) + + self.scheduler.set_timesteps(steps) + + width, height = img.size + encoded = self.encode(img) + + torch.manual_seed(seed) + noise = torch.randn( + (1, self.unet.in_channels, height // 8, width // 8), + ).to(self.device) + + latents = self.scheduler.add_noise( + encoded, + noise, + timesteps=self.scheduler.timesteps[tmax], + ) + + input = torch.cat([latents] * 2) + + input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax]) + + with torch.no_grad(): + pred = self.unet( + input, + self.scheduler.timesteps[tmax], + encoder_hidden_states=text_embeddings, + ).sample + + pred_uncond, pred_text = pred.chunk(2) + pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) + + latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample + + for i, t in enumerate(tqdm(self.scheduler.timesteps)): + if i > tmax: + if i < tmin: # layout generation phase + orig_latents = self.scheduler.add_noise( + encoded, + noise, + timesteps=t, + ) + + input = (v * latents) + ( + 1 - v + ) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics + input = torch.cat([input] * 2) + + else: # content generation phase + input = torch.cat([latents] * 2) + + input = self.scheduler.scale_model_input(input, t) + + with torch.no_grad(): + pred = self.unet( + input, + t, + encoder_hidden_states=text_embeddings, + ).sample + + pred_uncond, pred_text = pred.chunk(2) + pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) + + latents = self.scheduler.step(pred, t, latents).prev_sample + + return self.decode(latents) From d930ced3489efa677ab9208933cf948dc00eac70 Mon Sep 17 00:00:00 2001 From: daspartho Date: Mon, 26 Dec 2022 13:24:38 +0530 Subject: [PATCH 2/7] type hints --- examples/community/magic_mix.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/examples/community/magic_mix.py b/examples/community/magic_mix.py index f2f6a769a582..04509de3b75b 100644 --- a/examples/community/magic_mix.py +++ b/examples/community/magic_mix.py @@ -1,19 +1,20 @@ import torch -from diffusers import DiffusionPipeline +from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, UNet2DConditionModel from PIL import Image from torchvision import transforms as tfms from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer class MagicMixPipeline(DiffusionPipeline): def __init__( self, - vae, - text_encoder, - tokenizer, - unet, - scheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, ): super().__init__() @@ -62,15 +63,15 @@ def prep_text(self, prompt): def __call__( self, - img, - prompt, - kmin=0.3, - kmax=0.6, - v=0.5, - seed=42, - steps=50, - guidance_scale=7.5, - ): + img: Image.Image, + prompt: str, + kmin: float = 0.3, + kmax: float = 0.6, + v: float = 0.5, + seed: int = 42, + steps: int = 50, + guidance_scale: float = 7.5, + ) -> Image.Image: tmin = steps - int(kmin * steps) tmax = steps - int(kmax * steps) From 6847d594e74dec14dcd680f36e2ea890bb078150 Mon Sep 17 00:00:00 2001 From: daspartho Date: Mon, 26 Dec 2022 13:31:08 +0530 Subject: [PATCH 3/7] update scheduler type hint --- examples/community/magic_mix.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/examples/community/magic_mix.py b/examples/community/magic_mix.py index 04509de3b75b..7560f6ee0602 100644 --- a/examples/community/magic_mix.py +++ b/examples/community/magic_mix.py @@ -1,6 +1,15 @@ +from typing import Union + import torch -from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, +) from PIL import Image from torchvision import transforms as tfms from tqdm.auto import tqdm @@ -14,7 +23,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: DDIMScheduler, + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], ): super().__init__() From a3095f1b6aaa32ef6e32a79e5216f7ea2bf6e4aa Mon Sep 17 00:00:00 2001 From: daspartho Date: Tue, 27 Dec 2022 00:57:41 +0530 Subject: [PATCH 4/7] add to README --- examples/community/README.md | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/examples/community/README.md b/examples/community/README.md index ddb0b8ce9389..7baacf9282d3 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -25,6 +25,7 @@ If a community doesn't work as expected, please open an issue and ping the autho | K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | | Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) | +MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) | @@ -815,6 +816,38 @@ plt.title('Stable Diffusion v1.4') plt.axis('off') plt.show() +``` + +As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints. + +### Magic Mix + +Implementation of the [MagicMix: Semantic Mixing with Diffusion Models](https://arxiv.org/abs/2210.16056) paper. This is a Diffusion Pipeline for semantic mixing of an image and a text prompt to create a new concept while preserving the spatial layout and geometry of the subject in the image. The pipeline takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process. + +There are 3 parameters for the method- +- `v`: It is the interpolation constant used in the layout generation phase. The greater the value of v, the greater the influence of the prompt on the layout generation process. +- `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process. + +Here is an example usage- + ```python +from diffusers import DiffusionPipeline, DDIMScheduler +from PIL import Image -As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints. \ No newline at end of file +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + custom_pipeline="magic_mix", + scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False), +).to('cuda') + +img = Image.open('phone.jpg') +mix_img = pipe( + img, + promt = 'bed', + kmin = 0.3, + kmax = 0.5, + v = 0.5, + ) +mix_img.save('phone_bed_mix.jpg') +``` +The `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt. \ No newline at end of file From cbf4ca58611a4948c18a9148b1e400b642475b0d Mon Sep 17 00:00:00 2001 From: Partho Date: Tue, 27 Dec 2022 01:10:08 +0530 Subject: [PATCH 5/7] add example generation to README --- examples/community/README.md | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index 7baacf9282d3..3ad1c85f89fc 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -843,11 +843,23 @@ pipe = DiffusionPipeline.from_pretrained( img = Image.open('phone.jpg') mix_img = pipe( img, - promt = 'bed', + prompt = 'bed', kmin = 0.3, kmax = 0.5, v = 0.5, ) mix_img.save('phone_bed_mix.jpg') ``` -The `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt. \ No newline at end of file +The `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt. + +E.g. the above script generates the following image: + +`phone.jpg` + +![206903102-34e79b9f-9ed2-4fac-bb38-82871343c655](https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg) + +`phone_bed_mix.jpg` + +![206903104-913a671d-ef53-4ae4-919d-64c3059c8f67](https://user-images.githubusercontent.com/59410571/209578602-70f323fa-05b7-4dd6-b055-e40683e37914.jpg) + +For more example generations check out this [demo notebook](https://github.com/daspartho/MagicMix/blob/main/demo.ipynb). From 28e1a3c0a809754bde9f4a2cf2408a687fa9d18a Mon Sep 17 00:00:00 2001 From: daspartho Date: Wed, 28 Dec 2022 11:08:48 +0530 Subject: [PATCH 6/7] v -> mix_factor --- examples/community/README.md | 4 ++-- examples/community/magic_mix.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index 3ad1c85f89fc..486ee7d27c35 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -825,7 +825,7 @@ As a result, you can look at a grid of all 4 generated images being shown togeth Implementation of the [MagicMix: Semantic Mixing with Diffusion Models](https://arxiv.org/abs/2210.16056) paper. This is a Diffusion Pipeline for semantic mixing of an image and a text prompt to create a new concept while preserving the spatial layout and geometry of the subject in the image. The pipeline takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process. There are 3 parameters for the method- -- `v`: It is the interpolation constant used in the layout generation phase. The greater the value of v, the greater the influence of the prompt on the layout generation process. +- `mix_factor`: It is the interpolation constant used in the layout generation phase. The greater the value of `mix_factor`, the greater the influence of the prompt on the layout generation process. - `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process. Here is an example usage- @@ -846,7 +846,7 @@ mix_img = pipe( prompt = 'bed', kmin = 0.3, kmax = 0.5, - v = 0.5, + mix_factor = 0.5, ) mix_img.save('phone_bed_mix.jpg') ``` diff --git a/examples/community/magic_mix.py b/examples/community/magic_mix.py index 7560f6ee0602..d67aec781c36 100644 --- a/examples/community/magic_mix.py +++ b/examples/community/magic_mix.py @@ -76,7 +76,7 @@ def __call__( prompt: str, kmin: float = 0.3, kmax: float = 0.6, - v: float = 0.5, + mix_factor: float = 0.5, seed: int = 42, steps: int = 50, guidance_scale: float = 7.5, @@ -127,8 +127,8 @@ def __call__( timesteps=t, ) - input = (v * latents) + ( - 1 - v + input = (mix_factor * latents) + ( + 1 - mix_factor ) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics input = torch.cat([input] * 2) From 573dec8557a1957767e785badcf387d3ccfb0cf0 Mon Sep 17 00:00:00 2001 From: daspartho Date: Wed, 28 Dec 2022 11:33:22 +0530 Subject: [PATCH 7/7] load scheduler from pretrained --- examples/community/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/README.md b/examples/community/README.md index 486ee7d27c35..a848f74f2a29 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -837,7 +837,7 @@ from PIL import Image pipe = DiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", custom_pipeline="magic_mix", - scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False), + scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"), ).to('cuda') img = Image.open('phone.jpg')