|
| 1 | +import argparse |
| 2 | +import math |
| 3 | +import os |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn.functional as F |
| 7 | + |
| 8 | +from accelerate import Accelerator |
| 9 | +from accelerate.logging import get_logger |
| 10 | +from datasets import load_dataset |
| 11 | +# from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel |
| 12 | +from diffusers import DDPMPipeline, DDPMScheduler, UNet2DConditionModel |
| 13 | +from diffusers.hub_utils import init_git_repo, push_to_hub |
| 14 | +from diffusers.optimization import get_scheduler |
| 15 | +from diffusers.training_utils import EMAModel |
| 16 | +from torchvision.transforms import ( |
| 17 | + CenterCrop, |
| 18 | + Compose, |
| 19 | + InterpolationMode, |
| 20 | + Normalize, |
| 21 | + RandomHorizontalFlip, |
| 22 | + Resize, |
| 23 | + ToTensor, |
| 24 | +) |
| 25 | +from tqdm.auto import tqdm |
| 26 | +from transformers import CLIPTextModel, CLIPTokenizer |
| 27 | + |
| 28 | + |
| 29 | +logger = get_logger(__name__) |
| 30 | + |
| 31 | +def main(args): |
| 32 | + logging_dir = os.path.join(args.output_dir, args.logging_dir) |
| 33 | + accelerator = Accelerator( |
| 34 | + gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 35 | + mixed_precision=args.mixed_precision, |
| 36 | + log_with="tensorboard", |
| 37 | + logging_dir=logging_dir, |
| 38 | + ) |
| 39 | + |
| 40 | + # FIXME implement training script |
| 41 | + model = UNet2DConditionModel( |
| 42 | + sample_size=args.resolution, |
| 43 | + in_channels=3, |
| 44 | + out_channels=3, |
| 45 | + layers_per_block=2, |
| 46 | + block_out_channels=(128, 128, 256, 256, 512, 512), |
| 47 | + down_block_types=( |
| 48 | + "DownBlock2D", |
| 49 | + "DownBlock2D", |
| 50 | + "DownBlock2D", |
| 51 | + "DownBlock2D", |
| 52 | + "AttnDownBlock2D", |
| 53 | + "DownBlock2D", |
| 54 | + ), |
| 55 | + up_block_types=( |
| 56 | + "UpBlock2D", |
| 57 | + "AttnUpBlock2D", |
| 58 | + "UpBlock2D", |
| 59 | + "UpBlock2D", |
| 60 | + "UpBlock2D", |
| 61 | + "UpBlock2D", |
| 62 | + ), |
| 63 | + ) |
| 64 | + |
| 65 | + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt") |
| 66 | + optimizer = torch.optim.AdamW( |
| 67 | + model.parameters(), |
| 68 | + lr=args.learning_rate, |
| 69 | + betas=(args.adam_beta1, args.adam_beta2), |
| 70 | + weight_decay=args.adam_weight_decay, |
| 71 | + eps=args.adam_epsilon, |
| 72 | + ) |
| 73 | + |
| 74 | + # it is needed to generate tokenized input to train. |
| 75 | + text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") |
| 76 | + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
| 77 | + |
| 78 | + augmentations = Compose( |
| 79 | + [ |
| 80 | + Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), |
| 81 | + CenterCrop(args.resolution), |
| 82 | + RandomHorizontalFlip(), |
| 83 | + ToTensor(), |
| 84 | + Normalize([0.5], [0.5]), |
| 85 | + ] |
| 86 | + ) |
| 87 | + |
| 88 | + if args.dataset_name is not None: |
| 89 | + dataset = load_dataset( |
| 90 | + args.dataset_name, |
| 91 | + args.dataset_config_name, |
| 92 | + cache_dir=args.cache_dir, |
| 93 | + use_auth_token=True if args.use_auth_token else None, |
| 94 | + split="train", |
| 95 | + ) |
| 96 | + else: |
| 97 | + dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train") |
| 98 | + |
| 99 | + def transforms(examples): |
| 100 | + images = [augmentations(image.convert("RGB")) for image in examples["image"]] |
| 101 | + return {"input": images} |
| 102 | + |
| 103 | + dataset.set_transform(transforms) |
| 104 | + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) |
| 105 | + |
| 106 | + lr_scheduler = get_scheduler( |
| 107 | + args.lr_scheduler, |
| 108 | + optimizer=optimizer, |
| 109 | + num_warmup_steps=args.lr_warmup_steps, |
| 110 | + num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, |
| 111 | + ) |
| 112 | + |
| 113 | + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( |
| 114 | + model, optimizer, train_dataloader, lr_scheduler |
| 115 | + ) |
| 116 | + |
| 117 | + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| 118 | + |
| 119 | + ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) |
| 120 | + |
| 121 | + if args.push_to_hub: |
| 122 | + repo = init_git_repo(args, at_init=True) |
| 123 | + |
| 124 | + if accelerator.is_main_process: |
| 125 | + run = os.path.split(__file__)[-1].split(".")[0] |
| 126 | + accelerator.init_trackers(run) |
| 127 | + |
| 128 | + global_step = 0 |
| 129 | + for epoch in range(args.num_epochs): |
| 130 | + model.train() |
| 131 | + progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) |
| 132 | + progress_bar.set_description(f"Epoch {epoch}") |
| 133 | + for step, batch in enumerate(train_dataloader): |
| 134 | + clean_images = batch["input"] |
| 135 | + # Sample noise that we'll add to the images |
| 136 | + noise = torch.randn(clean_images.shape).to(clean_images.device) |
| 137 | + bsz = clean_images.shape[0] |
| 138 | + # Sample a random timestep for each image |
| 139 | + timesteps = torch.randint( |
| 140 | + 0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device |
| 141 | + ).long() |
| 142 | + |
| 143 | + # FIXME The input should probably select the appropriate one from the dataset. |
| 144 | + # Sample a text input |
| 145 | + uncond_input = tokenizer( |
| 146 | + [""] * args.eval_batch_size, padding="max_length", max_length=77, return_tensors="pt" |
| 147 | + ) |
| 148 | + uncond_embeddings = text_encoder(uncond_input.input_ids.to(clean_images.device))[0] |
| 149 | + hidden_state = uncond_embeddings |
| 150 | + |
| 151 | + # Add noise to the clean images according to the noise magnitude at each timestep |
| 152 | + # (this is the forward diffusion process) |
| 153 | + noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) |
| 154 | + |
| 155 | + with accelerator.accumulate(model): |
| 156 | + # Predict the noise residual |
| 157 | + # FIXME Implement a successfully trainable model and training script |
| 158 | + noise_pred = model(noisy_images, timesteps, encoder_hidden_states=hidden_state)["sample"] |
| 159 | + loss = F.mse_loss(noise_pred, noise) |
| 160 | + accelerator.backward(loss) |
| 161 | + |
| 162 | + accelerator.clip_grad_norm_(model.parameters(), 1.0) |
| 163 | + optimizer.step() |
| 164 | + lr_scheduler.step() |
| 165 | + if args.use_ema: |
| 166 | + ema_model.step(model) |
| 167 | + optimizer.zero_grad() |
| 168 | + |
| 169 | + # Checks if the accelerator has performed an optimization step behind the scenes |
| 170 | + if accelerator.sync_gradients: |
| 171 | + progress_bar.update(1) |
| 172 | + global_step += 1 |
| 173 | + |
| 174 | + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} |
| 175 | + if args.use_ema: |
| 176 | + logs["ema_decay"] = ema_model.decay |
| 177 | + progress_bar.set_postfix(**logs) |
| 178 | + accelerator.log(logs, step=global_step) |
| 179 | + progress_bar.close() |
| 180 | + |
| 181 | + accelerator.wait_for_everyone() |
| 182 | + |
| 183 | + # Generate sample images for visual inspection |
| 184 | + if accelerator.is_main_process: |
| 185 | + if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: |
| 186 | + pipeline = DDPMPipeline( |
| 187 | + unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model), |
| 188 | + scheduler=noise_scheduler, |
| 189 | + ) |
| 190 | + |
| 191 | + generator = torch.manual_seed(0) |
| 192 | + # run pipeline in inference (sample random noise and denoise) |
| 193 | + images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"] |
| 194 | + |
| 195 | + # denormalize the images and save to tensorboard |
| 196 | + images_processed = (images * 255).round().astype("uint8") |
| 197 | + accelerator.trackers[0].writer.add_images( |
| 198 | + "test_samples", images_processed.transpose(0, 3, 1, 2), epoch |
| 199 | + ) |
| 200 | + |
| 201 | + if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: |
| 202 | + # save the model |
| 203 | + if args.push_to_hub: |
| 204 | + push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) |
| 205 | + else: |
| 206 | + pipeline.save_pretrained(args.output_dir) |
| 207 | + accelerator.wait_for_everyone() |
| 208 | + |
| 209 | + accelerator.end_training() |
| 210 | + |
| 211 | + |
| 212 | +if __name__ == "__main__": |
| 213 | + parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 214 | + parser.add_argument("--local_rank", type=int, default=-1) |
| 215 | + parser.add_argument("--dataset_name", type=str, default=None) |
| 216 | + parser.add_argument("--dataset_config_name", type=str, default=None) |
| 217 | + parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") |
| 218 | + parser.add_argument("--output_dir", type=str, default="ddpm-model-64") |
| 219 | + parser.add_argument("--overwrite_output_dir", action="store_true") |
| 220 | + parser.add_argument("--cache_dir", type=str, default=None) |
| 221 | + parser.add_argument("--resolution", type=int, default=64) |
| 222 | + parser.add_argument("--train_batch_size", type=int, default=16) |
| 223 | + parser.add_argument("--eval_batch_size", type=int, default=16) |
| 224 | + parser.add_argument("--num_epochs", type=int, default=100) |
| 225 | + parser.add_argument("--save_images_epochs", type=int, default=10) |
| 226 | + parser.add_argument("--save_model_epochs", type=int, default=10) |
| 227 | + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) |
| 228 | + parser.add_argument("--learning_rate", type=float, default=1e-4) |
| 229 | + parser.add_argument("--lr_scheduler", type=str, default="cosine") |
| 230 | + parser.add_argument("--lr_warmup_steps", type=int, default=500) |
| 231 | + parser.add_argument("--adam_beta1", type=float, default=0.95) |
| 232 | + parser.add_argument("--adam_beta2", type=float, default=0.999) |
| 233 | + parser.add_argument("--adam_weight_decay", type=float, default=1e-6) |
| 234 | + parser.add_argument("--adam_epsilon", type=float, default=1e-08) |
| 235 | + parser.add_argument("--use_ema", action="store_true", default=True) |
| 236 | + parser.add_argument("--ema_inv_gamma", type=float, default=1.0) |
| 237 | + parser.add_argument("--ema_power", type=float, default=3 / 4) |
| 238 | + parser.add_argument("--ema_max_decay", type=float, default=0.9999) |
| 239 | + parser.add_argument("--push_to_hub", action="store_true") |
| 240 | + parser.add_argument("--use_auth_token", action="store_true") |
| 241 | + parser.add_argument("--hub_token", type=str, default=None) |
| 242 | + parser.add_argument("--hub_model_id", type=str, default=None) |
| 243 | + parser.add_argument("--hub_private_repo", action="store_true") |
| 244 | + parser.add_argument("--logging_dir", type=str, default="logs") |
| 245 | + parser.add_argument( |
| 246 | + "--mixed_precision", |
| 247 | + type=str, |
| 248 | + default="no", |
| 249 | + choices=["no", "fp16", "bf16"], |
| 250 | + help=( |
| 251 | + "Whether to use mixed precision. Choose" |
| 252 | + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." |
| 253 | + "and an Nvidia Ampere GPU." |
| 254 | + ), |
| 255 | + ) |
| 256 | + |
| 257 | + args = parser.parse_args() |
| 258 | + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
| 259 | + if env_local_rank != -1 and env_local_rank != args.local_rank: |
| 260 | + args.local_rank = env_local_rank |
| 261 | + |
| 262 | + if args.dataset_name is None and args.train_data_dir is None: |
| 263 | + raise ValueError("You must specify either a dataset name from the hub or a train data directory.") |
| 264 | + |
| 265 | + main(args) |
0 commit comments