Skip to content

Commit 79eb3d0

Browse files
Ttlwilliambermansayakpaulpatrickvonplaten
authored
Controlnet training (#2545)
* Controlnet training code initial commit Works with circle dataset: https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md * Script for adding a controlnet to existing model * Fix control image transform Control image should be in 0..1 range. * Add license header and remove more unused configs * controlnet training readme * Allow nonlocal model in add_controlnet.py * Formatting * Remove unused code * Code quality * Initialize controlnet in training script * Formatting * Address review comments * doc style * explicit constructor args and submodule names * hub dataset NOTE - not tested * empty prompts * add conditioning image * rename * remove instance data dir * image_transforms -> -1,1 . conditioning_image_transformers -> 0, 1 * nits * remove local rank config I think this isn't necessary in any of our training scripts * validation images * proportion_empty_prompts typo * weight copying to controlnet bug * call log validation fix * fix * gitignore wandb * fix progress bar and resume from checkpoint iteration * initial step fix * log multiple images * fix * fixes * tracker project name configurable * misc * add controlnet requirements.txt * update docs * image labels * small fixes * log validation using existing models for pipeline * fix for deepspeed saving * memory usage docs * Update examples/controlnet/train_controlnet.py Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/train_controlnet.py Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * remove extra is main process check * link to dataset in intro paragraph * remove unnecessary paragraph * note on deepspeed * Update examples/controlnet/README.md Co-authored-by: Patrick von Platen <[email protected]> * assert -> value error * weights and biases note * move images out of git * remove .gitignore --------- Co-authored-by: William Berman <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 279f744 commit 79eb3d0

File tree

6 files changed

+1406
-3
lines changed

6 files changed

+1406
-3
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,5 @@ tags
172172

173173
# ruff
174174
.ruff_cache
175+
176+
wandb

examples/controlnet/README.md

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
# ControlNet training example
2+
3+
[Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) by Lvmin Zhang and Maneesh Agrawala.
4+
5+
This example is based on the [training example in the original ControlNet repository](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md). It trains a ControlNet to fill circles using a [small synthetic dataset](https://huggingface.co/datasets/fusing/fill50k).
6+
7+
## Installing the dependencies
8+
9+
Before running the scripts, make sure to install the library's training dependencies:
10+
11+
**Important**
12+
13+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
14+
```bash
15+
git clone https://github.com/huggingface/diffusers
16+
cd diffusers
17+
pip install -e .
18+
```
19+
20+
Then cd in the example folder and run
21+
```bash
22+
pip install -r requirements.txt
23+
```
24+
25+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
26+
27+
```bash
28+
accelerate config
29+
```
30+
31+
Or for a default accelerate configuration without answering questions about your environment
32+
33+
```bash
34+
accelerate config default
35+
```
36+
37+
Or if your environment doesn't support an interactive shell e.g. a notebook
38+
39+
```python
40+
from accelerate.utils import write_basic_config
41+
write_basic_config()
42+
```
43+
44+
## Circle filling dataset
45+
46+
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
47+
48+
Our training examples use [Stable Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) as the original set of ControlNet models were trained from it. However, ControlNet can be trained to augment any Stable Diffusion compatible model (such as [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4)) or [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
49+
50+
## Training
51+
52+
Our training examples use two test conditioning images. They can be downloaded by running
53+
54+
```sh
55+
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
56+
57+
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
58+
```
59+
60+
61+
```bash
62+
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
63+
export OUTPUT_DIR="path to save model"
64+
65+
accelerate launch train_controlnet.py \
66+
--pretrained_model_name_or_path=$MODEL_DIR \
67+
--output_dir=$OUTPUT_DIR \
68+
--dataset_name=fusing/fill50k \
69+
--resolution=512 \
70+
--learning_rate=1e-5 \
71+
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
72+
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
73+
--train_batch_size=4
74+
```
75+
76+
This default configuration requires ~38GB VRAM.
77+
78+
By default, the training script logs outputs to tensorboard. Pass `--report_to wandb` to use weights and
79+
biases.
80+
81+
Gradient accumulation with a smaller batch size can be used to reduce training requirements to ~20 GB VRAM.
82+
83+
```bash
84+
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
85+
export OUTPUT_DIR="path to save model"
86+
87+
accelerate launch train_controlnet.py \
88+
--pretrained_model_name_or_path=$MODEL_DIR \
89+
--output_dir=$OUTPUT_DIR \
90+
--dataset_name=fusing/fill50k \
91+
--resolution=512 \
92+
--learning_rate=1e-5 \
93+
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
94+
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
95+
--train_batch_size=1 \
96+
--gradient_accumulation_steps=4
97+
```
98+
99+
## Example results
100+
101+
#### After 300 steps with batch size 8
102+
103+
| | |
104+
|-------------------|:-------------------------:|
105+
| | red circle with blue background |
106+
![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![red circle with blue background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/red_circle_with_blue_background_300_steps.png) |
107+
| | cyan circle with brown floral background |
108+
![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png) | ![cyan circle with brown floral background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/cyan_circle_with_brown_floral_background_300_steps.png) |
109+
110+
111+
#### After 6000 steps with batch size 8:
112+
113+
| | |
114+
|-------------------|:-------------------------:|
115+
| | red circle with blue background |
116+
![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![red circle with blue background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/red_circle_with_blue_background_6000_steps.png) |
117+
| | cyan circle with brown floral background |
118+
![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png) | ![cyan circle with brown floral background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/cyan_circle_with_brown_floral_background_6000_steps.png) |
119+
120+
## Training on a 16 GB GPU
121+
122+
Optimizations:
123+
- Gradient checkpointing
124+
- bitsandbyte's 8-bit optimizer
125+
126+
[bitandbytes install instructions](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
127+
128+
```bash
129+
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
130+
export OUTPUT_DIR="path to save model"
131+
132+
accelerate launch train_controlnet.py \
133+
--pretrained_model_name_or_path=$MODEL_DIR \
134+
--output_dir=$OUTPUT_DIR \
135+
--dataset_name=fusing/fill50k \
136+
--resolution=512 \
137+
--learning_rate=1e-5 \
138+
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
139+
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
140+
--train_batch_size=1 \
141+
--gradient_accumulation_steps=4 \
142+
--gradient_checkpointing \
143+
--use_8bit_adam
144+
```
145+
146+
## Training on a 12 GB GPU
147+
148+
Optimizations:
149+
- Gradient checkpointing
150+
- bitsandbyte's 8-bit optimizer
151+
- xformers
152+
- set grads to none
153+
154+
```bash
155+
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
156+
export OUTPUT_DIR="path to save model"
157+
158+
accelerate launch train_controlnet.py \
159+
--pretrained_model_name_or_path=$MODEL_DIR \
160+
--output_dir=$OUTPUT_DIR \
161+
--dataset_name=fusing/fill50k \
162+
--resolution=512 \
163+
--learning_rate=1e-5 \
164+
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
165+
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
166+
--train_batch_size=1 \
167+
--gradient_accumulation_steps=4 \
168+
--gradient_checkpointing \
169+
--use_8bit_adam \
170+
--enable_xformers_memory_efficient_attention \
171+
--set_grads_to_none
172+
```
173+
174+
When using `enable_xformers_memory_efficient_attention`, please make sure to install `xformers` by `pip install xformers`.
175+
176+
## Training on an 8 GB GPU
177+
178+
We have not exhaustively tested DeepSpeed support for ControlNet. While the configuration does
179+
save memory, we have not confirmed the configuration to train successfully. You will very likely
180+
have to make changes to the config to have a successful training run.
181+
182+
Optimizations:
183+
- Gradient checkpointing
184+
- xformers
185+
- set grads to none
186+
- DeepSpeed stage 2 with parameter and optimizer offloading
187+
- fp16 mixed precision
188+
189+
[DeepSpeed](https://www.deepspeed.ai/) can offload tensors from VRAM to either
190+
CPU or NVME. This requires significantly more RAM (about 25 GB).
191+
192+
Use `accelerate config` to enable DeepSpeed stage 2.
193+
194+
The relevant parts of the resulting accelerate config file are
195+
196+
```yaml
197+
compute_environment: LOCAL_MACHINE
198+
deepspeed_config:
199+
gradient_accumulation_steps: 4
200+
offload_optimizer_device: cpu
201+
offload_param_device: cpu
202+
zero3_init_flag: false
203+
zero_stage: 2
204+
distributed_type: DEEPSPEED
205+
```
206+
207+
See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
208+
209+
Changing the default Adam optimizer to DeepSpeed's Adam
210+
`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but
211+
it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
212+
does not seem to be compatible with DeepSpeed at the moment.
213+
214+
```bash
215+
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
216+
export OUTPUT_DIR="path to save model"
217+
218+
accelerate launch train_controlnet.py \
219+
--pretrained_model_name_or_path=$MODEL_DIR \
220+
--output_dir=$OUTPUT_DIR \
221+
--dataset_name=fusing/fill50k \
222+
--resolution=512 \
223+
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
224+
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
225+
--train_batch_size=1 \
226+
--gradient_accumulation_steps=4 \
227+
--gradient_checkpointing \
228+
--enable_xformers_memory_efficient_attention \
229+
--set_grads_to_none \
230+
--mixed_precision fp16
231+
```
232+
233+
## Performing inference with the trained ControlNet
234+
235+
The trained model can be run the same as the original ControlNet pipeline with the newly trained ControlNet.
236+
Set `base_model_path` and `controlnet_path` to the values `--pretrained_model_name_or_path` and
237+
`--output_dir` were respectively set to in the training script.
238+
239+
```py
240+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
241+
from diffusers.utils import load_image
242+
import torch
243+
244+
base_model_path = "path to model"
245+
controlnet_path = "path to controlnet"
246+
247+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
248+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
249+
base_model_path, controlnet=controlnet, torch_dtype=torch.float16
250+
)
251+
252+
# speed up diffusion process with faster scheduler and memory optimization
253+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
254+
# remove following line if xformers is not installed
255+
pipe.enable_xformers_memory_efficient_attention()
256+
257+
pipe.enable_model_cpu_offload()
258+
259+
control_image = load_image("./conditioning_image_1.png")
260+
prompt = "pale golden rod circle with old lace background"
261+
262+
# generate image
263+
generator = torch.manual_seed(0)
264+
image = pipe(
265+
prompt, num_inference_steps=20, generator=generator, image=control_image
266+
).images[0]
267+
268+
image.save("./output.png")
269+
```
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
accelerate
2+
torchvision
3+
transformers>=4.25.1
4+
ftfy
5+
tensorboard
6+
datasets

0 commit comments

Comments
 (0)