diff --git a/README.md b/README.md index 8af4766f6744..5b3660349e90 100644 --- a/README.md +++ b/README.md @@ -76,15 +76,13 @@ You need to accept the model license before downloading or using the Stable Diff ```python # make sure you're logged in with `huggingface-cli login` -from torch import autocast from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).images[0] ``` **Note**: If you don't want to use the token, you can also simply download the model weights @@ -104,8 +102,7 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4") pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).images[0] ``` If you are limited by GPU memory, you might want to consider using the model in `fp16` as @@ -123,8 +120,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" pipe.enable_attention_slicing() -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).images[0] ``` Finally, if you wish to use a different scheduler, you can simply instantiate @@ -149,8 +145,7 @@ pipe = StableDiffusionPipeline.from_pretrained( pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` @@ -160,7 +155,6 @@ image.save("astronaut_rides_horse.png") The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images. ```python -from torch import autocast import requests import torch from PIL import Image @@ -190,8 +184,7 @@ init_image = init_image.resize((768, 512)) prompt = "A fantasy landscape, trending on artstation" -with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images +images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images images[0].save("fantasy_landscape.png") ``` @@ -204,7 +197,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by ```python from io import BytesIO -from torch import autocast import torch import requests import PIL @@ -234,8 +226,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = pipe.to(device) prompt = "a cat sitting on a bench" -with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images +images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images images[0].save("cat_on_bench.png") ``` @@ -258,7 +249,6 @@ If you want to run the code yourself 💻, you can try out: - [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256) ```python # !pip install diffusers transformers -from torch import autocast from diffusers import DiffusionPipeline device = "cuda" @@ -270,8 +260,7 @@ ldm = ldm.to(device) # run pipeline in inference (sample random noise and denoise) prompt = "A painting of a squirrel eating a burger" -with autocast(device): - image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0] +image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0] # save image image.save("squirrel.png") @@ -279,7 +268,6 @@ image.save("squirrel.png") - [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256) ```python # !pip install diffusers -from torch import autocast from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline model_id = "google/ddpm-celebahq-256" @@ -290,8 +278,7 @@ ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline wi ddpm.to(device) # run pipeline in inference (sample random noise and denoise) -with autocast("cuda"): - image = ddpm().images[0] +image = ddpm().images[0] # save image image.save("ddpm_generated_image.png") diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index dd9e2e570bf7..d510589404bc 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -266,7 +266,12 @@ def forward( timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) - emb = self.time_embedding(t_emb.to(self.dtype)) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md index 3462f5ff518d..71841e023372 100644 --- a/src/diffusers/pipelines/README.md +++ b/src/diffusers/pipelines/README.md @@ -86,15 +86,13 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing ```python # make sure you're logged in with `huggingface-cli login` -from torch import autocast from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` @@ -104,7 +102,6 @@ image.save("astronaut_rides_horse.png") The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images. ```python -from torch import autocast import requests from PIL import Image from io import BytesIO @@ -129,8 +126,7 @@ init_image = init_image.resize((768, 512)) prompt = "A fantasy landscape, trending on artstation" -with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images +images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images images[0].save("fantasy_landscape.png") ``` @@ -148,7 +144,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by ```python from io import BytesIO -from torch import autocast import requests import PIL @@ -173,8 +168,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained( ).to(device) prompt = "a cat sitting on a bench" -with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images +images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images images[0].save("cat_on_bench.png") ``` diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md index 3a600c5859e9..45c4e4798cf7 100644 --- a/src/diffusers/pipelines/stable_diffusion/README.md +++ b/src/diffusers/pipelines/stable_diffusion/README.md @@ -59,15 +59,13 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4") ```python # make sure you're logged in with `huggingface-cli login` -from torch import autocast from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).sample[0] image.save("astronaut_rides_horse.png") ``` @@ -76,7 +74,6 @@ image.save("astronaut_rides_horse.png") ```python # make sure you're logged in with `huggingface-cli login` -from torch import autocast from diffusers import StableDiffusionPipeline, DDIMScheduler scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) @@ -88,8 +85,7 @@ pipe = StableDiffusionPipeline.from_pretrained( ).to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).sample[0] image.save("astronaut_rides_horse.png") ``` @@ -98,7 +94,6 @@ image.save("astronaut_rides_horse.png") ```python # make sure you're logged in with `huggingface-cli login` -from torch import autocast from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler lms = LMSDiscreteScheduler( @@ -114,8 +109,7 @@ pipe = StableDiffusionPipeline.from_pretrained( ).to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).sample[0] image.save("astronaut_rides_horse.png") ``` diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index bc6ca1efbd9b..705b23d5fa7e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -260,19 +260,20 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_device = "cpu" if self.device.type == "mps" else self.device latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_dtype = text_embeddings.dtype if latents is None: - latents = torch.randn( - latents_shape, - generator=generator, - device=latents_device, - dtype=text_embeddings.dtype, - ) + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(latents_device) + latents = latents.to(self.device) # set timesteps self.scheduler.set_timesteps(num_inference_steps) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 78a22ec3138b..ca82cf5ce9b9 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1214,6 +1214,37 @@ def test_stable_diffusion_memory_chunking(self): assert mem_bytes > 3.75 * 10**9 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_text2img_pipeline_fp16(self): + torch.cuda.reset_peak_memory_stats() + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained( + model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True + ).to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output_chunked = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # Make sure results are close enough + diff = np.abs(image_chunked.flatten() - image.flatten()) + # They ARE different since ops are not run always at the same precision + # however, they should be extremely close. + assert diff.mean() < 2e-2 + @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") def test_stable_diffusion_text2img_pipeline(self):