Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 8 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,13 @@ You need to accept the model license before downloading or using the Stable Diff

```python
# make sure you're logged in with `huggingface-cli login`
from torch import autocast
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).images[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that this is faster? Using autocast gives currently (before this PR) a 2x boost in terms of generation speed.

Will also test a bit locally on a GPU tomorrow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is extremely surprising but I am also measuring a 2x speedup with autocast on f32.

I am looking into it, I see to copies but not nearly the same amount, there's probably a device-to-host /host-to-device somewhere that kills performance but I haven't found it yet.

Copy link
Contributor Author

@Narsil Narsil Sep 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I figure it out.

autocast will actually use fp16 for some ops by doing some heuristics. https://pytorch.org/docs/stable/amp.html#cuda-op-specific-behavior

So it's faster because it's running on fp16 even if the model was loaded in f32.
So without it it's slower just because it's actually running f32.

If we enable real fp16 with a big performance boost I feel like we shouldn't need it f32 (but that does make it slower but also "more" correct.). Some ops are kind of dangerous to actual run in fp16 but we should be able to tell them apart (and for now it seems the generations are actually still good enough even when running everything in f16)

But it's still a nice way to get f16 running "for free". The heuristics they use seem OK (but as they mention, they probably wouldn't work for gradients, I guess because fp16 overflows faster)

Copy link
Contributor

@patrickvonplaten patrickvonplaten Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented extensively here: #511

Should we maybe change the example to load native fp16 weights then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So replace:

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)

by

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, revision="fp16")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems OK. Is fp16 supported by enough GPU cards at this point ? That would be my only point of concern.
But given the speed difference, advocating for fp16 is definitely something we should do !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think fp16 is widely supported now

```

**Note**: If you don't want to use the token, you can also simply download the model weights
Expand All @@ -104,8 +102,7 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).images[0]
```

If you are limited by GPU memory, you might want to consider using the model in `fp16` as
Expand All @@ -123,8 +120,7 @@ pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_attention_slicing()
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).images[0]
```

Finally, if you wish to use a different scheduler, you can simply instantiate
Expand All @@ -149,8 +145,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).images[0]

image.save("astronaut_rides_horse.png")
```
Expand All @@ -160,7 +155,6 @@ image.save("astronaut_rides_horse.png")
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.

```python
from torch import autocast
import requests
import torch
from PIL import Image
Expand Down Expand Up @@ -190,8 +184,7 @@ init_image = init_image.resize((768, 512))

prompt = "A fantasy landscape, trending on artstation"

with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images

images[0].save("fantasy_landscape.png")
```
Expand All @@ -204,7 +197,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
```python
from io import BytesIO

from torch import autocast
import torch
import requests
import PIL
Expand Down Expand Up @@ -234,8 +226,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
pipe = pipe.to(device)

prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images

images[0].save("cat_on_bench.png")
```
Expand All @@ -258,7 +249,6 @@ If you want to run the code yourself 💻, you can try out:
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
```python
# !pip install diffusers transformers
from torch import autocast
from diffusers import DiffusionPipeline

device = "cuda"
Expand All @@ -270,16 +260,14 @@ ldm = ldm.to(device)

# run pipeline in inference (sample random noise and denoise)
prompt = "A painting of a squirrel eating a burger"
with autocast(device):
image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]
image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]

# save image
image.save("squirrel.png")
```
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
```python
# !pip install diffusers
from torch import autocast
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline

model_id = "google/ddpm-celebahq-256"
Expand All @@ -290,8 +278,7 @@ ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline wi
ddpm.to(device)

# run pipeline in inference (sample random noise and denoise)
with autocast("cuda"):
image = ddpm().images[0]
image = ddpm().images[0]

# save image
image.save("ddpm_generated_image.png")
Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,12 @@ def forward(
timesteps = timesteps.expand(sample.shape[0])

t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb.to(self.dtype))

# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)

# 2. pre-process
sample = self.conv_in(sample)
Expand Down
12 changes: 3 additions & 9 deletions src/diffusers/pipelines/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,13 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing

```python
# make sure you're logged in with `huggingface-cli login`
from torch import autocast
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).images[0]

image.save("astronaut_rides_horse.png")
```
Expand All @@ -104,7 +102,6 @@ image.save("astronaut_rides_horse.png")
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.

```python
from torch import autocast
import requests
from PIL import Image
from io import BytesIO
Expand All @@ -129,8 +126,7 @@ init_image = init_image.resize((768, 512))

prompt = "A fantasy landscape, trending on artstation"

with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images

images[0].save("fantasy_landscape.png")
```
Expand All @@ -148,7 +144,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
```python
from io import BytesIO

from torch import autocast
import requests
import PIL

Expand All @@ -173,8 +168,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
).to(device)

prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images

images[0].save("cat_on_bench.png")
```
Expand Down
12 changes: 3 additions & 9 deletions src/diffusers/pipelines/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,13 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")

```python
# make sure you're logged in with `huggingface-cli login`
from torch import autocast
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).sample[0]

image.save("astronaut_rides_horse.png")
```
Expand All @@ -76,7 +74,6 @@ image.save("astronaut_rides_horse.png")

```python
# make sure you're logged in with `huggingface-cli login`
from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler

scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
Expand All @@ -88,8 +85,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
).to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).sample[0]

image.save("astronaut_rides_horse.png")
```
Expand All @@ -98,7 +94,6 @@ image.save("astronaut_rides_horse.png")

```python
# make sure you're logged in with `huggingface-cli login`
from torch import autocast
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler

lms = LMSDiscreteScheduler(
Expand All @@ -114,8 +109,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
).to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).sample[0]

image.save("astronaut_rides_horse.png")
```
Original file line number Diff line number Diff line change
Expand Up @@ -260,19 +260,20 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
latents = torch.randn(
latents_shape,
generator=generator,
device=latents_device,
dtype=text_embeddings.dtype,
)
if self.device.type == "mps":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca could you take a look here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I commented in the other conservation, I think it's ok like this.

# randn does not exist on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(latents_device)
latents = latents.to(self.device)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
Expand Down
31 changes: 31 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,37 @@ def test_stable_diffusion_memory_chunking(self):
assert mem_bytes > 3.75 * 10**9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_text2img_pipeline_fp16(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
).to(torch_device)
pipe.set_progress_bar_config(disable=None)

prompt = "a photograph of an astronaut riding a horse"

generator = torch.Generator(device=torch_device).manual_seed(0)
output_chunked = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image_chunked = output_chunked.images

generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image = output.images

# Make sure results are close enough
diff = np.abs(image_chunked.flatten() - image.flatten())
# They ARE different since ops are not run always at the same precision
# however, they should be extremely close.
assert diff.mean() < 2e-2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test tolerance is a bit high to me... => will play around with it a bit!

Copy link
Contributor Author

@Narsil Narsil Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's roughly exactly the same difference as in #371 for pure fp16 run.

Running some ops in f32 directly instead of f16 (which autocast will do) does change some patches.
This is less than 2% total variance in images, so it really doesn't show visually. (Much less than some other changes in #371 where I think some visual differences were visible).


@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_text2img_pipeline(self):
Expand Down