From d394eaafa7159a512f2b85fb3729f5a7b62a0ddd Mon Sep 17 00:00:00 2001 From: Adalberto Date: Mon, 31 Oct 2022 20:33:17 -0300 Subject: [PATCH 1/6] Create train_dreambooth_inpaint.py train_dreambooth.py adapted to work with the inpaint model, generating random masks during the training --- .../dreambooth/train_dreambooth_inpaint.py | 697 ++++++++++++++++++ 1 file changed, 697 insertions(+) create mode 100644 examples/dreambooth/train_dreambooth_inpaint.py diff --git a/examples/dreambooth/train_dreambooth_inpaint.py b/examples/dreambooth/train_dreambooth_inpaint.py new file mode 100644 index 000000000000..5844534b2d7a --- /dev/null +++ b/examples/dreambooth/train_dreambooth_inpaint.py @@ -0,0 +1,697 @@ +import argparse +import hashlib +import itertools +import math +import os +from pathlib import Path +from typing import Optional + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.utils.data import Dataset + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image, ImageDraw +import random +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +import numpy as np + +logger = get_logger(__name__) + +def prepare_mask_and_masked_image(image, mask): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + return mask, masked_image + +#generate random masks +def random_mask(im_shape, ratio=1): + mask = Image.new("L", im_shape, 0) + draw = ImageDraw.Draw(mask) + size=(random.randint(0, int(im_shape[0]*ratio)), random.randint(0, int(im_shape[1]*ratio))) + # size=(int(im_shape[0]*ratio), int(im_shape[1]*ratio)) #use this to mask the whole image always + limits=(im_shape[0]-size[0]//2, im_shape[1]-size[1]//2) + center = (random.randint(size[0]//2, limits[0]), random.randint(size[1]//2, limits[1])) + draw.rectangle((center[0]-size[0]//2, center[1]-size[1]//2, center[0]+size[0]//2, center[1]+size[1]//2), fill=255) + + return mask + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If not have enough images, additional images will be" + " sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.instance_data_dir is None: + raise ValueError("You must specify a train data directory.") + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + + + example["PIL_images"] = instance_image + example["instance_images"] = self.image_transforms(instance_image) + + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + return example + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + if args.seed is not None: + set_seed(args.seed) + + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size, num_workers=1) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 + ) + + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + def collate_fn(examples): + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + ] + ) + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + masks=[] + masked_images=[] + for example in examples: + pil_image = example["PIL_images"] + #generate a random mask + mask = random_mask(pil_image.size) + #apply transforms + mask=image_transforms(mask) + pil_image=image_transforms(pil_image) + #prepare mask and masked image + mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) + + masks.append(mask) + masked_images.append(masked_image) + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + masks=torch.stack(masks) + masked_images=torch.stack(masked_images) + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + "masks": masks, + "masked_images": masked_images + } + return batch + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + global_step = 0 + + for epoch in range(args.num_train_epochs): + unet.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Convert masked images to latent space + masked_latents = vae.encode(batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)).latent_dist.sample() + masked_latents = masked_latents * 0.18215 + + masks = batch["masks"] + # resize the mask to latents shape as we concatenate the mask to the latents + mask = torch.stack([torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) for mask in masks]) + mask = mask.reshape(-1, 1, args.resolution//8, args.resolution//8) + + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + #concatenate the noised latents with the mask and the masked latents + latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1) + + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample + + if args.with_prior_preservation: + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + accelerator.end_training() + + +if __name__ == "__main__": + main() From 3d7729e9c94197f6ad7dfd829c38fbcfb6ed9435 Mon Sep 17 00:00:00 2001 From: Adalberto Date: Mon, 31 Oct 2022 21:14:16 -0300 Subject: [PATCH 2/6] Update train_dreambooth_inpaint.py refactored train_dreambooth_inpaint with black --- .../dreambooth/train_dreambooth_inpaint.py | 338 ++++++++++++++---- 1 file changed, 259 insertions(+), 79 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_inpaint.py b/examples/dreambooth/train_dreambooth_inpaint.py index 5844534b2d7a..783ee7c23804 100644 --- a/examples/dreambooth/train_dreambooth_inpaint.py +++ b/examples/dreambooth/train_dreambooth_inpaint.py @@ -14,7 +14,12 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed -from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) from diffusers.optimization import get_scheduler from huggingface_hub import HfFolder, Repository, whoami from PIL import Image, ImageDraw @@ -26,6 +31,7 @@ logger = get_logger(__name__) + def prepare_mask_and_masked_image(image, mask): image = np.array(image.convert("RGB")) image = image[None].transpose(0, 3, 1, 2) @@ -42,18 +48,36 @@ def prepare_mask_and_masked_image(image, mask): return mask, masked_image -#generate random masks -def random_mask(im_shape, ratio=1): + +# generate random masks +def random_mask(im_shape, ratio=1, mask_full_image=False): mask = Image.new("L", im_shape, 0) draw = ImageDraw.Draw(mask) - size=(random.randint(0, int(im_shape[0]*ratio)), random.randint(0, int(im_shape[1]*ratio))) - # size=(int(im_shape[0]*ratio), int(im_shape[1]*ratio)) #use this to mask the whole image always - limits=(im_shape[0]-size[0]//2, im_shape[1]-size[1]//2) - center = (random.randint(size[0]//2, limits[0]), random.randint(size[1]//2, limits[1])) - draw.rectangle((center[0]-size[0]//2, center[1]-size[1]//2, center[0]+size[0]//2, center[1]+size[1]//2), fill=255) + size = ( + random.randint(0, int(im_shape[0] * ratio)), + random.randint(0, int(im_shape[1] * ratio)), + ) + # use this to always mask the whole image + if mask_full_image: + size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio)) + limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2) + center = ( + random.randint(size[0] // 2, limits[0]), + random.randint(size[1] // 2, limits[1]), + ) + draw.rectangle( + ( + center[0] - size[0] // 2, + center[1] - size[1] // 2, + center[0] + size[0] // 2, + center[1] + size[1] // 2, + ), + fill=255, + ) return mask + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -101,7 +125,12 @@ def parse_args(): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss.", + ) parser.add_argument( "--num_class_images", type=int, @@ -117,7 +146,9 @@ def parse_args(): default="text-inversion-model", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) parser.add_argument( "--resolution", type=int, @@ -128,14 +159,26 @@ def parse_args(): ), ) parser.add_argument( - "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + "--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution", ) - parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder", ) parser.add_argument( - "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + "--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=4, + help="Batch size (per device) for sampling images.", ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -177,18 +220,51 @@ def parse_args(): ), ) parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." ) parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", ) - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, @@ -215,7 +291,12 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -275,8 +356,12 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ), + transforms.CenterCrop(size) + if center_crop + else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -287,11 +372,12 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = Image.open( + self.instance_images_path[index % self.num_instance_images] + ) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") - example["PIL_images"] = instance_image example["instance_images"] = self.image_transforms(instance_image) @@ -303,7 +389,9 @@ def __getitem__(self, index): ).input_ids if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = Image.open( + self.class_images_path[index % self.num_class_images] + ) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) @@ -334,7 +422,9 @@ def __getitem__(self, index): return example -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): +def get_full_repo_name( + model_id: str, organization: Optional[str] = None, token: Optional[str] = None +): if token is None: token = HfFolder.get_token() if organization is None: @@ -358,7 +448,11 @@ def main(): # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. - if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + if ( + args.train_text_encoder + and args.gradient_accumulation_steps > 1 + and accelerator.num_processes > 1 + ): raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." @@ -374,9 +468,13 @@ def main(): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + torch_dtype = ( + torch.float16 if accelerator.device.type == "cuda" else torch.float32 + ) pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, ) pipeline.set_progress_bar_config(disable=True) @@ -384,19 +482,26 @@ def main(): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size, num_workers=1) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size, num_workers=1 + ) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, + desc="Generating class images", + disable=not accelerator.is_local_main_process, ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() - image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image_filename = ( + class_images_dir + / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + ) image.save(image_filename) del pipeline @@ -407,7 +512,9 @@ def main(): if accelerator.is_main_process: if args.push_to_hub: if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + repo_name = get_full_repo_name( + Path(args.output_dir).name, token=args.hub_token + ) else: repo_name = args.hub_model_id repo = Repository(args.output_dir, clone_from=repo_name) @@ -424,12 +531,20 @@ def main(): if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer" + ) # Load models and create wrapper for stable diffusion - text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder" + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae" + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet" + ) vae.requires_grad_(False) if not args.train_text_encoder: @@ -442,7 +557,10 @@ def main(): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs @@ -459,7 +577,9 @@ def main(): optimizer_class = torch.optim.AdamW params_to_optimize = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() ) optimizer = optimizer_class( params_to_optimize, @@ -470,7 +590,10 @@ def main(): ) noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, ) train_dataset = DreamBoothDataset( @@ -486,23 +609,27 @@ def main(): def collate_fn(examples): image_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.Resize( + args.resolution, interpolation=transforms.InterpolationMode.BILINEAR + ), + transforms.CenterCrop(args.resolution) + if args.center_crop + else transforms.RandomCrop(args.resolution), ] ) input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - masks=[] - masked_images=[] + masks = [] + masked_images = [] for example in examples: pil_image = example["PIL_images"] - #generate a random mask + # generate a random mask mask = random_mask(pil_image.size) - #apply transforms - mask=image_transforms(mask) - pil_image=image_transforms(pil_image) - #prepare mask and masked image + # apply transforms + mask = image_transforms(mask) + pil_image = image_transforms(pil_image) + # prepare mask and masked image mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) masks.append(mask) @@ -517,24 +644,31 @@ def collate_fn(examples): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids - masks=torch.stack(masks) - masked_images=torch.stack(masked_images) + input_ids = tokenizer.pad( + {"input_ids": input_ids}, padding=True, return_tensors="pt" + ).input_ids + masks = torch.stack(masks) + masked_images = torch.stack(masked_images) batch = { "input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, - "masked_images": masked_images + "masked_images": masked_images, } return batch train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -547,7 +681,13 @@ def collate_fn(examples): ) if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ( + unet, + text_encoder, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler ) else: @@ -569,7 +709,9 @@ def collate_fn(examples): text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -581,18 +723,26 @@ def collate_fn(examples): accelerator.init_trackers("dreambooth", config=vars(args)) # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + total_batch_size = ( + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar = tqdm( + range(args.max_train_steps), disable=not accelerator.is_local_main_process + ) progress_bar.set_description("Steps") global_step = 0 @@ -601,40 +751,60 @@ def collate_fn(examples): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - - latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + + latents = vae.encode( + batch["pixel_values"].to(dtype=weight_dtype) + ).latent_dist.sample() latents = latents * 0.18215 # Convert masked images to latent space - masked_latents = vae.encode(batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)).latent_dist.sample() + masked_latents = vae.encode( + batch["masked_images"] + .reshape(batch["pixel_values"].shape) + .to(dtype=weight_dtype) + ).latent_dist.sample() masked_latents = masked_latents * 0.18215 masks = batch["masks"] # resize the mask to latents shape as we concatenate the mask to the latents - mask = torch.stack([torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) for mask in masks]) - mask = mask.reshape(-1, 1, args.resolution//8, args.resolution//8) - + mask = torch.stack( + [ + torch.nn.functional.interpolate( + mask, size=(args.resolution // 8, args.resolution // 8) + ) + for mask in masks + ] + ) + mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8) # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + device=latents.device, + ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - #concatenate the noised latents with the mask and the masked latents - latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1) - + # concatenate the noised latents with the mask and the masked latents + latent_model_input = torch.cat( + [noisy_latents, mask, masked_latents], dim=1 + ) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample + noise_pred = unet( + latent_model_input, timesteps, encoder_hidden_states + ).sample if args.with_prior_preservation: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. @@ -642,15 +812,23 @@ def collate_fn(examples): noise, noise_prior = torch.chunk(noise, 2, dim=0) # Compute instance loss - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() + loss = ( + F.mse_loss(noise_pred.float(), noise.float(), reduction="none") + .mean([1, 2, 3]) + .mean() + ) # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") + prior_loss = F.mse_loss( + noise_pred_prior.float(), noise_prior.float(), reduction="mean" + ) # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + loss = F.mse_loss( + noise_pred.float(), noise.float(), reduction="mean" + ) accelerator.backward(loss) if accelerator.sync_gradients: @@ -688,7 +866,9 @@ def collate_fn(examples): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + repo.push_to_hub( + commit_message="End of training", blocking=False, auto_lfs_prune=True + ) accelerator.end_training() From 1a6f6ff47a9cc6dee38e514b0a60627facfb8d55 Mon Sep 17 00:00:00 2001 From: Adalberto Date: Mon, 31 Oct 2022 23:15:06 -0300 Subject: [PATCH 3/6] Update train_dreambooth_inpaint.py --- .../dreambooth/train_dreambooth_inpaint.py | 164 +++++------------- 1 file changed, 40 insertions(+), 124 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_inpaint.py b/examples/dreambooth/train_dreambooth_inpaint.py index 783ee7c23804..6385727c91fe 100644 --- a/examples/dreambooth/train_dreambooth_inpaint.py +++ b/examples/dreambooth/train_dreambooth_inpaint.py @@ -3,9 +3,11 @@ import itertools import math import os +import random from pathlib import Path from typing import Optional +import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -14,20 +16,14 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - StableDiffusionPipeline, - UNet2DConditionModel, -) +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from huggingface_hub import HfFolder, Repository, whoami from PIL import Image, ImageDraw -import random from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer -import numpy as np + logger = get_logger(__name__) @@ -146,9 +142,7 @@ def parse_args(): default="text-inversion-model", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, @@ -242,18 +236,14 @@ def parse_args(): default=0.999, help="The beta2 parameter for the Adam optimizer.", ) - parser.add_argument( - "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." - ) + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument( "--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer", ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument( "--push_to_hub", action="store_true", @@ -356,12 +346,8 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize( - size, interpolation=transforms.InterpolationMode.BILINEAR - ), - transforms.CenterCrop(size) - if center_crop - else transforms.RandomCrop(size), + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -372,9 +358,7 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = Image.open( - self.instance_images_path[index % self.num_instance_images] - ) + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") @@ -389,9 +373,7 @@ def __getitem__(self, index): ).input_ids if self.class_data_root: - class_image = Image.open( - self.class_images_path[index % self.num_class_images] - ) + class_image = Image.open(self.class_images_path[index % self.num_class_images]) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) @@ -422,9 +404,7 @@ def __getitem__(self, index): return example -def get_full_repo_name( - model_id: str, organization: Optional[str] = None, token: Optional[str] = None -): +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): if token is None: token = HfFolder.get_token() if organization is None: @@ -448,11 +428,7 @@ def main(): # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. - if ( - args.train_text_encoder - and args.gradient_accumulation_steps > 1 - and accelerator.num_processes > 1 - ): + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." @@ -468,9 +444,7 @@ def main(): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = ( - torch.float16 if accelerator.device.type == "cuda" else torch.float32 - ) + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -498,10 +472,7 @@ def main(): for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() - image_filename = ( - class_images_dir - / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" - ) + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) del pipeline @@ -512,9 +483,7 @@ def main(): if accelerator.is_main_process: if args.push_to_hub: if args.hub_model_id is None: - repo_name = get_full_repo_name( - Path(args.output_dir).name, token=args.hub_token - ) + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) else: repo_name = args.hub_model_id repo = Repository(args.output_dir, clone_from=repo_name) @@ -531,20 +500,12 @@ def main(): if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer" - ) + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") # Load models and create wrapper for stable diffusion - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder" - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae" - ) - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet" - ) + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") vae.requires_grad_(False) if not args.train_text_encoder: @@ -557,10 +518,7 @@ def main(): if args.scale_lr: args.learning_rate = ( - args.learning_rate - * args.gradient_accumulation_steps - * args.train_batch_size - * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs @@ -577,9 +535,7 @@ def main(): optimizer_class = torch.optim.AdamW params_to_optimize = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder - else unet.parameters() + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() ) optimizer = optimizer_class( params_to_optimize, @@ -609,12 +565,8 @@ def main(): def collate_fn(examples): image_transforms = transforms.Compose( [ - transforms.Resize( - args.resolution, interpolation=transforms.InterpolationMode.BILINEAR - ), - transforms.CenterCrop(args.resolution) - if args.center_crop - else transforms.RandomCrop(args.resolution), + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), ] ) input_ids = [example["instance_prompt_ids"] for example in examples] @@ -644,9 +596,7 @@ def collate_fn(examples): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = tokenizer.pad( - {"input_ids": input_ids}, padding=True, return_tensors="pt" - ).input_ids + input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids masks = torch.stack(masks) masked_images = torch.stack(masked_images) batch = { @@ -666,9 +616,7 @@ def collate_fn(examples): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -687,9 +635,7 @@ def collate_fn(examples): optimizer, train_dataloader, lr_scheduler, - ) = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + ) = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader, lr_scheduler) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler @@ -709,9 +655,7 @@ def collate_fn(examples): text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -723,26 +667,18 @@ def collate_fn(examples): accelerator.init_trackers("dreambooth", config=vars(args)) # Train! - total_batch_size = ( - args.train_batch_size - * accelerator.num_processes - * args.gradient_accumulation_steps - ) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" - ) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. - progress_bar = tqdm( - range(args.max_train_steps), disable=not accelerator.is_local_main_process - ) + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") global_step = 0 @@ -752,16 +688,12 @@ def collate_fn(examples): with accelerator.accumulate(unet): # Convert images to latent space - latents = vae.encode( - batch["pixel_values"].to(dtype=weight_dtype) - ).latent_dist.sample() + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 # Convert masked images to latent space masked_latents = vae.encode( - batch["masked_images"] - .reshape(batch["pixel_values"].shape) - .to(dtype=weight_dtype) + batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype) ).latent_dist.sample() masked_latents = masked_latents * 0.18215 @@ -769,9 +701,7 @@ def collate_fn(examples): # resize the mask to latents shape as we concatenate the mask to the latents mask = torch.stack( [ - torch.nn.functional.interpolate( - mask, size=(args.resolution // 8, args.resolution // 8) - ) + torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) for mask in masks ] ) @@ -794,17 +724,13 @@ def collate_fn(examples): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # concatenate the noised latents with the mask and the masked latents - latent_model_input = torch.cat( - [noisy_latents, mask, masked_latents], dim=1 - ) + latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - noise_pred = unet( - latent_model_input, timesteps, encoder_hidden_states - ).sample + noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample if args.with_prior_preservation: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. @@ -812,23 +738,15 @@ def collate_fn(examples): noise, noise_prior = torch.chunk(noise, 2, dim=0) # Compute instance loss - loss = ( - F.mse_loss(noise_pred.float(), noise.float(), reduction="none") - .mean([1, 2, 3]) - .mean() - ) + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = F.mse_loss( - noise_pred_prior.float(), noise_prior.float(), reduction="mean" - ) + prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss( - noise_pred.float(), noise.float(), reduction="mean" - ) + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: @@ -866,9 +784,7 @@ def collate_fn(examples): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: - repo.push_to_hub( - commit_message="End of training", blocking=False, auto_lfs_prune=True - ) + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) accelerator.end_training() From a58d6aaec6affdfc0f0fe70c92719b3e44855217 Mon Sep 17 00:00:00 2001 From: Adalberto Date: Tue, 1 Nov 2022 18:24:13 -0300 Subject: [PATCH 4/6] Update train_dreambooth_inpaint.py --- .../dreambooth/train_dreambooth_inpaint.py | 183 +++++++----------- 1 file changed, 73 insertions(+), 110 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_inpaint.py b/examples/dreambooth/train_dreambooth_inpaint.py index 6385727c91fe..f0f88fbc9953 100644 --- a/examples/dreambooth/train_dreambooth_inpaint.py +++ b/examples/dreambooth/train_dreambooth_inpaint.py @@ -74,7 +74,7 @@ def random_mask(im_shape, ratio=1, mask_full_image=False): return mask -def parse_args(): +def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", @@ -83,6 +83,13 @@ def parse_args(): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) parser.add_argument( "--tokenizer_name", type=str, @@ -121,12 +128,7 @@ def parse_args(): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument( - "--prior_loss_weight", - type=float, - default=1.0, - help="The weight of prior preservation loss.", - ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( "--num_class_images", type=int, @@ -153,26 +155,14 @@ def parse_args(): ), ) parser.add_argument( - "--center_crop", - action="store_true", - help="Whether to center crop images before resizing to resolution", + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") parser.add_argument( - "--train_text_encoder", - action="store_true", - help="Whether to train the text encoder", - ) - parser.add_argument( - "--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.", + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument( - "--sample_batch_size", - type=int, - default=4, - help="Batch size (per device) for sampling images.", + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -214,47 +204,18 @@ def parse_args(): ), ) parser.add_argument( - "--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.", + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( - "--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.", - ) - parser.add_argument( - "--adam_beta1", - type=float, - default=0.9, - help="The beta1 parameter for the Adam optimizer.", - ) - parser.add_argument( - "--adam_beta2", - type=float, - default=0.999, - help="The beta2 parameter for the Adam optimizer.", + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument( - "--adam_epsilon", - type=float, - default=1e-08, - help="Epsilon value for the Adam optimizer", - ) + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument( - "--push_to_hub", - action="store_true", - help="Whether or not to push the model to the Hub.", - ) - parser.add_argument( - "--hub_token", - type=str, - default=None, - help="The token to use to push to the Model Hub.", - ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, @@ -281,14 +242,13 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) - parser.add_argument( - "--local_rank", - type=int, - default=-1, - help="For distributed training: local_rank", - ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() - args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -361,10 +321,8 @@ def __getitem__(self, index): instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") - example["PIL_images"] = instance_image example["instance_images"] = self.image_transforms(instance_image) - example["instance_prompt_ids"] = self.tokenizer( self.instance_prompt, padding="do_not_pad", @@ -414,8 +372,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -def main(): - args = parse_args() +def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator = Accelerator( @@ -449,6 +406,7 @@ def main(): args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None, + revision=args.revision, ) pipeline.set_progress_bar_config(disable=True) @@ -456,17 +414,13 @@ def main(): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader( - sample_dataset, batch_size=args.sample_batch_size, num_workers=1 - ) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process, + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images @@ -498,14 +452,33 @@ def main(): # Load the tokenizer if args.tokenizer_name: - tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + tokenizer = CLIPTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + ) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) # Load models and create wrapper for stable diffusion - text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + ) vae.requires_grad_(False) if not args.train_text_encoder: @@ -545,12 +518,7 @@ def main(): eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - ) + noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -563,17 +531,19 @@ def main(): ) def collate_fn(examples): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + masks = [] + masked_images = [] + image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), ] ) - input_ids = [example["instance_prompt_ids"] for example in examples] - pixel_values = [example["instance_images"] for example in examples] - masks = [] - masked_images = [] for example in examples: pil_image = example["PIL_images"] # generate a random mask @@ -597,8 +567,10 @@ def collate_fn(examples): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + masks = torch.stack(masks) masked_images = torch.stack(masked_images) + batch = { "input_ids": input_ids, "pixel_values": pixel_values, @@ -608,10 +580,7 @@ def collate_fn(examples): return batch train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 ) # Scheduler and math around the number of training steps. @@ -629,13 +598,9 @@ def collate_fn(examples): ) if args.train_text_encoder: - ( - unet, - text_encoder, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler @@ -684,10 +649,11 @@ def collate_fn(examples): for epoch in range(args.num_train_epochs): unet.train() + if args.train_text_encoder: + text_encoder.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 @@ -711,12 +677,7 @@ def collate_fn(examples): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (bsz,), - device=latents.device, - ) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep @@ -780,6 +741,7 @@ def collate_fn(examples): args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, ) pipeline.save_pretrained(args.output_dir) @@ -790,4 +752,5 @@ def collate_fn(examples): if __name__ == "__main__": - main() + args = parse_args() + main(args) From 6dd054b2370d6328e153ce3eea0903828940037a Mon Sep 17 00:00:00 2001 From: Adalberto Date: Fri, 18 Nov 2022 16:21:03 -0300 Subject: [PATCH 5/6] Update train_dreambooth_inpaint.py Fix prior preservation --- .../dreambooth/train_dreambooth_inpaint.py | 166 ++++++++---------- 1 file changed, 76 insertions(+), 90 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_inpaint.py b/examples/dreambooth/train_dreambooth_inpaint.py index f0f88fbc9953..0ddc66289c0c 100644 --- a/examples/dreambooth/train_dreambooth_inpaint.py +++ b/examples/dreambooth/train_dreambooth_inpaint.py @@ -16,7 +16,13 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed -from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + UNet2DConditionModel, +) from diffusers.optimization import get_scheduler from huggingface_hub import HfFolder, Repository, whoami from PIL import Image, ImageDraw @@ -49,32 +55,28 @@ def prepare_mask_and_masked_image(image, mask): def random_mask(im_shape, ratio=1, mask_full_image=False): mask = Image.new("L", im_shape, 0) draw = ImageDraw.Draw(mask) - size = ( - random.randint(0, int(im_shape[0] * ratio)), - random.randint(0, int(im_shape[1] * ratio)), - ) + size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio))) # use this to always mask the whole image if mask_full_image: size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio)) limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2) - center = ( - random.randint(size[0] // 2, limits[0]), - random.randint(size[1] // 2, limits[1]), - ) - draw.rectangle( - ( - center[0] - size[0] // 2, - center[1] - size[1] // 2, - center[0] + size[0] // 2, - center[1] + size[1] // 2, - ), - fill=255, - ) + center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1])) + draw_type = random.randint(0, 1) + if draw_type == 0 or mask_full_image: + draw.rectangle( + (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2), + fill=255, + ) + else: + draw.ellipse( + (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2), + fill=255, + ) return mask -def parse_args(input_args=None): +def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", @@ -83,13 +85,6 @@ def parse_args(input_args=None): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", - ) parser.add_argument( "--tokenizer_name", type=str, @@ -244,11 +239,7 @@ def parse_args(input_args=None): ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - if input_args is not None: - args = parser.parse_args(input_args) - else: - args = parser.parse_args() - + args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -321,8 +312,10 @@ def __getitem__(self, index): instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") + example["PIL_images"] = instance_image example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( self.instance_prompt, padding="do_not_pad", @@ -335,6 +328,7 @@ def __getitem__(self, index): if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) + example["class_PIL_images"] = class_image example["class_prompt_ids"] = self.tokenizer( self.class_prompt, padding="do_not_pad", @@ -372,7 +366,8 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -def main(args): +def main(): + args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) accelerator = Accelerator( @@ -402,11 +397,8 @@ def main(args): if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 - pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - torch_dtype=torch_dtype, - safety_checker=None, - revision=args.revision, + pipeline = StableDiffusionInpaintPipeline.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None ) pipeline.set_progress_bar_config(disable=True) @@ -414,15 +406,24 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size, num_workers=1 + ) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) - + transform_to_pil = transforms.ToPILImage() for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - images = pipeline(example["prompt"]).images + bsz = len(example["prompt"]) + fake_images = torch.rand((3, args.resolution, args.resolution)) + transform_to_pil = transforms.ToPILImage() + fake_pil_images = transform_to_pil(fake_images) + + fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True) + + images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() @@ -452,33 +453,14 @@ def main(args): # Load the tokenizer if args.tokenizer_name: - tokenizer = CLIPTokenizer.from_pretrained( - args.tokenizer_name, - revision=args.revision, - ) + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=args.revision, - ) + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") # Load models and create wrapper for stable diffusion - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="vae", - revision=args.revision, - ) - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - ) + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") vae.requires_grad_(False) if not args.train_text_encoder: @@ -518,7 +500,9 @@ def main(args): eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 + ) train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -531,23 +515,28 @@ def main(args): ) def collate_fn(examples): - input_ids = [example["instance_prompt_ids"] for example in examples] - pixel_values = [example["instance_images"] for example in examples] - - masks = [] - masked_images = [] - image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), ] ) + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + pior_pil = [example["class_PIL_images"] for example in examples] + masks = [] + masked_images = [] for example in examples: pil_image = example["PIL_images"] # generate a random mask - mask = random_mask(pil_image.size) + mask = random_mask(pil_image.size, 1, False) # apply transforms mask = image_transforms(mask) pil_image = image_transforms(pil_image) @@ -557,30 +546,30 @@ def collate_fn(examples): masks.append(mask) masked_images.append(masked_image) - # Concat class and instance examples for prior preservation. - # We do this to avoid doing two forward passes. if args.with_prior_preservation: - input_ids += [example["class_prompt_ids"] for example in examples] - pixel_values += [example["class_images"] for example in examples] + for pil_image in pior_pil: + # generate a random mask + mask = random_mask(pil_image.size, 1, False) + # apply transforms + mask = image_transforms(mask) + pil_image = image_transforms(pil_image) + # prepare mask and masked image + mask, masked_image = prepare_mask_and_masked_image(pil_image, mask) + + masks.append(mask) + masked_images.append(masked_image) pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids - masks = torch.stack(masks) masked_images = torch.stack(masked_images) - - batch = { - "input_ids": input_ids, - "pixel_values": pixel_values, - "masks": masks, - "masked_images": masked_images, - } + batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images} return batch train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn ) # Scheduler and math around the number of training steps. @@ -649,11 +638,10 @@ def collate_fn(examples): for epoch in range(args.num_train_epochs): unet.train() - if args.train_text_encoder: - text_encoder.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 @@ -741,7 +729,6 @@ def collate_fn(examples): args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), text_encoder=accelerator.unwrap_model(text_encoder), - revision=args.revision, ) pipeline.save_pretrained(args.output_dir) @@ -752,5 +739,4 @@ def collate_fn(examples): if __name__ == "__main__": - args = parse_args() - main(args) + main() From 107f4405427c8f99fd41eefba61d51d321a937c6 Mon Sep 17 00:00:00 2001 From: thedarkzeno Date: Fri, 2 Dec 2022 01:19:20 -0300 Subject: [PATCH 6/6] add instructions to readme, fix SD2 compatibility --- examples/dreambooth/README.md | 118 ++++++++++++++++++ .../dreambooth/train_dreambooth_inpaint.py | 20 +-- 2 files changed, 131 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 08ebdd5c32c0..2f43153c7770 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -304,3 +304,121 @@ python train_dreambooth_flax.py \ --num_class_images=200 \ --max_train_steps=800 ``` + +## Dreambooth for the inpainting model + + +```bash +export MODEL_NAME="runwayml/stable-diffusion-inpainting" +export INSTANCE_DIR="path-to-instance-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth_inpaint.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=400 +``` + +The script is also compatible with prior preservation loss and gradient checkpointing + +### Training with prior-preservation loss + +Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. +According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. + +```bash +export MODEL_NAME="runwayml/stable-diffusion-inpainting" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth_inpaint.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + + +### Training with gradient checkpointing and 8-bit optimizer: + +With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU. + +To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation). + +```bash +export MODEL_NAME="runwayml/stable-diffusion-inpainting" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth_inpaint.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=2 --gradient_checkpointing \ + --use_8bit_adam \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 +``` + +### Fine-tune text encoder with the UNet. + +The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces. +Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`. + +___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___ + +```bash +export MODEL_NAME="runwayml/stable-diffusion-inpainting" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth_inpaint.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_text_encoder \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --use_8bit_adam \ + --gradient_checkpointing \ + --learning_rate=2e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 +``` \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth_inpaint.py b/examples/dreambooth/train_dreambooth_inpaint.py index 0ddc66289c0c..bb5672669d5b 100644 --- a/examples/dreambooth/train_dreambooth_inpaint.py +++ b/examples/dreambooth/train_dreambooth_inpaint.py @@ -500,9 +500,7 @@ def main(): eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 - ) + noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -681,21 +679,29 @@ def collate_fn(examples): # Predict the noise residual noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + if args.with_prior_preservation: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) - noise, noise_prior = torch.chunk(noise, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") + prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: