Skip to content

Commit 53a845a

Browse files
authored
Merge branch 'main' into main
2 parents ddbdec7 + 5156acc commit 53a845a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1650
-732
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,16 @@ This library concretizes previous work by many different authors and would not h
377377
- @yang-song's Score-VE and Score-VP implementations, available [here](https://github.com/yang-song/score_sde_pytorch)
378378

379379
We also want to thank @heejkoo for the very helpful overview of papers, code and resources on diffusion models, available [here](https://github.com/heejkoo/Awesome-Diffusion-Models) as well as @crowsonkb and @rromb for useful discussions and insights.
380+
381+
## Citation
382+
383+
```bibtex
384+
@misc{von-platen-etal-2022-diffusers,
385+
author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj},
386+
title = {Diffusers: State-of-the-art diffusion models},
387+
year = {2022},
388+
publisher = {GitHub},
389+
journal = {GitHub repository},
390+
howpublished = {\url{https://github.com/huggingface/diffusers}}
391+
}
392+
```

docs/source/api/schedulers.mdx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ To this end, the design of schedulers is such that:
4444
The core API for any new scheduler must follow a limited structure.
4545
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
4646
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
47-
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
48-
with a `set_format(...)` method.
47+
- Schedulers should be framework-specific.
4948

5049
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.
5150

docs/source/index.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ available a colab notebook to directly try them out.
3535
| Pipeline | Paper | Tasks | Colab
3636
|---|---|:---:|:---:|
3737
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
38-
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
38+
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
3939
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
4040
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
4141
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |

docs/source/optimization/fp16.mdx

Lines changed: 172 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,41 @@ specific language governing permissions and limitations under the License.
1414

1515
We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for memory or speed.
1616

17-
## CUDA `autocast`
17+
18+
| | Latency | Speedup |
19+
|------------------|---------|---------|
20+
| original | 9.50s | x1 |
21+
| cuDNN auto-tuner | 9.37s | x1.01 |
22+
| autocast (fp16) | 5.47s | x1.91 |
23+
| fp16 | 3.61s | x2.91 |
24+
| channels last | 3.30s | x2.87 |
25+
| traced UNet | 3.21s | x2.96 |
26+
27+
<em>obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM steps.</em>
28+
29+
## Enable cuDNN auto-tuner
30+
31+
[NVIDIA cuDNN](https://developer.nvidia.com/cudnn) supports many algorithms to compute a convolution. Autotuner runs a short benchmark and selects the kernel with the best performance on a given hardware for a given input size.
32+
33+
Since we’re using **convolutional networks** (other types currently not supported), we can enable cuDNN autotuner before launching the inference by setting:
34+
35+
```python
36+
import torch
37+
38+
torch.backends.cudnn.benchmark = True
39+
```
40+
41+
### Use tf32 instead of fp32 (on Ampere and later CUDA devices)
42+
43+
On Ampere and later CUDA devices matrix multiplications and convolutions can use the TensorFloat32 (TF32) mode for faster but slightly less accurate computations. By default PyTorch enables TF32 mode for convolutions but not matrix multiplications, and unless a network requires full float32 precision we recommend enabling this setting for matrix multiplications, too. It can significantly speed up computations with typically negligible loss of numerical accuracy. You can read more about it [here](https://huggingface.co/docs/transformers/v4.18.0/en/performance#tf32). All you need to do is to add this before your inference:
44+
45+
```python
46+
import torch
47+
48+
torch.backends.cuda.matmul.allow_tf32 = True
49+
```
50+
51+
## Automatic mixed precision (AMP)
1852

1953
If you use a CUDA GPU, you can take advantage of `torch.autocast` to perform inference roughly twice as fast at the cost of slightly lower precision. All you need to do is put your inference call inside an `autocast` context manager. The following example shows how to do it using Stable Diffusion text-to-image generation as an example:
2054

@@ -47,7 +81,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
4781

4882
## Sliced attention for additional memory savings
4983

50-
For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.
84+
For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.
5185

5286
<Tip>
5387
Attention slicing is useful even if a batch size of just 1 is used - as long as the model uses more than one attention head. If there is more than one attention head the *QK^T* attention matrix can be computed sequentially for each head which can save a significant amount of memory.
@@ -73,4 +107,139 @@ with torch.autocast("cuda"):
73107
image = pipe(prompt).images[0]
74108
```
75109

76-
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
110+
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
111+
112+
## Using Channels Last memory format
113+
114+
Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model.
115+
116+
For example, in order to set the UNet model in our pipeline to use channels last format, we can use the following:
117+
118+
```python
119+
print(pipe.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1)
120+
pipe.unet.to(memory_format=torch.channels_last) # in-place operation
121+
print(
122+
pipe.unet.conv_out.state_dict()["weight"].stride()
123+
) # (2880, 1, 960, 320) haveing a stride of 1 for the 2nd dimension proves that it works
124+
```
125+
126+
## Tracing
127+
128+
Tracing runs an example input tensor through your model, and captures the operations that are invoked as that input makes its way through the model's layers so that an executable or `ScriptFunction` is returned that will be optimized using just-in-time compilation.
129+
130+
To trace our UNet model, we can use the following:
131+
132+
```python
133+
import time
134+
import torch
135+
from diffusers import StableDiffusionPipeline
136+
import functools
137+
138+
# torch disable grad
139+
torch.set_grad_enabled(False)
140+
141+
# set variables
142+
n_experiments = 2
143+
unet_runs_per_experiment = 50
144+
145+
# load inputs
146+
def generate_inputs():
147+
sample = torch.randn(2, 4, 64, 64).half().cuda()
148+
timestep = torch.rand(1).half().cuda() * 999
149+
encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
150+
return sample, timestep, encoder_hidden_states
151+
152+
153+
pipe = StableDiffusionPipeline.from_pretrained(
154+
"CompVis/stable-diffusion-v1-4",
155+
# scheduler=scheduler,
156+
use_auth_token=True,
157+
revision="fp16",
158+
torch_dtype=torch.float16,
159+
).to("cuda")
160+
unet = pipe.unet
161+
unet.eval()
162+
unet.to(memory_format=torch.channels_last) # use channels_last memory format
163+
unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
164+
165+
# warmup
166+
for _ in range(3):
167+
with torch.inference_mode():
168+
inputs = generate_inputs()
169+
orig_output = unet(*inputs)
170+
171+
# trace
172+
print("tracing..")
173+
unet_traced = torch.jit.trace(unet, inputs)
174+
unet_traced.eval()
175+
print("done tracing")
176+
177+
178+
# warmup and optimize graph
179+
for _ in range(5):
180+
with torch.inference_mode():
181+
inputs = generate_inputs()
182+
orig_output = unet_traced(*inputs)
183+
184+
185+
# benchmarking
186+
with torch.inference_mode():
187+
for _ in range(n_experiments):
188+
torch.cuda.synchronize()
189+
start_time = time.time()
190+
for _ in range(unet_runs_per_experiment):
191+
orig_output = unet_traced(*inputs)
192+
torch.cuda.synchronize()
193+
print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
194+
for _ in range(n_experiments):
195+
torch.cuda.synchronize()
196+
start_time = time.time()
197+
for _ in range(unet_runs_per_experiment):
198+
orig_output = unet(*inputs)
199+
torch.cuda.synchronize()
200+
print(f"unet inference took {time.time() - start_time:.2f} seconds")
201+
202+
# save the model
203+
unet_traced.save("unet_traced.pt")
204+
```
205+
206+
Then we can replace the `unet` attribute of the pipeline with the traced model like the following
207+
208+
```python
209+
from diffusers import StableDiffusionPipeline
210+
import torch
211+
from dataclasses import dataclass
212+
213+
214+
@dataclass
215+
class UNet2DConditionOutput:
216+
sample: torch.FloatTensor
217+
218+
219+
pipe = StableDiffusionPipeline.from_pretrained(
220+
"CompVis/stable-diffusion-v1-4",
221+
# scheduler=scheduler,
222+
use_auth_token=True,
223+
revision="fp16",
224+
torch_dtype=torch.float16,
225+
).to("cuda")
226+
227+
# use jitted unet
228+
unet_traced = torch.jit.load("unet_traced.pt")
229+
# del pipe.unet
230+
class TracedUNet(torch.nn.Module):
231+
def __init__(self):
232+
super().__init__()
233+
self.in_channels = pipe.unet.in_channels
234+
self.device = pipe.unet.device
235+
236+
def forward(self, latent_model_input, t, encoder_hidden_states):
237+
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
238+
return UNet2DConditionOutput(sample=sample)
239+
240+
241+
pipe.unet = TracedUNet()
242+
243+
with torch.inference_mode():
244+
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
245+
```

examples/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
3838

3939
| Task | 🤗 Accelerate | 🤗 Datasets | Colab
4040
|---|---|:---:|:---:|
41-
| [**Unconditional Image Generation**](https://github.com/huggingface/transformers/tree/main/examples/training/train_unconditional.py) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
41+
| [**Unconditional Image Generation**](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
4242

4343
## Community
4444

examples/community/clip_guided_stable_diffusion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def __init__(
6060
feature_extractor: CLIPFeatureExtractor,
6161
):
6262
super().__init__()
63-
scheduler = scheduler.set_format("pt")
6463
self.register_modules(
6564
vae=vae,
6665
text_encoder=text_encoder,
@@ -274,7 +273,7 @@ def __call__(
274273
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
275274
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
276275

277-
# # predict the noise residual
276+
# predict the noise residual
278277
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
279278

280279
# perform classifier free guidance
@@ -285,7 +284,7 @@ def __call__(
285284
# perform clip guidance
286285
if clip_guidance_scale > 0:
287286
text_embeddings_for_guidance = (
288-
text_embeddings.chunk(2)[0] if do_classifier_free_guidance else text_embeddings
287+
text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
289288
)
290289
noise_pred, latents = self.cond_fn(
291290
latents,

0 commit comments

Comments
 (0)