Skip to content

Commit 9d8943b

Browse files
ChenWu98patrickvonplatenpatil-suraj
authored
Add CycleDiffusion pipeline using Stable Diffusion (#888)
* Add CycleDiffusion pipeline for Stable Diffusion * Add the option of passing noise to DDIMScheduler Add the option of providing the noise itself to DDIMScheduler, instead of the random seed generator. * Update README.md * Update README.md * Update pipeline_stable_diffusion_cycle_diffusion.py * Update pipeline_stable_diffusion_cycle_diffusion.py * Update pipeline_stable_diffusion_cycle_diffusion.py * Update pipeline_stable_diffusion_cycle_diffusion.py * Update scheduling_ddim.py * Update import format * Update pipeline_stable_diffusion_cycle_diffusion.py * Update scheduling_ddim.py * Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Patrick von Platen <[email protected]> * Update scheduling_ddim.py * Update scheduling_ddim.py * Update scheduling_ddim.py * add two tests * Update pipeline_stable_diffusion_cycle_diffusion.py * Update pipeline_stable_diffusion_cycle_diffusion.py * Update README.md * Rename pipeline name as suggested in the latest reviewer comment * Update test_pipelines.py * Update test_pipelines.py * Update test_pipelines.py * Update pipeline_stable_diffusion_cycle_diffusion.py * Remove the generator This generator does not control all randomness during sampling, which can be misleading. * Update optimal hyperparameters * Update src/diffusers/pipelines/stable_diffusion/README.md Co-authored-by: Suraj Patil <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/README.md Co-authored-by: Suraj Patil <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/README.md Co-authored-by: Suraj Patil <[email protected]> * Apply suggestions from code review * uP * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_cycle_diffusion.py Co-authored-by: Suraj Patil <[email protected]> * up * up * Replace assert with ValueError * finish docs Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent 1172c96 commit 9d8943b

File tree

12 files changed

+1097
-15
lines changed

12 files changed

+1097
-15
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@
7878
- sections:
7979
- local: api/pipelines/overview
8080
title: "Overview"
81+
- local: api/pipelines/cycle_diffusion
82+
title: "Cycle Diffusion"
8183
- local: api/pipelines/ddim
8284
title: "DDIM"
8385
- local: api/pipelines/ddpm
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Cycle Diffusion
14+
15+
## Overview
16+
17+
Cycle Diffusion is a Text-Guided Image-to-Image Generation model proposed in [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) by Chen Henry Wu, Fernando De la Torre.
18+
19+
The abstract of the paper is the following:
20+
21+
*Diffusion models have achieved unprecedented performance in generative modeling. The commonly-adopted formulation of the latent code of diffusion models is a sequence of gradually denoised samples, as opposed to the simpler (e.g., Gaussian) latent space of GANs, VAEs, and normalizing flows. This paper provides an alternative, Gaussian formulation of the latent space of various diffusion models, as well as an invertible DPM-Encoder that maps images into the latent space. While our formulation is purely based on the definition of diffusion models, we demonstrate several intriguing consequences. (1) Empirically, we observe that a common latent space emerges from two diffusion models trained independently on related domains. In light of this finding, we propose CycleDiffusion, which uses DPM-Encoder for unpaired image-to-image translation. Furthermore, applying CycleDiffusion to text-to-image diffusion models, we show that large-scale text-to-image diffusion models can be used as zero-shot image-to-image editors. (2) One can guide pre-trained diffusion models and GANs by controlling the latent codes in a unified, plug-and-play formulation based on energy-based models. Using the CLIP model and a face recognition model as guidance, we demonstrate that diffusion models have better coverage of low-density sub-populations and individuals than GANs.*
22+
23+
*Tips*:
24+
- The Cycle Diffusion pipeline is fully compatible with any [Stable Diffusion](./stable_diffusion) checkpoints
25+
- Currently Cycle Diffusion only works with the [`DDIMScheduler`].
26+
27+
*Example*:
28+
29+
In the following we should how to best use the [`CycleDiffusionPipeline`]
30+
31+
```python
32+
import requests
33+
import torch
34+
from PIL import Image
35+
from io import BytesIO
36+
37+
from diffusers import CycleDiffusionPipeline, DDIMScheduler
38+
39+
# load the pipeline
40+
# make sure you're logged in with `huggingface-cli login`
41+
model_id_or_path = "CompVis/stable-diffusion-v1-4"
42+
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
43+
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
44+
45+
# let's download an initial image
46+
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png"
47+
response = requests.get(url)
48+
init_image = Image.open(BytesIO(response.content)).convert("RGB")
49+
init_image = init_image.resize((512, 512))
50+
init_image.save("horse.png")
51+
52+
# let's specify a prompt
53+
source_prompt = "An astronaut riding a horse"
54+
prompt = "An astronaut riding an elephant"
55+
56+
# call the pipeline
57+
image = pipe(
58+
prompt=prompt,
59+
source_prompt=source_prompt,
60+
init_image=init_image,
61+
num_inference_steps=100,
62+
eta=0.1,
63+
strength=0.8,
64+
guidance_scale=2,
65+
source_guidance_scale=1,
66+
).images[0]
67+
68+
image.save("horse_to_elephant.png")
69+
70+
# let's try another example
71+
# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion
72+
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
73+
response = requests.get(url)
74+
init_image = Image.open(BytesIO(response.content)).convert("RGB")
75+
init_image = init_image.resize((512, 512))
76+
init_image.save("black.png")
77+
78+
source_prompt = "A black colored car"
79+
prompt = "A blue colored car"
80+
81+
# call the pipeline
82+
torch.manual_seed(0)
83+
image = pipe(
84+
prompt=prompt,
85+
source_prompt=source_prompt,
86+
init_image=init_image,
87+
num_inference_steps=100,
88+
eta=0.1,
89+
strength=0.85,
90+
guidance_scale=3,
91+
source_guidance_scale=1,
92+
).images[0]
93+
94+
image.save("black_to_blue.png")
95+
```
96+
97+
## CycleDiffusionPipeline
98+
[[autodoc]] CycleDiffusionPipeline
99+
- __call__

docs/source/api/pipelines/overview.mdx

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,24 @@ If you are looking for *official* training examples, please have a look at [exam
4141
The following table summarizes all officially supported pipelines, their corresponding paper, and if
4242
available a colab notebook to directly try them out.
4343

44+
4445
| Pipeline | Paper | Tasks | Colab
4546
|---|---|:---:|:---:|
46-
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
47-
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
48-
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
49-
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
50-
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
51-
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
52-
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
53-
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
54-
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
55-
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
56-
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
57-
| [vq_diffusion](./vq_diffusion) | [**Vector Quantized Diffusion Model for Text-to-Image Synthesis**](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
58-
| [repaint](./repaint) | [**RePaint: Inpainting using Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2201.09865) | Image Inpainting |
47+
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
48+
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
49+
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
50+
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
51+
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
52+
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
53+
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
54+
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
55+
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
56+
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
57+
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
58+
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
59+
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
60+
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
61+
5962

6063
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
6164

docs/source/index.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ available a colab notebook to directly try them out.
3434

3535
| Pipeline | Paper | Tasks | Colab
3636
|---|---|:---:|:---:|
37+
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
3738
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
3839
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
3940
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363

6464
if is_torch_available() and is_transformers_available():
6565
from .pipelines import (
66+
CycleDiffusionPipeline,
6667
LDMTextToImagePipeline,
6768
StableDiffusionImg2ImgPipeline,
6869
StableDiffusionInpaintPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
if is_torch_available() and is_transformers_available():
1717
from .latent_diffusion import LDMTextToImagePipeline
1818
from .stable_diffusion import (
19+
CycleDiffusionPipeline,
1920
StableDiffusionImg2ImgPipeline,
2021
StableDiffusionInpaintPipeline,
2122
StableDiffusionInpaintPipelineLegacy,

src/diffusers/pipelines/stable_diffusion/README.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,74 @@ image = pipe(prompt).sample[0]
103103

104104
image.save("astronaut_rides_horse.png")
105105
```
106+
107+
### CycleDiffusion using Stable Diffusion and DDIM scheduler
108+
109+
```python
110+
import requests
111+
import torch
112+
from PIL import Image
113+
from io import BytesIO
114+
115+
from diffusers import CycleDiffusionPipeline, DDIMScheduler
116+
117+
118+
# load the scheduler. CycleDiffusion only supports stochastic schedulers.
119+
120+
# load the pipeline
121+
# make sure you're logged in with `huggingface-cli login`
122+
model_id_or_path = "CompVis/stable-diffusion-v1-4"
123+
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
124+
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
125+
126+
# let's download an initial image
127+
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png"
128+
response = requests.get(url)
129+
init_image = Image.open(BytesIO(response.content)).convert("RGB")
130+
init_image = init_image.resize((512, 512))
131+
init_image.save("horse.png")
132+
133+
# let's specify a prompt
134+
source_prompt = "An astronaut riding a horse"
135+
prompt = "An astronaut riding an elephant"
136+
137+
# call the pipeline
138+
image = pipe(
139+
prompt=prompt,
140+
source_prompt=source_prompt,
141+
init_image=init_image,
142+
num_inference_steps=100,
143+
eta=0.1,
144+
strength=0.8,
145+
guidance_scale=2,
146+
source_guidance_scale=1,
147+
).images[0]
148+
149+
image.save("horse_to_elephant.png")
150+
151+
# let's try another example
152+
# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion
153+
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
154+
response = requests.get(url)
155+
init_image = Image.open(BytesIO(response.content)).convert("RGB")
156+
init_image = init_image.resize((512, 512))
157+
init_image.save("black.png")
158+
159+
source_prompt = "A black colored car"
160+
prompt = "A blue colored car"
161+
162+
# call the pipeline
163+
torch.manual_seed(0)
164+
image = pipe(
165+
prompt=prompt,
166+
source_prompt=source_prompt,
167+
init_image=init_image,
168+
num_inference_steps=100,
169+
eta=0.1,
170+
strength=0.85,
171+
guidance_scale=3,
172+
source_guidance_scale=1,
173+
).images[0]
174+
175+
image.save("black_to_blue.png")
176+
```

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
2828

2929

3030
if is_transformers_available() and is_torch_available():
31+
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
3132
from .pipeline_stable_diffusion import StableDiffusionPipeline
3233
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
3334
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline

0 commit comments

Comments
 (0)