Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,28 @@ def parse_args():
)

parser.add_argument(
"--predict_epsilon",
action="store_true",
default=True,
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
"--prediction_type",
type=str,
default="epsilon",
help=(
"Whether the model should predict the 'epsilon'/noise error, directly the reconstructed image 'x0', or the"
" velocity of the ODE 'velocity'."
),
)

parser.add_argument("--ddpm_num_steps", type=int, default=1000)
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")

args = parser.parse_args()

message = (
"Please make sure to instantiate your training with `--prediction_type=epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, prediction_type=epsilon)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=args)
if predict_epsilon:
args.prediction_type = "epsilon"

env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
Expand Down Expand Up @@ -256,13 +268,13 @@ def main(args):
"UpBlock2D",
),
)
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())

if accepts_predict_epsilon:
if accepts_prediction_type:
noise_scheduler = DDPMScheduler(
num_train_timesteps=args.ddpm_num_steps,
beta_schedule=args.ddpm_beta_schedule,
predict_epsilon=args.predict_epsilon,
prediction_type=args.prediction_type,
)
else:
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
Expand Down Expand Up @@ -365,7 +377,7 @@ def transforms(examples):
# Predict the noise residual
model_output = model(noisy_images, timesteps).sample

if args.predict_epsilon:
if args.prediction_type == "epsilon":
loss = F.mse_loss(model_output, noise) # this could have different weights!
else:
alpha_t = _extract_into_tensor(
Expand Down
227 changes: 227 additions & 0 deletions examples/v_prediction/train_butterflies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import glob
import os
from dataclasses import dataclass

import torch
import torch.nn.functional as F

from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_cosine_schedule_with_warmup
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm


@dataclass
class TrainingConfig:
image_size = 128 # the generated image resolution
train_batch_size = 16
eval_batch_size = 16 # how many images to sample during evaluation
num_epochs = 50
gradient_accumulation_steps = 1
learning_rate = 5e-5
lr_warmup_steps = 500
save_image_epochs = 10
save_model_epochs = 30
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
output_dir = "ddim-butterflies-128-v-diffusion" # the model namy locally and on the HF Hub

push_to_hub = False # whether to upload the saved model to the HF Hub
hub_private_repo = False
overwrite_output_dir = True # overwrite the old model when re-running the notebook
seed = 0


config = TrainingConfig()


config.dataset_name = "huggan/smithsonian_butterflies_subset"
dataset = load_dataset(config.dataset_name, split="train")


preprocess = transforms.Compose(
[
transforms.Resize((config.image_size, config.image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)


def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {"images": images}


dataset.set_transform(transform)


train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)


model = UNet2DModel(
sample_size=config.image_size, # the target image resolution
in_channels=3, # the number of input channels, 3 for RGB images
out_channels=3, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"DownBlock2D",
),
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)


if config.output_dir.startswith("ddpm"):
noise_scheduler = DDPMScheduler(
num_train_timesteps=1000,
beta_schedule="squaredcos_cap_v2",
variance_type="v_diffusion",
prediction_type="velocity",
)
else:
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_schedule="squaredcos_cap_v2",
variance_type="v_diffusion",
prediction_type="velocity",
)


optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)


lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)


def make_grid(images, rows, cols):
w, h = images[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, image in enumerate(images):
grid.paste(image, box=(i % cols * w, i // cols * h))
return grid


def evaluate(config, epoch, pipeline):
# Sample some images from random noise (this is the backward diffusion process).
# The default pipeline output type is `List[PIL.Image]`
images = pipeline(
batch_size=config.eval_batch_size,
generator=torch.manual_seed(config.seed),
).images

# Make a grid out of the images
image_grid = make_grid(images, rows=4, cols=4)

# Save the images
test_dir = os.path.join(config.output_dir, "samples")
os.makedirs(test_dir, exist_ok=True)
image_grid.save(f"{test_dir}/{epoch:04d}.png")


def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
# Initialize accelerator and tensorboard logging
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with="tensorboard",
logging_dir=os.path.join(config.output_dir, "logs"),
)
if accelerator.is_main_process:
if config.push_to_hub:
repo = init_git_repo(config, at_init=True)
accelerator.init_trackers("train_example")

# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)

global_step = 0

if config.output_dir.startswith("ddpm"):
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
else:
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

evaluate(config, 0, pipeline)

# Now you train the model
for epoch in range(config.num_epochs):
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")

for step, batch in enumerate(train_dataloader):
clean_images = batch["images"]
# Sample noise to add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bs = clean_images.shape[0]

# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()

with accelerator.accumulate(model):
# Predict the noise residual
alpha_t, sigma_t = noise_scheduler.get_alpha_sigma(clean_images, timesteps, accelerator.device)
z_t = alpha_t * clean_images + sigma_t * noise
noise_pred = model(z_t, timesteps).sample
v = alpha_t * noise - sigma_t * clean_images
loss = F.mse_loss(noise_pred, v)
accelerator.backward(loss)

accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

progress_bar.update(1)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
global_step += 1

# After each epoch you optionally sample some demo images with evaluate() and save the model
if accelerator.is_main_process:
if config.output_dir.startswith("ddpm"):
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
else:
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
evaluate(config, epoch, pipeline)

if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
if config.push_to_hub:
push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True)
else:
pipeline.save_pretrained(config.output_dir)


args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)

train_loop(*args)

sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
Image.open(sample_images[-1])
Loading