Skip to content

Commit c3a830e

Browse files
committed
feat(exaple): add conditional image generation
1 parent b1fe170 commit c3a830e

File tree

4 files changed

+371
-0
lines changed

4 files changed

+371
-0
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
## Training examples
2+
3+
Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).
4+
5+
### Installing the dependencies
6+
7+
Before running the scipts, make sure to install the library's training dependencies:
8+
9+
```bash
10+
pip install diffusers[training] accelerate datasets
11+
```
12+
13+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
14+
15+
```bash
16+
accelerate config
17+
```
18+
19+
### conditional example
20+
21+
TODO: prepare examples
22+
23+
### Using your own data
24+
25+
To use your own dataset, there are 2 ways:
26+
- you can either provide your own folder as `--train_data_dir`
27+
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
28+
29+
Below, we explain both in more detail.
30+
31+
#### Provide the dataset as a folder
32+
33+
If you provide your own folders with images, the script expects the following directory structure:
34+
35+
```bash
36+
data_dir/xxx.png
37+
data_dir/xxy.png
38+
data_dir/[...]/xxz.png
39+
```
40+
41+
In other words, the script will take care of gathering all images inside the folder. You can then run the script like this:
42+
43+
```bash
44+
accelerate launch train_conditional.py \
45+
--train_data_dir <path-to-train-directory> \
46+
<other-arguments>
47+
```
48+
49+
Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.
50+
51+
#### Upload your data to the hub, as a (possibly private) repo
52+
53+
It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:
54+
55+
```python
56+
from datasets import load_dataset
57+
58+
# example 1: local folder
59+
dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")
60+
61+
# example 2: local files (suppoted formats are tar, gzip, zip, xz, rar, zstd)
62+
dataset = load_dataset("imagefolder", data_files="path_to_zip_file")
63+
64+
# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd)
65+
dataset = load_dataset("imagefolder", data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip")
66+
67+
# example 4: providing several splits
68+
dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]})
69+
```
70+
71+
`ImageFolder` will create an `image` column containing the PIL-encoded images.
72+
73+
Next, push it to the hub!
74+
75+
```python
76+
# assuming you have ran the huggingface-cli login command in a terminal
77+
dataset.push_to_hub("name_of_your_dataset")
78+
79+
# if you want to push to a private repo, simply pass private=True:
80+
dataset.push_to_hub("name_of_your_dataset", private=True)
81+
```
82+
83+
and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.
84+
85+
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
86+
87+
#### How to use in the pipeline
88+
89+
```python
90+
# make sure you're logged in with `huggingface-cli login`
91+
from torch import autocast
92+
from diffusers import StableDiffusionPipeline
93+
94+
# Replace it to model that you want to use.
95+
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_auth_token=True)
96+
97+
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", unet=unet use_auth_token=True)
98+
pipe = pipe.to("cuda")
99+
100+
prompt = "a photo of an astronaut riding a horse on mars"
101+
with autocast("cuda"):
102+
image = pipe(prompt)["sample"][0]
103+
```
24.9 KB
Loading
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
accelerate
2+
torchvision
3+
datasets
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
import argparse
2+
import math
3+
import os
4+
5+
import torch
6+
import torch.nn.functional as F
7+
8+
from accelerate import Accelerator
9+
from accelerate.logging import get_logger
10+
from datasets import load_dataset
11+
# from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
12+
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DConditionModel
13+
from diffusers.hub_utils import init_git_repo, push_to_hub
14+
from diffusers.optimization import get_scheduler
15+
from diffusers.training_utils import EMAModel
16+
from torchvision.transforms import (
17+
CenterCrop,
18+
Compose,
19+
InterpolationMode,
20+
Normalize,
21+
RandomHorizontalFlip,
22+
Resize,
23+
ToTensor,
24+
)
25+
from tqdm.auto import tqdm
26+
from transformers import CLIPTextModel, CLIPTokenizer
27+
28+
29+
logger = get_logger(__name__)
30+
31+
def main(args):
32+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
33+
accelerator = Accelerator(
34+
gradient_accumulation_steps=args.gradient_accumulation_steps,
35+
mixed_precision=args.mixed_precision,
36+
log_with="tensorboard",
37+
logging_dir=logging_dir,
38+
)
39+
40+
# FIXME implement training script
41+
model = UNet2DConditionModel(
42+
sample_size=args.resolution,
43+
in_channels=3,
44+
out_channels=3,
45+
layers_per_block=2,
46+
block_out_channels=(128, 128, 256, 256, 512, 512),
47+
down_block_types=(
48+
"DownBlock2D",
49+
"DownBlock2D",
50+
"DownBlock2D",
51+
"DownBlock2D",
52+
"AttnDownBlock2D",
53+
"DownBlock2D",
54+
),
55+
up_block_types=(
56+
"UpBlock2D",
57+
"AttnUpBlock2D",
58+
"UpBlock2D",
59+
"UpBlock2D",
60+
"UpBlock2D",
61+
"UpBlock2D",
62+
),
63+
)
64+
65+
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
66+
optimizer = torch.optim.AdamW(
67+
model.parameters(),
68+
lr=args.learning_rate,
69+
betas=(args.adam_beta1, args.adam_beta2),
70+
weight_decay=args.adam_weight_decay,
71+
eps=args.adam_epsilon,
72+
)
73+
74+
# it is needed to generate tokenized input to train.
75+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
76+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
77+
78+
augmentations = Compose(
79+
[
80+
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
81+
CenterCrop(args.resolution),
82+
RandomHorizontalFlip(),
83+
ToTensor(),
84+
Normalize([0.5], [0.5]),
85+
]
86+
)
87+
88+
if args.dataset_name is not None:
89+
dataset = load_dataset(
90+
args.dataset_name,
91+
args.dataset_config_name,
92+
cache_dir=args.cache_dir,
93+
use_auth_token=True if args.use_auth_token else None,
94+
split="train",
95+
)
96+
else:
97+
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
98+
99+
def transforms(examples):
100+
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
101+
return {"input": images}
102+
103+
dataset.set_transform(transforms)
104+
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
105+
106+
lr_scheduler = get_scheduler(
107+
args.lr_scheduler,
108+
optimizer=optimizer,
109+
num_warmup_steps=args.lr_warmup_steps,
110+
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
111+
)
112+
113+
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
114+
model, optimizer, train_dataloader, lr_scheduler
115+
)
116+
117+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
118+
119+
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
120+
121+
if args.push_to_hub:
122+
repo = init_git_repo(args, at_init=True)
123+
124+
if accelerator.is_main_process:
125+
run = os.path.split(__file__)[-1].split(".")[0]
126+
accelerator.init_trackers(run)
127+
128+
global_step = 0
129+
for epoch in range(args.num_epochs):
130+
model.train()
131+
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
132+
progress_bar.set_description(f"Epoch {epoch}")
133+
for step, batch in enumerate(train_dataloader):
134+
clean_images = batch["input"]
135+
# Sample noise that we'll add to the images
136+
noise = torch.randn(clean_images.shape).to(clean_images.device)
137+
bsz = clean_images.shape[0]
138+
# Sample a random timestep for each image
139+
timesteps = torch.randint(
140+
0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device
141+
).long()
142+
143+
# FIXME The input should probably select the appropriate one from the dataset.
144+
# Sample a text input
145+
uncond_input = tokenizer(
146+
[""] * args.eval_batch_size, padding="max_length", max_length=77, return_tensors="pt"
147+
)
148+
uncond_embeddings = text_encoder(uncond_input.input_ids.to(clean_images.device))[0]
149+
hidden_state = uncond_embeddings
150+
151+
# Add noise to the clean images according to the noise magnitude at each timestep
152+
# (this is the forward diffusion process)
153+
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
154+
155+
with accelerator.accumulate(model):
156+
# Predict the noise residual
157+
# FIXME Implement a successfully trainable model and training script
158+
noise_pred = model(noisy_images, timesteps, encoder_hidden_states=hidden_state)["sample"]
159+
loss = F.mse_loss(noise_pred, noise)
160+
accelerator.backward(loss)
161+
162+
accelerator.clip_grad_norm_(model.parameters(), 1.0)
163+
optimizer.step()
164+
lr_scheduler.step()
165+
if args.use_ema:
166+
ema_model.step(model)
167+
optimizer.zero_grad()
168+
169+
# Checks if the accelerator has performed an optimization step behind the scenes
170+
if accelerator.sync_gradients:
171+
progress_bar.update(1)
172+
global_step += 1
173+
174+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
175+
if args.use_ema:
176+
logs["ema_decay"] = ema_model.decay
177+
progress_bar.set_postfix(**logs)
178+
accelerator.log(logs, step=global_step)
179+
progress_bar.close()
180+
181+
accelerator.wait_for_everyone()
182+
183+
# Generate sample images for visual inspection
184+
if accelerator.is_main_process:
185+
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
186+
pipeline = DDPMPipeline(
187+
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
188+
scheduler=noise_scheduler,
189+
)
190+
191+
generator = torch.manual_seed(0)
192+
# run pipeline in inference (sample random noise and denoise)
193+
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
194+
195+
# denormalize the images and save to tensorboard
196+
images_processed = (images * 255).round().astype("uint8")
197+
accelerator.trackers[0].writer.add_images(
198+
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
199+
)
200+
201+
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
202+
# save the model
203+
if args.push_to_hub:
204+
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
205+
else:
206+
pipeline.save_pretrained(args.output_dir)
207+
accelerator.wait_for_everyone()
208+
209+
accelerator.end_training()
210+
211+
212+
if __name__ == "__main__":
213+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
214+
parser.add_argument("--local_rank", type=int, default=-1)
215+
parser.add_argument("--dataset_name", type=str, default=None)
216+
parser.add_argument("--dataset_config_name", type=str, default=None)
217+
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
218+
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
219+
parser.add_argument("--overwrite_output_dir", action="store_true")
220+
parser.add_argument("--cache_dir", type=str, default=None)
221+
parser.add_argument("--resolution", type=int, default=64)
222+
parser.add_argument("--train_batch_size", type=int, default=16)
223+
parser.add_argument("--eval_batch_size", type=int, default=16)
224+
parser.add_argument("--num_epochs", type=int, default=100)
225+
parser.add_argument("--save_images_epochs", type=int, default=10)
226+
parser.add_argument("--save_model_epochs", type=int, default=10)
227+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
228+
parser.add_argument("--learning_rate", type=float, default=1e-4)
229+
parser.add_argument("--lr_scheduler", type=str, default="cosine")
230+
parser.add_argument("--lr_warmup_steps", type=int, default=500)
231+
parser.add_argument("--adam_beta1", type=float, default=0.95)
232+
parser.add_argument("--adam_beta2", type=float, default=0.999)
233+
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
234+
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
235+
parser.add_argument("--use_ema", action="store_true", default=True)
236+
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
237+
parser.add_argument("--ema_power", type=float, default=3 / 4)
238+
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
239+
parser.add_argument("--push_to_hub", action="store_true")
240+
parser.add_argument("--use_auth_token", action="store_true")
241+
parser.add_argument("--hub_token", type=str, default=None)
242+
parser.add_argument("--hub_model_id", type=str, default=None)
243+
parser.add_argument("--hub_private_repo", action="store_true")
244+
parser.add_argument("--logging_dir", type=str, default="logs")
245+
parser.add_argument(
246+
"--mixed_precision",
247+
type=str,
248+
default="no",
249+
choices=["no", "fp16", "bf16"],
250+
help=(
251+
"Whether to use mixed precision. Choose"
252+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
253+
"and an Nvidia Ampere GPU."
254+
),
255+
)
256+
257+
args = parser.parse_args()
258+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
259+
if env_local_rank != -1 and env_local_rank != args.local_rank:
260+
args.local_rank = env_local_rank
261+
262+
if args.dataset_name is None and args.train_data_dir is None:
263+
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
264+
265+
main(args)

0 commit comments

Comments
 (0)