diff --git a/examples/v_prediction/train_butterflies.py b/examples/v_prediction/train_butterflies.py new file mode 100644 index 000000000000..5074ece86a98 --- /dev/null +++ b/examples/v_prediction/train_butterflies.py @@ -0,0 +1,227 @@ +import glob +import os +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +from accelerate import Accelerator +from datasets import load_dataset +from diffusers import DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, UNet2DModel +from diffusers.hub_utils import init_git_repo, push_to_hub +from diffusers.optimization import get_cosine_schedule_with_warmup +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm + + +@dataclass +class TrainingConfig: + image_size = 128 # the generated image resolution + train_batch_size = 16 + eval_batch_size = 16 # how many images to sample during evaluation + num_epochs = 50 + gradient_accumulation_steps = 1 + learning_rate = 5e-5 + lr_warmup_steps = 500 + save_image_epochs = 10 + save_model_epochs = 30 + mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision + output_dir = "ddim-butterflies-128-v-diffusion" # the model namy locally and on the HF Hub + + push_to_hub = False # whether to upload the saved model to the HF Hub + hub_private_repo = False + overwrite_output_dir = True # overwrite the old model when re-running the notebook + seed = 0 + + +config = TrainingConfig() + + +config.dataset_name = "huggan/smithsonian_butterflies_subset" +dataset = load_dataset(config.dataset_name, split="train") + + +preprocess = transforms.Compose( + [ + transforms.Resize((config.image_size, config.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + + +def transform(examples): + images = [preprocess(image.convert("RGB")) for image in examples["image"]] + return {"images": images} + + +dataset.set_transform(transform) + + +train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True) + + +model = UNet2DModel( + sample_size=config.image_size, # the target image resolution + in_channels=3, # the number of input channels, 3 for RGB images + out_channels=3, # the number of output channels + layers_per_block=2, # how many ResNet layers to use per UNet block + block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block + down_block_types=( + "DownBlock2D", # a regular ResNet downsampling block + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", # a regular ResNet upsampling block + "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), +) + + +if config.output_dir.startswith("ddpm"): + noise_scheduler = DDPMScheduler( + num_train_timesteps=1000, + beta_schedule="squaredcos_cap_v2", + variance_type="v_diffusion", + prediction_type="v", + ) +else: + noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_schedule="squaredcos_cap_v2", + variance_type="v_diffusion", + prediction_type="v", + ) + + +optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) + + +lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=config.lr_warmup_steps, + num_training_steps=(len(train_dataloader) * config.num_epochs), +) + + +def make_grid(images, rows, cols): + w, h = images[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + for i, image in enumerate(images): + grid.paste(image, box=(i % cols * w, i // cols * h)) + return grid + + +def evaluate(config, epoch, pipeline): + # Sample some images from random noise (this is the backward diffusion process). + # The default pipeline output type is `List[PIL.Image]` + images = pipeline( + batch_size=config.eval_batch_size, + generator=torch.manual_seed(config.seed), + ).images + + # Make a grid out of the images + image_grid = make_grid(images, rows=4, cols=4) + + # Save the images + test_dir = os.path.join(config.output_dir, "samples") + os.makedirs(test_dir, exist_ok=True) + image_grid.save(f"{test_dir}/{epoch:04d}.png") + + +def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): + # Initialize accelerator and tensorboard logging + accelerator = Accelerator( + mixed_precision=config.mixed_precision, + gradient_accumulation_steps=config.gradient_accumulation_steps, + log_with="tensorboard", + logging_dir=os.path.join(config.output_dir, "logs"), + ) + if accelerator.is_main_process: + if config.push_to_hub: + repo = init_git_repo(config, at_init=True) + accelerator.init_trackers("train_example") + + # Prepare everything + # There is no specific order to remember, you just need to unpack the + # objects in the same order you gave them to the prepare method. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + + if config.output_dir.startswith("ddpm"): + pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + else: + pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + + evaluate(config, 0, pipeline) + + # Now you train the model + for epoch in range(config.num_epochs): + progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Epoch {epoch}") + + for step, batch in enumerate(train_dataloader): + clean_images = batch["images"] + # Sample noise to add to the images + noise = torch.randn(clean_images.shape).to(clean_images.device) + bs = clean_images.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long() + + with accelerator.accumulate(model): + # Predict the noise residual + alpha_t, sigma_t = noise_scheduler.get_alpha_sigma(clean_images, timesteps, accelerator.device) + z_t = alpha_t * clean_images + sigma_t * noise + noise_pred = model(z_t, timesteps).sample + v = alpha_t * noise - sigma_t * clean_images + loss = F.mse_loss(noise_pred, v) + accelerator.backward(loss) + + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + progress_bar.update(1) + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + global_step += 1 + + # After each epoch you optionally sample some demo images with evaluate() and save the model + if accelerator.is_main_process: + if config.output_dir.startswith("ddpm"): + pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + else: + pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + + if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: + evaluate(config, epoch, pipeline) + + if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: + if config.push_to_hub: + push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True) + else: + pipeline.save_pretrained(config.output_dir) + + +args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler) + +train_loop(*args) + +sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png")) +Image.open(sample_images[-1]) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 972e4d45b079..89d90ba60ad4 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union import numpy as np import torch @@ -27,6 +27,17 @@ from .scheduling_utils import SchedulerMixin +def expand_to_shape(input, timesteps, shape, device): + """ + Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast + nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once. + """ + out = torch.gather(input.to(device), 0, timesteps.to(device)) + reshape = [shape[0]] + [1] * (len(shape) - 1) + out = out.reshape(*reshape) + return out + + @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): @@ -75,6 +86,18 @@ def alpha_bar(time_step): return torch.tensor(betas) +def t_to_alpha_sigma(num_diffusion_timesteps): + """Returns the scaling factors for the clean image and for the noise, given + a timestep.""" + alphas = torch.cos( + torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)]) + ) + sigmas = torch.sin( + torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)]) + ) + return alphas, sigmas + + class DDIMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising @@ -128,7 +151,10 @@ def __init__( trained_betas: Optional[np.ndarray] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, + variance_type: str = "fixed", steps_offset: int = 0, + prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", + **kwargs, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -145,15 +171,18 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + self.variance_type = variance_type self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.sigmas = 1 - self.alphas**2 + if prediction_type == "v": + self.alphas, self.sigmas = t_to_alpha_sigma(num_train_timesteps) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.final_sigma = torch.tensor(0.0) if set_alpha_to_one else self.sigmas[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -161,6 +190,8 @@ def __init__( # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.variance_type = variance_type + self.prediction_type = prediction_type def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ @@ -170,20 +201,31 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep - Returns: `torch.FloatTensor`: scaled input sample """ return sample - def _get_variance(self, timestep, prev_timestep): + def _get_variance(self, timestep, prev_timestep, eta=0): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev - variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) - + if self.variance_type == "fixed": + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + elif self.variance_type == "v_diffusion": + # If eta > 0, adjust the scaling factor for the predicted noise + # downward according to the amount of additional noise to add + alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + sigma_prev = self.sigmas[prev_timestep] if prev_timestep >= 0 else self.final_sigma + if eta: + numerator = eta * (sigma_prev**2 / self.sigmas[timestep] ** 2).clamp(min=1.0e-7).sqrt() + else: + numerator = 0 + denominator = (1 - self.alphas[timestep] ** 2 / alpha_prev**2).clamp(min=1.0e-7).sqrt() + ddim_sigma = (numerator * denominator).clamp(min=1.0e-7) + variance = (sigma_prev**2 - ddim_sigma**2).clamp(min=1.0e-7).sqrt() return variance def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): @@ -207,7 +249,6 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, - prediction_type: str = "epsilon", eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, @@ -271,19 +312,21 @@ def step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - if prediction_type == "epsilon": + if self.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) eps = torch.tensor(1) - elif prediction_type == "sample": + elif self.prediction_type == "sample": pred_original_sample = model_output eps = torch.tensor(1) - elif prediction_type == "v": + elif self.prediction_type == "v": # v_t = alpha_t * epsilon - sigma_t * x # need to merge the PRs for sigma to be available in DDPM pred_original_sample = sample * self.alphas[timestep] - model_output * self.sigmas[timestep] - eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep] + eps = model_output * self.alphas[timestep] + sample * self.sigmas[timestep] else: - raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`") + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`" + ) # 4. Clip "predicted x_0" if self.config.clip_sample: @@ -291,7 +334,7 @@ def step( # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep) + variance = self._get_variance(timestep, prev_timestep, eta) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: @@ -299,10 +342,14 @@ def step( model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + if self.prediction_type == "epsilon": + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output - # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction + else: + alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + prev_sample = pred_original_sample * alpha_prev + eps * variance if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 @@ -325,7 +372,6 @@ def step( variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise prev_sample = prev_sample + variance - if not return_dict: return (prev_sample,) @@ -337,6 +383,10 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: + if self.variance_type == "v_diffusion": + alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device) + z_t = alpha * original_samples + sigma * noise + return z_t # Make sure alphas_cumprod and timestep have same device and dtype as original_samples self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) @@ -356,3 +406,8 @@ def add_noise( def __len__(self): return self.config.num_train_timesteps + + def get_alpha_sigma(self, sample, timesteps, device): + alpha = expand_to_shape(self.alphas, timesteps, sample.shape, device) + sigma = expand_to_shape(self.sigmas, timesteps, sample.shape, device) + return alpha, sigma