Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions examples/conditional_image_generation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
## Training examples

Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).

### Installing the dependencies

Before running the scipts, make sure to install the library's training dependencies:

```bash
pip install diffusers[training] accelerate datasets
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```

### conditional example

TODO: prepare examples

### Using your own data

To use your own dataset, there are 2 ways:
- you can either provide your own folder as `--train_data_dir`
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.

Below, we explain both in more detail.

#### Provide the dataset as a folder

If you provide your own folders with images, the script expects the following directory structure:

```bash
data_dir/xxx.png
data_dir/xxy.png
data_dir/[...]/xxz.png
```

In other words, the script will take care of gathering all images inside the folder. You can then run the script like this:

```bash
accelerate launch train_conditional.py \
--train_data_dir <path-to-train-directory> \
<other-arguments>
```

Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.

#### Upload your data to the hub, as a (possibly private) repo

It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:

```python
from datasets import load_dataset

# example 1: local folder
dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")

# example 2: local files (suppoted formats are tar, gzip, zip, xz, rar, zstd)
dataset = load_dataset("imagefolder", data_files="path_to_zip_file")

# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd)
dataset = load_dataset("imagefolder", data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip")

# example 4: providing several splits
dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]})
```

`ImageFolder` will create an `image` column containing the PIL-encoded images.

Next, push it to the hub!

```python
# assuming you have ran the huggingface-cli login command in a terminal
dataset.push_to_hub("name_of_your_dataset")

# if you want to push to a private repo, simply pass private=True:
dataset.push_to_hub("name_of_your_dataset", private=True)
```

and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.

More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).

#### How to use in the pipeline

```python
# make sure you're logged in with `huggingface-cli login`
from torch import autocast
from diffusers import StableDiffusionPipeline

# Replace it to model that you want to use.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_auth_token=True)

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", unet=unet use_auth_token=True)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
```
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions examples/conditional_image_generation/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
accelerate
torchvision
datasets
265 changes: 265 additions & 0 deletions examples/conditional_image_generation/train_conditional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
import argparse
import math
import os

import torch
import torch.nn.functional as F

from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
# from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DConditionModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer


logger = get_logger(__name__)

def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
logging_dir=logging_dir,
)

# FIXME implement training script
model = UNet2DConditionModel(
sample_size=args.resolution,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)

# it is needed to generate tokenized input to train.
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

augmentations = Compose(
[
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Normalize([0.5], [0.5]),
]
)

if args.dataset_name is not None:
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
use_auth_token=True if args.use_auth_token else None,
split="train",
)
else:
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")

def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}

dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)

lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
)

model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)

num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)

if args.push_to_hub:
repo = init_git_repo(args, at_init=True)

if accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
accelerator.init_trackers(run)

global_step = 0
for epoch in range(args.num_epochs):
model.train()
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
# Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bsz = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device
).long()

# FIXME The input should probably select the appropriate one from the dataset.
# Sample a text input
uncond_input = tokenizer(
[""] * args.eval_batch_size, padding="max_length", max_length=77, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(clean_images.device))[0]
hidden_state = uncond_embeddings

# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

with accelerator.accumulate(model):
# Predict the noise residual
# FIXME Implement a successfully trainable model and training script
noise_pred = model(noisy_images, timesteps, encoder_hidden_states=hidden_state)["sample"]
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)

accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
if args.use_ema:
ema_model.step(model)
optimizer.zero_grad()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if args.use_ema:
logs["ema_decay"] = ema_model.decay
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
progress_bar.close()

accelerator.wait_for_everyone()

# Generate sample images for visual inspection
if accelerator.is_main_process:
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
scheduler=noise_scheduler,
)

generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]

# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
accelerator.trackers[0].writer.add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)

if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone()

accelerator.end_training()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset_name", type=str, default=None)
parser.add_argument("--dataset_config_name", type=str, default=None)
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--save_images_epochs", type=int, default=10)
parser.add_argument("--save_model_epochs", type=int, default=10)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--lr_scheduler", type=str, default="cosine")
parser.add_argument("--lr_warmup_steps", type=int, default=500)
parser.add_argument("--adam_beta1", type=float, default=0.95)
parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
parser.add_argument("--use_ema", action="store_true", default=True)
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
parser.add_argument("--ema_power", type=float, default=3 / 4)
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--use_auth_token", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true")
parser.add_argument("--logging_dir", type=str, default="logs")
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank

if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")

main(args)