Skip to content

Commit 065d1d8

Browse files
committed
Merge remote-tracking branch 'origin/main' into fix-lms-mps
2 parents dc936fa + 4b9f589 commit 065d1d8

File tree

110 files changed

+4880
-733
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

110 files changed

+4880
-733
lines changed

.github/workflows/pr_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ jobs:
8989
- name: Run all fast tests on MPS
9090
shell: arch -arch arm64 bash {0}
9191
run: |
92-
${CONDA_RUN} python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_mps tests/
92+
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
9393
9494
- name: Failure short reports
9595
if: ${{ failure() }}

README.md

Lines changed: 96 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,44 +64,54 @@ In order to get started, we recommend taking a look at two notebooks:
6464
- The [Training a diffusers model](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) notebook summarizes diffusion models training methods. This notebook takes a step-by-step approach to training your
6565
diffusion models on an image dataset, with explanatory graphics.
6666

67-
## **New** Stable Diffusion is now fully compatible with `diffusers`!
67+
## Stable Diffusion is fully compatible with `diffusers`!
6868

69-
Stable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). It's trained on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM.
69+
Stable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [LAION](https://laion.ai/) and [RunwayML](https://runwayml.com/). It's trained on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 4GB VRAM.
7070
See the [model card](https://huggingface.co/CompVis/stable-diffusion) for more information.
7171

72-
You need to accept the model license before downloading or using the Stable Diffusion weights. Please, visit the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section](https://huggingface.co/docs/hub/security-tokens) of the documentation.
72+
You need to accept the model license before downloading or using the Stable Diffusion weights. Please, visit the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license carefully and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section](https://huggingface.co/docs/hub/security-tokens) of the documentation.
7373

7474

7575
### Text-to-Image generation with Stable Diffusion
7676

77+
First let's install
78+
```bash
79+
pip install --upgrade diffusers transformers scipy
80+
```
81+
82+
Run this command to log in with your HF Hub token if you haven't before (you can skip this step if you prefer to run the model locally, follow [this](#running-the-model-locally) instead)
83+
```bash
84+
huggingface-cli login
85+
```
86+
7787
We recommend using the model in [half-precision (`fp16`)](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/) as it gives almost always the same results as full
7888
precision while being roughly twice as fast and requiring half the amount of GPU RAM.
7989

8090
```python
81-
# make sure you're logged in with `huggingface-cli login`
8291
from diffusers import StableDiffusionPipeline
8392

84-
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_type=torch.float16, revision="fp16")
93+
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, revision="fp16")
8594
pipe = pipe.to("cuda")
8695

8796
prompt = "a photo of an astronaut riding a horse on mars"
8897
image = pipe(prompt).images[0]
8998
```
9099

91-
**Note**: If you don't want to use the token, you can also simply download the model weights
92-
(after having [accepted the license](https://huggingface.co/CompVis/stable-diffusion-v1-4)) and pass
100+
#### Running the model locally
101+
If you don't want to login to Hugging Face, you can also simply download the model folder
102+
(after having [accepted the license](https://huggingface.co/runwayml/stable-diffusion-v1-5)) and pass
93103
the path to the local folder to the `StableDiffusionPipeline`.
94104

95105
```
96106
git lfs install
97-
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
107+
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
98108
```
99109

100-
Assuming the folder is stored locally under `./stable-diffusion-v1-4`, you can also run stable diffusion
110+
Assuming the folder is stored locally under `./stable-diffusion-v1-5`, you can also run stable diffusion
101111
without requiring an authentication token:
102112

103113
```python
104-
pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
114+
pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
105115
pipe = pipe.to("cuda")
106116

107117
prompt = "a photo of an astronaut riding a horse on mars"
@@ -114,7 +124,7 @@ The following snippet should result in less than 4GB VRAM.
114124

115125
```python
116126
pipe = StableDiffusionPipeline.from_pretrained(
117-
"CompVis/stable-diffusion-v1-4",
127+
"runwayml/stable-diffusion-v1-5",
118128
revision="fp16",
119129
torch_dtype=torch.float16,
120130
)
@@ -125,7 +135,7 @@ pipe.enable_attention_slicing()
125135
image = pipe(prompt).images[0]
126136
```
127137

128-
If you wish to use a different scheduler, you can simply instantiate
138+
If you wish to use a different scheduler (e.g.: DDIM, LMS, PNDM/PLMS), you can instantiate
129139
it before the pipeline and pass it to `from_pretrained`.
130140

131141
```python
@@ -138,7 +148,7 @@ lms = LMSDiscreteScheduler(
138148
)
139149

140150
pipe = StableDiffusionPipeline.from_pretrained(
141-
"CompVis/stable-diffusion-v1-4",
151+
"runwayml/stable-diffusion-v1-5",
142152
revision="fp16",
143153
torch_dtype=torch.float16,
144154
scheduler=lms,
@@ -158,7 +168,7 @@ please run the model in the default *full-precision* setting:
158168
# make sure you're logged in with `huggingface-cli login`
159169
from diffusers import StableDiffusionPipeline
160170

161-
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
171+
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
162172

163173
# disable the following line if you run on CPU
164174
pipe = pipe.to("cuda")
@@ -169,6 +179,75 @@ image = pipe(prompt).images[0]
169179
image.save("astronaut_rides_horse.png")
170180
```
171181

182+
### JAX/Flax
183+
184+
To use StableDiffusion on TPUs and GPUs for faster inference you can leverage JAX/Flax.
185+
186+
Running the pipeline with default PNDMScheduler
187+
188+
```python
189+
import jax
190+
import numpy as np
191+
from flax.jax_utils import replicate
192+
from flax.training.common_utils import shard
193+
194+
from diffusers import FlaxStableDiffusionPipeline
195+
196+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
197+
"runwayml/stable-diffusion-v1-5", revision="flax", dtype=jax.numpy.bfloat16
198+
)
199+
200+
prompt = "a photo of an astronaut riding a horse on mars"
201+
202+
prng_seed = jax.random.PRNGKey(0)
203+
num_inference_steps = 50
204+
205+
num_samples = jax.device_count()
206+
prompt = num_samples * [prompt]
207+
prompt_ids = pipeline.prepare_inputs(prompt)
208+
209+
# shard inputs and rng
210+
params = replicate(params)
211+
prng_seed = jax.random.split(prng_seed, jax.device_count())
212+
prompt_ids = shard(prompt_ids)
213+
214+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
215+
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
216+
```
217+
218+
**Note**:
219+
If you are limited by TPU memory, please make sure to load the `FlaxStableDiffusionPipeline` in `bfloat16` precision instead of the default `float32` precision as done above. You can do so by telling diffusers to load the weights from "bf16" branch.
220+
221+
```python
222+
import jax
223+
import numpy as np
224+
from flax.jax_utils import replicate
225+
from flax.training.common_utils import shard
226+
227+
from diffusers import FlaxStableDiffusionPipeline
228+
229+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
230+
"runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16
231+
)
232+
233+
prompt = "a photo of an astronaut riding a horse on mars"
234+
235+
prng_seed = jax.random.PRNGKey(0)
236+
num_inference_steps = 50
237+
238+
num_samples = jax.device_count()
239+
prompt = num_samples * [prompt]
240+
prompt_ids = pipeline.prepare_inputs(prompt)
241+
242+
# shard inputs and rng
243+
params = replicate(params)
244+
prng_seed = jax.random.split(prng_seed, jax.device_count())
245+
prompt_ids = shard(prompt_ids)
246+
247+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
248+
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
249+
```
250+
172251
### Image-to-Image text-guided generation with Stable Diffusion
173252

174253
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
@@ -183,14 +262,14 @@ from diffusers import StableDiffusionImg2ImgPipeline
183262

184263
# load the pipeline
185264
device = "cuda"
186-
model_id_or_path = "CompVis/stable-diffusion-v1-4"
265+
model_id_or_path = "runwayml/stable-diffusion-v1-5"
187266
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
188267
model_id_or_path,
189268
revision="fp16",
190269
torch_dtype=torch.float16,
191270
)
192-
# or download via git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
193-
# and pass `model_id_or_path="./stable-diffusion-v1-4"`.
271+
# or download via git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
272+
# and pass `model_id_or_path="./stable-diffusion-v1-5"`.
194273
pipe = pipe.to(device)
195274

196275
# let's download an initial image

docs/source/_toctree.yml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
title: "Loading Pipelines, Models, and Schedulers"
1313
- local: using-diffusers/configuration
1414
title: "Configuring Pipelines, Models, and Schedulers"
15-
- local: using-diffusers/custom_pipelines
16-
title: "Loading and Creating Custom Pipelines"
17-
title: "Loading"
15+
- local: using-diffusers/custom_pipeline_overview
16+
title: "Loading and Adding Custom Pipelines"
17+
title: "Loading & Hub"
1818
- sections:
1919
- local: using-diffusers/unconditional_image_generation
2020
title: "Unconditional Image Generation"
@@ -24,8 +24,10 @@
2424
title: "Text-Guided Image-to-Image"
2525
- local: using-diffusers/inpaint
2626
title: "Text-Guided Image-Inpainting"
27-
- local: using-diffusers/custom
28-
title: "Create a custom pipeline"
27+
- local: using-diffusers/custom_pipeline_examples
28+
title: "Community Pipelines"
29+
- local: using-diffusers/contribute_pipeline
30+
title: "How to contribute a Pipeline"
2931
title: "Pipelines for Inference"
3032
title: "Using Diffusers"
3133
- sections:
@@ -34,7 +36,7 @@
3436
- local: optimization/onnx
3537
title: "ONNX"
3638
- local: optimization/open_vino
37-
title: "Open Vino"
39+
title: "OpenVINO"
3840
- local: optimization/mps
3941
title: "MPS"
4042
title: "Optimization/Special Hardware"
@@ -90,5 +92,7 @@
9092
title: "Stable Diffusion"
9193
- local: api/pipelines/stochastic_karras_ve
9294
title: "Stochastic Karras VE"
95+
- local: api/pipelines/dance_diffusion
96+
title: "Dance Diffusion"
9397
title: "Pipelines"
9498
title: "API"

docs/source/api/models.mdx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
2222
## UNet2DOutput
2323
[[autodoc]] models.unet_2d.UNet2DOutput
2424

25+
## UNet1DModel
26+
[[autodoc]] UNet1DModel
27+
2528
## UNet2DModel
2629
[[autodoc]] UNet2DModel
2730

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Dance Diffusion
14+
15+
## Overview
16+
17+
[Dance Diffusion](https://github.com/Harmonai-org/sample-generator) by Zach Evans.
18+
19+
Dance Diffusion is the first in a suite of generative audio tools for producers and musicians to be released by Harmonai.
20+
For more info or to get involved in the development of these tools, please visit https://harmonai.org and fill out the form on the front page.
21+
22+
The original codebase of this implementation can be found [here](https://github.com/Harmonai-org/sample-generator).
23+
24+
## Available Pipelines:
25+
26+
| Pipeline | Tasks | Colab
27+
|---|---|:---:|
28+
| [pipeline_dance_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py) | *Unconditional Audio Generation* | - |
29+
30+
31+
## DanceDiffusionPipeline
32+
[[autodoc]] DanceDiffusionPipeline
33+
- __call__

docs/source/api/pipelines/ddim.mdx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
113
# DDIM
214

315
## Overview

docs/source/api/pipelines/ddpm.mdx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
113
# DDPM
214

315
## Overview

docs/source/api/pipelines/latent_diffusion.mdx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
113
# Latent Diffusion
214

315
## Overview

docs/source/api/pipelines/latent_diffusion_uncond.mdx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
113
# Unconditional Latent Diffusion
214

315
## Overview

docs/source/api/pipelines/overview.mdx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ Diffusion models often consist of multiple independently-trained models or other
6767
Each model has been trained independently on a different task and the scheduler can easily be swapped out and replaced with a different one.
6868
During inference, we however want to be able to easily load all components and use them in inference - even if one component, *e.g.* CLIP's text encoder, originates from a different library, such as [Transformers](https://github.com/huggingface/transformers). To that end, all pipelines provide the following functionality:
6969

70-
- [`from_pretrained` method](../diffusion_pipeline) that accepts a Hugging Face Hub repository id, *e.g.* [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) or a path to a local directory, *e.g.*
71-
"./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [CompVis/stable-diffusion-v1-4/model_index.json](https://huggingface.co/CompVis/stable-diffusion-v1-4/blob/main/model_index.json), which defines all components that should be
70+
- [`from_pretrained` method](../diffusion_pipeline) that accepts a Hugging Face Hub repository id, *e.g.* [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) or a path to a local directory, *e.g.*
71+
"./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [runwayml/stable-diffusion-v1-5/model_index.json](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), which defines all components that should be
7272
loaded into the pipelines. More specifically, for each model/component one needs to define the format `<name>: ["<library>", "<class name>"]`. `<name>` is the attribute name given to the loaded instance of `<class name>` which can be found in the library or pipeline folder called `"<library>"`.
7373
- [`save_pretrained`](../diffusion_pipeline) that accepts a local path, *e.g.* `./stable-diffusion` under which all models/components of the pipeline will be saved. For each component/model a folder is created inside the local path that is named after the given attribute name, *e.g.* `./stable_diffusion/unet`.
7474
In addition, a `model_index.json` file is created at the root of the local path, *e.g.* `./stable_diffusion/model_index.json` so that the complete pipeline can again be instantiated
@@ -100,7 +100,7 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing
100100
# make sure you're logged in with `huggingface-cli login`
101101
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
102102

103-
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
103+
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
104104
pipe = pipe.to("cuda")
105105

106106
prompt = "a photo of an astronaut riding a horse on mars"
@@ -123,7 +123,7 @@ from diffusers import StableDiffusionImg2ImgPipeline
123123
# load the pipeline
124124
device = "cuda"
125125
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
126-
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
126+
"runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16
127127
).to(device)
128128

129129
# let's download an initial image

0 commit comments

Comments
 (0)