Skip to content

Commit 11362ae

Browse files
bglick13Ben GlickenhausNathan Lambert
authored
V prediction ddim (#1313)
* v diffusion support for ddpm * quality and style * variable name consistency * missing base case * pass prediction type along in the pipeline * put prediction type in scheduler config * style * try to train on ddim * changes to ddim * ddim v prediction works to train butterflies example * fix bad merge, style and quality * try to fix broken doc strings * second pass * one more * white space * Update src/diffusers/schedulers/scheduling_ddim.py * remove extra lines * Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Ben Glickenhaus <[email protected]> Co-authored-by: Nathan Lambert <[email protected]>
1 parent 56164f5 commit 11362ae

File tree

2 files changed

+299
-17
lines changed

2 files changed

+299
-17
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import glob
2+
import os
3+
from dataclasses import dataclass
4+
5+
import torch
6+
import torch.nn.functional as F
7+
8+
from accelerate import Accelerator
9+
from datasets import load_dataset
10+
from diffusers import DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, UNet2DModel
11+
from diffusers.hub_utils import init_git_repo, push_to_hub
12+
from diffusers.optimization import get_cosine_schedule_with_warmup
13+
from PIL import Image
14+
from torchvision import transforms
15+
from tqdm.auto import tqdm
16+
17+
18+
@dataclass
19+
class TrainingConfig:
20+
image_size = 128 # the generated image resolution
21+
train_batch_size = 16
22+
eval_batch_size = 16 # how many images to sample during evaluation
23+
num_epochs = 50
24+
gradient_accumulation_steps = 1
25+
learning_rate = 5e-5
26+
lr_warmup_steps = 500
27+
save_image_epochs = 10
28+
save_model_epochs = 30
29+
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
30+
output_dir = "ddim-butterflies-128-v-diffusion" # the model namy locally and on the HF Hub
31+
32+
push_to_hub = False # whether to upload the saved model to the HF Hub
33+
hub_private_repo = False
34+
overwrite_output_dir = True # overwrite the old model when re-running the notebook
35+
seed = 0
36+
37+
38+
config = TrainingConfig()
39+
40+
41+
config.dataset_name = "huggan/smithsonian_butterflies_subset"
42+
dataset = load_dataset(config.dataset_name, split="train")
43+
44+
45+
preprocess = transforms.Compose(
46+
[
47+
transforms.Resize((config.image_size, config.image_size)),
48+
transforms.RandomHorizontalFlip(),
49+
transforms.ToTensor(),
50+
transforms.Normalize([0.5], [0.5]),
51+
]
52+
)
53+
54+
55+
def transform(examples):
56+
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
57+
return {"images": images}
58+
59+
60+
dataset.set_transform(transform)
61+
62+
63+
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
64+
65+
66+
model = UNet2DModel(
67+
sample_size=config.image_size, # the target image resolution
68+
in_channels=3, # the number of input channels, 3 for RGB images
69+
out_channels=3, # the number of output channels
70+
layers_per_block=2, # how many ResNet layers to use per UNet block
71+
block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block
72+
down_block_types=(
73+
"DownBlock2D", # a regular ResNet downsampling block
74+
"DownBlock2D",
75+
"DownBlock2D",
76+
"DownBlock2D",
77+
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
78+
"DownBlock2D",
79+
),
80+
up_block_types=(
81+
"UpBlock2D", # a regular ResNet upsampling block
82+
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
83+
"UpBlock2D",
84+
"UpBlock2D",
85+
"UpBlock2D",
86+
"UpBlock2D",
87+
),
88+
)
89+
90+
91+
if config.output_dir.startswith("ddpm"):
92+
noise_scheduler = DDPMScheduler(
93+
num_train_timesteps=1000,
94+
beta_schedule="squaredcos_cap_v2",
95+
variance_type="v_diffusion",
96+
prediction_type="v",
97+
)
98+
else:
99+
noise_scheduler = DDIMScheduler(
100+
num_train_timesteps=1000,
101+
beta_schedule="squaredcos_cap_v2",
102+
variance_type="v_diffusion",
103+
prediction_type="v",
104+
)
105+
106+
107+
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
108+
109+
110+
lr_scheduler = get_cosine_schedule_with_warmup(
111+
optimizer=optimizer,
112+
num_warmup_steps=config.lr_warmup_steps,
113+
num_training_steps=(len(train_dataloader) * config.num_epochs),
114+
)
115+
116+
117+
def make_grid(images, rows, cols):
118+
w, h = images[0].size
119+
grid = Image.new("RGB", size=(cols * w, rows * h))
120+
for i, image in enumerate(images):
121+
grid.paste(image, box=(i % cols * w, i // cols * h))
122+
return grid
123+
124+
125+
def evaluate(config, epoch, pipeline):
126+
# Sample some images from random noise (this is the backward diffusion process).
127+
# The default pipeline output type is `List[PIL.Image]`
128+
images = pipeline(
129+
batch_size=config.eval_batch_size,
130+
generator=torch.manual_seed(config.seed),
131+
).images
132+
133+
# Make a grid out of the images
134+
image_grid = make_grid(images, rows=4, cols=4)
135+
136+
# Save the images
137+
test_dir = os.path.join(config.output_dir, "samples")
138+
os.makedirs(test_dir, exist_ok=True)
139+
image_grid.save(f"{test_dir}/{epoch:04d}.png")
140+
141+
142+
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
143+
# Initialize accelerator and tensorboard logging
144+
accelerator = Accelerator(
145+
mixed_precision=config.mixed_precision,
146+
gradient_accumulation_steps=config.gradient_accumulation_steps,
147+
log_with="tensorboard",
148+
logging_dir=os.path.join(config.output_dir, "logs"),
149+
)
150+
if accelerator.is_main_process:
151+
if config.push_to_hub:
152+
repo = init_git_repo(config, at_init=True)
153+
accelerator.init_trackers("train_example")
154+
155+
# Prepare everything
156+
# There is no specific order to remember, you just need to unpack the
157+
# objects in the same order you gave them to the prepare method.
158+
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
159+
model, optimizer, train_dataloader, lr_scheduler
160+
)
161+
162+
global_step = 0
163+
164+
if config.output_dir.startswith("ddpm"):
165+
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
166+
else:
167+
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
168+
169+
evaluate(config, 0, pipeline)
170+
171+
# Now you train the model
172+
for epoch in range(config.num_epochs):
173+
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
174+
progress_bar.set_description(f"Epoch {epoch}")
175+
176+
for step, batch in enumerate(train_dataloader):
177+
clean_images = batch["images"]
178+
# Sample noise to add to the images
179+
noise = torch.randn(clean_images.shape).to(clean_images.device)
180+
bs = clean_images.shape[0]
181+
182+
# Sample a random timestep for each image
183+
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()
184+
185+
with accelerator.accumulate(model):
186+
# Predict the noise residual
187+
alpha_t, sigma_t = noise_scheduler.get_alpha_sigma(clean_images, timesteps, accelerator.device)
188+
z_t = alpha_t * clean_images + sigma_t * noise
189+
noise_pred = model(z_t, timesteps).sample
190+
v = alpha_t * noise - sigma_t * clean_images
191+
loss = F.mse_loss(noise_pred, v)
192+
accelerator.backward(loss)
193+
194+
accelerator.clip_grad_norm_(model.parameters(), 1.0)
195+
optimizer.step()
196+
lr_scheduler.step()
197+
optimizer.zero_grad()
198+
199+
progress_bar.update(1)
200+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
201+
progress_bar.set_postfix(**logs)
202+
accelerator.log(logs, step=global_step)
203+
global_step += 1
204+
205+
# After each epoch you optionally sample some demo images with evaluate() and save the model
206+
if accelerator.is_main_process:
207+
if config.output_dir.startswith("ddpm"):
208+
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
209+
else:
210+
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
211+
212+
if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
213+
evaluate(config, epoch, pipeline)
214+
215+
if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
216+
if config.push_to_hub:
217+
push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True)
218+
else:
219+
pipeline.save_pretrained(config.output_dir)
220+
221+
222+
args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
223+
224+
train_loop(*args)
225+
226+
sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
227+
Image.open(sample_images[-1])

0 commit comments

Comments
 (0)