diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml
index f63d4ffda464..c25aa888f459 100644
--- a/.github/workflows/pr_tests.yml
+++ b/.github/workflows/pr_tests.yml
@@ -21,7 +21,7 @@ jobs:
runs-on: [ self-hosted, docker-gpu ]
container:
image: python:3.7
- options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
steps:
- name: Checkout diffusers
diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml
index 3db6814e071d..3e4a81c91c01 100644
--- a/.github/workflows/push_tests.yml
+++ b/.github/workflows/push_tests.yml
@@ -15,14 +15,10 @@ env:
jobs:
run_tests_single_gpu:
name: Diffusers tests
- strategy:
- fail-fast: false
- matrix:
- machine_type: [ single-gpu ]
- runs-on: [ self-hosted, docker-gpu, '${{ matrix.machine_type }}' ]
+ runs-on: [ self-hosted, docker-gpu, single-gpu ]
container:
image: nvcr.io/nvidia/pytorch:22.07-py3
- options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
+ options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache
steps:
- name: Checkout diffusers
@@ -66,14 +62,10 @@ jobs:
run_examples_single_gpu:
name: Examples tests
- strategy:
- fail-fast: false
- matrix:
- machine_type: [ single-gpu ]
- runs-on: [ self-hosted, docker-gpu, '${{ matrix.machine_type }}' ]
+ runs-on: [ self-hosted, docker-gpu, single-gpu ]
container:
image: nvcr.io/nvidia/pytorch:22.07-py3
- options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
+ options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache
steps:
- name: Checkout diffusers
diff --git a/README.md b/README.md
index 5a25ce501263..f2abe1978a02 100644
--- a/README.md
+++ b/README.md
@@ -74,17 +74,18 @@ You need to accept the model license before downloading or using the Stable Diff
### Text-to-Image generation with Stable Diffusion
+We recommend using the model in [half-precision (`fp16`)](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/) as it gives almost always the same results as full
+precision while being roughly twice as fast and requiring half the amount of GPU RAM.
+
```python
# make sure you're logged in with `huggingface-cli login`
-from torch import autocast
from diffusers import StableDiffusionPipeline
-pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_type=torch.float16, revision="fp16")
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
-with autocast("cuda"):
- image = pipe(prompt).images[0]
+image = pipe(prompt).images[0]
```
**Note**: If you don't want to use the token, you can also simply download the model weights
@@ -104,12 +105,11 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
-with autocast("cuda"):
- image = pipe(prompt).images[0]
+image = pipe(prompt).images[0]
```
-If you are limited by GPU memory, you might want to consider using the model in `fp16` as
-well as chunking the attention computation.
+If you are limited by GPU memory, you might want to consider chunking the attention computation in addition
+to using `fp16`.
The following snippet should result in less than 4GB VRAM.
```python
@@ -117,17 +117,15 @@ pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
- use_auth_token=True
)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_attention_slicing()
-with autocast("cuda"):
- image = pipe(prompt).images[0]
+image = pipe(prompt).images[0]
```
-Finally, if you wish to use a different scheduler, you can simply instantiate
+If you wish to use a different scheduler, you can simply instantiate
it before the pipeline and pass it to `from_pretrained`.
```python
@@ -144,13 +142,29 @@ pipe = StableDiffusionPipeline.from_pretrained(
revision="fp16",
torch_dtype=torch.float16,
scheduler=lms,
- use_auth_token=True
)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
-with autocast("cuda"):
- image = pipe(prompt).images[0]
+image = pipe(prompt).images[0]
+
+image.save("astronaut_rides_horse.png")
+```
+
+If you want to run Stable Diffusion on CPU or you want to have maximum precision on GPU,
+please run the model in the default *full-precision* setting:
+
+```python
+# make sure you're logged in with `huggingface-cli login`
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
+
+# disable the following line if you run on CPU
+pipe = pipe.to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```
@@ -160,7 +174,6 @@ image.save("astronaut_rides_horse.png")
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
```python
-from torch import autocast
import requests
import torch
from PIL import Image
@@ -175,10 +188,9 @@ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id_or_path,
revision="fp16",
torch_dtype=torch.float16,
- use_auth_token=True
)
# or download via git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
-# and pass `model_id_or_path="./stable-diffusion-v1-4"` without having to use `use_auth_token=True`.
+# and pass `model_id_or_path="./stable-diffusion-v1-4"`.
pipe = pipe.to(device)
# let's download an initial image
@@ -190,8 +202,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
-with autocast("cuda"):
- images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
+images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```
@@ -204,7 +215,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
```python
from io import BytesIO
-from torch import autocast
import torch
import requests
import PIL
@@ -227,15 +237,13 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id_or_path,
revision="fp16",
torch_dtype=torch.float16,
- use_auth_token=True
)
# or download via git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
-# and pass `model_id_or_path="./stable-diffusion-v1-4"` without having to use `use_auth_token=True`.
+# and pass `model_id_or_path="./stable-diffusion-v1-4"`.
pipe = pipe.to(device)
prompt = "a cat sitting on a bench"
-with autocast("cuda"):
- images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
+images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png")
```
@@ -258,7 +266,6 @@ If you want to run the code yourself 💻, you can try out:
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
```python
# !pip install diffusers transformers
-from torch import autocast
from diffusers import DiffusionPipeline
device = "cuda"
@@ -270,8 +277,7 @@ ldm = ldm.to(device)
# run pipeline in inference (sample random noise and denoise)
prompt = "A painting of a squirrel eating a burger"
-with autocast(device):
- image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]
+image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]
# save image
image.save("squirrel.png")
@@ -279,7 +285,6 @@ image.save("squirrel.png")
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
```python
# !pip install diffusers
-from torch import autocast
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
model_id = "google/ddpm-celebahq-256"
@@ -290,8 +295,7 @@ ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline wi
ddpm.to(device)
# run pipeline in inference (sample random noise and denoise)
-with autocast("cuda"):
- image = ddpm().images[0]
+image = ddpm().images[0]
# save image
image.save("ddpm_generated_image.png")
@@ -377,3 +381,16 @@ This library concretizes previous work by many different authors and would not h
- @yang-song's Score-VE and Score-VP implementations, available [here](https://github.com/yang-song/score_sde_pytorch)
We also want to thank @heejkoo for the very helpful overview of papers, code and resources on diffusion models, available [here](https://github.com/heejkoo/Awesome-Diffusion-Models) as well as @crowsonkb and @rromb for useful discussions and insights.
+
+## Citation
+
+```bibtex
+@misc{von-platen-etal-2022-diffusers,
+ author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Thomas Wolf},
+ title = {Diffusers: State-of-the-art diffusion models},
+ year = {2022},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://github.com/huggingface/diffusers}}
+}
+```
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 3d1bd4929d88..390b8ff042db 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -12,6 +12,8 @@
title: "Loading Pipelines, Models, and Schedulers"
- local: using-diffusers/configuration
title: "Configuring Pipelines, Models, and Schedulers"
+ - local: using-diffusers/custom_pipelines
+ title: "Loading and Creating Custom Pipelines"
title: "Loading"
- sections:
- local: using-diffusers/unconditional_image_generation
diff --git a/docs/source/api/pipelines/overview.mdx b/docs/source/api/pipelines/overview.mdx
index a9b1bb282153..7b2d89e849d6 100644
--- a/docs/source/api/pipelines/overview.mdx
+++ b/docs/source/api/pipelines/overview.mdx
@@ -98,15 +98,13 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing
```python
# make sure you're logged in with `huggingface-cli login`
-from torch import autocast
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
-pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
-with autocast("cuda"):
- image = pipe(prompt).images[0]
+image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```
@@ -116,7 +114,6 @@ image.save("astronaut_rides_horse.png")
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
```python
-from torch import autocast
import requests
from PIL import Image
from io import BytesIO
@@ -126,7 +123,7 @@ from diffusers import StableDiffusionImg2ImgPipeline
# load the pipeline
device = "cuda"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True
+ "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
).to(device)
# let's download an initial image
@@ -138,8 +135,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
-with autocast("cuda"):
- images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
+images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```
@@ -157,7 +153,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
```python
from io import BytesIO
-from torch import autocast
import requests
import PIL
@@ -177,12 +172,11 @@ mask_image = download_image(mask_url).resize((512, 512))
device = "cuda"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True
+ "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
).to(device)
prompt = "a cat sitting on a bench"
-with autocast("cuda"):
- images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
+images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png")
```
diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx
index 2b6e58fe128d..12a6b5c587bc 100644
--- a/docs/source/api/schedulers.mdx
+++ b/docs/source/api/schedulers.mdx
@@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
To this end, the design of schedulers is such that:
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
-- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists).
+- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
## API
@@ -44,8 +44,7 @@ To this end, the design of schedulers is such that:
The core API for any new scheduler must follow a limited structure.
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
-- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
-with a `set_format(...)` method.
+- Schedulers should be framework-specific.
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index 434c58cc8b27..392b22399908 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -35,7 +35,7 @@ available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
-| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
+| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx
index 064bc58f8c2b..b561aedbfe4a 100644
--- a/docs/source/optimization/fp16.mdx
+++ b/docs/source/optimization/fp16.mdx
@@ -14,7 +14,41 @@ specific language governing permissions and limitations under the License.
We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for memory or speed.
-## CUDA `autocast`
+
+| | Latency | Speedup |
+|------------------|---------|---------|
+| original | 9.50s | x1 |
+| cuDNN auto-tuner | 9.37s | x1.01 |
+| autocast (fp16) | 5.47s | x1.91 |
+| fp16 | 3.61s | x2.91 |
+| channels last | 3.30s | x2.87 |
+| traced UNet | 3.21s | x2.96 |
+
+obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM steps.
+
+## Enable cuDNN auto-tuner
+
+[NVIDIA cuDNN](https://developer.nvidia.com/cudnn)Â supports many algorithms to compute a convolution. Autotuner runs a short benchmark and selects the kernel with the best performance on a given hardware for a given input size.
+
+Since we’re using **convolutional networks** (other types currently not supported), we can enable cuDNN autotuner before launching the inference by setting:
+
+```python
+import torch
+
+torch.backends.cudnn.benchmark = True
+```
+
+### Use tf32 instead of fp32 (on Ampere and later CUDA devices)
+
+On Ampere and later CUDA devices matrix multiplications and convolutions can use the TensorFloat32 (TF32) mode for faster but slightly less accurate computations. By default PyTorch enables TF32 mode for convolutions but not matrix multiplications, and unless a network requires full float32 precision we recommend enabling this setting for matrix multiplications, too. It can significantly speed up computations with typically negligible loss of numerical accuracy. You can read more about it [here](https://huggingface.co/docs/transformers/v4.18.0/en/performance#tf32). All you need to do is to add this before your inference:
+
+```python
+import torch
+
+torch.backends.cuda.matmul.allow_tf32 = True
+```
+
+## Automatic mixed precision (AMP)
If you use a CUDA GPU, you can take advantage of `torch.autocast` to perform inference roughly twice as fast at the cost of slightly lower precision. All you need to do is put your inference call inside an `autocast` context manager. The following example shows how to do it using Stable Diffusion text-to-image generation as an example:
@@ -22,7 +56,7 @@ If you use a CUDA GPU, you can take advantage of `torch.autocast` to perform inf
from torch import autocast
from diffusers import StableDiffusionPipeline
-pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
@@ -34,20 +68,23 @@ Despite the precision loss, in our experience the final image results look the s
## Half precision weights
-To save more GPU memory, you can load the model weights directly in half precision. This involves loading the float16 version of the weights, which was saved to a branch named `fp16`, and telling PyTorch to use the `float16` type when loading them:
+To save more GPU memory and get even more speed, you can load and run the model weights directly in half precision. This involves loading the float16 version of the weights, which was saved to a branch named `fp16`, and telling PyTorch to use the `float16` type when loading them:
```Python
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
- use_auth_token=True
)
+pipe = pipe.to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).images[0]
```
## Sliced attention for additional memory savings
-For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.
+For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.
Attention slicing is useful even if a batch size of just 1 is used - as long as the model uses more than one attention head. If there is more than one attention head the *QK^T* attention matrix can be computed sequentially for each head which can save a significant amount of memory.
@@ -63,14 +100,143 @@ pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
- use_auth_token=True
)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_attention_slicing()
-with torch.autocast("cuda"):
- image = pipe(prompt).images[0]
+image = pipe(prompt).images[0]
```
-There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
+There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
+
+## Using Channels Last memory format
+
+Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model.
+
+For example, in order to set the UNet model in our pipeline to use channels last format, we can use the following:
+
+```python
+print(pipe.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1)
+pipe.unet.to(memory_format=torch.channels_last) # in-place operation
+print(
+ pipe.unet.conv_out.state_dict()["weight"].stride()
+) # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works
+```
+
+## Tracing
+
+Tracing runs an example input tensor through your model, and captures the operations that are invoked as that input makes its way through the model's layers so that an executable or `ScriptFunction` is returned that will be optimized using just-in-time compilation.
+
+To trace our UNet model, we can use the following:
+
+```python
+import time
+import torch
+from diffusers import StableDiffusionPipeline
+import functools
+
+# torch disable grad
+torch.set_grad_enabled(False)
+
+# set variables
+n_experiments = 2
+unet_runs_per_experiment = 50
+
+# load inputs
+def generate_inputs():
+ sample = torch.randn(2, 4, 64, 64).half().cuda()
+ timestep = torch.rand(1).half().cuda() * 999
+ encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
+ return sample, timestep, encoder_hidden_states
+
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ revision="fp16",
+ torch_dtype=torch.float16,
+).to("cuda")
+unet = pipe.unet
+unet.eval()
+unet.to(memory_format=torch.channels_last) # use channels_last memory format
+unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
+
+# warmup
+for _ in range(3):
+ with torch.inference_mode():
+ inputs = generate_inputs()
+ orig_output = unet(*inputs)
+
+# trace
+print("tracing..")
+unet_traced = torch.jit.trace(unet, inputs)
+unet_traced.eval()
+print("done tracing")
+
+
+# warmup and optimize graph
+for _ in range(5):
+ with torch.inference_mode():
+ inputs = generate_inputs()
+ orig_output = unet_traced(*inputs)
+
+
+# benchmarking
+with torch.inference_mode():
+ for _ in range(n_experiments):
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for _ in range(unet_runs_per_experiment):
+ orig_output = unet_traced(*inputs)
+ torch.cuda.synchronize()
+ print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
+ for _ in range(n_experiments):
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for _ in range(unet_runs_per_experiment):
+ orig_output = unet(*inputs)
+ torch.cuda.synchronize()
+ print(f"unet inference took {time.time() - start_time:.2f} seconds")
+
+# save the model
+unet_traced.save("unet_traced.pt")
+```
+
+Then we can replace the `unet` attribute of the pipeline with the traced model like the following
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+from dataclasses import dataclass
+
+
+@dataclass
+class UNet2DConditionOutput:
+ sample: torch.FloatTensor
+
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ revision="fp16",
+ torch_dtype=torch.float16,
+).to("cuda")
+
+# use jitted unet
+unet_traced = torch.jit.load("unet_traced.pt")
+# del pipe.unet
+class TracedUNet(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.in_channels = pipe.unet.in_channels
+ self.device = pipe.unet.device
+
+ def forward(self, latent_model_input, t, encoder_hidden_states):
+ sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
+ return UNet2DConditionOutput(sample=sample)
+
+
+pipe.unet = TracedUNet()
+
+with torch.inference_mode():
+ image = pipe([prompt] * 1, num_inference_steps=50).images[0]
+```
diff --git a/docs/source/optimization/mps.mdx b/docs/source/optimization/mps.mdx
index 56cdbbde2818..ff9d614c870f 100644
--- a/docs/source/optimization/mps.mdx
+++ b/docs/source/optimization/mps.mdx
@@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License.
- Mac computer with Apple silicon (M1/M2) hardware.
- macOS 12.3 or later.
- arm64 version of Python.
-- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.13.0.dev20220830` or later.
+- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later.
## Inference Pipeline
@@ -31,7 +31,7 @@ We recommend to "prime" the pipeline using an additional one-time pass through i
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline
-pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = pipe.to("mps")
prompt = "a photo of an astronaut riding a horse on mars"
diff --git a/docs/source/optimization/onnx.mdx b/docs/source/optimization/onnx.mdx
index 95fd59c86dcf..9bbc4f2077c2 100644
--- a/docs/source/optimization/onnx.mdx
+++ b/docs/source/optimization/onnx.mdx
@@ -31,7 +31,6 @@ pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="onnx",
provider="CUDAExecutionProvider",
- use_auth_token=True,
)
prompt = "a photo of an astronaut riding a horse on mars"
diff --git a/docs/source/training/text_inversion.mdx b/docs/source/training/text_inversion.mdx
index 8c53421e2184..13ea7c942b4e 100644
--- a/docs/source/training/text_inversion.mdx
+++ b/docs/source/training/text_inversion.mdx
@@ -74,7 +74,7 @@ Run the following command to authenticate your token
huggingface-cli login
```
-If you have already cloned the repo, then you won't need to go through these steps. You can simple remove the `--use_auth_token` arg from the following command.
+If you have already cloned the repo, then you won't need to go through these steps.
@@ -87,7 +87,7 @@ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export DATA_DIR="path-to-dir-containing-images"
accelerate launch textual_inversion.py \
- --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
+ --pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATA_DIR \
--learnable_property="object" \
--placeholder_token="" --initializer_token="toy" \
@@ -109,7 +109,6 @@ A full training run takes ~1 hour on one V100 GPU.
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
```python
-from torch import autocast
from diffusers import StableDiffusionPipeline
model_id = "path-to-your-trained-model"
@@ -117,8 +116,7 @@ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float
prompt = "A backpack"
-with autocast("cuda"):
- image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack.png")
```
diff --git a/docs/source/using-diffusers/custom_pipelines.mdx b/docs/source/using-diffusers/custom_pipelines.mdx
new file mode 100644
index 000000000000..b52d405581b1
--- /dev/null
+++ b/docs/source/using-diffusers/custom_pipelines.mdx
@@ -0,0 +1,121 @@
+
+
+# Custom Pipelines
+
+Diffusers allows you to conveniently load any custom pipeline from the Hugging Face Hub as well as any [official community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community)
+via the [`DiffusionPipeline`] class.
+
+## Loading custom pipelines from the Hub
+
+Custom pipelines can be easily loaded from any model repository on the Hub that defines a diffusion pipeline in a `pipeline.py` file.
+Let's load a dummy pipeline from [hf-internal-testing/diffusers-dummy-pipeline](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline).
+
+All you need to do is pass the custom pipeline repo id with the `custom_pipeline` argument alongside the repo from where you wish to load the pipeline modules.
+
+```python
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
+)
+```
+
+This will load the custom pipeline as defined in the [model repository](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py).
+
+
+
+By loading a custom pipeline from the Hugging Face Hub, you are trusting that the code you are loading
+is safe 🔒. Make sure to check out the code online before loading & running it automatically.
+
+
+
+## Loading official community pipelines
+
+Community pipelines are summarized in the [community examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community)
+
+Similarly, you need to pass both the *repo id* from where you wish to load the weights as well as the `custom_pipeline` argument. Here the `custom_pipeline` argument should consist simply of the filename of the community pipeline excluding the `.py` suffix, *e.g.* `clip_guided_stable_diffusion`.
+
+Since community pipelines are often more complex, one can mix loading weights from an official *repo id*
+and passing pipeline modules directly.
+
+```python
+from diffusers import DiffusionPipeline
+from transformers import CLIPFeatureExtractor, CLIPModel
+
+clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+
+feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
+clip_model = CLIPModel.from_pretrained(clip_model_id)
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="clip_guided_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+)
+```
+
+## Adding custom pipelines to the Hub
+
+To add a custom pipeline to the Hub, all you need to do is to define a pipeline class that inherits
+from [`DiffusionPipeline`] in a `pipeline.py` file.
+Make sure that the whole pipeline is encapsulated within a single class and that the `pipeline.py` file
+has only one such class.
+
+Let's quickly define an example pipeline.
+
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+
+class MyPipeline(DiffusionPipeline):
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):
+ # Sample gaussian noise to begin loop
+ image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size))
+
+ image = image.to(self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # eta corresponds to η in paper and should be between [0, 1]
+ # do x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, eta).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ return image
+```
+
+Now you can upload this short file under the name `pipeline.py` in your preferred [model repository](https://huggingface.co/docs/hub/models-uploading). For Stable Diffusion pipelines, you may also [join the community organisation for shared pipelines](https://huggingface.co/organizations/sd-diffusers-pipelines-library/share/BUPyDUuHcciGTOKaExlqtfFcyCZsVFdrjr) to upload yours.
+Finally, we can load the custom pipeline by passing the model repository name, *e.g.* `sd-diffusers-pipelines-library/my_custom_pipeline` alongside the model repository from where we want to load the `unet` and `scheduler` components.
+
+```python
+my_pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline="patrickvonplaten/my_custom_pipeline"
+)
+```
diff --git a/docs/source/using-diffusers/img2img.mdx b/docs/source/using-diffusers/img2img.mdx
index e3b06871445d..62eaeea911c9 100644
--- a/docs/source/using-diffusers/img2img.mdx
+++ b/docs/source/using-diffusers/img2img.mdx
@@ -15,7 +15,7 @@ specific language governing permissions and limitations under the License.
The [`StableDiffusionImg2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images.
```python
-from torch import autocast
+import torch
import requests
from PIL import Image
from io import BytesIO
@@ -25,7 +25,7 @@ from diffusers import StableDiffusionImg2ImgPipeline
# load the pipeline
device = "cuda"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True
+ "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
).to(device)
# let's download an initial image
@@ -37,8 +37,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
-with autocast("cuda"):
- images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
+images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```
diff --git a/docs/source/using-diffusers/inpaint.mdx b/docs/source/using-diffusers/inpaint.mdx
index 215b2c80730d..7b4687c21204 100644
--- a/docs/source/using-diffusers/inpaint.mdx
+++ b/docs/source/using-diffusers/inpaint.mdx
@@ -17,7 +17,6 @@ The [`StableDiffusionInpaintPipeline`] lets you edit specific parts of an image
```python
from io import BytesIO
-from torch import autocast
import requests
import PIL
@@ -37,12 +36,11 @@ mask_image = download_image(mask_url).resize((512, 512))
device = "cuda"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True
+ "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
).to(device)
prompt = "a cat sitting on a bench"
-with autocast("cuda"):
- images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
+images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png")
```
diff --git a/examples/README.md b/examples/README.md
index 573b692c1a74..2680b638d585 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -38,7 +38,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
| Task | 🤗 Accelerate | 🤗 Datasets | Colab
|---|---|:---:|:---:|
-| [**Unconditional Image Generation**](https://github.com/huggingface/transformers/tree/main/examples/training/train_unconditional.py) | ✅ | ✅ | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
+| [**Unconditional Image Generation**](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) | ✅ | ✅ | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
## Community
diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py
index f78175735931..974f4ab2e883 100644
--- a/examples/community/clip_guided_stable_diffusion.py
+++ b/examples/community/clip_guided_stable_diffusion.py
@@ -60,7 +60,6 @@ def __init__(
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
- scheduler = scheduler.set_format("pt")
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -147,9 +146,9 @@ def cond_fn(
image = self.make_cutouts(image, num_cutouts)
else:
image = transforms.Resize(self.feature_extractor.size)(image)
- image = self.normalize(image)
+ image = self.normalize(image).to(latents.dtype)
- image_embeddings_clip = self.clip_model.get_image_features(image).float()
+ image_embeddings_clip = self.clip_model.get_image_features(image)
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
if use_cutouts:
@@ -176,6 +175,7 @@ def __call__(
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
clip_guidance_scale: Optional[float] = 100,
clip_prompt: Optional[Union[str, List[str]]] = None,
num_cutouts: Optional[int] = 4,
@@ -204,6 +204,8 @@ def __call__(
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ # duplicate text embeddings for each generation per prompt
+ text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
if clip_guidance_scale > 0:
if clip_prompt is not None:
@@ -218,6 +220,8 @@ def __call__(
clip_text_input = text_input.input_ids.to(self.device)
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+ # duplicate text embeddings clip for each generation per prompt
+ text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -226,10 +230,10 @@ def __call__(
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1]
- uncond_input = self.tokenizer(
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
- )
+ uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+ # duplicate unconditional embeddings for each generation per prompt
+ uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
@@ -241,18 +245,20 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
- latents_device = "cpu" if self.device.type == "mps" else self.device
- latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
if latents is None:
- latents = torch.randn(
- latents_shape,
- generator=generator,
- device=latents_device,
- )
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
- latents = latents.to(self.device)
+ latents = latents.to(self.device)
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
@@ -262,19 +268,19 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
- # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- latents = latents * self.scheduler.sigmas[0]
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
- for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- sigma = self.scheduler.sigmas[i]
- # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
- latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
- # # predict the noise residual
+ # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform classifier free guidance
@@ -285,7 +291,7 @@ def __call__(
# perform clip guidance
if clip_guidance_scale > 0:
text_embeddings_for_guidance = (
- text_embeddings.chunk(2)[0] if do_classifier_free_guidance else text_embeddings
+ text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
)
noise_pred, latents = self.cond_fn(
latents,
@@ -300,10 +306,7 @@ def __call__(
)
# compute the previous noisy sample x_t -> x_t-1
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- latents = self.scheduler.step(noise_pred, i, latents).prev_sample
- else:
- latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
diff --git a/examples/conftest.py b/examples/conftest.py
index a72bc85310d2..d2f9600313a1 100644
--- a/examples/conftest.py
+++ b/examples/conftest.py
@@ -32,13 +32,13 @@
def pytest_addoption(parser):
- from diffusers.testing_utils import pytest_addoption_shared
+ from diffusers.utils.testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter):
- from diffusers.testing_utils import pytest_terminal_summary_main
+ from diffusers.utils.testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md
new file mode 100644
index 000000000000..9ff90ea809a7
--- /dev/null
+++ b/examples/dreambooth/README.md
@@ -0,0 +1,178 @@
+# DreamBooth training example
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
+The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+
+## Running locally
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install git+https://github.com/huggingface/diffusers.git
+pip install -U -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+### Dog toy example
+
+You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
+
+You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+If you have already cloned the repo, then you won't need to go through these steps.
+
+
+
+Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data.
+
+And launch the training using
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=400
+```
+
+### Training with prior-preservation loss
+
+Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
+According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### Training on a 16GB GPU:
+
+With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
+
+Install `bitsandbytes` with `pip install bitsandbytes`
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=2 --gradient_checkpointing \
+ --use_8bit_adam \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### Training on a 8 GB GPU:
+
+By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some
+tensors from VRAM to either CPU or NVME allowing to train with less VRAM.
+
+DeepSpeed needs to be enabled with `accelerate config`. During configuration
+answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16
+mixed precision and offloading both parameters and optimizer state to cpu it's
+possible to train on under 8 GB VRAM with a drawback of requiring significantly
+more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
+
+Changing the default Adam optimizer to DeepSpeed's special version of Adam
+`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling
+it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
+does not seem to be compatible with DeepSpeed at the moment.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800 \
+ --mixed_precision=fp16
+```
+
+## Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+prompt = "A photo of sks dog in a bucket"
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("dog-bucket.png")
+```
diff --git a/examples/dreambooth/requirements.txt b/examples/dreambooth/requirements.txt
new file mode 100644
index 000000000000..c0649bbe2bef
--- /dev/null
+++ b/examples/dreambooth/requirements.txt
@@ -0,0 +1,6 @@
+accelerate
+torchvision
+transformers>=4.21.0
+ftfy
+tensorboard
+modelcards
\ No newline at end of file
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
new file mode 100644
index 000000000000..fe4741d5e2db
--- /dev/null
+++ b/examples/dreambooth/train_dreambooth.py
@@ -0,0 +1,592 @@
+import argparse
+import math
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.utils.data import Dataset
+
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import set_seed
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from huggingface_hub import HfFolder, Repository, whoami
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+
+logger = get_logger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
+ " sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.instance_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ size=512,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+ example["instance_prompt_ids"] = self.tokenizer(
+ self.instance_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt_ids"] = self.tokenizer(
+ self.class_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ return example
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
+def main():
+ args = parse_args()
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with="tensorboard",
+ logging_dir=logging_dir,
+ )
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, torch_dtype=torch_dtype
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.push_to_hub:
+ if args.hub_model_id is None:
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
+ else:
+ repo_name = args.hub_model_id
+ repo = Repository(args.output_dir, clone_from=repo_name)
+
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
+ if "step_*" not in gitignore:
+ gitignore.write("step_*\n")
+ if "epoch_*" not in gitignore:
+ gitignore.write("epoch_*\n")
+ elif args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ unet.parameters(), # only optimize unet
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ noise_scheduler = DDPMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
+ )
+
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ def collate_fn(examples):
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if args.with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+ return batch
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu.
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("dreambooth", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+ global_step = 0
+
+ for epoch in range(args.num_train_epochs):
+ unet.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ with torch.no_grad():
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * 0.18215
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ with torch.no_grad():
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ if args.with_prior_preservation:
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
+ noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
+ noise, noise_prior = torch.chunk(noise, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ accelerator.wait_for_everyone()
+
+ # Create the pipeline using using the trained modules and save it.
+ if accelerator.is_main_process:
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet)
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/test_examples.py b/examples/test_examples.py
index 0099d17e638d..15c8e05a5c12 100644
--- a/examples/test_examples.py
+++ b/examples/test_examples.py
@@ -24,7 +24,7 @@
from typing import List
from accelerate.utils import write_basic_config
-from diffusers.testing_utils import slow
+from diffusers.utils import slow
logging.basicConfig(level=logging.DEBUG)
@@ -40,7 +40,7 @@ class SubprocessCallException(Exception):
def run_command(command: List[str], return_stdout=False):
"""
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
- if an error occured while running `command`
+ if an error occurred while running `command`
"""
try:
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
@@ -102,7 +102,6 @@ def test_textual_inversion(self):
test_args = f"""
examples/textual_inversion/textual_inversion.py
--pretrained_model_name_or_path CompVis/stable-diffusion-v1-4
- --use_auth_token
--train_data_dir docs/source/imgs
--learnable_property object
--placeholder_token
diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md
new file mode 100644
index 000000000000..6aca642cda4a
--- /dev/null
+++ b/examples/text_to_image/README.md
@@ -0,0 +1,101 @@
+# Stable Diffusion text-to-image fine-tuning
+
+The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset.
+
+___Note___:
+
+___This script is experimental. The script fine-tunes the whole model and often times the model overifits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
+
+
+## Running locally
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install git+https://github.com/huggingface/diffusers.git
+pip install -U -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+### Pokemon example
+
+You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
+
+You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+If you have already cloned the repo, then you won't need to go through these steps.
+
+
+
+#### Hardware
+With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export dataset_name="lambdalabs/pokemon-blip-captions"
+
+accelerate launch train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$dataset_name \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --mixed_precision="fp16" \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-pokemon-model"
+```
+
+
+To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
+If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export TRAIN_DIR="path_to_your_dataset"
+
+accelerate launch train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$TRAIN_DIR \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --mixed_precision="fp16" \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-pokemon-model"
+```
+
+Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
+
+
+```python
+from diffusers import StableDiffusionPipeline
+
+model_path = "path_to_saved_model"
+pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
+pipe.to("cuda")
+
+image = pipe(prompt="yoda").images[0]
+image.save("yoda-pokemon.png")
+```
diff --git a/examples/text_to_image/requirements.txt b/examples/text_to_image/requirements.txt
new file mode 100644
index 000000000000..a80836a32027
--- /dev/null
+++ b/examples/text_to_image/requirements.txt
@@ -0,0 +1,7 @@
+diffusers==0.4.1
+accelerate
+torchvision
+transformers>=4.21.0
+ftfy
+tensorboard
+modelcards
\ No newline at end of file
diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
new file mode 100644
index 000000000000..e4a91ff5c8b3
--- /dev/null
+++ b/examples/text_to_image/train_text_to_image.py
@@ -0,0 +1,621 @@
+import argparse
+import copy
+import logging
+import math
+import os
+import random
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import set_seed
+from datasets import load_dataset
+from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
+from huggingface_hub import HfFolder, Repository, whoami
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+
+logger = get_logger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ action="store_true",
+ help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)",
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
+dataset_name_mapping = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+class EMAModel:
+ """
+ Exponential Moving Average of models weights
+ """
+
+ def __init__(
+ self,
+ model,
+ decay=0.9999,
+ device=None,
+ ):
+ self.averaged_model = copy.deepcopy(model).eval()
+ self.averaged_model.requires_grad_(False)
+
+ self.decay = decay
+
+ if device is not None:
+ self.averaged_model = self.averaged_model.to(device=device)
+
+ self.optimization_step = 0
+
+ def get_decay(self, optimization_step):
+ """
+ Compute the decay factor for the exponential moving average.
+ """
+ value = (1 + optimization_step) / (10 + optimization_step)
+ return 1 - min(self.decay, value)
+
+ @torch.no_grad()
+ def step(self, new_model):
+ ema_state_dict = self.averaged_model.state_dict()
+
+ self.optimization_step += 1
+ self.decay = self.get_decay(self.optimization_step)
+
+ for key, param in new_model.named_parameters():
+ if isinstance(param, dict):
+ continue
+ try:
+ ema_param = ema_state_dict[key]
+ except KeyError:
+ ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
+ ema_state_dict[key] = ema_param
+
+ param = param.clone().detach().to(ema_param.dtype).to(ema_param.device)
+
+ if param.requires_grad:
+ ema_state_dict[key].sub_(self.decay * (ema_param - param))
+ else:
+ ema_state_dict[key].copy_(param)
+
+ for key, param in new_model.named_buffers():
+ ema_state_dict[key] = param
+
+ self.averaged_model.load_state_dict(ema_state_dict, strict=False)
+ torch.cuda.empty_cache()
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ logging_dir=logging_dir,
+ )
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.push_to_hub:
+ if args.hub_model_id is None:
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
+ else:
+ repo_name = args.hub_model_id
+ repo = Repository(args.output_dir, clone_from=repo_name)
+
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
+ if "step_*" not in gitignore:
+ gitignore.write("step_*\n")
+ if "epoch_*" not in gitignore:
+ gitignore.write("epoch_*\n")
+ elif args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load models and create wrapper for stable diffusion
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+ if args.use_ema:
+ ema_unet = EMAModel(unet)
+
+ # Freeze vae and text_encoder
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # TODO (patil-suraj): load scheduler using args
+ noise_scheduler = DDPMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
+ input_ids = inputs.input_ids
+ return input_ids
+
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = [example["input_ids"] for example in examples]
+ padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")
+ return {
+ "pixel_values": pixel_values,
+ "input_ids": padded_tokens.input_ids,
+ "attention_mask": padded_tokens.attention_mask,
+ }
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu.
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # Move the ema_unet to gpu.
+ ema_unet.averaged_model.to(accelerator.device)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+ global_step = 0
+
+ for epoch in range(args.num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
+ latents = latents * 0.18215
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual and compute loss
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+ loss = F.mse_loss(noise_pred, noise, reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet)
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ pipeline = StableDiffusionPipeline(
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=accelerator.unwrap_model(ema_unet.averaged_model if args.use_ema else unet),
+ tokenizer=tokenizer,
+ scheduler=PNDMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
+ ),
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
+ feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md
index 65b9d4958b95..05d8ffb8c9f2 100644
--- a/examples/textual_inversion/README.md
+++ b/examples/textual_inversion/README.md
@@ -17,7 +17,7 @@ Colab for inference
Before running the scripts, make sure to install the library's training dependencies:
```bash
-pip install diffusers[training] accelerate transformers
+pip install diffusers"[training]" accelerate "transformers>=4.21.0"
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
@@ -39,7 +39,7 @@ Run the following command to authenticate your token
huggingface-cli login
```
-If you have already cloned the repo, then you won't need to go through these steps. You can simple remove the `--use_auth_token` arg from the following command.
+If you have already cloned the repo, then you won't need to go through these steps.
@@ -52,7 +52,7 @@ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export DATA_DIR="path-to-dir-containing-images"
accelerate launch textual_inversion.py \
- --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
+ --pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATA_DIR \
--learnable_property="object" \
--placeholder_token="" --initializer_token="toy" \
@@ -74,8 +74,6 @@ A full training run takes ~1 hour on one V100 GPU.
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
```python
-
-from torch import autocast
from diffusers import StableDiffusionPipeline
model_id = "path-to-your-trained-model"
@@ -83,8 +81,7 @@ pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float1
prompt = "A backpack"
-with autocast("cuda"):
- image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack.png")
```
diff --git a/examples/textual_inversion/requirements.txt b/examples/textual_inversion/requirements.txt
index cdd8abaf87da..9f8d9832dfb2 100644
--- a/examples/textual_inversion/requirements.txt
+++ b/examples/textual_inversion/requirements.txt
@@ -1,3 +1,3 @@
accelerate
torchvision
-transformers
+transformers>=4.21.0
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index de5761646a00..18469af3d4ae 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -29,8 +29,21 @@
logger = get_logger(__name__)
+def save_progress(text_encoder, placeholder_token_id, accelerator, args):
+ logger.info("Saving embeddings")
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
+ torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
+
+
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
@@ -123,14 +136,6 @@ def parse_args():
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
- parser.add_argument(
- "--use_auth_token",
- action="store_true",
- help=(
- "Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
- " private models)."
- ),
- )
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
@@ -358,9 +363,7 @@ def main():
if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path:
- tokenizer = CLIPTokenizer.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=args.use_auth_token
- )
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
# Add the placeholder token in tokenizer
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
@@ -380,15 +383,9 @@ def main():
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
# Load models and create wrapper for stable diffusion
- text_encoder = CLIPTextModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=args.use_auth_token
- )
- vae = AutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=args.use_auth_token
- )
- unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token
- )
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
@@ -424,7 +421,10 @@ def main():
# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ num_train_timesteps=1000,
)
train_dataset = TextualInversionDataset(
@@ -539,6 +539,8 @@ def main():
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
+ if global_step % args.save_steps == 0:
+ save_progress(text_encoder, placeholder_token_id, accelerator, args)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -564,14 +566,10 @@ def main():
)
pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings
- learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
- learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
- torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
+ save_progress(text_encoder, placeholder_token_id, accelerator, args)
if args.push_to_hub:
- repo.push_to_hub(
- args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
- )
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training()
diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py
index f6affe8a1400..8ddbdd77ba42 100644
--- a/examples/unconditional_image_generation/train_unconditional.py
+++ b/examples/unconditional_image_generation/train_unconditional.py
@@ -9,7 +9,7 @@
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
-from diffusers.hub_utils import init_git_repo, push_to_hub
+from diffusers.hub_utils import init_git_repo
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from torchvision.transforms import (
@@ -59,7 +59,7 @@ def main(args):
"UpBlock2D",
),
)
- noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
+ noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
@@ -83,7 +83,6 @@ def main(args):
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
- use_auth_token=True if args.use_auth_token else None,
split="train",
)
else:
@@ -143,7 +142,8 @@ def transforms(examples):
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
- accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
if args.use_ema:
@@ -185,7 +185,7 @@ def transforms(examples):
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
if args.push_to_hub:
- push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
+ repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone()
@@ -221,7 +221,6 @@ def transforms(examples):
parser.add_argument("--ema_power", type=float, default=3 / 4)
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
parser.add_argument("--push_to_hub", action="store_true")
- parser.add_argument("--use_auth_token", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true")
diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py
new file mode 100644
index 000000000000..9888f628a9e3
--- /dev/null
+++ b/scripts/convert_diffusers_to_original_stable_diffusion.py
@@ -0,0 +1,234 @@
+# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
+# *Only* converts the UNet, VAE, and Text Encoder.
+# Does not convert optimizer state or any other thing.
+
+import argparse
+import os.path as osp
+
+import torch
+
+
+# =================#
+# UNet Conversion #
+# =================#
+
+unet_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+ ("input_blocks.0.0.weight", "conv_in.weight"),
+ ("input_blocks.0.0.bias", "conv_in.bias"),
+ ("out.0.weight", "conv_norm_out.weight"),
+ ("out.0.bias", "conv_norm_out.bias"),
+ ("out.2.weight", "conv_out.weight"),
+ ("out.2.bias", "conv_out.bias"),
+]
+
+unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0", "norm1"),
+ ("in_layers.2", "conv1"),
+ ("out_layers.0", "norm2"),
+ ("out_layers.3", "conv2"),
+ ("emb_layers.1", "time_emb_proj"),
+ ("skip_connection", "conv_shortcut"),
+]
+
+unet_conversion_map_layer = []
+# hardcoded number of downblocks and resnets/attentions...
+# would need smarter logic for other networks.
+for i in range(4):
+ # loop over downblocks/upblocks
+
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+ if i < 3:
+ # no attention layers in down_blocks.3
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(3):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+ if i > 0:
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+hf_mid_atn_prefix = "mid_block.attentions.0."
+sd_mid_atn_prefix = "middle_block.1."
+unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+def convert_unet_state_dict(unet_state_dict):
+ # buyer beware: this is a *brittle* function,
+ # and correct output requires that all of these pieces interact in
+ # the exact order in which I have arranged them.
+ mapping = {k: k for k in unet_state_dict.keys()}
+ for sd_name, hf_name in unet_conversion_map:
+ mapping[hf_name] = sd_name
+ for k, v in mapping.items():
+ if "resnets" in k:
+ for sd_part, hf_part in unet_conversion_map_resnet:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ for sd_part, hf_part in unet_conversion_map_layer:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
+ return new_state_dict
+
+
+# ================#
+# VAE Conversion #
+# ================#
+
+vae_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("nin_shortcut", "conv_shortcut"),
+ ("norm_out", "conv_norm_out"),
+ ("mid.attn_1.", "mid_block.attentions.0."),
+]
+
+for i in range(4):
+ # down_blocks have two resnets
+ for j in range(2):
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
+
+ if i < 3:
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
+ sd_downsample_prefix = f"down.{i}.downsample."
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"up.{3-i}.upsample."
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ # up_blocks have three resnets
+ # also, up blocks in hf are numbered in reverse from sd
+ for j in range(3):
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
+
+# this part accounts for mid blocks in both the encoder and the decoder
+for i in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
+ sd_mid_res_prefix = f"mid.block_{i+1}."
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+vae_conversion_map_attn = [
+ # (stable-diffusion, HF Diffusers)
+ ("norm.", "group_norm."),
+ ("q.", "query."),
+ ("k.", "key."),
+ ("v.", "value."),
+ ("proj_out.", "proj_attn."),
+]
+
+
+def reshape_weight_for_sd(w):
+ # convert HF linear weights to SD conv2d weights
+ return w.reshape(*w.shape, 1, 1)
+
+
+def convert_vae_state_dict(vae_state_dict):
+ mapping = {k: k for k in vae_state_dict.keys()}
+ for k, v in mapping.items():
+ for sd_part, hf_part in vae_conversion_map:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ if "attentions" in k:
+ for sd_part, hf_part in vae_conversion_map_attn:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
+ weights_to_convert = ["q", "k", "v", "proj_out"]
+ for k, v in new_state_dict.items():
+ for weight_name in weights_to_convert:
+ if f"mid.attn_1.{weight_name}.weight" in k:
+ print(f"Reshaping {k} for SD format")
+ new_state_dict[k] = reshape_weight_for_sd(v)
+ return new_state_dict
+
+
+# =========================#
+# Text Encoder Conversion #
+# =========================#
+# pretty much a no-op
+
+
+def convert_text_enc_state_dict(text_enc_dict):
+ return text_enc_dict
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
+
+ args = parser.parse_args()
+
+ assert args.model_path is not None, "Must provide a model path!"
+
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
+
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
+
+ # Convert the UNet model
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
+
+ # Convert the VAE model
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
+
+ # Convert the text encoder model
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
+
+ # Put together new checkpoint
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
+ if args.half:
+ state_dict = {k: v.half() for k, v in state_dict.items()}
+ state_dict = {"state_dict": state_dict}
+ torch.save(state_dict, args.checkpoint_path)
diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py
index ee7fc335438f..db1b30736984 100644
--- a/scripts/convert_original_stable_diffusion_to_diffusers.py
+++ b/scripts/convert_original_stable_diffusion_to_diffusers.py
@@ -595,6 +595,22 @@ def _copy_layers(hf_layers, pt_layers):
return hf_model
+def convert_ldm_clip_checkpoint(checkpoint):
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
+
+ keys = list(checkpoint.keys())
+
+ text_model_dict = {}
+
+ for key in keys:
+ if key.startswith("cond_stage_model.transformer"):
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
+
+ text_model.load_state_dict(text_model_dict)
+
+ return text_model
+
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -668,7 +684,7 @@ def _copy_layers(hf_layers, pt_layers):
# Convert the text model.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenCLIPEmbedder":
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
+ text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
diff --git a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
index beeacfe37761..a388c0d078dc 100644
--- a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
+++ b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
@@ -70,7 +70,7 @@ def onnx_export(
@torch.no_grad()
def convert_models(model_path: str, output_path: str, opset: int):
- pipeline = StableDiffusionPipeline.from_pretrained(model_path, use_auth_token=True)
+ pipeline = StableDiffusionPipeline.from_pretrained(model_path)
output_path = Path(output_path)
# TEXT ENCODER
@@ -206,7 +206,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
parser.add_argument(
"--opset",
default=14,
- type=str,
+ type=int,
help="The version of the ONNX operator set to use.",
)
diff --git a/setup.py b/setup.py
index 20c9ea61f5f2..908dd3fa96b8 100644
--- a/setup.py
+++ b/setup.py
@@ -67,12 +67,13 @@
you need to go back to main before executing this.
"""
-import re
import os
+import re
from distutils.core import Command
from setuptools import find_packages, setup
+
# IMPORTANT:
# 1. all dependencies should be listed here with their version requirements if any
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
@@ -85,14 +86,14 @@
"flake8>=3.8.3",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
- "huggingface-hub>=0.9.1",
+ "huggingface-hub>=0.10.0",
"importlib_metadata",
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib>=0.1.65,<=0.3.6",
"modelcards>=0.1.4",
"numpy",
- "onnxruntime-gpu",
+ "onnxruntime",
"pytest",
"pytest-timeout",
"pytest-xdist",
@@ -177,8 +178,9 @@ def run(self):
extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list(
+ "accelerate",
"datasets",
- "onnxruntime-gpu",
+ "onnxruntime",
"pytest",
"pytest-timeout",
"pytest-xdist",
@@ -193,7 +195,9 @@ def run(self):
else:
extras["flax"] = deps_list("jax", "jaxlib", "flax")
-extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
+extras["dev"] = (
+ extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
+)
install_requires = [
deps["importlib_metadata"],
@@ -207,7 +211,7 @@ def run(self):
setup(
name="diffusers",
- version="0.4.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="0.5.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 078ba9ba7bc1..6f01e05c8e61 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -9,7 +9,7 @@
)
-__version__ = "0.4.0.dev0"
+__version__ = "0.5.0.dev0"
from .configuration_utils import ConfigMixin
from .onnx_utils import OnnxRuntimeModel
@@ -74,6 +74,7 @@
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
+ FlaxSchedulerMixin,
FlaxScoreSdeVeScheduler,
)
else:
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index 19f58fd8165d..6f866c6b5fff 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -58,6 +58,10 @@ def register_to_config(self, **kwargs):
kwargs["_class_name"] = self.__class__.__name__
kwargs["_diffusers_version"] = __version__
+ # Special case for `kwargs` used in deprecation warning added to schedulers
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
+ # or solve in a more general way.
+ kwargs.pop("kwargs", None)
for key, value in kwargs.items():
try:
setattr(self, key, value)
@@ -141,7 +145,8 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret
- Passing `use_auth_token=True`` is required when you want to use a private model.
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
@@ -234,7 +239,7 @@ def get_config_dict(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
- " login` and pass `use_auth_token=True`."
+ " login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index 09a7baad560d..8b10d70a26f7 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -10,7 +10,7 @@
"flake8": "flake8>=3.8.3",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
- "huggingface-hub": "huggingface-hub>=0.9.1",
+ "huggingface-hub": "huggingface-hub>=0.10.0",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
@@ -18,7 +18,6 @@
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
"onnxruntime": "onnxruntime",
- "onnxruntime-gpu": "onnxruntime-gpu",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
diff --git a/src/diffusers/dynamic_modules_utils.py b/src/diffusers/dynamic_modules_utils.py
index 0ebf916e7af5..8b8e2b1de421 100644
--- a/src/diffusers/dynamic_modules_utils.py
+++ b/src/diffusers/dynamic_modules_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2021 The HuggingFace Inc. team.
+# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
"""Utilities to dynamically load objects from the Hub."""
import importlib
+import inspect
import os
import re
import shutil
@@ -22,11 +23,16 @@
from pathlib import Path
from typing import Dict, Optional, Union
-from huggingface_hub import cached_download
+from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
+COMMUNITY_PIPELINES_URL = (
+ "https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/{pipeline}.py"
+)
+
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -145,9 +151,35 @@ def get_class_in_module(class_name, module_path):
"""
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
+
+ if class_name is None:
+ return find_pipeline_class(module)
return getattr(module, class_name)
+def find_pipeline_class(loaded_module):
+ """
+ Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class
+ inheriting from `DiffusionPipeline`.
+ """
+ from .pipeline_utils import DiffusionPipeline
+
+ cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))
+
+ pipeline_class = None
+ for cls_name, cls in cls_members.items():
+ if cls_name != DiffusionPipeline.__name__ and issubclass(cls, DiffusionPipeline):
+ if pipeline_class is not None:
+ raise ValueError(
+ f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:"
+ f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in"
+ f" {loaded_module}."
+ )
+ pipeline_class = cls
+
+ return pipeline_class
+
+
def get_cached_module_file(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
@@ -198,7 +230,8 @@ def get_cached_module_file(
- Passing `use_auth_token=True` is required when you want to use a private model.
+ You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
+ or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
@@ -207,16 +240,36 @@ def get_cached_module_file(
"""
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
- submodule = "local"
if os.path.isfile(module_file_or_url):
resolved_module_file = module_file_or_url
+ submodule = "local"
+ elif pretrained_model_name_or_path.count("/") == 0:
+ # community pipeline on GitHub
+ github_url = COMMUNITY_PIPELINES_URL.format(pipeline=pretrained_model_name_or_path)
+ try:
+ resolved_module_file = cached_download(
+ github_url,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=False,
+ )
+ submodule = "git"
+ module_file = pretrained_model_name_or_path + ".py"
+ except EnvironmentError:
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
+ raise
else:
try:
# Load from URL or cache if already cached
- resolved_module_file = cached_download(
- module_file_or_url,
+ resolved_module_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ module_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
@@ -224,7 +277,7 @@ def get_cached_module_file(
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
-
+ submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise
@@ -236,20 +289,55 @@ def get_cached_module_file(
full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
- # We always copy local files (we could hash the file to see if there was a change, and give them the name of
- # that hash, to only copy when there is a modification but it seems overkill for now).
- # The only reason we do the copy is to avoid putting too many folders in sys.path.
- shutil.copy(resolved_module_file, submodule_path / module_file)
- for module_needed in modules_needed:
- module_needed = f"{module_needed}.py"
- shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
+ if submodule == "local" or submodule == "git":
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
+ # that hash, to only copy when there is a modification but it seems overkill for now).
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ for module_needed in modules_needed:
+ module_needed = f"{module_needed}.py"
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
+ else:
+ # Get the commit hash
+ # TODO: we will get this info in the etag soon, so retrieve it from there and not here.
+ if isinstance(use_auth_token, str):
+ token = use_auth_token
+ elif use_auth_token is True:
+ token = HfFolder.get_token()
+ else:
+ token = None
+
+ commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
+
+ # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
+ # benefit of versioning.
+ submodule_path = submodule_path / commit_hash
+ full_submodule = full_submodule + os.path.sep + commit_hash
+ create_dynamic_module(full_submodule)
+
+ if not (submodule_path / module_file).exists():
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ # Make sure we also have every file with relative
+ for module_needed in modules_needed:
+ if not (submodule_path / module_needed).exists():
+ get_cached_module_file(
+ pretrained_model_name_or_path,
+ f"{module_needed}.py",
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
return os.path.join(full_submodule, module_file)
def get_class_from_dynamic_module(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
- class_name: str,
+ class_name: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
@@ -306,7 +394,8 @@ def get_class_from_dynamic_module(
- Passing `use_auth_token=True` is required when you want to use a private model.
+ You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
+ or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py
index 7f1d65e2edc0..28cd29d2264e 100644
--- a/src/diffusers/modeling_flax_utils.py
+++ b/src/diffusers/modeling_flax_utils.py
@@ -27,8 +27,8 @@
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
+from . import is_torch_available
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
-from .modeling_utils import load_state_dict
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
@@ -357,7 +357,7 @@ def from_pretrained(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
- "login` and pass `use_auth_token=True`."
+ "login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
@@ -391,6 +391,14 @@ def from_pretrained(
)
if from_pt:
+ if is_torch_available():
+ from .modeling_utils import load_state_dict
+ else:
+ raise EnvironmentError(
+ "Can't load the model in PyTorch format because PyTorch is not installed. "
+ "Please, install PyTorch or use native Flax weights."
+ )
+
# Step 1: Get the pytorch file
pytorch_model_file = load_state_dict(model_file)
@@ -436,9 +444,6 @@ def from_pretrained(
)
cls._missing_keys = missing_keys
- # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
- # matching the weights in the model.
- mismatched_keys = []
for key in state.keys():
if key in shape_state and state[key].shape != shape_state[key].shape:
raise ValueError(
@@ -466,26 +471,13 @@ def from_pretrained(
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
- elif len(mismatched_keys) == 0:
+ else:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
" training."
)
- if len(mismatched_keys) > 0:
- mismatched_warning = "\n".join(
- [
- f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
- for key, shape1, shape2 in mismatched_keys
- ]
- )
- logger.warning(
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
- f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
- f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
- " to use it for predictions and inference."
- )
# dictionary of key: dtypes for the model params
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py
index 659f2ee8a66a..8bb5e728c17b 100644
--- a/src/diffusers/modeling_utils.py
+++ b/src/diffusers/modeling_utils.py
@@ -21,6 +21,7 @@
import torch
from torch import Tensor, device
+from diffusers.utils import is_accelerate_available
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
@@ -269,7 +270,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
- Passing `use_auth_token=True`` is required when you want to use a private model.
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
@@ -293,33 +295,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
from_auto_class = kwargs.pop("_from_auto", False)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
+ device_map = kwargs.pop("device_map", None)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
- model, unused_kwargs = cls.from_config(
- config_path,
- cache_dir=cache_dir,
- return_unused_kwargs=True,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- revision=revision,
- subfolder=subfolder,
- **kwargs,
- )
- if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
- raise ValueError(
- f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
- )
- elif torch_dtype is not None:
- model = model.to(torch_dtype)
-
- model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
@@ -357,7 +339,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
- "login` and pass `use_auth_token=True`."
+ "login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
@@ -391,25 +373,81 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)
# restore default dtype
- state_dict = load_state_dict(model_file)
- model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
- model,
- state_dict,
- model_file,
- pretrained_model_name_or_path,
- ignore_mismatched_sizes=ignore_mismatched_sizes,
- )
- # Set model in evaluation mode to deactivate DropOut modules by default
- model.eval()
+ if device_map == "auto":
+ if is_accelerate_available():
+ import accelerate
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ with accelerate.init_empty_weights():
+ model, unused_kwargs = cls.from_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ device_map=device_map,
+ **kwargs,
+ )
+
+ accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
+
+ loading_info = {
+ "missing_keys": [],
+ "unexpected_keys": [],
+ "mismatched_keys": [],
+ "error_msgs": [],
+ }
+ else:
+ model, unused_kwargs = cls.from_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ device_map=device_map,
+ **kwargs,
+ )
+
+ state_dict = load_state_dict(model_file)
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
+ model,
+ state_dict,
+ model_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ )
- if output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ elif torch_dtype is not None:
+ model = model.to(torch_dtype)
+
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+ if output_loading_info:
return model, loading_info
return model
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index 25e1ea28dcf0..c2f27bd9282d 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -72,8 +72,7 @@ def forward(self, hidden_states):
# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
-
- attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
# compute attention output
@@ -144,10 +143,11 @@ def forward(self, hidden_states, context=None):
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=context)
- hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
return hidden_states + residual
@@ -274,6 +274,7 @@ def forward(self, hidden_states, context=None, mask=None):
return self.to_out(hidden_states)
def _attention(self, query, key, value):
+ # TODO: use baddbmm for better performance
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
@@ -291,7 +292,9 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
- attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
+ attn_slice = (
+ torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
+ ) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 86ac074c1d0e..06b814e2bbcd 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -37,10 +37,12 @@ def get_timestep_embedding(
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
- exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
exponent = exponent / (half_dim - downscale_freq_shift)
- emb = torch.exp(exponent).to(device=timesteps.device)
+ emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index 97f3c02a8ccf..d4cb367ebc0b 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -9,9 +9,10 @@ class Upsample2D(nn.Module):
"""
An upsampling layer with an optional convolution.
- :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
- applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
- upsampling occurs in the inner-two dimensions.
+ Parameters:
+ channels: channels in the inputs and outputs.
+ use_conv: a bool determining if a convolution is applied.
+ dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
@@ -34,30 +35,48 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann
else:
self.Conv2d_0 = conv
- def forward(self, x):
- assert x.shape[1] == self.channels
+ def forward(self, hidden_states, output_size=None):
+ assert hidden_states.shape[1] == self.channels
+
if self.use_conv_transpose:
- return self.conv(x)
+ return self.conv(hidden_states)
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
+ # https://github.com/pytorch/pytorch/issues/86679
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
- x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
- x = self.conv(x)
+ hidden_states = self.conv(hidden_states)
else:
- x = self.Conv2d_0(x)
+ hidden_states = self.Conv2d_0(hidden_states)
- return x
+ return hidden_states
class Downsample2D(nn.Module):
"""
A downsampling layer with an optional convolution.
- :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
- applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
- downsampling occurs in the inner-two dimensions.
+ Parameters:
+ channels: channels in the inputs and outputs.
+ use_conv: a bool determining if a convolution is applied.
+ dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
@@ -84,16 +103,16 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name=
else:
self.conv = conv
- def forward(self, x):
- assert x.shape[1] == self.channels
+ def forward(self, hidden_states):
+ assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
- x = F.pad(x, pad, mode="constant", value=0)
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
- assert x.shape[1] == self.channels
- x = self.conv(x)
+ assert hidden_states.shape[1] == self.channels
+ hidden_states = self.conv(hidden_states)
- return x
+ return hidden_states
class FirUpsample2D(nn.Module):
@@ -106,24 +125,25 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
self.fir_kernel = fir_kernel
self.out_channels = out_channels
- def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
+ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
- Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
- efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
- order.
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
- C]`.
- weight: Weight tensor of the shape `[filterH, filterW, inChannels,
- outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
- (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
- factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
+ arbitrary order.
+
+ Args:
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
- Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
- `x`.
+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
+ datatype as `hidden_states`.
"""
assert isinstance(factor, int) and factor >= 1
@@ -145,41 +165,52 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
convW = weight.shape[3]
inC = weight.shape[1]
- p = (kernel.shape[0] - factor) - (convW - 1)
+ pad_value = (kernel.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
- output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
+ output_shape = (
+ (hidden_states.shape[2] - 1) * factor + convH,
+ (hidden_states.shape[3] - 1) * factor + convW,
+ )
output_padding = (
- output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
- output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
+ output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
+ output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
- inC = weight.shape[1]
- num_groups = x.shape[1] // inC
+ num_groups = hidden_states.shape[1] // inC
# Transpose weights.
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
- x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
+ inverse_conv = F.conv_transpose2d(
+ hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
+ )
- x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
+ output = upfirdn2d_native(
+ inverse_conv,
+ torch.tensor(kernel, device=inverse_conv.device),
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
+ )
else:
- p = kernel.shape[0] - factor
- x = upfirdn2d_native(
- x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
+ pad_value = kernel.shape[0] - factor
+ output = upfirdn2d_native(
+ hidden_states,
+ torch.tensor(kernel, device=hidden_states.device),
+ up=factor,
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
- return x
+ return output
- def forward(self, x):
+ def forward(self, hidden_states):
if self.use_conv:
- height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
+ height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
- height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
+ height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
return height
@@ -194,22 +225,25 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
self.use_conv = use_conv
self.out_channels = out_channels
- def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
+ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
+ arbitrary order.
Args:
- Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
- efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
- order.
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
- filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
- numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
- factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
- Scaling factor for signal magnitude (default: 1.0).
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ weight:
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
+ performed by `inChannels = x.shape[0] // numGroups`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
+ factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
- Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
- datatype as `x`.
+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
+ same datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
@@ -226,24 +260,33 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
if self.use_conv:
_, _, convH, convW = weight.shape
- p = (kernel.shape[0] - factor) + (convW - 1)
- s = [factor, factor]
- x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
- x = F.conv2d(x, weight, stride=s, padding=0)
+ pad_value = (kernel.shape[0] - factor) + (convW - 1)
+ stride_value = [factor, factor]
+ upfirdn_input = upfirdn2d_native(
+ hidden_states,
+ torch.tensor(kernel, device=hidden_states.device),
+ pad=((pad_value + 1) // 2, pad_value // 2),
+ )
+ output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
else:
- p = kernel.shape[0] - factor
- x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
+ pad_value = kernel.shape[0] - factor
+ output = upfirdn2d_native(
+ hidden_states,
+ torch.tensor(kernel, device=hidden_states.device),
+ down=factor,
+ pad=((pad_value + 1) // 2, pad_value // 2),
+ )
- return x
+ return output
- def forward(self, x):
+ def forward(self, hidden_states):
if self.use_conv:
- x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
- x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
- x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
- return x
+ return hidden_states
class ResnetBlock2D(nn.Module):
@@ -326,19 +369,17 @@ def __init__(
if self.use_in_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
- def forward(self, x, temb):
- hidden_states = x
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
- # make sure hidden states is in float32
- # when running in half-precision
- hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
+ hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
- x = self.upsample(x)
+ input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
- x = self.downsample(x)
+ input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states)
@@ -347,43 +388,41 @@ def forward(self, x, temb):
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
hidden_states = hidden_states + temb
- # make sure hidden states is in float32
- # when running in half-precision
- hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
+ hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
- x = self.conv_shortcut(x)
+ input_tensor = self.conv_shortcut(input_tensor)
- out = (x + hidden_states) / self.output_scale_factor
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
- return out
+ return output_tensor
class Mish(torch.nn.Module):
- def forward(self, x):
- return x * torch.tanh(torch.nn.functional.softplus(x))
+ def forward(self, hidden_states):
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
-def upsample_2d(x, kernel=None, factor=2, gain=1):
+def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter.
-
- Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
- `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
- multiple of the upsampling factor.
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
- C]`.
- k: FIR filter of the shape `[firH, firW]` or `[firN]`
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
+ a: multiple of the upsampling factor.
+
+ Args:
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
- factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+ factor: Integer upsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
- Tensor of the shape `[N, C, H * factor, W * factor]`
+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
@@ -395,26 +434,32 @@ def upsample_2d(x, kernel=None, factor=2, gain=1):
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
- p = kernel.shape[0] - factor
- return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
+ pad_value = kernel.shape[0] - factor
+ output = upfirdn2d_native(
+ hidden_states,
+ kernel.to(device=hidden_states.device),
+ up=factor,
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
+ )
+ return output
-def downsample_2d(x, kernel=None, factor=2, gain=1):
+def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Downsample2D a batch of 2D images with the given filter.
-
- Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
- C]`.
+
+ Args:
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
- factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+ factor: Integer downsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
- Tensor of the shape `[N, C, H // factor, W // factor]`
+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
@@ -427,32 +472,35 @@ def downsample_2d(x, kernel=None, factor=2, gain=1):
kernel /= torch.sum(kernel)
kernel = kernel * gain
- p = kernel.shape[0] - factor
- return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
+ pad_value = kernel.shape[0] - factor
+ output = upfirdn2d_native(
+ hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
+ )
+ return output
-def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
+def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
up_x = up_y = up
down_x = down_y = down
pad_x0 = pad_y0 = pad[0]
pad_x1 = pad_y1 = pad[1]
- _, channel, in_h, in_w = input.shape
- input = input.reshape(-1, in_h, in_w, 1)
+ _, channel, in_h, in_w = tensor.shape
+ tensor = tensor.reshape(-1, in_h, in_w, 1)
- _, in_h, in_w, minor = input.shape
+ _, in_h, in_w, minor = tensor.shape
kernel_h, kernel_w = kernel.shape
- out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = tensor.view(-1, in_h, 1, in_w, 1, minor)
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
- if input.device.type == "mps":
+ if tensor.device.type == "mps":
out = out.to("cpu")
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
- out = out.to(input.device) # Move back to mps if necessary
+ out = out.to(tensor.device) # Move back to mps if necessary
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py
index 89321a5503a9..2415bf4ee78d 100644
--- a/src/diffusers/models/unet_2d.py
+++ b/src/diffusers/models/unet_2d.py
@@ -170,7 +170,7 @@ def forward(
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
- """r
+ r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py
index 5e3ee091c311..4e4eaddf5dfe 100644
--- a/src/diffusers/models/unet_2d_condition.py
+++ b/src/diffusers/models/unet_2d_condition.py
@@ -7,7 +7,7 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
-from ..utils import BaseOutput
+from ..utils import BaseOutput, logging
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import (
CrossAttnDownBlock2D,
@@ -20,6 +20,9 @@
)
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
@@ -145,15 +148,25 @@ def __init__(
resnet_groups=norm_num_groups,
)
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
- is_final_block = i == len(block_out_channels) - 1
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
up_block = get_up_block(
up_block_type,
@@ -162,7 +175,7 @@ def __init__(
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
- add_upsample=not is_final_block,
+ add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
@@ -210,7 +223,7 @@ def forward(
encoder_hidden_states: torch.Tensor,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
- """r
+ r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
@@ -223,6 +236,20 @@ def forward(
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
@@ -230,15 +257,20 @@ def forward(
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
- timesteps = timesteps.to(dtype=torch.float32)
- timesteps = timesteps[None].to(device=sample.device)
+ timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
# 2. pre-process
@@ -262,24 +294,31 @@ def forward(
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. up
- for upsample_block in self.up_blocks:
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
)
else:
- sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
-
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
# 6. post-process
- # make sure hidden states is in float32
- # when running in half-precision
- sample = self.conv_norm_out(sample.float()).type(sample.dtype)
+ sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py
index 9e84da068757..6c83bbdbd08a 100644
--- a/src/diffusers/models/unet_2d_condition_flax.py
+++ b/src/diffusers/models/unet_2d_condition_flax.py
@@ -215,7 +215,7 @@ def __call__(
return_dict: bool = True,
train: bool = False,
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
- """r
+ r"""
Args:
sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor
timestep (`jnp.ndarray` or `float` or `int`): timesteps
diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py
index f42389b98562..a17b1d2a5333 100644
--- a/src/diffusers/models/unet_blocks.py
+++ b/src/diffusers/models/unet_blocks.py
@@ -1126,6 +1126,7 @@ def forward(
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
+ upsample_size=None,
):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
@@ -1151,7 +1152,7 @@ def custom_forward(*inputs):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states)
+ hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
@@ -1204,7 +1205,7 @@ def __init__(
self.gradient_checkpointing = False
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
@@ -1225,7 +1226,7 @@ def custom_forward(*inputs):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states)
+ hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py
index fe89b41c074e..7ce2f98eee27 100644
--- a/src/diffusers/models/vae.py
+++ b/src/diffusers/models/vae.py
@@ -337,12 +337,16 @@ def __init__(self, parameters, deterministic=False):
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
- self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+ self.var = self.std = torch.zeros_like(
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
+ )
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
device = self.parameters.device
sample_device = "cpu" if device.type == "mps" else device
- sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
+ sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
+ # make sure sample is on the same device as the parameters and has same dtype
+ sample = sample.to(device=device, dtype=self.parameters.dtype)
x = self.mean + self.std * sample
return x
diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py
index b3261b11cf6c..074133a05c4a 100644
--- a/src/diffusers/models/vae_flax.py
+++ b/src/diffusers/models/vae_flax.py
@@ -119,6 +119,8 @@ class FlaxResnetBlock2D(nn.Module):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
+ groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for group norm.
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -128,13 +130,14 @@ class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout: float = 0.0
+ groups: int = 32
use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32
def setup(self):
out_channels = self.in_channels if self.out_channels is None else self.out_channels
- self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
+ self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
self.conv1 = nn.Conv(
out_channels,
kernel_size=(3, 3),
@@ -143,7 +146,7 @@ def setup(self):
dtype=self.dtype,
)
- self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
+ self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
self.dropout_layer = nn.Dropout(self.dropout)
self.conv2 = nn.Conv(
out_channels,
@@ -191,12 +194,15 @@ class FlaxAttentionBlock(nn.Module):
Input channels
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
Number of attention heads
+ num_groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for group norm
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
channels: int
num_head_channels: int = None
+ num_groups: int = 32
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -204,7 +210,7 @@ def setup(self):
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
- self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
+ self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
self.query, self.key, self.value = dense(), dense(), dense()
self.proj_attn = dense()
@@ -264,6 +270,8 @@ class FlaxDownEncoderBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for the Resnet block group norm
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsample layer
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -273,6 +281,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
out_channels: int
dropout: float = 0.0
num_layers: int = 1
+ resnet_groups: int = 32
add_downsample: bool = True
dtype: jnp.dtype = jnp.float32
@@ -285,6 +294,7 @@ def setup(self):
in_channels=in_channels,
out_channels=self.out_channels,
dropout=self.dropout,
+ groups=self.resnet_groups,
dtype=self.dtype,
)
resnets.append(res_block)
@@ -303,9 +313,9 @@ def __call__(self, hidden_states, deterministic=True):
return hidden_states
-class FlaxUpEncoderBlock2D(nn.Module):
+class FlaxUpDecoderBlock2D(nn.Module):
r"""
- Flax Resnet blocks-based Encoder block for diffusion-based VAE.
+ Flax Resnet blocks-based Decoder block for diffusion-based VAE.
Parameters:
in_channels (:obj:`int`):
@@ -316,8 +326,10 @@ class FlaxUpEncoderBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
- add_downsample (:obj:`bool`, *optional*, defaults to `True`):
- Whether to add downsample layer
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for the Resnet block group norm
+ add_upsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add upsample layer
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
@@ -325,6 +337,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
out_channels: int
dropout: float = 0.0
num_layers: int = 1
+ resnet_groups: int = 32
add_upsample: bool = True
dtype: jnp.dtype = jnp.float32
@@ -336,6 +349,7 @@ def setup(self):
in_channels=in_channels,
out_channels=self.out_channels,
dropout=self.dropout,
+ groups=self.resnet_groups,
dtype=self.dtype,
)
resnets.append(res_block)
@@ -366,6 +380,8 @@ class FlaxUNetMidBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for the Resnet and Attention block group norm
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
Number of attention heads for each attention block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -374,16 +390,20 @@ class FlaxUNetMidBlock2D(nn.Module):
in_channels: int
dropout: float = 0.0
num_layers: int = 1
+ resnet_groups: int = 32
attn_num_head_channels: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
+ resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
+
# there is always at least one resnet
resnets = [
FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout=self.dropout,
+ groups=resnet_groups,
dtype=self.dtype,
)
]
@@ -392,7 +412,10 @@ def setup(self):
for _ in range(self.num_layers):
attn_block = FlaxAttentionBlock(
- channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
+ channels=self.in_channels,
+ num_head_channels=self.attn_num_head_channels,
+ num_groups=resnet_groups,
+ dtype=self.dtype,
)
attentions.append(attn_block)
@@ -400,6 +423,7 @@ def setup(self):
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout=self.dropout,
+ groups=resnet_groups,
dtype=self.dtype,
)
resnets.append(res_block)
@@ -441,7 +465,7 @@ class FlaxEncoder(nn.Module):
Tuple containing the number of output channels for each block
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
Number of Resnet layer for each block
- norm_num_groups (:obj:`int`, *optional*, defaults to `2`):
+ norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
norm num group
act_fn (:obj:`str`, *optional*, defaults to `silu`):
Activation function
@@ -483,6 +507,7 @@ def setup(self):
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
+ resnet_groups=self.norm_num_groups,
add_downsample=not is_final_block,
dtype=self.dtype,
)
@@ -491,12 +516,15 @@ def setup(self):
# middle
self.mid_block = FlaxUNetMidBlock2D(
- in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
+ in_channels=block_out_channels[-1],
+ resnet_groups=self.norm_num_groups,
+ attn_num_head_channels=None,
+ dtype=self.dtype,
)
# end
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
- self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
+ self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
self.conv_out = nn.Conv(
conv_out_channels,
kernel_size=(3, 3),
@@ -581,7 +609,10 @@ def setup(self):
# middle
self.mid_block = FlaxUNetMidBlock2D(
- in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
+ in_channels=block_out_channels[-1],
+ resnet_groups=self.norm_num_groups,
+ attn_num_head_channels=None,
+ dtype=self.dtype,
)
# upsampling
@@ -594,10 +625,11 @@ def setup(self):
is_final_block = i == len(block_out_channels) - 1
- up_block = FlaxUpEncoderBlock2D(
+ up_block = FlaxUpDecoderBlock2D(
in_channels=prev_output_channel,
out_channels=output_channel,
num_layers=self.layers_per_block + 1,
+ resnet_groups=self.norm_num_groups,
add_upsample=not is_final_block,
dtype=self.dtype,
)
@@ -607,7 +639,7 @@ def setup(self):
self.up_blocks = up_blocks
# end
- self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
+ self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
self.conv_out = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
diff --git a/src/diffusers/onnx_utils.py b/src/diffusers/onnx_utils.py
index 2282f411aed8..142174f6e101 100644
--- a/src/diffusers/onnx_utils.py
+++ b/src/diffusers/onnx_utils.py
@@ -79,8 +79,10 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional
src_path = self.model_save_dir.joinpath(self.latest_model_name)
dst_path = Path(save_directory).joinpath(model_file_name)
- if not src_path.samefile(dst_path):
+ try:
shutil.copyfile(src_path, dst_path)
+ except shutil.SameFileError:
+ pass
def save_pretrained(
self,
diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py
index 9ea94ee0f2e1..92b71caeb343 100644
--- a/src/diffusers/pipeline_flax_utils.py
+++ b/src/diffusers/pipeline_flax_utils.py
@@ -30,7 +30,7 @@
from .configuration_utils import ConfigMixin
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
-from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
+from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
@@ -46,7 +46,7 @@
LOADABLE_CLASSES = {
"diffusers": {
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
- "SchedulerMixin": ["save_config", "from_config"],
+ "FlaxSchedulerMixin": ["save_config", "from_config"],
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
},
"transformers": {
@@ -249,8 +249,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
- Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
- `"CompVis/stable-diffusion-v1-4"`
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"CompVis/stable-diffusion-v1-4"`
@@ -272,15 +272,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
>>> # Download pipeline that requires an authorization token
>>> # For more information on access tokens, please refer to this section
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
- >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+ >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> # Download pipeline, but overwrite scheduler
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
- >>> pipeline = FlaxDiffusionPipeline.from_pretrained(
- ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
- ... )
+ >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler)
```
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
@@ -436,7 +434,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
params[name] = loaded_params
- elif issubclass(class_obj, SchedulerMixin):
+ elif issubclass(class_obj, FlaxSchedulerMixin):
loaded_sub_model, scheduler_state = load_method(loadable_folder)
params[name] = scheduler_state
else:
diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py
index fb8801bc959a..81118967aade 100644
--- a/src/diffusers/pipeline_utils.py
+++ b/src/diffusers/pipeline_utils.py
@@ -30,11 +30,25 @@
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
+from .dynamic_modules_utils import get_class_from_dynamic_module
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
-from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
+from .utils import (
+ CONFIG_NAME,
+ DIFFUSERS_CACHE,
+ ONNX_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ BaseOutput,
+ is_transformers_available,
+ logging,
+)
+
+
+if is_transformers_available():
+ from transformers import PreTrainedModel
INDEX_FILE = "diffusion_pytorch_model.bin"
+CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
logger = logging.get_logger(__name__)
@@ -166,6 +180,14 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
+ if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
+ logger.warning(
+ "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
+ " is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
+ " sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
+ " `float16` operations on those devices in PyTorch. Please remove the"
+ " `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
+ )
module.to(torch_device)
return self
@@ -208,6 +230,52 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
+ custom_pipeline (`str`, *optional*):
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+
+ Can be either:
+
+ - A string, the *repo id* of a custom pipeline hosted inside a model repo on
+ https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
+ like `hf-internal-testing/diffusers-dummy-pipeline`.
+
+
+
+ It is required that the model repo has a file, called `pipeline.py` that defines the custom
+ pipeline.
+
+
+
+ - A string, the *file name* of a community pipeline hosted on GitHub under
+ https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
+ match exactly the file name without `.py` located under the above link, *e.g.*
+ `clip_guided_stable_diffusion`.
+
+
+
+ Community pipelines are always loaded from the current `main` branch of GitHub.
+
+
+
+ - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
+
+
+
+ It is required that the directory has a file, called `pipeline.py` that defines the custom
+ pipeline.
+
+
+
+ For more information on how to load and create custom pipelines, please have a look at [Loading and
+ Creating Custom
+ Pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/custom_pipelines)
+
+ torch_dtype (`str` or `torch.dtype`, *optional*):
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -240,8 +308,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
- Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
- `"CompVis/stable-diffusion-v1-4"`
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"CompVis/stable-diffusion-v1-4"`
@@ -263,15 +331,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
>>> # Download pipeline that requires an authorization token
>>> # For more information on access tokens, please refer to this section
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
- >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> # Download pipeline, but overwrite scheduler
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
- >>> pipeline = DiffusionPipeline.from_pretrained(
- ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
- ... )
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler)
```
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
@@ -281,8 +347,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
+ custom_pipeline = kwargs.pop("custom_pipeline", None)
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
+ device_map = kwargs.pop("device_map", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
@@ -301,6 +369,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
+ if custom_pipeline is not None:
+ allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
+
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
@@ -319,7 +390,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
- if cls != DiffusionPipeline:
+ if custom_pipeline is not None:
+ pipeline_class = get_class_from_dynamic_module(
+ custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline
+ )
+ elif cls != DiffusionPipeline:
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
@@ -328,7 +403,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
- expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
+ expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
@@ -401,6 +476,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
+ if (
+ issubclass(class_obj, diffusers.ModelMixin)
+ or is_transformers_available()
+ and issubclass(class_obj, PreTrainedModel)
+ ):
+ loading_kwargs["device_map"] = device_map
+
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
@@ -410,7 +492,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
- # 4. Instantiate the pipeline
+ # 4. Potentially add passed objects if expected
+ missing_modules = set(expected_modules) - set(init_kwargs.keys())
+ if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()):
+ for module in missing_modules:
+ init_kwargs[module] = passed_class_obj[module]
+ elif len(missing_modules) > 0:
+ passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys()))
+ raise ValueError(
+ f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
+ )
+
+ # 5. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
return model
diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md
index 3462f5ff518d..b5ea112feafc 100644
--- a/src/diffusers/pipelines/README.md
+++ b/src/diffusers/pipelines/README.md
@@ -86,15 +86,13 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing
```python
# make sure you're logged in with `huggingface-cli login`
-from torch import autocast
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
-pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
-with autocast("cuda"):
- image = pipe(prompt).images[0]
+image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```
@@ -104,7 +102,6 @@ image.save("astronaut_rides_horse.png")
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
```python
-from torch import autocast
import requests
from PIL import Image
from io import BytesIO
@@ -117,7 +114,6 @@ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
- use_auth_token=True
).to(device)
# let's download an initial image
@@ -129,8 +125,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
-with autocast("cuda"):
- images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
+images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```
@@ -148,7 +143,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
```python
from io import BytesIO
-from torch import autocast
import requests
import PIL
@@ -169,12 +163,10 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
- use_auth_token=True
).to(device)
prompt = "a cat sitting on a bench"
-with autocast("cuda"):
- images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
+images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png")
```
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 8e3c8592a258..1c31595fb0cf 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -1,12 +1,16 @@
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
-from .ddim import DDIMPipeline
-from .ddpm import DDPMPipeline
-from .latent_diffusion_uncond import LDMPipeline
-from .pndm import PNDMPipeline
-from .score_sde_ve import ScoreSdeVePipeline
-from .stochastic_karras_ve import KarrasVePipeline
+if is_torch_available():
+ from .ddim import DDIMPipeline
+ from .ddpm import DDPMPipeline
+ from .latent_diffusion_uncond import LDMPipeline
+ from .pndm import PNDMPipeline
+ from .score_sde_ve import ScoreSdeVePipeline
+ from .stochastic_karras_ve import KarrasVePipeline
+else:
+ from ..utils.dummy_pt_objects import * # noqa F403
+
if is_torch_available() and is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py
index 95b49e045f67..74607fe87a3d 100644
--- a/src/diffusers/pipelines/ddim/pipeline_ddim.py
+++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py
@@ -14,7 +14,6 @@
# limitations under the License.
-import warnings
from typing import Optional, Tuple, Union
import torch
@@ -36,7 +35,6 @@ class DDIMPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
- scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
@@ -74,20 +72,6 @@ def __call__(
generated images.
"""
- if "torch_device" in kwargs:
- device = kwargs.pop("torch_device")
- warnings.warn(
- "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
- " Consider using `pipe.to(torch_device)` instead."
- )
-
- # Set device as before (to be removed in 0.3.0)
- if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
- self.to(device)
-
- # eta corresponds to η in paper and should be between [0, 1]
-
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
@@ -103,6 +87,7 @@ def __call__(
model_output = self.unet(image, t).sample
# 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta).prev_sample
diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
index b7f7093e379b..aae29737aae3 100644
--- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
@@ -14,7 +14,6 @@
# limitations under the License.
-import warnings
from typing import Optional, Tuple, Union
import torch
@@ -36,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
- scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
@@ -66,17 +64,6 @@ def __call__(
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
- if "torch_device" in kwargs:
- device = kwargs.pop("torch_device")
- warnings.warn(
- "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
- " Consider using `pipe.to(torch_device)` instead."
- )
-
- # Set device as before (to be removed in 0.3.0)
- if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
- self.to(device)
# Sample gaussian noise to begin loop
image = torch.randn(
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
index 4a4f29be7f75..62a5785beb2f 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
@@ -1,5 +1,4 @@
import inspect
-import warnings
from typing import List, Optional, Tuple, Union
import torch
@@ -46,7 +45,6 @@ def __init__(
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
- scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad()
@@ -94,17 +92,6 @@ def __call__(
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
- if "torch_device" in kwargs:
- device = kwargs.pop("torch_device")
- warnings.warn(
- "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
- " Consider using `pipe.to(torch_device)` instead."
- )
-
- # Set device as before (to be removed in 0.3.0)
- if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
- self.to(device)
if isinstance(prompt, str):
batch_size = 1
@@ -192,7 +179,7 @@ def __call__(
LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json",
+ "ldm-bert": "https://huggingface.co/valhalla/ldm-bert/blob/main/config.json",
}
diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
index 5574b65df9f8..ef82abb7e6cb 100644
--- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
@@ -1,5 +1,4 @@
import inspect
-import warnings
from typing import Optional, Tuple, Union
import torch
@@ -24,7 +23,6 @@ class LDMPipeline(DiffusionPipeline):
def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
super().__init__()
- scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
@torch.no_grad()
@@ -60,18 +58,6 @@ def __call__(
generated images.
"""
- if "torch_device" in kwargs:
- device = kwargs.pop("torch_device")
- warnings.warn(
- "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
- " Consider using `pipe.to(torch_device)` instead."
- )
-
- # Set device as before (to be removed in 0.3.0)
- if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
- self.to(device)
-
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py
index ae6c10e9e967..f360da09ac94 100644
--- a/src/diffusers/pipelines/pndm/pipeline_pndm.py
+++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py
@@ -14,7 +14,6 @@
# limitations under the License.
-import warnings
from typing import Optional, Tuple, Union
import torch
@@ -40,7 +39,6 @@ class PNDMPipeline(DiffusionPipeline):
def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
super().__init__()
- scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
@@ -75,18 +73,6 @@ def __call__(
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
- if "torch_device" in kwargs:
- device = kwargs.pop("torch_device")
- warnings.warn(
- "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
- " Consider using `pipe.to(torch_device)` instead."
- )
-
- # Set device as before (to be removed in 0.3.0)
- if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
- self.to(device)
-
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
index b29795e7f661..7f63820eec28 100644
--- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
+++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-import warnings
from typing import Optional, Tuple, Union
import torch
@@ -53,24 +52,12 @@ def __call__(
generated images.
"""
- if "torch_device" in kwargs:
- device = kwargs.pop("torch_device")
- warnings.warn(
- "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
- " Consider using `pipe.to(torch_device)` instead."
- )
-
- # Set device as before (to be removed in 0.3.0)
- if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
- self.to(device)
-
img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size)
model = self.unet
- sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
+ sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma
sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)
diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md
index 3a600c5859e9..47c38acbdb35 100644
--- a/src/diffusers/pipelines/stable_diffusion/README.md
+++ b/src/diffusers/pipelines/stable_diffusion/README.md
@@ -28,16 +28,12 @@ download the weights with `git lfs install; git clone https://huggingface.co/Com
### Using Stable Diffusion without being logged into the Hub.
-If you want to download the model weights using a single Python line, you need to pass the token
-to `use_auth_token` or be logged in via `huggingface-cli login`.
-For more information on access tokens, please refer to [this section](https://huggingface.co/docs/hub/security-tokens) of the documentation.
-
-Assuming your token is stored under YOUR_TOKEN, you can download the stable diffusion pipeline as follows:
+If you want to download the model weights using a single Python line, you need to be logged in via `huggingface-cli login`.
```python
from diffusers import DiffusionPipeline
-pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=YOUR_TOKEN)
+pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
```
This however can make it difficult to build applications on top of `diffusers` as you will always have to pass the token around. A potential way to solve this issue is by downloading the weights to a local path `"./stable-diffusion-v1-4"`:
@@ -59,15 +55,13 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
```python
# make sure you're logged in with `huggingface-cli login`
-from torch import autocast
from diffusers import StableDiffusionPipeline
-pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
-with autocast("cuda"):
- image = pipe(prompt).images[0]
+image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```
@@ -76,7 +70,6 @@ image.save("astronaut_rides_horse.png")
```python
# make sure you're logged in with `huggingface-cli login`
-from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
@@ -84,12 +77,10 @@ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="sca
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
scheduler=scheduler,
- use_auth_token=True
).to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
-with autocast("cuda"):
- image = pipe(prompt).images[0]
+image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```
@@ -98,7 +89,6 @@ image.save("astronaut_rides_horse.png")
```python
# make sure you're logged in with `huggingface-cli login`
-from torch import autocast
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
lms = LMSDiscreteScheduler(
@@ -110,12 +100,10 @@ lms = LMSDiscreteScheduler(
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
scheduler=lms,
- use_auth_token=True
).to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
-with autocast("cuda"):
- image = pipe(prompt).images[0]
+image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```
diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py
index 1016ce69e450..615fa404da0b 100644
--- a/src/diffusers/pipelines/stable_diffusion/__init__.py
+++ b/src/diffusers/pipelines/stable_diffusion/__init__.py
@@ -6,7 +6,7 @@
import PIL
from PIL import Image
-from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available
+from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
@dataclass
@@ -27,7 +27,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
nsfw_content_detected: List[bool]
-if is_transformers_available():
+if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
@@ -56,5 +56,6 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
+ from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
index 870a715ef516..102a26583db4 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
@@ -186,7 +186,9 @@ def loop_body(step, args):
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state
- scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps)
+ scheduler_state = self.scheduler.set_timesteps(
+ params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
+ )
if debug:
# run with python for loop
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index 71ab67f71691..c17bf1b8f7ce 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -1,6 +1,5 @@
import inspect
-import warnings
-from typing import List, Optional, Union
+from typing import Callable, List, Optional, Union
import torch
@@ -10,10 +9,14 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, EulerAScheduler
+from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
class StableDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -53,18 +56,17 @@ def __init__(
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
- scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
- warnings.warn(
+ deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
- " file",
- DeprecationWarning,
+ " file"
)
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
@@ -110,15 +112,19 @@ def disable_attention_slicing(self):
def __call__(
self,
prompt: Union[str, List[str]],
- height: Optional[int] = 512,
- width: Optional[int] = 512,
- num_inference_steps: Optional[int] = 50,
- guidance_scale: Optional[float] = 7.5,
- eta: Optional[float] = 0.0,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
@@ -140,6 +146,11 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
@@ -156,6 +167,12 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@@ -165,18 +182,6 @@ def __call__(
(nsfw) content, according to the `safety_checker`.
"""
- if "torch_device" in kwargs:
- device = kwargs.pop("torch_device")
- warnings.warn(
- "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
- " Consider using `pipe.to(torch_device)` instead."
- )
-
- # Set device as before (to be removed in 0.3.0)
- if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
- self.to(device)
-
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
@@ -187,15 +192,36 @@ def __call__(
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
# get prompt text embeddings
- text_input = self.tokenizer(
+ text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
- truncation=True,
return_tensors="pt",
)
- text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -203,12 +229,40 @@ def __call__(
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
- max_length = text_input.input_ids.shape[-1]
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""]
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ "`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ " {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
@@ -219,29 +273,31 @@ def __call__(
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
- latents_device = "cpu" if self.device.type == "mps" else self.device
- latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
if latents is None:
- latents = torch.randn(
- latents_shape,
- generator=generator,
- device=latents_device,
- )
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
- latents = latents.to(self.device)
+ latents = latents.to(self.device)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
- # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- latents = latents * self.scheduler.sigmas[0]
- elif isinstance(self.scheduler, EulerAScheduler):
- sigma = self.scheduler.timesteps[0]
- latents = latents * sigma
-
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
@@ -253,29 +309,25 @@ def __call__(
if generator is not None:
extra_step_kwargs["generator"] = generator
- for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- sigma = self.scheduler.sigmas[i]
- # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
- latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
-
- noise_pred = None
- if isinstance(self.scheduler, EulerAScheduler):
- # sigma = t.reshape(1) #A# potential bug: doesn't work on samples > 1
- # sigma_in = torch.cat([sigma] * 2)
- # # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas)
- # # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in)
- # c_out, c_in = [self.scheduler.append_dims(x, latent_model_input.ndim) for x in self.scheduler.get_scalings(sigma_in)]
- c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size)
-
- eps = self.unet(latent_model_input * c_in, sigma_in , encoder_hidden_states=text_embeddings).sample
- noise_pred = latent_model_input + eps * c_out
+ # latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ #TODO merge conform the EulerAScheduler interface to the standardized scheduler interface
+ if isinstance(self.scheduler, EulerAScheduler):
+ latent_unscaled = latent_model_input # store the unscaled latent
+ c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t)
+ # latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = latent_unscaled * c_in
+ # sigma_in = self.scheduler.get_sigma_in(latent_model_input,t)
+ eps = self.unet(latent_model_input, sigma_in , encoder_hidden_states=text_embeddings).sample
+ noise_pred = latent_unscaled + eps * c_out
- # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample
else:
# predict the noise residual
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
@@ -283,26 +335,31 @@ def __call__(
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ if isinstance(self.scheduler, EulerAScheduler):
+ # change from self.scheduler.timesteps.shape[0] - 1 to num_inference_steps
+ if i < self.scheduler.num_inference_steps: #avoid out of bound error
+ # t_prev = self.scheduler.timesteps[i+1]
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
+ else:
# compute the previous noisy sample x_t -> x_t-1
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
- elif isinstance(self.scheduler, EulerAScheduler):
- if i < self.scheduler.timesteps.shape[0] - 1: #avoid out of bound error
- t_prev = self.scheduler.timesteps[i+1]
- latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample
- else:
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
- # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).numpy()
- # run safety checker
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
- image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
if output_type == "pil":
image = self.numpy_to_pil(image)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index 46299bf3b3e7..72e15f4f904b 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -1,6 +1,5 @@
import inspect
-import warnings
-from typing import List, Optional, Union
+from typing import Callable, List, Optional, Union
import numpy as np
import torch
@@ -12,10 +11,14 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
@@ -65,18 +68,17 @@ def __init__(
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
- scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
- warnings.warn(
+ deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
- " file",
- DeprecationWarning,
+ " file"
)
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
@@ -126,10 +128,15 @@ def __call__(
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -155,6 +162,11 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
@@ -167,6 +179,12 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@@ -185,45 +203,40 @@ def __call__(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
- # encode the init image into latents and scale the latents
- init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
- init_latents = init_latent_dist.sample(generator=generator)
- init_latents = 0.18215 * init_latents
-
- # expand init_latents for batch_size
- init_latents = torch.cat([init_latents] * batch_size)
-
- # get the original timestep using init_timestep
- offset = self.scheduler.config.get("steps_offset", 0)
- init_timestep = int(num_inference_steps * strength) + offset
- init_timestep = min(init_timestep, num_inference_steps)
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- timesteps = torch.tensor(
- [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
- )
- else:
- timesteps = self.scheduler.timesteps[-init_timestep]
- timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
-
- # add noise to latents using the timesteps
- noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
- init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
-
# get prompt text embeddings
- text_input = self.tokenizer(
+ text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
- truncation=True,
return_tensors="pt",
)
- text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt
+ text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -231,17 +244,61 @@ def __call__(
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
- max_length = text_input.input_ids.shape[-1]
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""]
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ "`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ " {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+ # duplicate unconditional embeddings for each generation per prompt
+ uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
+
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+ # encode the init image into latents and scale the latents
+ latents_dtype = text_embeddings.dtype
+ init_image = init_image.to(device=self.device, dtype=latents_dtype)
+ init_latent_dist = self.vae.encode(init_image).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = 0.18215 * init_latents
+
+ # expand init_latents for batch_size
+ init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
+
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ timesteps = self.scheduler.timesteps[-init_timestep]
+ timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
@@ -254,17 +311,15 @@ def __call__(
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
- for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
- t_index = t_start + i
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps = self.scheduler.timesteps[t_start:].to(self.device)
+
+ for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
-
- # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- sigma = self.scheduler.sigmas[t_index]
- # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
- latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -275,21 +330,22 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
- else:
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
- # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
- # run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
- image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
if output_type == "pil":
image = self.numpy_to_pil(image)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 7de7925a302b..30a588e754b3 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -1,6 +1,5 @@
import inspect
-import warnings
-from typing import List, Optional, Union
+from typing import Callable, List, Optional, Union
import numpy as np
import torch
@@ -13,7 +12,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
-from ...utils import logging
+from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@@ -83,19 +82,18 @@ def __init__(
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
- scheduler = scheduler.set_format("pt")
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
- warnings.warn(
+ deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
- " file",
- DeprecationWarning,
+ " file"
)
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
@@ -146,10 +144,15 @@ def __call__(
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -179,6 +182,11 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
@@ -191,6 +199,12 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@@ -209,29 +223,101 @@ def __call__(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt
+ text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""]
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ "`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ " {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt
+ uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
- init_image = init_image.to(self.device)
# encode the init image into latents and scale the latents
+ latents_dtype = text_embeddings.dtype
+ init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
-
init_latents = 0.18215 * init_latents
- # Expand init_latents for batch_size
- init_latents = torch.cat([init_latents] * batch_size)
+ # Expand init_latents for batch_size and num_images_per_prompt
+ init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents
# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
- mask_image = mask_image.to(self.device)
- mask = torch.cat([mask_image] * batch_size)
+ mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
+ mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
# check sizes
if not mask.shape == init_latents.shape:
@@ -241,45 +327,14 @@ def __call__(
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- timesteps = torch.tensor(
- [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
- )
- else:
- timesteps = self.scheduler.timesteps[-init_timestep]
- timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
+
+ timesteps = self.scheduler.timesteps[-init_timestep]
+ timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
- noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
- # get prompt text embeddings
- text_input = self.tokenizer(
- prompt,
- padding="max_length",
- max_length=self.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- do_classifier_free_guidance = guidance_scale > 1.0
- # get unconditional embeddings for classifier free guidance
- if do_classifier_free_guidance:
- max_length = text_input.input_ids.shape[-1]
- uncond_input = self.tokenizer(
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
- )
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
-
- # For classifier free guidance, we need to do two forward passes.
- # Here we concatenate the unconditional and text embeddings into a single batch
- # to avoid doing two forward passes
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
@@ -290,15 +345,17 @@ def __call__(
extra_step_kwargs["eta"] = eta
latents = init_latents
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
- for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
- t_index = t_start + i
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps = self.scheduler.timesteps[t_start:].to(self.device)
+
+ for i, t in tqdm(enumerate(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- sigma = self.scheduler.sigmas[t_index]
- # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
- latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -309,25 +366,22 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
- # masking
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index))
- else:
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
- # masking
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+ # masking
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask))
- # scale and decode the image latents with vae
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
- # run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
index ccba29ade5d3..4bd6c2c8bb3e 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
@@ -1,5 +1,5 @@
import inspect
-from typing import List, Optional, Union
+from typing import Callable, List, Optional, Union
import numpy as np
@@ -8,9 +8,13 @@
from ...onnx_utils import OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import logging
from . import StableDiffusionPipelineOutput
+logger = logging.get_logger(__name__)
+
+
class StableDiffusionOnnxPipeline(DiffusionPipeline):
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
@@ -31,7 +35,6 @@ def __init__(
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
- scheduler = scheduler.set_format("np")
self.register_modules(
vae_decoder=vae_decoder,
text_encoder=text_encoder,
@@ -49,10 +52,13 @@ def __call__(
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
eta: Optional[float] = 0.0,
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: Optional[int] = 1,
**kwargs,
):
if isinstance(prompt, str):
@@ -65,15 +71,31 @@ def __call__(
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
# get prompt text embeddings
- text_input = self.tokenizer(
+ text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
- truncation=True,
return_tensors="np",
)
- text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0]
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -81,9 +103,32 @@ def __call__(
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
- max_length = text_input.input_ids.shape[-1]
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ "`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ " {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
)
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
@@ -102,9 +147,7 @@ def __call__(
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
- # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- latents = latents * self.scheduler.sigmas[0]
+ latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -118,10 +161,7 @@ def __call__(
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- sigma = self.scheduler.sigmas[i]
- # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
- latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
@@ -135,19 +175,19 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
- if isinstance(self.scheduler, LMSDiscreteScheduler):
- latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
- else:
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+ latents = np.array(latents)
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
- # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0]
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
- # run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py
index 3eb8828cdb0a..09d7a3bbf95a 100644
--- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py
+++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py
@@ -19,6 +19,8 @@ def cosine_distance(image_embeds, text_embeds):
class StableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
+ _no_split_modules = ["CLIPEncoderLayer"]
+
def __init__(self, config: CLIPConfig):
super().__init__(config)
@@ -28,16 +30,17 @@ def __init__(self, config: CLIPConfig):
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
- self.register_buffer("concept_embeds_weights", torch.ones(17))
- self.register_buffer("special_care_embeds_weights", torch.ones(3))
+ self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
+ self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
@torch.no_grad()
def forward(self, clip_input, images):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
- special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
- cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
result = []
batch_size = image_embeds.shape[0]
@@ -79,7 +82,7 @@ def forward(self, clip_input, images):
return images, has_nsfw_concepts
- @torch.inference_mode()
+ @torch.no_grad()
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
diff --git a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
index 1984a25ac0c6..9e8864b4ca76 100644
--- a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
+++ b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-import warnings
from typing import Optional, Tuple, Union
import torch
@@ -30,7 +29,6 @@ class KarrasVePipeline(DiffusionPipeline):
def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
super().__init__()
- scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
@@ -64,17 +62,6 @@ def __call__(
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
- if "torch_device" in kwargs:
- device = kwargs.pop("torch_device")
- warnings.warn(
- "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
- " Consider using `pipe.to(torch_device)` instead."
- )
-
- # Set device as before (to be removed in 0.3.0)
- if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
- self.to(device)
img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size)
@@ -82,7 +69,7 @@ def __call__(
model = self.unet
# sample x_0 ~ N(0, sigma_0^2 * I)
- sample = torch.randn(*shape) * self.scheduler.config.sigma_max
+ sample = torch.randn(*shape) * self.scheduler.init_noise_sigma
sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)
diff --git a/src/diffusers/schedulers/README.md b/src/diffusers/schedulers/README.md
index edf2299446fe..6a01c503a909 100644
--- a/src/diffusers/schedulers/README.md
+++ b/src/diffusers/schedulers/README.md
@@ -2,17 +2,16 @@
- Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps.
- Schedulers can be used interchangeable between diffusion models in inference to find the preferred trade-off between speed and generation quality.
-- Schedulers are available in numpy, but can easily be transformed into PyTorch.
+- Schedulers are available in PyTorch and Jax.
## API
- Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during
the forward pass.
-- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
-with a `set_format(...)` method.
+- Schedulers should be framework specific.
## Examples
-- The DDPM scheduler was proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and can be found in [scheduling_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py). An example of how to use this scheduler can be found in [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
-- The DDIM scheduler was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) and can be found in [scheduling_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py). An example of how to use this scheduler can be found in [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
-- The PNDM scheduler was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
+- The DDPM scheduler was proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and can be found in [scheduling_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py). An example of how to use this scheduler can be found in [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm/pipeline_ddpm.py).
+- The DDIM scheduler was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) and can be found in [scheduling_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py). An example of how to use this scheduler can be found in [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim/pipeline_ddim.py).
+- The PNDM scheduler was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm/pipeline_pndm.py).
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index 9dcd2b3627a5..98a5fde2e13c 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -35,10 +35,12 @@
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
+ from .scheduling_utils_flax import FlaxSchedulerMixin
else:
from ..utils.dummy_flax_objects import * # noqa F403
-if is_scipy_available():
+
+if is_scipy_available() and is_torch_available():
from .scheduling_lms_discrete import LMSDiscreteScheduler
else:
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py
index 0613ffd41d0e..33d9bafb8aed 100644
--- a/src/diffusers/schedulers/scheduling_ddim.py
+++ b/src/diffusers/schedulers/scheduling_ddim.py
@@ -16,7 +16,6 @@
# and https://github.com/hojonathanho/diffusion
import math
-import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -24,7 +23,7 @@
import torch
from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput
+from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
@@ -46,7 +45,7 @@ class DDIMSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None
-def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -72,7 +71,7 @@ def alpha_bar(time_step):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
- return np.array(betas, dtype=np.float32)
+ return torch.tensor(betas)
class DDIMScheduler(SchedulerMixin, ConfigMixin):
@@ -106,7 +105,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
- tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@@ -121,15 +119,24 @@ def __init__(
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
- tensor_format: str = "pt",
+ **kwargs,
):
+ deprecate(
+ "tensor_format",
+ "0.6.0",
+ "If you're running your code in PyTorch, you can safely remove this argument.",
+ take_from=kwargs,
+ )
+
if trained_betas is not None:
- self.betas = np.asarray(trained_betas)
- if beta_schedule == "linear":
- self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ self.betas = torch.from_numpy(trained_betas)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -137,20 +144,34 @@ def __init__(
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
- self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
- self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
# setable values
self.num_inference_steps = None
- self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
- self.tensor_format = tensor_format
- self.set_format(tensor_format=tensor_format)
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
@@ -162,7 +183,7 @@ def _get_variance(self, timestep, prev_timestep):
return variance
- def set_timesteps(self, num_inference_steps: int, **kwargs):
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -170,31 +191,24 @@ def set_timesteps(self, num_inference_steps: int, **kwargs):
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
-
- offset = self.config.steps_offset
-
- if "offset" in kwargs:
- warnings.warn(
- "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
- " Please pass `steps_offset` to `__init__` instead.",
- DeprecationWarning,
- )
-
- offset = kwargs["offset"]
+ deprecated_offset = deprecate(
+ "offset", "0.7.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
+ )
+ offset = deprecated_offset or self.config.steps_offset
self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
- self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += offset
- self.set_format(tensor_format=self.tensor_format)
def step(
self,
- model_output: Union[torch.FloatTensor, np.ndarray],
+ model_output: torch.FloatTensor,
timestep: int,
- sample: Union[torch.FloatTensor, np.ndarray],
+ sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
@@ -205,9 +219,9 @@ def step(
process from the learned model outputs (most often the predicted noise).
Args:
- model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
- sample (`torch.FloatTensor` or `np.ndarray`):
+ sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
@@ -251,7 +265,7 @@ def step(
# 4. Clip "predicted x_0"
if self.config.clip_sample:
- pred_original_sample = self.clip(pred_original_sample, -1, 1)
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
@@ -269,13 +283,11 @@ def step(
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
+ # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device = model_output.device if torch.is_tensor(model_output) else "cpu"
- noise = torch.randn(model_output.shape, generator=generator).to(device)
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
- if not torch.is_tensor(model_output):
- variance = variance.numpy()
-
prev_sample = prev_sample + variance
if not return_dict:
@@ -285,16 +297,23 @@ def step(
def add_noise(
self,
- original_samples: Union[torch.FloatTensor, np.ndarray],
- noise: Union[torch.FloatTensor, np.ndarray],
- timesteps: Union[torch.IntTensor, np.ndarray],
- ) -> Union[torch.FloatTensor, np.ndarray]:
- if self.tensor_format == "pt":
- timesteps = timesteps.to(self.alphas_cumprod.device)
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
- sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
- sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py
index d81d66607147..9a1f339b49da 100644
--- a/src/diffusers/schedulers/scheduling_ddim_flax.py
+++ b/src/diffusers/schedulers/scheduling_ddim_flax.py
@@ -23,7 +23,7 @@
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
-from .scheduling_utils import SchedulerMixin, SchedulerOutput
+from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
@@ -68,11 +68,11 @@ def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):
@dataclass
-class FlaxSchedulerOutput(SchedulerOutput):
+class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput):
state: DDIMSchedulerState
-class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
+class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
@@ -156,7 +156,7 @@ def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
return variance
- def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState:
+ def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -183,7 +183,7 @@ def step(
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
- ) -> Union[FlaxSchedulerOutput, Tuple]:
+ ) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
@@ -197,11 +197,11 @@ def step(
key (`random.KeyArray`): a PRNG key.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
- return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
Returns:
- [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
- When returning a tuple, the first element is the sample tensor.
+ [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
@@ -252,7 +252,7 @@ def step(
if not return_dict:
return (prev_sample, state)
- return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
+ return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state)
def add_noise(
self,
diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py
index 440b880385d4..04c92904a660 100644
--- a/src/diffusers/schedulers/scheduling_ddpm.py
+++ b/src/diffusers/schedulers/scheduling_ddpm.py
@@ -22,7 +22,7 @@
import torch
from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput
+from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
@@ -70,7 +70,7 @@ def alpha_bar(time_step):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
- return np.array(betas, dtype=np.float32)
+ return torch.tensor(betas, dtype=torch.float32)
class DDPMScheduler(SchedulerMixin, ConfigMixin):
@@ -99,7 +99,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
- tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@@ -113,35 +112,62 @@ def __init__(
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
- tensor_format: str = "pt",
+ **kwargs,
):
+ deprecate(
+ "tensor_format",
+ "0.6.0",
+ "If you're running your code in PyTorch, you can safely remove this argument.",
+ take_from=kwargs,
+ )
+
if trained_betas is not None:
- self.betas = np.asarray(trained_betas)
+ self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
- self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
+ elif beta_schedule == "sigmoid":
+ # GeoDiff sigmoid schedule
+ betas = torch.linspace(-6, 6, num_train_timesteps)
+ self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
- self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
- self.one = np.array(1.0)
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ self.one = torch.tensor(1.0)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
# setable values
self.num_inference_steps = None
- self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
-
- self.tensor_format = tensor_format
- self.set_format(tensor_format=tensor_format)
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.variance_type = variance_type
- def set_timesteps(self, num_inference_steps: int):
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -151,10 +177,10 @@ def set_timesteps(self, num_inference_steps: int):
"""
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps
- self.timesteps = np.arange(
+ timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1].copy()
- self.set_format(tensor_format=self.tensor_format)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
@@ -170,15 +196,15 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
# hacks - were probably added for training stability
if variance_type == "fixed_small":
- variance = self.clip(variance, min_value=1e-20)
+ variance = torch.clamp(variance, min=1e-20)
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log":
- variance = self.log(self.clip(variance, min_value=1e-20))
+ variance = torch.log(torch.clamp(variance, min=1e-20))
elif variance_type == "fixed_large":
variance = self.betas[t]
elif variance_type == "fixed_large_log":
# Glide max_log
- variance = self.log(self.betas[t])
+ variance = torch.log(self.betas[t])
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
@@ -191,9 +217,9 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
def step(
self,
- model_output: Union[torch.FloatTensor, np.ndarray],
+ model_output: torch.FloatTensor,
timestep: int,
- sample: Union[torch.FloatTensor, np.ndarray],
+ sample: torch.FloatTensor,
predict_epsilon=True,
generator=None,
return_dict: bool = True,
@@ -203,9 +229,9 @@ def step(
process from the learned model outputs (most often the predicted noise).
Args:
- model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
- sample (`torch.FloatTensor` or `np.ndarray`):
+ sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
@@ -240,7 +266,7 @@ def step(
# 3. Clip "predicted x_0"
if self.config.clip_sample:
- pred_original_sample = self.clip(pred_original_sample, -1, 1)
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
@@ -254,7 +280,9 @@ def step(
# 6. Add noise
variance = 0
if t > 0:
- noise = self.randn_like(model_output, generator=generator)
+ noise = torch.randn(
+ model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
+ ).to(model_output.device)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
pred_prev_sample = pred_prev_sample + variance
@@ -266,16 +294,23 @@ def step(
def add_noise(
self,
- original_samples: Union[torch.FloatTensor, np.ndarray],
- noise: Union[torch.FloatTensor, np.ndarray],
- timesteps: Union[torch.IntTensor, np.ndarray],
- ) -> Union[torch.FloatTensor, np.ndarray]:
- if self.tensor_format == "pt":
- timesteps = timesteps.to(self.alphas_cumprod.device)
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
- sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
- sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py
index 7c7b8d29ab52..7b3265611101 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_flax.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py
@@ -23,7 +23,7 @@
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
-from .scheduling_utils import SchedulerMixin, SchedulerOutput
+from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
@@ -67,11 +67,11 @@ def create(cls, num_train_timesteps: int):
@dataclass
-class FlaxSchedulerOutput(SchedulerOutput):
+class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput):
state: DDPMSchedulerState
-class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
+class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling.
@@ -133,7 +133,7 @@ def __init__(
self.variance_type = variance_type
- def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState:
+ def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -191,7 +191,7 @@ def step(
key: random.KeyArray,
predict_epsilon: bool = True,
return_dict: bool = True,
- ) -> Union[FlaxSchedulerOutput, Tuple]:
+ ) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
@@ -205,11 +205,11 @@ def step(
key (`random.KeyArray`): a PRNG key.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
- return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
Returns:
- [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
- When returning a tuple, the first element is the sample tensor.
+ [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
"""
t = timestep
@@ -257,7 +257,7 @@ def step(
if not return_dict:
return (pred_prev_sample, state)
- return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state)
+ return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state)
def add_noise(
self,
diff --git a/src/diffusers/schedulers/scheduling_euler_a.py b/src/diffusers/schedulers/scheduling_euler_a.py
index 3044ed7921de..7c717bcc7153 100644
--- a/src/diffusers/schedulers/scheduling_euler_a.py
+++ b/src/diffusers/schedulers/scheduling_euler_a.py
@@ -101,12 +101,13 @@ def __init__(
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
- elif beta_schedule == "squaredcos_cap_v2":
- # Glide cosine schedule
- self.betas = betas_for_alpha_bar(num_train_timesteps)
+ # elif beta_schedule == "squaredcos_cap_v2":
+ # # Glide cosine schedule
+ # self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
self.device = device
self.alphas = 1.0 - torch.from_numpy(self.betas).to(self.device)
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
@@ -124,6 +125,10 @@ def __init__(
self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
self.sigmas = self.get_sigmas(self.DSsigmas,self.num_inference_steps)
self.tensor_format = tensor_format
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = self.sigmas[0]
+
self.set_format(tensor_format=tensor_format)
@@ -158,6 +163,29 @@ def set_timesteps(self, num_inference_steps: int, **kwargs):
self.timesteps = self.sigmas
self.set_format(tensor_format=self.tensor_format)
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ # c_out, c_in, sigma_in = self.prepare_input(sample, timestep)
+
+
+ # noise_pred = latent_model_input + eps * c_out
+ # sample * c_in, sigma_in
+ # return sample *c_in
+ c_out, c_in, sigma_in = self.prepare_input(sample, timestep)
+
+ return sample * c_in
+
def add_noise_to_input(
self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
@@ -183,7 +211,7 @@ def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
- timestep_prev: int,
+ # timestep_prev: int,
sample:float,
generator: Optional[torch.Generator] = None,
# ,sigma_hat: float,
@@ -209,16 +237,23 @@ def step(
returning a tuple, the first element is the sample tensor.
"""
- latents = sample
- sigma_down, sigma_up = self.get_ancestral_step(timestep, timestep_prev)
+ latents = sample
+ # ideally we could pass the index aka step to the this method
+ # which will allow as to get the current timestep and the previous timestep
+ i = timestep # we are passing timestep as index
+ timestep = self.timesteps[i]
+ prev_timestep = self.timesteps[i + 1]
+ sigma_down, sigma_up = self.get_ancestral_step(timestep, prev_timestep)
# if callback is not None:
# callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output})
d = self.to_d(latents, timestep, model_output)
# Euler method
dt = sigma_down - timestep
latents = latents + d * dt
- latents = latents + self.randn_like(latents,generator=generator) * sigma_up # use self.randn_like instead of torch.randn_like to get deterministic output
+ # latents = latents + self.randn_like(latents,generator=generator) * sigma_up # use self.randn_like instead of torch.randn_like to get deterministic output
+ noise = torch.randn(latents.shape, dtype=latents.dtype, generator=generator).to(self.device)
+ latents = latents + noise * sigma_up
return SchedulerOutput(prev_sample=latents)
@@ -311,7 +346,7 @@ def get_scalings(self, sigma):
c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
return c_out, c_in
- #DiscreteSchedule DS
+ #DiscreteSchedule -> DS
def DSsigma_to_t(self, sigma, quantize=None):
# quantize = self.quantize if quantize is None else quantize
quantize = False
@@ -325,10 +360,10 @@ def DSsigma_to_t(self, sigma, quantize=None):
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
- def prepare_input(self,latent_in, t, batch_size):
- sigma = t.reshape(1) #A# potential bug: doesn't work on samples > 1
+ def prepare_input(self,latent_in, t):
+ sigma = t.reshape(1)
- sigma_in = torch.cat([sigma] * 2 * batch_size)
+ sigma_in = torch.cat([sigma] * latent_in.shape[0])# latent_in.shape[0] => 2 * batch_size
# noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas)
# noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in)
c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)]
@@ -337,4 +372,14 @@ def prepare_input(self,latent_in, t, batch_size):
# s_in = latent_in.new_ones([latent_in.shape[0]])
# sigma_in = sigma_in * s_in
- return c_out, c_in, sigma_in
\ No newline at end of file
+ return c_out, c_in, sigma_in
+
+ def get_sigma_in(self,latent_in, t):
+ sigma = t.reshape(1)
+
+ sigma_in = torch.cat([sigma] * latent_in.shape[0])# latent_in.shape[0] => 2 * batch_size
+
+ sigma_in = self.DSsigma_to_t(sigma_in)
+
+ return sigma_in
+
\ No newline at end of file
diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py
index 98dafc72a734..3b0ec91ed157 100644
--- a/src/diffusers/schedulers/scheduling_karras_ve.py
+++ b/src/diffusers/schedulers/scheduling_karras_ve.py
@@ -20,7 +20,7 @@
import torch
from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput
+from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
@@ -74,7 +74,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
A reasonable range is [0, 10].
s_max (`float`): the end value of the sigma range where we add noise.
A reasonable range is [0.2, 80].
- tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@@ -87,17 +86,38 @@ def __init__(
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
- tensor_format: str = "pt",
+ **kwargs,
):
+ deprecate(
+ "tensor_format",
+ "0.6.0",
+ "If you're running your code in PyTorch, you can safely remove this argument.",
+ take_from=kwargs,
+ )
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = sigma_max
+
# setable values
- self.num_inference_steps = None
- self.timesteps = None
- self.schedule = None # sigma(t_i)
+ self.num_inference_steps: int = None
+ self.timesteps: np.IntTensor = None
+ self.schedule: torch.FloatTensor = None # sigma(t_i)
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
- self.tensor_format = tensor_format
- self.set_format(tensor_format=tensor_format)
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
- def set_timesteps(self, num_inference_steps: int):
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -107,21 +127,20 @@ def set_timesteps(self, num_inference_steps: int):
"""
self.num_inference_steps = num_inference_steps
- self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
- self.schedule = [
+ timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+ schedule = [
(
self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in self.timesteps
]
- self.schedule = np.array(self.schedule, dtype=np.float32)
-
- self.set_format(tensor_format=self.tensor_format)
+ self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
def add_noise_to_input(
- self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
- ) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
+ self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
+ ) -> Tuple[torch.FloatTensor, float]:
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
@@ -142,10 +161,10 @@ def add_noise_to_input(
def step(
self,
- model_output: Union[torch.FloatTensor, np.ndarray],
+ model_output: torch.FloatTensor,
sigma_hat: float,
sigma_prev: float,
- sample_hat: Union[torch.FloatTensor, np.ndarray],
+ sample_hat: torch.FloatTensor,
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
"""
@@ -153,10 +172,10 @@ def step(
process from the learned model outputs (most often the predicted noise).
Args:
- model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
- sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
+ sample_hat (`torch.FloatTensor`): TODO
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
@@ -180,24 +199,24 @@ def step(
def step_correct(
self,
- model_output: Union[torch.FloatTensor, np.ndarray],
+ model_output: torch.FloatTensor,
sigma_hat: float,
sigma_prev: float,
- sample_hat: Union[torch.FloatTensor, np.ndarray],
- sample_prev: Union[torch.FloatTensor, np.ndarray],
- derivative: Union[torch.FloatTensor, np.ndarray],
+ sample_hat: torch.FloatTensor,
+ sample_prev: torch.FloatTensor,
+ derivative: torch.FloatTensor,
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. TODO complete description
Args:
- model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
- sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
- sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
- derivative (`torch.FloatTensor` or `np.ndarray`): TODO
+ sample_hat (`torch.FloatTensor`): TODO
+ sample_prev (`torch.FloatTensor`): TODO
+ derivative (`torch.FloatTensor`): TODO
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
Returns:
diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py
index c320b79e6dcd..caf27aa4c226 100644
--- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py
+++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py
@@ -22,7 +22,7 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
-from .scheduling_utils import SchedulerMixin
+from .scheduling_utils_flax import FlaxSchedulerMixin
@flax.struct.dataclass
@@ -56,7 +56,7 @@ class FlaxKarrasVeOutput(BaseOutput):
state: KarrasVeSchedulerState
-class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
+class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference.
@@ -99,7 +99,9 @@ def __init__(
):
self.state = KarrasVeSchedulerState.create()
- def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState:
+ def set_timesteps(
+ self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple
+ ) -> KarrasVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -170,7 +172,7 @@ def step(
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
- return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
Returns:
[`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
@@ -209,7 +211,7 @@ def step_correct(
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
- return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py
index 1dd6dbda1e19..12dc473f63ae 100644
--- a/src/diffusers/schedulers/scheduling_lms_discrete.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -21,7 +21,7 @@
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput
+from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
@@ -63,9 +63,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
- options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
- `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
- tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@@ -77,30 +74,64 @@ def __init__(
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
- tensor_format: str = "pt",
+ **kwargs,
):
+ deprecate(
+ "tensor_format",
+ "0.6.0",
+ "If you're running your code in PyTorch, you can safely remove this argument.",
+ take_from=kwargs,
+ )
+
if trained_betas is not None:
- self.betas = np.asarray(trained_betas)
- if beta_schedule == "linear":
- self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ self.betas = torch.from_numpy(trained_betas)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
- self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas)
- self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = self.sigmas.max()
# setable values
self.num_inference_steps = None
- self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
self.derivatives = []
+ self.is_scale_input_called = False
+
+ def scale_model_input(
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
+ ) -> torch.FloatTensor:
+ """
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
- self.tensor_format = tensor_format
- self.set_format(tensor_format=tensor_format)
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ self.is_scale_input_called = True
+ return sample
def get_lms_coefficient(self, order, t, current_order):
"""
@@ -124,33 +155,32 @@ def lms_derivative(tau):
return integrated_coeff
- def set_timesteps(self, num_inference_steps: int):
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
- self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
- low_idx = np.floor(self.timesteps).astype(int)
- high_idx = np.ceil(self.timesteps).astype(int)
- frac = np.mod(self.timesteps, 1.0)
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
- sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
- self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.derivatives = []
- self.set_format(tensor_format=self.tensor_format)
-
def step(
self,
- model_output: Union[torch.FloatTensor, np.ndarray],
- timestep: int,
- sample: Union[torch.FloatTensor, np.ndarray],
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
order: int = 4,
return_dict: bool = True,
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
@@ -159,9 +189,9 @@ def step(
process from the learned model outputs (most often the predicted noise).
Args:
- model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
- timestep (`int`): current discrete timestep in the diffusion chain.
- sample (`torch.FloatTensor` or `np.ndarray`):
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`float`): current timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
@@ -172,7 +202,31 @@ def step(
When returning a tuple, the first element is the sample tensor.
"""
- sigma = self.sigmas[timestep]
+ if not self.is_scale_input_called:
+ warnings.warn(
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
+ "See `StableDiffusionPipeline` for a usage example."
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ deprecate(
+ "timestep as an index",
+ "0.7.0",
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `LMSDiscreteScheduler.step()` will not be supported in future versions. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep.",
+ standard_warn=False,
+ )
+ step_index = timestep
+ else:
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output
@@ -184,8 +238,8 @@ def step(
self.derivatives.pop(0)
# 3. Compute linear multistep coefficients
- order = min(timestep + 1, order)
- lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]
+ order = min(step_index + 1, order)
+ lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]
# 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum(
@@ -199,15 +253,35 @@ def step(
def add_noise(
self,
- original_samples: Union[torch.FloatTensor, np.ndarray],
- noise: Union[torch.FloatTensor, np.ndarray],
- timesteps: Union[torch.IntTensor, np.ndarray],
- ) -> Union[torch.FloatTensor, np.ndarray]:
- if self.tensor_format == "pt":
- timesteps = timesteps.to(self.sigmas.device)
- sigmas = self.match_shape(self.sigmas[timesteps], noise)
- noisy_samples = original_samples + noise * sigmas
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ self.timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ schedule_timesteps = self.timesteps
+
+ if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
+ deprecate(
+ "timesteps as indices",
+ "0.7.0",
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `LMSDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
+ " pass values from `scheduler.timesteps` as timesteps.",
+ standard_warn=False,
+ )
+ step_indices = timesteps
+ else:
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = self.sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+ noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
index 7f4c076b54d1..cd71e1960e8c 100644
--- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
@@ -20,7 +20,7 @@
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
-from .scheduling_utils import SchedulerMixin, SchedulerOutput
+from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
@flax.struct.dataclass
@@ -37,11 +37,11 @@ def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray):
@dataclass
-class FlaxSchedulerOutput(SchedulerOutput):
+class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
state: LMSDiscreteSchedulerState
-class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
+class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
Katherine Crowson:
@@ -61,8 +61,6 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
- options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
- `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
"""
@register_to_config
@@ -76,7 +74,7 @@ def __init__(
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
- if beta_schedule == "linear":
+ elif beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
@@ -113,7 +111,9 @@ def lms_derivative(tau):
return integrated_coeff
- def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState:
+ def set_timesteps(
+ self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple
+ ) -> LMSDiscreteSchedulerState:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -134,7 +134,7 @@ def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: i
return state.replace(
num_inference_steps=num_inference_steps,
- timesteps=timesteps,
+ timesteps=timesteps.astype(int),
derivatives=jnp.array([]),
sigmas=sigmas,
)
@@ -147,7 +147,7 @@ def step(
sample: jnp.ndarray,
order: int = 4,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> Union[FlaxLMSSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
@@ -159,11 +159,11 @@ def step(
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
- return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class
Returns:
- [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
- When returning a tuple, the first element is the sample tensor.
+ [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
"""
sigma = state.sigmas[timestep]
@@ -189,7 +189,7 @@ def step(
if not return_dict:
return (prev_sample, state)
- return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
+ return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
def add_noise(
self,
diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py
index 09e8a7e240c2..b26840ea1997 100644
--- a/src/diffusers/schedulers/scheduling_pndm.py
+++ b/src/diffusers/schedulers/scheduling_pndm.py
@@ -15,13 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
-import warnings
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -51,7 +51,7 @@ def alpha_bar(time_step):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
- return np.array(betas, dtype=np.float32)
+ return torch.tensor(betas, dtype=torch.float32)
class PNDMScheduler(SchedulerMixin, ConfigMixin):
@@ -86,7 +86,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
- tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
"""
@@ -101,15 +100,24 @@ def __init__(
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
- tensor_format: str = "pt",
+ **kwargs,
):
+ deprecate(
+ "tensor_format",
+ "0.6.0",
+ "If you're running your code in PyTorch, you can safely remove this argument.",
+ take_from=kwargs,
+ )
+
if trained_betas is not None:
- self.betas = np.asarray(trained_betas)
- if beta_schedule == "linear":
- self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ self.betas = torch.from_numpy(trained_betas)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -117,9 +125,12 @@ def __init__(
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
- self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
- self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
@@ -139,10 +150,7 @@ def __init__(
self.plms_timesteps = None
self.timesteps = None
- self.tensor_format = tensor_format
- self.set_format(tensor_format=tensor_format)
-
- def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -150,16 +158,10 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
-
- offset = self.config.steps_offset
-
- if "offset" in kwargs:
- warnings.warn(
- "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
- " Please pass `steps_offset` to `__init__` instead."
- )
-
- offset = kwargs["offset"]
+ deprecated_offset = deprecate(
+ "offset", "0.7.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
+ )
+ offset = deprecated_offset or self.config.steps_offset
self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
@@ -185,17 +187,17 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor
::-1
].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
- self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
+ timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
self.ets = []
self.counter = 0
- self.set_format(tensor_format=self.tensor_format)
def step(
self,
- model_output: Union[torch.FloatTensor, np.ndarray],
+ model_output: torch.FloatTensor,
timestep: int,
- sample: Union[torch.FloatTensor, np.ndarray],
+ sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
@@ -205,9 +207,9 @@ def step(
This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
Args:
- model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
- sample (`torch.FloatTensor` or `np.ndarray`):
+ sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
@@ -224,9 +226,9 @@ def step(
def step_prk(
self,
- model_output: Union[torch.FloatTensor, np.ndarray],
+ model_output: torch.FloatTensor,
timestep: int,
- sample: Union[torch.FloatTensor, np.ndarray],
+ sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
@@ -234,9 +236,9 @@ def step_prk(
solution to the differential equation.
Args:
- model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
- sample (`torch.FloatTensor` or `np.ndarray`):
+ sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
@@ -279,9 +281,9 @@ def step_prk(
def step_plms(
self,
- model_output: Union[torch.FloatTensor, np.ndarray],
+ model_output: torch.FloatTensor,
timestep: int,
- sample: Union[torch.FloatTensor, np.ndarray],
+ sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
@@ -289,9 +291,9 @@ def step_plms(
times to approximate the solution.
Args:
- model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
- sample (`torch.FloatTensor` or `np.ndarray`):
+ sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
@@ -343,6 +345,19 @@ def step_plms(
return SchedulerOutput(prev_sample=prev_sample)
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9)
@@ -381,16 +396,23 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
def add_noise(
self,
- original_samples: Union[torch.FloatTensor, np.ndarray],
- noise: Union[torch.FloatTensor, np.ndarray],
- timesteps: Union[torch.IntTensor, np.ndarray],
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
) -> torch.Tensor:
- if self.tensor_format == "pt":
- timesteps = timesteps.to(self.alphas_cumprod.device)
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
- sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
- sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py
index 8344505620c4..0a1bcce564b8 100644
--- a/src/diffusers/schedulers/scheduling_pndm_flax.py
+++ b/src/diffusers/schedulers/scheduling_pndm_flax.py
@@ -19,10 +19,11 @@
from typing import Optional, Tuple, Union
import flax
+import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
-from .scheduling_utils import SchedulerMixin, SchedulerOutput
+from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
@@ -75,11 +76,11 @@ def create(cls, num_train_timesteps: int):
@dataclass
-class FlaxSchedulerOutput(SchedulerOutput):
+class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput):
state: PNDMSchedulerState
-class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
+class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
namely Runge-Kutta method and a linear multi-step method.
@@ -131,7 +132,7 @@ def __init__(
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
- if beta_schedule == "linear":
+ elif beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
@@ -155,7 +156,7 @@ def __init__(
def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
- def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
+ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -196,8 +197,11 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) ->
return state.replace(
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
- ets=jnp.array([]),
counter=0,
+ # Reserve space for the state variables
+ cur_model_output=jnp.zeros(shape),
+ cur_sample=jnp.zeros(shape),
+ ets=jnp.zeros((4,) + shape),
)
def step(
@@ -207,7 +211,7 @@ def step(
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
- ) -> Union[FlaxSchedulerOutput, Tuple]:
+ ) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
@@ -220,30 +224,40 @@ def step(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
- return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
Returns:
- [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
- When returning a tuple, the first element is the sample tensor.
+ [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
"""
- if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps:
- return self.step_prk(
- state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
+ if self.config.skip_prk_steps:
+ prev_sample, state = self.step_plms(
+ state=state, model_output=model_output, timestep=timestep, sample=sample
)
else:
- return self.step_plms(
- state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
+ prev_sample, state = jax.lax.switch(
+ jnp.where(state.counter < len(state.prk_timesteps), 0, 1),
+ (self.step_prk, self.step_plms),
+ # Args to either branch
+ state,
+ model_output,
+ timestep,
+ sample,
)
+ if not return_dict:
+ return (prev_sample, state)
+
+ return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state)
+
def step_prk(
self,
state: PNDMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
- return_dict: bool = True,
- ) -> Union[FlaxSchedulerOutput, Tuple]:
+ ) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
@@ -254,11 +268,11 @@ def step_prk(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
- return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
Returns:
- [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
- When returning a tuple, the first element is the sample tensor.
+ [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
@@ -266,34 +280,46 @@ def step_prk(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
- diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
+ diff_to_prev = jnp.where(
+ state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
+ )
prev_timestep = timestep - diff_to_prev
timestep = state.prk_timesteps[state.counter // 4 * 4]
- if state.counter % 4 == 0:
- state = state.replace(
- cur_model_output=state.cur_model_output + 1 / 6 * model_output,
- ets=state.ets.append(model_output),
- cur_sample=sample,
+ def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
+ return (
+ state.replace(
+ cur_model_output=state.cur_model_output + 1 / 6 * model_output,
+ ets=state.ets.at[ets_at].set(model_output),
+ cur_sample=sample,
+ ),
+ model_output,
)
- elif (self.counter - 1) % 4 == 0:
- state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
- elif (self.counter - 2) % 4 == 0:
- state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
- elif (self.counter - 3) % 4 == 0:
- model_output = state.cur_model_output + 1 / 6 * model_output
- state = state.replace(cur_model_output=0)
- # cur_sample should not be `None`
- cur_sample = state.cur_sample if state.cur_sample is not None else sample
+ def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
+ return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
+ def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
+ return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
+
+ def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
+ model_output = state.cur_model_output + 1 / 6 * model_output
+ return state.replace(cur_model_output=jnp.zeros_like(state.cur_model_output)), model_output
+
+ state, model_output = jax.lax.switch(
+ state.counter % 4,
+ (remainder_0, remainder_1, remainder_2, remainder_3),
+ # Args to either branch
+ state,
+ model_output,
+ state.counter // 4,
+ )
+
+ cur_sample = state.cur_sample
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)
- if not return_dict:
- return (prev_sample, state)
-
- return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
+ return (prev_sample, state)
def step_plms(
self,
@@ -301,8 +327,7 @@ def step_plms(
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
- return_dict: bool = True,
- ) -> Union[FlaxSchedulerOutput, Tuple]:
+ ) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
@@ -313,11 +338,11 @@ def step_plms(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
- return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
Returns:
- [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
- When returning a tuple, the first element is the sample tensor.
+ [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
@@ -334,36 +359,91 @@ def step_plms(
)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
+ prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)
+
+ # Reference:
+ # if state.counter != 1:
+ # state.ets.append(model_output)
+ # else:
+ # prev_timestep = timestep
+ # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
+
+ prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
+ timestep = jnp.where(
+ state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
+ )
- if state.counter != 1:
- state = state.replace(ets=state.ets.append(model_output))
- else:
- prev_timestep = timestep
- timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
-
- if len(state.ets) == 1 and state.counter == 0:
- model_output = model_output
- state = state.replace(cur_sample=sample)
- elif len(state.ets) == 1 and state.counter == 1:
- model_output = (model_output + state.ets[-1]) / 2
- sample = state.cur_sample
- state = state.replace(cur_sample=None)
- elif len(state.ets) == 2:
- model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
- elif len(state.ets) == 3:
- model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
- else:
- model_output = (1 / 24) * (
- 55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
+ # Reference:
+ # if len(state.ets) == 1 and state.counter == 0:
+ # model_output = model_output
+ # state.cur_sample = sample
+ # elif len(state.ets) == 1 and state.counter == 1:
+ # model_output = (model_output + state.ets[-1]) / 2
+ # sample = state.cur_sample
+ # state.cur_sample = None
+ # elif len(state.ets) == 2:
+ # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
+ # elif len(state.ets) == 3:
+ # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
+ # else:
+ # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])
+
+ def counter_0(state: PNDMSchedulerState):
+ ets = state.ets.at[0].set(model_output)
+ return state.replace(
+ ets=ets,
+ cur_sample=sample,
+ cur_model_output=jnp.array(model_output, dtype=jnp.float32),
+ )
+
+ def counter_1(state: PNDMSchedulerState):
+ return state.replace(
+ cur_model_output=(model_output + state.ets[0]) / 2,
+ )
+
+ def counter_2(state: PNDMSchedulerState):
+ ets = state.ets.at[1].set(model_output)
+ return state.replace(
+ ets=ets,
+ cur_model_output=(3 * ets[1] - ets[0]) / 2,
+ cur_sample=sample,
+ )
+
+ def counter_3(state: PNDMSchedulerState):
+ ets = state.ets.at[2].set(model_output)
+ return state.replace(
+ ets=ets,
+ cur_model_output=(23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12,
+ cur_sample=sample,
)
+ def counter_other(state: PNDMSchedulerState):
+ ets = state.ets.at[3].set(model_output)
+ next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0])
+
+ ets = ets.at[0].set(ets[1])
+ ets = ets.at[1].set(ets[2])
+ ets = ets.at[2].set(ets[3])
+
+ return state.replace(
+ ets=ets,
+ cur_model_output=next_model_output,
+ cur_sample=sample,
+ )
+
+ counter = jnp.clip(state.counter, 0, 4)
+ state = jax.lax.switch(
+ counter,
+ [counter_0, counter_1, counter_2, counter_3, counter_other],
+ state,
+ )
+
+ sample = state.cur_sample
+ model_output = state.cur_model_output
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)
- if not return_dict:
- return (prev_sample, state)
-
- return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
+ return (prev_sample, state)
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
@@ -379,7 +459,7 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep]
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py
index 4af8f4fdad7d..01fe222be97e 100644
--- a/src/diffusers/schedulers/scheduling_sde_ve.py
+++ b/src/diffusers/schedulers/scheduling_sde_ve.py
@@ -14,15 +14,14 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
-import warnings
+import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
-import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput
+from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -65,7 +64,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
epsilon.
correct_steps (`int`): number of correction steps performed on a produced sample.
- tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler.
"""
@register_to_config
@@ -77,17 +75,40 @@ def __init__(
sigma_max: float = 1348.0,
sampling_eps: float = 1e-5,
correct_steps: int = 1,
- tensor_format: str = "pt",
+ **kwargs,
):
+ deprecate(
+ "tensor_format",
+ "0.6.0",
+ "If you're running your code in PyTorch, you can safely remove this argument.",
+ take_from=kwargs,
+ )
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = sigma_max
+
# setable values
self.timesteps = None
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
- self.tensor_format = tensor_format
- self.set_format(tensor_format=tensor_format)
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
- def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(
+ self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
+ ):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -98,13 +119,8 @@ def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
"""
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
- tensor_format = getattr(self, "tensor_format", "pt")
- if tensor_format == "np":
- self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
- elif tensor_format == "pt":
- self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
- else:
- raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
def set_sigmas(
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
@@ -129,60 +145,33 @@ def set_sigmas(
if self.timesteps is None:
self.set_timesteps(num_inference_steps, sampling_eps)
- tensor_format = getattr(self, "tensor_format", "pt")
- if tensor_format == "np":
- self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
- self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
- elif tensor_format == "pt":
- self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
- self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
- else:
- raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+ self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
+ self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
+ self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
def get_adjacent_sigma(self, timesteps, t):
- tensor_format = getattr(self, "tensor_format", "pt")
- if tensor_format == "np":
- return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
- elif tensor_format == "pt":
- return torch.where(
- timesteps == 0,
- torch.zeros_like(t.to(timesteps.device)),
- self.discrete_sigmas[timesteps - 1].to(timesteps.device),
- )
-
- raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
-
- def set_seed(self, seed):
- warnings.warn(
- "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
- " generator instead.",
- DeprecationWarning,
+ return torch.where(
+ timesteps == 0,
+ torch.zeros_like(t.to(timesteps.device)),
+ self.discrete_sigmas[timesteps - 1].to(timesteps.device),
)
- tensor_format = getattr(self, "tensor_format", "pt")
- if tensor_format == "np":
- np.random.seed(seed)
- elif tensor_format == "pt":
- torch.manual_seed(seed)
- else:
- raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def step_pred(
self,
- model_output: Union[torch.FloatTensor, np.ndarray],
+ model_output: torch.FloatTensor,
timestep: int,
- sample: Union[torch.FloatTensor, np.ndarray],
+ sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
- **kwargs,
) -> Union[SdeVeOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
- model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
- sample (`torch.FloatTensor` or `np.ndarray`):
+ sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
@@ -192,9 +181,6 @@ def step_pred(
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
- if "seed" in kwargs and kwargs["seed"] is not None:
- self.set_seed(kwargs["seed"])
-
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
@@ -210,18 +196,21 @@ def step_pred(
sigma = self.discrete_sigmas[timesteps].to(sample.device)
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
- drift = self.zeros_like(sample)
+ drift = torch.zeros_like(sample)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
- drift = drift - diffusion[:, None, None, None] ** 2 * model_output
+ diffusion = diffusion.flatten()
+ while len(diffusion.shape) < len(sample.shape):
+ diffusion = diffusion.unsqueeze(-1)
+ drift = drift - diffusion**2 * model_output
# equation 6: sample noise for the diffusion term of
- noise = self.randn_like(sample, generator=generator)
+ noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
- prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
+ prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
if not return_dict:
return (prev_sample, prev_sample_mean)
@@ -230,19 +219,18 @@ def step_pred(
def step_correct(
self,
- model_output: Union[torch.FloatTensor, np.ndarray],
- sample: Union[torch.FloatTensor, np.ndarray],
+ model_output: torch.FloatTensor,
+ sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
- **kwargs,
) -> Union[SchedulerOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep.
Args:
- model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
- sample (`torch.FloatTensor` or `np.ndarray`):
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
@@ -252,9 +240,6 @@ def step_correct(
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
- if "seed" in kwargs and kwargs["seed"] is not None:
- self.set_seed(kwargs["seed"])
-
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
@@ -262,18 +247,21 @@ def step_correct(
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction
- noise = self.randn_like(sample, generator=generator)
+ noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
# compute step size from the model_output, the noise, and the snr
- grad_norm = self.norm(model_output)
- noise_norm = self.norm(noise)
+ grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
+ noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
# self.repeat_scalar(step_size, sample.shape[0])
# compute corrected sample: model_output term and noise term
- prev_sample_mean = sample + step_size[:, None, None, None] * model_output
- prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
+ step_size = step_size.flatten()
+ while len(step_size.shape) < len(sample.shape):
+ step_size = step_size.unsqueeze(-1)
+ prev_sample_mean = sample + step_size * model_output
+ prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
if not return_dict:
return (prev_sample,)
diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py
index 08fbe14732da..c4d802f83f94 100644
--- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py
+++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py
@@ -22,7 +22,7 @@
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
-from .scheduling_utils import SchedulerMixin, SchedulerOutput
+from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
@flax.struct.dataclass
@@ -38,7 +38,7 @@ def create(cls):
@dataclass
-class FlaxSdeVeOutput(SchedulerOutput):
+class FlaxSdeVeOutput(FlaxSchedulerOutput):
"""
Output class for the ScoreSdeVeScheduler's step function output.
@@ -56,7 +56,7 @@ class FlaxSdeVeOutput(SchedulerOutput):
prev_sample_mean: Optional[jnp.ndarray] = None
-class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
+class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
The variance exploding stochastic differential equation (SDE) scheduler.
@@ -95,7 +95,7 @@ def __init__(
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def set_timesteps(
- self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None
+ self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None
) -> ScoreSdeVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -168,7 +168,7 @@ def step_pred(
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
generator: random number generator.
- return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
Returns:
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
@@ -216,7 +216,7 @@ def step_correct(
sample: jnp.ndarray,
key: random.KeyArray,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> Union[FlaxSdeVeOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep.
@@ -227,7 +227,7 @@ def step_correct(
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
generator: random number generator.
- return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
Returns:
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py
index f19a5ad76f81..614e473eb8af 100644
--- a/src/diffusers/schedulers/scheduling_sde_vp.py
+++ b/src/diffusers/schedulers/scheduling_sde_vp.py
@@ -14,12 +14,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
-# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
+import math
+from typing import Union
-import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import deprecate
from .scheduling_utils import SchedulerMixin
@@ -39,15 +40,21 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
"""
@register_to_config
- def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
+ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs):
+ deprecate(
+ "tensor_format",
+ "0.6.0",
+ "If you're running your code in PyTorch, you can safely remove this argument.",
+ take_from=kwargs,
+ )
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None
- def set_timesteps(self, num_inference_steps):
- self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
+ def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None):
+ self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device)
- def step_pred(self, score, x, t):
+ def step_pred(self, score, x, t, generator=None):
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
@@ -59,20 +66,27 @@ def step_pred(self, score, x, t):
-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
)
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
- score = -score / std[:, None, None, None]
+ std = std.flatten()
+ while len(std.shape) < len(score.shape):
+ std = std.unsqueeze(-1)
+ score = -score / std
# compute
dt = -1.0 / len(self.timesteps)
beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
- drift = -0.5 * beta_t[:, None, None, None] * x
+ beta_t = beta_t.flatten()
+ while len(beta_t.shape) < len(x.shape):
+ beta_t = beta_t.unsqueeze(-1)
+ drift = -0.5 * beta_t * x
+
diffusion = torch.sqrt(beta_t)
- drift = drift - diffusion[:, None, None, None] ** 2 * score
+ drift = drift - diffusion**2 * score
x_mean = x + drift * dt
# add noise
- noise = torch.randn_like(x)
- x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise
+ noise = torch.randn(x.shape, layout=x.layout, generator=generator).to(x.device)
+ x = x_mean + diffusion * math.sqrt(-dt) * noise
return x, x_mean
diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py
index f2bcd73acf32..b83bf3b84626 100644
--- a/src/diffusers/schedulers/scheduling_utils.py
+++ b/src/diffusers/schedulers/scheduling_utils.py
@@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
-from typing import Union
-import numpy as np
import torch
-from ..utils import BaseOutput
+from ..utils import BaseOutput, deprecate
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@@ -43,83 +41,12 @@ class SchedulerMixin:
"""
config_name = SCHEDULER_CONFIG_NAME
- ignore_for_config = ["tensor_format"]
def set_format(self, tensor_format="pt"):
- self.tensor_format = tensor_format
- if tensor_format == "pt":
- for key, value in vars(self).items():
- if isinstance(value, np.ndarray):
- setattr(self, key, torch.from_numpy(value))
-
+ deprecate(
+ "set_format",
+ "0.6.0",
+ "If you're running your code in PyTorch, you can safely remove this function as the schedulers are always"
+ " in Pytorch",
+ )
return self
-
- def clip(self, tensor, min_value=None, max_value=None):
- tensor_format = getattr(self, "tensor_format", "pt")
-
- if tensor_format == "np":
- return np.clip(tensor, min_value, max_value)
- elif tensor_format == "pt":
- return torch.clamp(tensor, min_value, max_value)
-
- raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
-
- def log(self, tensor):
- tensor_format = getattr(self, "tensor_format", "pt")
-
- if tensor_format == "np":
- return np.log(tensor)
- elif tensor_format == "pt":
- return torch.log(tensor)
-
- raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
-
- def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
- """
- Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
-
- Args:
- values: an array or tensor of values to extract.
- broadcast_array: an array with a larger shape of K dimensions with the batch
- dimension equal to the length of timesteps.
- Returns:
- a tensor of shape [batch_size, 1, ...] where the shape has K dims.
- """
-
- tensor_format = getattr(self, "tensor_format", "pt")
- values = values.flatten()
-
- while len(values.shape) < len(broadcast_array.shape):
- values = values[..., None]
- if tensor_format == "pt":
- values = values.to(broadcast_array.device)
-
- return values
-
- def norm(self, tensor):
- tensor_format = getattr(self, "tensor_format", "pt")
- if tensor_format == "np":
- return np.linalg.norm(tensor)
- elif tensor_format == "pt":
- return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean()
-
- raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
-
- def randn_like(self, tensor, generator=None):
- tensor_format = getattr(self, "tensor_format", "pt")
- if tensor_format == "np":
- return np.random.randn(*np.shape(tensor))
- elif tensor_format == "pt":
- # return torch.randn_like(tensor)
- return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device)
-
- raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
-
- def zeros_like(self, tensor):
- tensor_format = getattr(self, "tensor_format", "pt")
- if tensor_format == "np":
- return np.zeros_like(tensor)
- elif tensor_format == "pt":
- return torch.zeros_like(tensor)
-
- raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py
new file mode 100644
index 000000000000..63de51f146f5
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_utils_flax.py
@@ -0,0 +1,43 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+
+import jax.numpy as jnp
+
+from ..utils import BaseOutput
+
+
+SCHEDULER_CONFIG_NAME = "scheduler_config.json"
+
+
+@dataclass
+class FlaxSchedulerOutput(BaseOutput):
+ """
+ Base class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: jnp.ndarray
+
+
+class FlaxSchedulerMixin:
+ """
+ Mixin containing common functions for the schedulers.
+ """
+
+ config_name = SCHEDULER_CONFIG_NAME
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index b63dbd2b285c..51798e2ad765 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -15,6 +15,7 @@
import os
+from .deprecation_utils import deprecate
from .import_utils import (
ENV_VARS_TRUE_AND_AUTO_VALUES,
ENV_VARS_TRUE_VALUES,
@@ -22,6 +23,7 @@
USE_TF,
USE_TORCH,
DummyObject,
+ is_accelerate_available,
is_flax_available,
is_inflect_available,
is_modelcards_available,
@@ -37,6 +39,10 @@
from .outputs import BaseOutput
+if is_torch_available():
+ from .testing_utils import floats_tensor, load_image, parse_flag_from_env, slow, torch_device
+
+
logger = get_logger(__name__)
diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py
new file mode 100644
index 000000000000..eac43031574f
--- /dev/null
+++ b/src/diffusers/utils/deprecation_utils.py
@@ -0,0 +1,49 @@
+import inspect
+import warnings
+from typing import Any, Dict, Optional, Union
+
+from packaging import version
+
+
+def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True):
+ from .. import __version__
+
+ deprecated_kwargs = take_from
+ values = ()
+ if not isinstance(args[0], tuple):
+ args = (args,)
+
+ for attribute, version_name, message in args:
+ if version.parse(version.parse(__version__).base_version) >= version.parse(version_name):
+ raise ValueError(
+ f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'"
+ f" version {__version__} is >= {version_name}"
+ )
+
+ warning = None
+ if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs:
+ values += (deprecated_kwargs.pop(attribute),)
+ warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}."
+ elif hasattr(deprecated_kwargs, attribute):
+ values += (getattr(deprecated_kwargs, attribute),)
+ warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}."
+ elif deprecated_kwargs is None:
+ warning = f"`{attribute}` is deprecated and will be removed in version {version_name}."
+
+ if warning is not None:
+ warning = warning + " " if standard_warn else ""
+ warnings.warn(warning + message, DeprecationWarning)
+
+ if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
+ call_frame = inspect.getouterframes(inspect.currentframe())[1]
+ filename = call_frame.filename
+ line_number = call_frame.lineno
+ function = call_frame.function
+ key, value = next(iter(deprecated_kwargs.items()))
+ raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
+
+ if len(values) == 0:
+ return
+ elif len(values) == 1:
+ return values[0]
+ return values
diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py
index 1e3ac002a609..4ab14f752c24 100644
--- a/src/diffusers/utils/dummy_flax_objects.py
+++ b/src/diffusers/utils/dummy_flax_objects.py
@@ -67,6 +67,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
+class FlaxSchedulerMixin(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["flax"]
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index de344d074da0..b2aabee70c92 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -159,6 +159,13 @@
except importlib_metadata.PackageNotFoundError:
_scipy_available = False
+_accelerate_available = importlib.util.find_spec("accelerate") is not None
+try:
+ _accelerate_version = importlib_metadata.version("accelerate")
+ logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
+except importlib_metadata.PackageNotFoundError:
+ _accelerate_available = False
+
def is_torch_available():
return _torch_available
@@ -196,6 +203,10 @@ def is_scipy_available():
return _scipy_available
+def is_accelerate_available():
+ return _accelerate_available
+
+
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py
index 45d483ce7b1d..10cffeeb0d41 100644
--- a/src/diffusers/utils/outputs.py
+++ b/src/diffusers/utils/outputs.py
@@ -15,13 +15,13 @@
Generic utilities
"""
-import warnings
from collections import OrderedDict
from dataclasses import fields
from typing import Any, Tuple
import numpy as np
+from .deprecation_utils import deprecate
from .import_utils import is_torch_available
@@ -87,11 +87,7 @@ def __getitem__(self, k):
if isinstance(k, str):
inner_dict = {k: v for (k, v) in self.items()}
if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample":
- warnings.warn(
- "The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or"
- " `'images'` instead.",
- DeprecationWarning,
- )
+ deprecate("samples", "0.6.0", "Please use `.images` or `'images'` instead.")
return inner_dict["images"]
return inner_dict[k]
else:
diff --git a/src/diffusers/testing_utils.py b/src/diffusers/utils/testing_utils.py
similarity index 89%
rename from src/diffusers/testing_utils.py
rename to src/diffusers/utils/testing_utils.py
index d3f6fa628d9d..f44b9cd394c9 100644
--- a/src/diffusers/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -1,3 +1,4 @@
+import inspect
import os
import random
import re
@@ -13,6 +14,8 @@
import requests
from packaging import version
+from .import_utils import is_flax_available
+
global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -22,6 +25,27 @@
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
+def get_tests_dir(append_path=None):
+ """
+ Args:
+ append_path: optional path to append to the tests dir path
+ Return:
+ The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
+ joined after the `tests` dir the former is provided.
+ """
+ # this function caller's __file__
+ caller__file__ = inspect.stack()[1][1]
+ tests_dir = os.path.abspath(os.path.dirname(caller__file__))
+
+ while not tests_dir.endswith("tests"):
+ tests_dir = os.path.dirname(tests_dir)
+
+ if append_path:
+ return os.path.join(tests_dir, append_path)
+ else:
+ return tests_dir
+
+
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
@@ -67,6 +91,13 @@ def slow(test_case):
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
+def require_flax(test_case):
+ """
+ Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
+ """
+ return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
+
+
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
Args:
diff --git a/tests/conftest.py b/tests/conftest.py
index e116f40e6461..3cfab533e43c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -31,13 +31,13 @@
def pytest_addoption(parser):
- from diffusers.testing_utils import pytest_addoption_shared
+ from diffusers.utils.testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter):
- from diffusers.testing_utils import pytest_terminal_summary_main
+ from diffusers.utils.testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
diff --git a/tests/fixtures/custom_pipeline/pipeline.py b/tests/fixtures/custom_pipeline/pipeline.py
new file mode 100644
index 000000000000..10a22edaa490
--- /dev/null
+++ b/tests/fixtures/custom_pipeline/pipeline.py
@@ -0,0 +1,102 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class CustomLocalPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ eta (`float`, *optional*, defaults to 0.0):
+ The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # eta corresponds to η in paper and should be between [0, 1]
+ # do x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, eta).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,), "This is a local test"
+
+ return ImagePipelineOutput(images=image), "This is a local test"
diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py
index 4c9b17caa74c..f6cb184651ef 100755
--- a/tests/test_layers_utils.py
+++ b/tests/test_layers_utils.py
@@ -22,7 +22,7 @@
from diffusers.models.attention import AttentionBlock, SpatialTransformer
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample2D, Upsample2D
-from diffusers.testing_utils import torch_device
+from diffusers.utils import torch_device
torch.backends.cuda.matmul.allow_tf32 = False
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index b0d00b863a78..e4e546e55ac3 100644
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -22,8 +22,8 @@
import torch
from diffusers.modeling_utils import ModelMixin
-from diffusers.testing_utils import torch_device
from diffusers.training_utils import EMAModel
+from diffusers.utils import torch_device
class ModelTesterMixin:
diff --git a/tests/test_modeling_common_flax.py b/tests/test_modeling_common_flax.py
new file mode 100644
index 000000000000..61849b22318f
--- /dev/null
+++ b/tests/test_modeling_common_flax.py
@@ -0,0 +1,44 @@
+from diffusers.utils import is_flax_available
+from diffusers.utils.testing_utils import require_flax
+
+
+if is_flax_available():
+ import jax
+
+
+@require_flax
+class FlaxModelTesterMixin:
+ def test_output(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict)
+ variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
+ jax.lax.stop_gradient(variables)
+
+ output = model.apply(variables, inputs_dict["sample"])
+
+ if isinstance(output, dict):
+ output = output.sample
+
+ self.assertIsNotNone(output)
+ expected_shape = inputs_dict["sample"].shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
+ def test_forward_with_norm_groups(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["norm_num_groups"] = 16
+ init_dict["block_out_channels"] = (16, 32)
+
+ model = self.model_class(**init_dict)
+ variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
+ jax.lax.stop_gradient(variables)
+
+ output = model.apply(variables, inputs_dict["sample"])
+
+ if isinstance(output, dict):
+ output = output.sample
+
+ self.assertIsNotNone(output)
+ expected_shape = inputs_dict["sample"].shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py
index 80055c1a10f8..b2f16aef5825 100644
--- a/tests/test_models_unet.py
+++ b/tests/test_models_unet.py
@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import gc
import math
+import tracemalloc
import unittest
import torch
from diffusers import UNet2DConditionModel, UNet2DModel
-from diffusers.testing_utils import floats_tensor, slow, torch_device
+from diffusers.utils import floats_tensor, slow, torch_device
from .test_modeling_common import ModelTesterMixin
@@ -133,6 +135,74 @@ def test_from_pretrained_hub(self):
assert image is not None, "Make sure output is not None"
+ @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
+ def test_from_pretrained_accelerate(self):
+ model, _ = UNet2DModel.from_pretrained(
+ "fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
+ )
+ model.to(torch_device)
+ image = model(**self.dummy_input).sample
+
+ assert image is not None, "Make sure output is not None"
+
+ @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
+ def test_from_pretrained_accelerate_wont_change_results(self):
+ model_accelerate, _ = UNet2DModel.from_pretrained(
+ "fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
+ )
+ model_accelerate.to(torch_device)
+ model_accelerate.eval()
+
+ noise = torch.randn(
+ 1,
+ model_accelerate.config.in_channels,
+ model_accelerate.config.sample_size,
+ model_accelerate.config.sample_size,
+ generator=torch.manual_seed(0),
+ )
+ noise = noise.to(torch_device)
+ time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
+
+ arr_accelerate = model_accelerate(noise, time_step)["sample"]
+
+ # two models don't need to stay in the device at the same time
+ del model_accelerate
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
+ model_normal_load.to(torch_device)
+ model_normal_load.eval()
+ arr_normal_load = model_normal_load(noise, time_step)["sample"]
+
+ assert torch.allclose(arr_accelerate, arr_normal_load, rtol=1e-3)
+
+ @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
+ def test_memory_footprint_gets_reduced(self):
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ tracemalloc.start()
+ model_accelerate, _ = UNet2DModel.from_pretrained(
+ "fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
+ )
+ model_accelerate.to(torch_device)
+ model_accelerate.eval()
+ _, peak_accelerate = tracemalloc.get_traced_memory()
+
+ del model_accelerate
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
+ model_normal_load.to(torch_device)
+ model_normal_load.eval()
+ _, peak_normal = tracemalloc.get_traced_memory()
+
+ tracemalloc.stop()
+
+ assert peak_accelerate < peak_normal
+
def test_output_pretrained(self):
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
model.eval()
@@ -198,41 +268,44 @@ def prepare_init_args_and_inputs_for_common(self):
return init_dict, inputs_dict
def test_gradient_checkpointing(self):
+ # enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
+ assert not model.is_gradient_checkpointing and model.training
+
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
- out.sum().backward()
- # now we save the output and parameter gradients that we will use for comparison purposes with
- # the non-checkpointed run.
- output_not_checkpointed = out.data.clone()
- grad_not_checkpointed = {}
- for name, param in model.named_parameters():
- grad_not_checkpointed[name] = param.grad.data.clone()
+ labels = torch.randn_like(out)
+ loss = (out - labels).mean()
+ loss.backward()
- model.enable_gradient_checkpointing()
- out = model(**inputs_dict).sample
+ # re-instantiate the model now enabling gradient checkpointing
+ model_2 = self.model_class(**init_dict)
+ # clone model
+ model_2.load_state_dict(model.state_dict())
+ model_2.to(torch_device)
+ model_2.enable_gradient_checkpointing()
+
+ assert model_2.is_gradient_checkpointing and model_2.training
+
+ out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
- model.zero_grad()
- out.sum().backward()
-
- # now we save the output and parameter gradients that we will use for comparison purposes with
- # the non-checkpointed run.
- output_checkpointed = out.data.clone()
- grad_checkpointed = {}
- for name, param in model.named_parameters():
- grad_checkpointed[name] = param.grad.data.clone()
+ model_2.zero_grad()
+ loss_2 = (out_2 - labels).mean()
+ loss_2.backward()
# compare the output and parameters gradients
- self.assertTrue((output_checkpointed == output_not_checkpointed).all())
- for name in grad_checkpointed:
- self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
+ self.assertTrue((loss - loss_2).abs() < 1e-5)
+ named_params = dict(model.named_parameters())
+ named_params_2 = dict(model_2.named_parameters())
+ for name, param in named_params.items():
+ self.assertTrue(torch.allclose(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
# TODO(Patrick) - Re-add this test after having cleaned up LDM
diff --git a/tests/test_models_vae.py b/tests/test_models_vae.py
index 361eb618ab22..9fb7e8ea3bb7 100644
--- a/tests/test_models_vae.py
+++ b/tests/test_models_vae.py
@@ -19,7 +19,7 @@
from diffusers import AutoencoderKL
from diffusers.modeling_utils import ModelMixin
-from diffusers.testing_utils import floats_tensor, torch_device
+from diffusers.utils import floats_tensor, torch_device
from .test_modeling_common import ModelTesterMixin
diff --git a/tests/test_models_vae_flax.py b/tests/test_models_vae_flax.py
new file mode 100644
index 000000000000..e5c56b61a5a4
--- /dev/null
+++ b/tests/test_models_vae_flax.py
@@ -0,0 +1,39 @@
+import unittest
+
+from diffusers import FlaxAutoencoderKL
+from diffusers.utils import is_flax_available
+from diffusers.utils.testing_utils import require_flax
+
+from .test_modeling_common_flax import FlaxModelTesterMixin
+
+
+if is_flax_available():
+ import jax
+
+
+@require_flax
+class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
+ model_class = FlaxAutoencoderKL
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_channels = 3
+ sizes = (32, 32)
+
+ prng_key = jax.random.PRNGKey(0)
+ image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
+
+ return {"sample": image, "prng_key": prng_key}
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "block_out_channels": [32, 64],
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ "latent_channels": 4,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
diff --git a/tests/test_models_vq.py b/tests/test_models_vq.py
index 7cce0ed13e01..9a2094d46cb4 100644
--- a/tests/test_models_vq.py
+++ b/tests/test_models_vq.py
@@ -18,7 +18,7 @@
import torch
from diffusers import VQModel
-from diffusers.testing_utils import floats_tensor, torch_device
+from diffusers.utils import floats_tensor, torch_device
from .test_modeling_common import ModelTesterMixin
diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py
index dddf42bd03f2..30beb033fca7 100644
--- a/tests/test_pipelines.py
+++ b/tests/test_pipelines.py
@@ -17,12 +17,15 @@
import os
import random
import tempfile
+import tracemalloc
import unittest
import numpy as np
import torch
+import accelerate
import PIL
+import transformers
from diffusers import (
AutoencoderKL,
DDIMPipeline,
@@ -48,10 +51,11 @@
)
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
-from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
-from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME
+from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
+from diffusers.utils.testing_utils import get_tests_dir
+from packaging import version
from PIL import Image
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
torch.backends.cuda.matmul.allow_tf32 = False
@@ -80,6 +84,64 @@ def test_progress_bar(capsys):
assert captured.err == "", "Progress bar should be disabled"
+class CustomPipelineTests(unittest.TestCase):
+ def test_load_custom_pipeline(self):
+ pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
+ )
+ # NOTE that `"CustomPipeline"` is not a class that is defined in this library, but solely on the Hub
+ # under https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L24
+ assert pipeline.__class__.__name__ == "CustomPipeline"
+
+ def test_run_custom_pipeline(self):
+ pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
+ )
+ images, output_str = pipeline(num_inference_steps=2, output_type="np")
+
+ assert images[0].shape == (1, 32, 32, 3)
+ # compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
+ assert output_str == "This is a test"
+
+ def test_local_custom_pipeline(self):
+ local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
+ pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path
+ )
+ images, output_str = pipeline(num_inference_steps=2, output_type="np")
+
+ assert pipeline.__class__.__name__ == "CustomLocalPipeline"
+ assert images[0].shape == (1, 32, 32, 3)
+ # compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
+ assert output_str == "This is a local test"
+
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_load_pipeline_from_git(self):
+ clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
+ clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="clip_guided_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ torch_dtype=torch.float16,
+ revision="fp16",
+ )
+ pipeline.enable_attention_slicing()
+ pipeline = pipeline.to(torch_device)
+
+ # NOTE that `"CLIPGuidedStableDiffusion"` is not a class that is defined in the pypi package of th e library, but solely on the community examples folder of GitHub under:
+ # https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py
+ assert pipeline.__class__.__name__ == "CLIPGuidedStableDiffusion"
+
+ image = pipeline("a prompt", num_inference_steps=2, output_type="np").images[0]
+ assert image.shape == (512, 512, 3)
+
+
class PipelineFastTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
@@ -191,7 +253,7 @@ def to(self, device):
def test_ddim(self):
unet = self.dummy_uncond_unet
- scheduler = DDIMScheduler(tensor_format="pt")
+ scheduler = DDIMScheduler()
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
@@ -220,7 +282,7 @@ def test_ddim(self):
def test_pndm_cifar10(self):
unet = self.dummy_uncond_unet
- scheduler = PNDMScheduler(tensor_format="pt")
+ scheduler = PNDMScheduler()
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device)
@@ -242,7 +304,7 @@ def test_pndm_cifar10(self):
def test_ldm_text2img(self):
unet = self.dummy_cond_unet
- scheduler = DDIMScheduler(tensor_format="pt")
+ scheduler = DDIMScheduler()
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
@@ -336,10 +398,59 @@ def test_stable_diffusion_ddim(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ def test_stable_diffusion_ddim_factor_8(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ unet = self.dummy_cond_unet
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ output = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ height=536,
+ width=536,
+ num_inference_steps=2,
+ output_type="np",
+ )
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 134, 134, 3)
+ expected_slice = np.array([0.7834, 0.5488, 0.5781, 0.46, 0.3609, 0.5369, 0.542, 0.4855, 0.5557])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
def test_stable_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
+ scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
@@ -460,7 +571,7 @@ def test_stable_diffusion_attention_chunk(self):
def test_score_sde_ve_pipeline(self):
unet = self.dummy_uncond_unet
- scheduler = ScoreSdeVeScheduler(tensor_format="pt")
+ scheduler = ScoreSdeVeScheduler()
sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler)
sde_ve.to(torch_device)
@@ -484,7 +595,7 @@ def test_score_sde_ve_pipeline(self):
def test_ldm_uncond(self):
unet = self.dummy_uncond_unet
- scheduler = DDIMScheduler(tensor_format="pt")
+ scheduler = DDIMScheduler()
vae = self.dummy_vq_model
ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler)
@@ -512,7 +623,7 @@ def test_ldm_uncond(self):
def test_karras_ve_pipeline(self):
unet = self.dummy_uncond_unet
- scheduler = KarrasVeScheduler(tensor_format="pt")
+ scheduler = KarrasVeScheduler()
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
pipe.to(torch_device)
@@ -535,7 +646,7 @@ def test_karras_ve_pipeline(self):
def test_stable_diffusion_img2img(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
+ scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
@@ -646,7 +757,7 @@ def test_stable_diffusion_img2img_k_lms(self):
def test_stable_diffusion_inpaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
+ scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
@@ -702,6 +813,320 @@ def test_stable_diffusion_inpaint(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ def test_stable_diffusion_num_images_per_prompt(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+
+ # test num_images_per_prompt=1 (default)
+ images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images
+
+ assert images.shape == (1, 128, 128, 3)
+
+ # test num_images_per_prompt=1 (default) for batch of prompts
+ batch_size = 2
+ images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images
+
+ assert images.shape == (batch_size, 128, 128, 3)
+
+ # test num_images_per_prompt for single prompt
+ num_images_per_prompt = 2
+ images = sd_pipe(
+ prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
+ ).images
+
+ assert images.shape == (num_images_per_prompt, 128, 128, 3)
+
+ # test num_images_per_prompt for batch of prompts
+ batch_size = 2
+ images = sd_pipe(
+ [prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
+ ).images
+
+ assert images.shape == (batch_size * num_images_per_prompt, 128, 128, 3)
+
+ def test_stable_diffusion_img2img_num_images_per_prompt(self):
+ device = "cpu"
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ init_image = self.dummy_image.to(device)
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionImg2ImgPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+
+ # test num_images_per_prompt=1 (default)
+ images = sd_pipe(
+ prompt,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ ).images
+
+ assert images.shape == (1, 32, 32, 3)
+
+ # test num_images_per_prompt=1 (default) for batch of prompts
+ batch_size = 2
+ images = sd_pipe(
+ [prompt] * batch_size,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ ).images
+
+ assert images.shape == (batch_size, 32, 32, 3)
+
+ # test num_images_per_prompt for single prompt
+ num_images_per_prompt = 2
+ images = sd_pipe(
+ prompt,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ num_images_per_prompt=num_images_per_prompt,
+ ).images
+
+ assert images.shape == (num_images_per_prompt, 32, 32, 3)
+
+ # test num_images_per_prompt for batch of prompts
+ batch_size = 2
+ images = sd_pipe(
+ [prompt] * batch_size,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ num_images_per_prompt=num_images_per_prompt,
+ ).images
+
+ assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
+
+ def test_stable_diffusion_inpaint_num_images_per_prompt(self):
+ device = "cpu"
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = Image.fromarray(np.uint8(image)).convert("RGB")
+ mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionInpaintPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+
+ # test num_images_per_prompt=1 (default)
+ images = sd_pipe(
+ prompt,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ mask_image=mask_image,
+ ).images
+
+ assert images.shape == (1, 32, 32, 3)
+
+ # test num_images_per_prompt=1 (default) for batch of prompts
+ batch_size = 2
+ images = sd_pipe(
+ [prompt] * batch_size,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ mask_image=mask_image,
+ ).images
+
+ assert images.shape == (batch_size, 32, 32, 3)
+
+ # test num_images_per_prompt for single prompt
+ num_images_per_prompt = 2
+ images = sd_pipe(
+ prompt,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ mask_image=mask_image,
+ num_images_per_prompt=num_images_per_prompt,
+ ).images
+
+ assert images.shape == (num_images_per_prompt, 32, 32, 3)
+
+ # test num_images_per_prompt for batch of prompts
+ batch_size = 2
+ images = sd_pipe(
+ [prompt] * batch_size,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ mask_image=mask_image,
+ num_images_per_prompt=num_images_per_prompt,
+ ).images
+
+ assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
+
+ @unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
+ def test_stable_diffusion_fp16(self):
+ """Test that stable diffusion works with fp16"""
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ # put models in fp16
+ unet = unet.half()
+ vae = vae.half()
+ bert = bert.half()
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
+
+ assert image.shape == (1, 128, 128, 3)
+
+ @unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
+ def test_stable_diffusion_img2img_fp16(self):
+ """Test that stable diffusion img2img works with fp16"""
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ init_image = self.dummy_image.to(torch_device)
+
+ # put models in fp16
+ unet = unet.half()
+ vae = vae.half()
+ bert = bert.half()
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionImg2ImgPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ image = sd_pipe(
+ [prompt],
+ generator=generator,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ ).images
+
+ assert image.shape == (1, 32, 32, 3)
+
+ @unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
+ def test_stable_diffusion_inpaint_fp16(self):
+ """Test that stable diffusion inpaint works with fp16"""
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = Image.fromarray(np.uint8(image)).convert("RGB")
+ mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
+
+ # put models in fp16
+ unet = unet.half()
+ vae = vae.half()
+ bert = bert.half()
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionInpaintPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=self.dummy_safety_checker,
+ feature_extractor=self.dummy_extractor,
+ )
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ image = sd_pipe(
+ [prompt],
+ generator=generator,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ mask_image=mask_image,
+ ).images
+
+ assert image.shape == (1, 32, 32, 3)
+
class PipelineTesterMixin(unittest.TestCase):
def tearDown(self):
@@ -842,7 +1267,6 @@ def test_ddpm_cifar10(self):
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDPMScheduler.from_config(model_id)
- scheduler = scheduler.set_format("pt")
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
@@ -882,7 +1306,7 @@ def test_ddim_cifar10(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
- scheduler = DDIMScheduler(tensor_format="pt")
+ scheduler = DDIMScheduler()
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
ddim.to(torch_device)
@@ -902,7 +1326,7 @@ def test_pndm_cifar10(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
- scheduler = PNDMScheduler(tensor_format="pt")
+ scheduler = PNDMScheduler()
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device)
@@ -954,7 +1378,7 @@ def test_ldm_text2img_fast(self):
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion(self):
# make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", use_auth_token=True)
+ sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
@@ -976,7 +1400,7 @@ def test_stable_diffusion(self):
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_fast_ddim(self):
- sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", use_auth_token=True)
+ sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
@@ -1043,8 +1467,8 @@ def test_ddpm_ddim_equality(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
- ddpm_scheduler = DDPMScheduler(tensor_format="pt")
- ddim_scheduler = DDIMScheduler(tensor_format="pt")
+ ddpm_scheduler = DDPMScheduler()
+ ddim_scheduler = DDIMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
@@ -1067,8 +1491,8 @@ def test_ddpm_ddim_equality_batched(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
- ddpm_scheduler = DDPMScheduler(tensor_format="pt")
- ddim_scheduler = DDIMScheduler(tensor_format="pt")
+ ddpm_scheduler = DDPMScheduler()
+ ddim_scheduler = DDIMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
@@ -1093,7 +1517,7 @@ def test_ddpm_ddim_equality_batched(self):
def test_karras_ve_pipeline(self):
model_id = "google/ncsnpp-celebahq-256"
model = UNet2DModel.from_pretrained(model_id)
- scheduler = KarrasVeScheduler(tensor_format="pt")
+ scheduler = KarrasVeScheduler()
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
pipe.to(torch_device)
@@ -1111,9 +1535,9 @@ def test_karras_ve_pipeline(self):
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_lms_stable_diffusion_pipeline(self):
model_id = "CompVis/stable-diffusion-v1-1"
- pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).to(torch_device)
+ pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
pipe.set_progress_bar_config(disable=None)
- scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True)
+ scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
pipe.scheduler = scheduler
prompt = "a photograph of an astronaut riding a horse"
@@ -1132,9 +1556,9 @@ def test_lms_stable_diffusion_pipeline(self):
def test_stable_diffusion_memory_chunking(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
- pipe = StableDiffusionPipeline.from_pretrained(
- model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
- ).to(torch_device)
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16).to(
+ torch_device
+ )
pipe.set_progress_bar_config(disable=None)
prompt = "a photograph of an astronaut riding a horse"
@@ -1167,6 +1591,37 @@ def test_stable_diffusion_memory_chunking(self):
assert mem_bytes > 3.75 * 10**9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion_text2img_pipeline_fp16(self):
+ torch.cuda.reset_peak_memory_stats()
+ model_id = "CompVis/stable-diffusion-v1-4"
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16).to(
+ torch_device
+ )
+ pipe.set_progress_bar_config(disable=None)
+
+ prompt = "a photograph of an astronaut riding a horse"
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ output_chunked = pipe(
+ [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
+ )
+ image_chunked = output_chunked.images
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ with torch.autocast(torch_device):
+ output = pipe(
+ [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
+ )
+ image = output.images
+
+ # Make sure results are close enough
+ diff = np.abs(image_chunked.flatten() - image.flatten())
+ # They ARE different since ops are not run always at the same precision
+ # however, they should be extremely close.
+ assert diff.mean() < 2e-2
+
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_text2img_pipeline(self):
@@ -1180,7 +1635,6 @@ def test_stable_diffusion_text2img_pipeline(self):
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
safety_checker=self.dummy_safety_checker,
- use_auth_token=True,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1213,7 +1667,6 @@ def test_stable_diffusion_img2img_pipeline(self):
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
safety_checker=self.dummy_safety_checker,
- use_auth_token=True,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1257,7 +1710,6 @@ def test_stable_diffusion_img2img_pipeline_k_lms(self):
model_id,
scheduler=lms,
safety_checker=self.dummy_safety_checker,
- use_auth_token=True,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1301,7 +1753,6 @@ def test_stable_diffusion_inpaint_pipeline(self):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
safety_checker=self.dummy_safety_checker,
- use_auth_token=True,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1348,7 +1799,6 @@ def test_stable_diffusion_inpaint_pipeline_k_lms(self):
model_id,
scheduler=lms,
safety_checker=self.dummy_safety_checker,
- use_auth_token=True,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1374,16 +1824,267 @@ def test_stable_diffusion_inpaint_pipeline_k_lms(self):
@slow
def test_stable_diffusion_onnx(self):
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True
+ "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
)
prompt = "A painting of a squirrel eating a burger"
np.random.seed(0)
- output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=20, output_type="np")
+ output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=5, output_type="np")
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.0385, 0.0252, 0.0234, 0.0287, 0.0358, 0.0287, 0.0276, 0.0235, 0.0010])
+ expected_slice = np.array([0.3602, 0.3688, 0.3652, 0.3895, 0.3782, 0.3747, 0.3927, 0.4241, 0.4327])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion_text2img_intermediate_state(self):
+ number_of_steps = 0
+
+ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
+ test_callback_fn.has_been_called = True
+ nonlocal number_of_steps
+ number_of_steps += 1
+ if step == 0:
+ latents = latents.detach().cpu().numpy()
+ assert latents.shape == (1, 4, 64, 64)
+ latents_slice = latents[0, -3:, -3:, -1]
+ expected_slice = np.array(
+ [1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506]
+ )
+ assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
+ elif step == 50:
+ latents = latents.detach().cpu().numpy()
+ assert latents.shape == (1, 4, 64, 64)
+ latents_slice = latents[0, -3:, -3:, -1]
+ expected_slice = np.array(
+ [1.1078, 1.5803, 0.2773, -0.0589, -1.7928, -0.3665, -0.4695, -1.0727, -1.1601]
+ )
+ assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
+
+ test_callback_fn.has_been_called = False
+
+ pipe = StableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
+ )
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ prompt = "Andromeda galaxy in a bottle"
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ with torch.autocast(torch_device):
+ pipe(
+ prompt=prompt,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ generator=generator,
+ callback=test_callback_fn,
+ callback_steps=1,
+ )
+ assert test_callback_fn.has_been_called
+ assert number_of_steps == 51
+
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion_img2img_intermediate_state(self):
+ number_of_steps = 0
+
+ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
+ test_callback_fn.has_been_called = True
+ nonlocal number_of_steps
+ number_of_steps += 1
+ if step == 0:
+ latents = latents.detach().cpu().numpy()
+ assert latents.shape == (1, 4, 64, 96)
+ latents_slice = latents[0, -3:, -3:, -1]
+ expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530])
+ assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
+ elif step == 37:
+ latents = latents.detach().cpu().numpy()
+ assert latents.shape == (1, 4, 64, 96)
+ latents_slice = latents[0, -3:, -3:, -1]
+ expected_slice = np.array([0.7071, 0.7831, 0.8300, 1.8140, 1.7840, 1.9402, 1.3651, 1.6590, 1.2828])
+ assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
+
+ test_callback_fn.has_been_called = False
+
+ init_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/img2img/sketch-mountains-input.jpg"
+ )
+ init_image = init_image.resize((768, 512))
+
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
+ )
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ prompt = "A fantasy landscape, trending on artstation"
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ with torch.autocast(torch_device):
+ pipe(
+ prompt=prompt,
+ init_image=init_image,
+ strength=0.75,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ generator=generator,
+ callback=test_callback_fn,
+ callback_steps=1,
+ )
+ assert test_callback_fn.has_been_called
+ assert number_of_steps == 38
+
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion_inpaint_intermediate_state(self):
+ number_of_steps = 0
+
+ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
+ test_callback_fn.has_been_called = True
+ nonlocal number_of_steps
+ number_of_steps += 1
+ if step == 0:
+ latents = latents.detach().cpu().numpy()
+ assert latents.shape == (1, 4, 64, 64)
+ latents_slice = latents[0, -3:, -3:, -1]
+ expected_slice = np.array(
+ [-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246]
+ )
+ assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
+ elif step == 37:
+ latents = latents.detach().cpu().numpy()
+ assert latents.shape == (1, 4, 64, 64)
+ latents_slice = latents[0, -3:, -3:, -1]
+ expected_slice = np.array([0.4781, 1.1572, 0.6258, 0.2291, 0.2554, -0.1443, 0.7085, -0.1598, -0.5659])
+ assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
+
+ test_callback_fn.has_been_called = False
+
+ init_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/in_paint/overture-creations-5sI6fQgYIuo.png"
+ )
+ mask_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
+ )
+
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
+ )
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ prompt = "A red cat sitting on a park bench"
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ with torch.autocast(torch_device):
+ pipe(
+ prompt=prompt,
+ init_image=init_image,
+ mask_image=mask_image,
+ strength=0.75,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ generator=generator,
+ callback=test_callback_fn,
+ callback_steps=1,
+ )
+ assert test_callback_fn.has_been_called
+ assert number_of_steps == 38
+
+ @slow
+ def test_stable_diffusion_onnx_intermediate_state(self):
+ number_of_steps = 0
+
+ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
+ test_callback_fn.has_been_called = True
+ nonlocal number_of_steps
+ number_of_steps += 1
+ if step == 0:
+ assert latents.shape == (1, 4, 64, 64)
+ latents_slice = latents[0, -3:, -3:, -1]
+ expected_slice = np.array(
+ [-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592]
+ )
+ assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
+ elif step == 5:
+ assert latents.shape == (1, 4, 64, 64)
+ latents_slice = latents[0, -3:, -3:, -1]
+ expected_slice = np.array(
+ [-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264]
+ )
+ assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
+
+ test_callback_fn.has_been_called = False
+
+ pipe = StableDiffusionOnnxPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
+ )
+ pipe.set_progress_bar_config(disable=None)
+
+ prompt = "Andromeda galaxy in a bottle"
+
+ np.random.seed(0)
+ pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
+ assert test_callback_fn.has_been_called
+ assert number_of_steps == 6
+
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
+ def test_stable_diffusion_accelerate_load_works(self):
+ if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
+ return
+
+ if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
+ return
+
+ model_id = "CompVis/stable-diffusion-v1-4"
+ _ = StableDiffusionPipeline.from_pretrained(
+ model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
+ ).to(torch_device)
+
+ @slow
+ @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
+ def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
+ if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
+ return
+
+ if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
+ return
+
+ pipeline_id = "CompVis/stable-diffusion-v1-4"
+
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ tracemalloc.start()
+ pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
+ pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
+ )
+ pipeline_normal_load.to(torch_device)
+ _, peak_normal = tracemalloc.get_traced_memory()
+ tracemalloc.stop()
+
+ del pipeline_normal_load
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ tracemalloc.start()
+ _ = StableDiffusionPipeline.from_pretrained(
+ pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
+ )
+ _, peak_accelerate = tracemalloc.get_traced_memory()
+
+ tracemalloc.stop()
+
+ assert peak_accelerate < peak_normal
diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py
index 7377797bebfa..c3d4b9bc76f9 100755
--- a/tests/test_scheduler.py
+++ b/tests/test_scheduler.py
@@ -173,34 +173,6 @@ def test_step_shape(self):
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
- def test_pytorch_equal_numpy(self):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample_pt = self.dummy_sample
- residual_pt = 0.1 * sample_pt
-
- sample = sample_pt.numpy()
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(tensor_format="np", **scheduler_config)
-
- scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- scheduler.set_timesteps(num_inference_steps)
- scheduler_pt.set_timesteps(num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
- output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample
-
- assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
-
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
t[t != t] = 0
@@ -229,7 +201,7 @@ def recursive_check(tuple_object, dict_object):
)
kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
+ num_inference_steps = kwargs.pop("num_inference_steps", 50)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
@@ -254,6 +226,27 @@ def recursive_check(tuple_object, dict_object):
recursive_check(outputs_tuple, outputs_dict)
+ def test_scheduler_public_api(self):
+ for scheduler_class in self.scheduler_classes:
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(**scheduler_config)
+ self.assertTrue(
+ hasattr(scheduler, "init_noise_sigma"),
+ f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
+ )
+ self.assertTrue(
+ hasattr(scheduler, "scale_model_input"),
+ f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`",
+ )
+ self.assertTrue(
+ hasattr(scheduler, "step"),
+ f"{scheduler_class} does not implement a required class method `step(...)`",
+ )
+
+ sample = self.dummy_sample
+ scaled_sample = scheduler.scale_model_input(sample, 0.0)
+ self.assertEqual(sample.shape, scaled_sample.shape)
+
class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMScheduler,)
@@ -266,7 +259,6 @@ def get_scheduler_config(self, **kwargs):
"beta_schedule": "linear",
"variance_type": "fixed_small",
"clip_sample": True,
- "tensor_format": "pt",
}
config.update(**kwargs)
@@ -305,10 +297,6 @@ def test_variance(self):
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
- # TODO Make DDPM Numpy compatible
- def test_pytorch_equal_numpy(self):
- pass
-
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
@@ -387,7 +375,7 @@ def test_steps_offset(self):
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(5)
- assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1]))
+ assert torch.equal(scheduler.timesteps, torch.LongTensor([801, 601, 401, 201, 1]))
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
@@ -556,72 +544,6 @@ def full_loop(self, **config):
return sample
- def test_pytorch_equal_numpy(self):
- kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample_pt = self.dummy_sample
- residual_pt = 0.1 * sample_pt
- dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
-
- sample = sample_pt.numpy()
- residual = 0.1 * sample
- dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
-
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(tensor_format="np", **scheduler_config)
-
- scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- scheduler.set_timesteps(num_inference_steps)
- scheduler_pt.set_timesteps(num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- # copy over dummy past residuals (must be done after set_timesteps)
- scheduler.ets = dummy_past_residuals[:]
- scheduler_pt.ets = dummy_past_residuals_pt[:]
-
- output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
- output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample
- assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
-
- output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
- output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample
-
- assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
-
- def test_set_format(self):
- kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(tensor_format="np", **scheduler_config)
- scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- scheduler.set_timesteps(num_inference_steps)
- scheduler_pt.set_timesteps(num_inference_steps)
-
- for key, value in vars(scheduler).items():
- # we only allow `ets` attr to be a list
- assert not isinstance(value, list) or key in [
- "ets"
- ], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}"
-
- # check if `scheduler.set_format` does convert correctly attrs to pt format
- for key, value in vars(scheduler_pt).items():
- # we only allow `ets` attr to be a list
- assert not isinstance(value, list) or key in [
- "ets"
- ], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
- assert not isinstance(
- value, np.ndarray
- ), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
-
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
@@ -669,7 +591,7 @@ def test_steps_offset(self):
scheduler.set_timesteps(10)
assert torch.equal(
scheduler.timesteps,
- torch.tensor(
+ torch.LongTensor(
[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
),
)
@@ -786,7 +708,6 @@ def get_scheduler_config(self, **kwargs):
"sigma_min": 0.01,
"sigma_max": 1348,
"sampling_eps": 1e-5,
- "tensor_format": "pt", # TODO add test for tensor formats
}
config.update(**kwargs)
@@ -936,7 +857,6 @@ def get_scheduler_config(self, **kwargs):
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
- "tensor_format": "pt",
}
config.update(**kwargs)
@@ -947,7 +867,7 @@ def test_timesteps(self):
self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self):
- for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
+ for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
@@ -958,28 +878,6 @@ def test_time_indices(self):
for t in [0, 500, 800]:
self.check_over_forward(time_step=t)
- def test_pytorch_equal_numpy(self):
- for scheduler_class in self.scheduler_classes:
- sample_pt = self.dummy_sample
- residual_pt = 0.1 * sample_pt
-
- sample = sample_pt.numpy()
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config()
- scheduler_config["tensor_format"] = "np"
- scheduler = scheduler_class(**scheduler_config)
-
- scheduler_config["tensor_format"] = "pt"
- scheduler_pt = scheduler_class(**scheduler_config)
-
- scheduler.set_timesteps(self.num_inference_steps)
- scheduler_pt.set_timesteps(self.num_inference_steps)
-
- output = scheduler.step(residual, 1, sample).prev_sample
- output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample
- assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
-
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
@@ -988,14 +886,14 @@ def test_full_loop_no_noise(self):
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
- sample = self.dummy_sample_deter * scheduler.sigmas[0]
+ sample = self.dummy_sample_deter * scheduler.init_noise_sigma
for i, t in enumerate(scheduler.timesteps):
- sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5)
+ sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
- output = scheduler.step(model_output, i, sample)
+ output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
diff --git a/tests/test_training.py b/tests/test_training.py
index 519c5ab9e716..fd0828329ebd 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -18,8 +18,8 @@
import torch
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
-from diffusers.testing_utils import slow
from diffusers.training_utils import set_seed
+from diffusers.utils.testing_utils import slow
torch.backends.cuda.matmul.allow_tf32 = False
@@ -41,7 +41,6 @@ def test_training_step_equality(self):
beta_end=0.02,
beta_schedule="linear",
clip_sample=True,
- tensor_format="pt",
)
ddim_scheduler = DDIMScheduler(
num_train_timesteps=1000,
@@ -49,7 +48,6 @@ def test_training_step_equality(self):
beta_end=0.02,
beta_schedule="linear",
clip_sample=True,
- tensor_format="pt",
)
assert ddpm_scheduler.config.num_train_timesteps == ddim_scheduler.config.num_train_timesteps
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 000000000000..35cf57421014
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,164 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from diffusers import __version__
+from diffusers.utils import deprecate
+
+
+class DeprecateTester(unittest.TestCase):
+ higher_version = ".".join([str(int(__version__.split(".")[0]) + 1)] + __version__.split(".")[1:])
+ lower_version = "0.0.1"
+
+ def test_deprecate_function_arg(self):
+ kwargs = {"deprecated_arg": 4}
+
+ with self.assertWarns(DeprecationWarning) as warning:
+ output = deprecate("deprecated_arg", self.higher_version, "message", take_from=kwargs)
+
+ assert output == 4
+ assert (
+ str(warning.warning)
+ == f"The `deprecated_arg` argument is deprecated and will be removed in version {self.higher_version}."
+ " message"
+ )
+
+ def test_deprecate_function_arg_tuple(self):
+ kwargs = {"deprecated_arg": 4}
+
+ with self.assertWarns(DeprecationWarning) as warning:
+ output = deprecate(("deprecated_arg", self.higher_version, "message"), take_from=kwargs)
+
+ assert output == 4
+ assert (
+ str(warning.warning)
+ == f"The `deprecated_arg` argument is deprecated and will be removed in version {self.higher_version}."
+ " message"
+ )
+
+ def test_deprecate_function_args(self):
+ kwargs = {"deprecated_arg_1": 4, "deprecated_arg_2": 8}
+ with self.assertWarns(DeprecationWarning) as warning:
+ output_1, output_2 = deprecate(
+ ("deprecated_arg_1", self.higher_version, "Hey"),
+ ("deprecated_arg_2", self.higher_version, "Hey"),
+ take_from=kwargs,
+ )
+ assert output_1 == 4
+ assert output_2 == 8
+ assert (
+ str(warning.warnings[0].message)
+ == "The `deprecated_arg_1` argument is deprecated and will be removed in version"
+ f" {self.higher_version}. Hey"
+ )
+ assert (
+ str(warning.warnings[1].message)
+ == "The `deprecated_arg_2` argument is deprecated and will be removed in version"
+ f" {self.higher_version}. Hey"
+ )
+
+ def test_deprecate_function_incorrect_arg(self):
+ kwargs = {"deprecated_arg": 4}
+
+ with self.assertRaises(TypeError) as error:
+ deprecate(("wrong_arg", self.higher_version, "message"), take_from=kwargs)
+
+ assert "test_deprecate_function_incorrect_arg in" in str(error.exception)
+ assert "line" in str(error.exception)
+ assert "got an unexpected keyword argument `deprecated_arg`" in str(error.exception)
+
+ def test_deprecate_arg_no_kwarg(self):
+ with self.assertWarns(DeprecationWarning) as warning:
+ deprecate(("deprecated_arg", self.higher_version, "message"))
+
+ assert (
+ str(warning.warning)
+ == f"`deprecated_arg` is deprecated and will be removed in version {self.higher_version}. message"
+ )
+
+ def test_deprecate_args_no_kwarg(self):
+ with self.assertWarns(DeprecationWarning) as warning:
+ deprecate(
+ ("deprecated_arg_1", self.higher_version, "Hey"),
+ ("deprecated_arg_2", self.higher_version, "Hey"),
+ )
+ assert (
+ str(warning.warnings[0].message)
+ == f"`deprecated_arg_1` is deprecated and will be removed in version {self.higher_version}. Hey"
+ )
+ assert (
+ str(warning.warnings[1].message)
+ == f"`deprecated_arg_2` is deprecated and will be removed in version {self.higher_version}. Hey"
+ )
+
+ def test_deprecate_class_obj(self):
+ class Args:
+ arg = 5
+
+ with self.assertWarns(DeprecationWarning) as warning:
+ arg = deprecate(("arg", self.higher_version, "message"), take_from=Args())
+
+ assert arg == 5
+ assert (
+ str(warning.warning)
+ == f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message"
+ )
+
+ def test_deprecate_class_objs(self):
+ class Args:
+ arg = 5
+ foo = 7
+
+ with self.assertWarns(DeprecationWarning) as warning:
+ arg_1, arg_2 = deprecate(
+ ("arg", self.higher_version, "message"),
+ ("foo", self.higher_version, "message"),
+ ("does not exist", self.higher_version, "message"),
+ take_from=Args(),
+ )
+
+ assert arg_1 == 5
+ assert arg_2 == 7
+ assert (
+ str(warning.warning)
+ == f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message"
+ )
+ assert (
+ str(warning.warnings[0].message)
+ == f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message"
+ )
+ assert (
+ str(warning.warnings[1].message)
+ == f"The `foo` attribute is deprecated and will be removed in version {self.higher_version}. message"
+ )
+
+ def test_deprecate_incorrect_version(self):
+ kwargs = {"deprecated_arg": 4}
+
+ with self.assertRaises(ValueError) as error:
+ deprecate(("wrong_arg", self.lower_version, "message"), take_from=kwargs)
+
+ assert (
+ str(error.exception)
+ == "The deprecation tuple ('wrong_arg', '0.0.1', 'message') should be removed since diffusers' version"
+ f" {__version__} is >= {self.lower_version}"
+ )
+
+ def test_deprecate_incorrect_no_standard_warn(self):
+ with self.assertWarns(DeprecationWarning) as warning:
+ deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False)
+
+ assert str(warning.warning) == "This message is better!!!"
diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py
index e1e079a99cde..44165d1fce23 100644
--- a/utils/custom_init_isort.py
+++ b/utils/custom_init_isort.py
@@ -200,7 +200,7 @@ def sort_imports(file, check_only=True):
indent = get_indent(block_lines[1])
# Slit the internal block into blocks of indent level 1.
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
- # We have two categories of import key: list or _import_structu[key].append/extend
+ # We have two categories of import key: list or _import_structure[key].append/extend
pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key
# Grab the keys, but there is a trap: some lines are empty or just comments.
keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
@@ -210,17 +210,17 @@ def sort_imports(file, check_only=True):
# We reorder the blocks by leaving empty lines/comments as they were and reorder the rest.
count = 0
- reorderded_blocks = []
+ reordered_blocks = []
for i in range(len(internal_blocks)):
if keys[i] is None:
- reorderded_blocks.append(internal_blocks[i])
+ reordered_blocks.append(internal_blocks[i])
else:
block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
- reorderded_blocks.append(block)
+ reordered_blocks.append(block)
count += 1
# And we put our main block back together with its first and last line.
- main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reorderded_blocks + [block_lines[-1]])
+ main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reordered_blocks + [block_lines[-1]])
if code != "\n".join(main_blocks):
if check_only: