From a4c91182085b69e1c9bff03f92ac1d62b0e24fff Mon Sep 17 00:00:00 2001 From: anton- Date: Thu, 5 Jan 2023 16:51:46 +0100 Subject: [PATCH 01/20] improve EMA --- .../unconditional_image_generation/README.md | 2 + .../train_unconditional.py | 194 ++++++++++++++++-- src/diffusers/training_utils.py | 9 +- 3 files changed, 189 insertions(+), 16 deletions(-) diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index 010200b5a9e9..01aa32e746e7 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -39,6 +39,7 @@ accelerate launch train_unconditional.py \ --train_batch_size=16 \ --num_epochs=100 \ --gradient_accumulation_steps=1 \ + --use_ema \ --learning_rate=1e-4 \ --lr_warmup_steps=500 \ --mixed_precision=no \ @@ -63,6 +64,7 @@ accelerate launch train_unconditional.py \ --train_batch_size=16 \ --num_epochs=100 \ --gradient_accumulation_steps=1 \ + --use_ema \ --learning_rate=1e-4 \ --lr_warmup_steps=500 \ --mixed_precision=no \ diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index af3d0ddc2259..c50b7a22ca25 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -1,9 +1,10 @@ import argparse +import copy import inspect import math import os from pathlib import Path -from typing import Optional +from typing import Iterable, Optional import torch import torch.nn.functional as F @@ -13,7 +14,6 @@ from datasets import load_dataset from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version from huggingface_hub import HfFolder, Repository, whoami from torchvision.transforms import ( @@ -29,7 +29,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.10.0.dev0") +check_min_version("0.12.0.dev0") logger = get_logger(__name__) @@ -156,7 +156,6 @@ def parse_args(): parser.add_argument( "--use_ema", action="store_true", - default=True, help="Whether to use Exponential Moving Average for the final model weights.", ) parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") @@ -253,6 +252,153 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + update_after_step=0, + inv_gamma=1.0, + power=2 / 3, + min_value=0.0, + max_value=0.9999, + ): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 2/3. + min_value (float): The minimum EMA decay rate. Default: 0. + """ + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + + self.collected_params = None + + self.update_after_step = update_after_step + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + + self.decay = 0.0 + self.optimization_step = 0 + + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma) ** -self.power + + if step <= 0: + return 0.0 + + return max(self.min_value, min(value, self.max_value)) + + @torch.no_grad() + def step(self, parameters): + parameters = list(parameters) + + self.optimization_step += 1 + + # Compute the decay factor for the exponential moving average. + self.decay = self.get_decay(self.optimization_step) + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.mul_(self.decay) + s_param.add_(param.data, alpha=1 - self.decay) + else: + s_param.copy_(param) + + torch.cuda.empty_cache() + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] + + def state_dict(self) -> dict: + r""" + Returns the state of the ExponentialMovingAverage as a dict. + This method is used by accelerate during checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "optimization_step": self.optimization_step, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params, + } + + def load_state_dict(self, state_dict: dict) -> None: + r""" + Loads the ExponentialMovingAverage state. + This method is used by accelerate during checkpointing to save the ema state dict. + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.optimization_step = state_dict["optimization_step"] + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.shadow_params = state_dict["shadow_params"] + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + if not isinstance(self.collected_params, list): + raise ValueError("collected_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.collected_params): + raise ValueError("collected_params must all be Tensors") + if len(self.collected_params) != len(self.shadow_params): + raise ValueError("collected_params and shadow_params must have the same length") + + def main(args): logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator = Accelerator( @@ -342,19 +488,34 @@ def transforms(examples): num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, ) + if args.use_ema: + ema_model = EMAModel( + model.parameters(), + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay, + ) + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) - accelerator.register_for_checkpointing(lr_scheduler) + if args.use_ema: + accelerator.register_for_checkpointing(ema_model, lr_scheduler) - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 - ema_model = EMAModel( - accelerator.unwrap_model(model), - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - ) + # Move text_encode and vae to gpu and cast to weight_dtype + model.to(accelerator.device, dtype=weight_dtype) + if args.use_ema: + ema_model.to(accelerator.device) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) # Handle the repository creation if accelerator.is_main_process: @@ -445,12 +606,12 @@ def transforms(examples): accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() - if args.use_ema: - ema_model.step(model) optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + if args.use_ema: + ema_model.step(model.parameters()) progress_bar.update(1) global_step += 1 @@ -472,8 +633,11 @@ def transforms(examples): # Generate sample images for visual inspection if accelerator.is_main_process: if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: + unet = copy.deepcopy(accelerator.unwrap_model(model)) + if args.use_ema: + ema_model.copy_to(unet.parameters()) pipeline = DDPMPipeline( - unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model), + unet=unet, scheduler=noise_scheduler, ) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index fefc490c1f01..aa68a01fe4fa 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,6 +5,8 @@ import numpy as np import torch +from diffusers.utils import deprecate + def enable_full_determinism(seed: int): """ @@ -66,7 +68,12 @@ def __init__( power (float): Exponential factor of EMA warmup. Default: 2/3. min_value (float): The minimum EMA decay rate. Default: 0. """ - + deprecation_message = ( + f"`diffusers.training_utils.EMAModel` is deprecated in favor of \n" + f"`EMAModel` in `examples/unconditional_image_generation/train_unconditional.py` \n" + f"and will be removed in version v1.0.0" + ) + deprecate("EMAModel", "1.0.0", deprecation_message, standard_warn=False) self.averaged_model = copy.deepcopy(model).eval() self.averaged_model.requires_grad_(False) From 4e32811a3548aececa3771a9cc7ac40a5ddfd554 Mon Sep 17 00:00:00 2001 From: anton- Date: Thu, 5 Jan 2023 16:52:20 +0100 Subject: [PATCH 02/20] style --- src/diffusers/training_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index aa68a01fe4fa..1e742f37ec7b 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -69,9 +69,9 @@ def __init__( min_value (float): The minimum EMA decay rate. Default: 0. """ deprecation_message = ( - f"`diffusers.training_utils.EMAModel` is deprecated in favor of \n" - f"`EMAModel` in `examples/unconditional_image_generation/train_unconditional.py` \n" - f"and will be removed in version v1.0.0" + "`diffusers.training_utils.EMAModel` is deprecated in favor of \n" + "`EMAModel` in `examples/unconditional_image_generation/train_unconditional.py` \n" + "and will be removed in version v1.0.0" ) deprecate("EMAModel", "1.0.0", deprecation_message, standard_warn=False) self.averaged_model = copy.deepcopy(model).eval() From 2bb2d385cf4dc49c3250b330b77874f74057a2de Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 17 Jan 2023 13:45:12 +0100 Subject: [PATCH 03/20] one EMA model --- examples/text_to_image/train_text_to_image.py | 110 +--------- .../train_unconditional.py | 151 +------------- src/diffusers/training_utils.py | 189 +++++++++++++----- 3 files changed, 139 insertions(+), 311 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 841849dcf3bf..589e1883580b 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -21,6 +21,7 @@ from datasets import load_dataset from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, whoami @@ -290,115 +291,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: } -# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 -class EMAModel: - """ - Exponential Moving Average of models weights - """ - - def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters] - - self.collected_params = None - - self.decay = decay - self.optimization_step = 0 - - @torch.no_grad() - def step(self, parameters): - parameters = list(parameters) - - self.optimization_step += 1 - - # Compute the decay factor for the exponential moving average. - value = (1 + self.optimization_step) / (10 + self.optimization_step) - one_minus_decay = 1 - min(self.decay, value) - - for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) - else: - s_param.copy_(param) - - torch.cuda.empty_cache() - - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: - """ - Copy current averaged parameters into given collection of parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) - - def to(self, device=None, dtype=None) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. - - Args: - device: like `device` argument to `torch.Tensor.to` - """ - # .to() on the tensors handles None correctly - self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) - for p in self.shadow_params - ] - - def state_dict(self) -> dict: - r""" - Returns the state of the ExponentialMovingAverage as a dict. - This method is used by accelerate during checkpointing to save the ema state dict. - """ - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { - "decay": self.decay, - "optimization_step": self.optimization_step, - "shadow_params": self.shadow_params, - "collected_params": self.collected_params, - } - - def load_state_dict(self, state_dict: dict) -> None: - r""" - Loads the ExponentialMovingAverage state. - This method is used by accelerate during checkpointing to save the ema state dict. - Args: - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) - - self.decay = state_dict["decay"] - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.optimization_step = state_dict["optimization_step"] - if not isinstance(self.optimization_step, int): - raise ValueError("Invalid optimization_step") - - self.shadow_params = state_dict["shadow_params"] - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") - - self.collected_params = state_dict["collected_params"] - if self.collected_params is not None: - if not isinstance(self.collected_params, list): - raise ValueError("collected_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.collected_params): - raise ValueError("collected_params must all be Tensors") - if len(self.collected_params) != len(self.shadow_params): - raise ValueError("collected_params and shadow_params must have the same length") - - def main(): args = parse_args() logging_dir = os.path.join(args.output_dir, args.logging_dir) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 5170fbef5880..41466441d72b 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -14,6 +14,7 @@ from datasets import load_dataset from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version from huggingface_hub import HfFolder, Repository, whoami from torchvision.transforms import ( @@ -252,153 +253,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 -class EMAModel: - """ - Exponential Moving Average of models weights - """ - - def __init__( - self, - parameters: Iterable[torch.nn.Parameter], - update_after_step=0, - inv_gamma=1.0, - power=2 / 3, - min_value=0.0, - max_value=0.9999, - ): - """ - @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan - to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), - gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 - at 215.4k steps). - - Args: - inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. - power (float): Exponential factor of EMA warmup. Default: 2/3. - min_value (float): The minimum EMA decay rate. Default: 0. - """ - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters] - - self.collected_params = None - - self.update_after_step = update_after_step - self.inv_gamma = inv_gamma - self.power = power - self.min_value = min_value - self.max_value = max_value - - self.decay = 0.0 - self.optimization_step = 0 - - def get_decay(self, optimization_step): - """ - Compute the decay factor for the exponential moving average. - """ - step = max(0, optimization_step - self.update_after_step - 1) - value = 1 - (1 + step / self.inv_gamma) ** -self.power - - if step <= 0: - return 0.0 - - return max(self.min_value, min(value, self.max_value)) - - @torch.no_grad() - def step(self, parameters): - parameters = list(parameters) - - self.optimization_step += 1 - - # Compute the decay factor for the exponential moving average. - self.decay = self.get_decay(self.optimization_step) - - for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.mul_(self.decay) - s_param.add_(param.data, alpha=1 - self.decay) - else: - s_param.copy_(param) - - torch.cuda.empty_cache() - - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: - """ - Copy current averaged parameters into given collection of parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) - - def to(self, device=None, dtype=None) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. - - Args: - device: like `device` argument to `torch.Tensor.to` - """ - # .to() on the tensors handles None correctly - self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) - for p in self.shadow_params - ] - - def state_dict(self) -> dict: - r""" - Returns the state of the ExponentialMovingAverage as a dict. - This method is used by accelerate during checkpointing to save the ema state dict. - """ - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { - "decay": self.decay, - "optimization_step": self.optimization_step, - "shadow_params": self.shadow_params, - "collected_params": self.collected_params, - } - - def load_state_dict(self, state_dict: dict) -> None: - r""" - Loads the ExponentialMovingAverage state. - This method is used by accelerate during checkpointing to save the ema state dict. - Args: - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) - - self.decay = state_dict["decay"] - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.optimization_step = state_dict["optimization_step"] - if not isinstance(self.optimization_step, int): - raise ValueError("Invalid optimization_step") - - self.shadow_params = state_dict["shadow_params"] - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") - - self.collected_params = state_dict["collected_params"] - if self.collected_params is not None: - if not isinstance(self.collected_params, list): - raise ValueError("collected_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.collected_params): - raise ValueError("collected_params must all be Tensors") - if len(self.collected_params) != len(self.shadow_params): - raise ValueError("collected_params and shadow_params must have the same length") - - def main(args): logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator = Accelerator( @@ -491,9 +345,10 @@ def transforms(examples): if args.use_ema: ema_model = EMAModel( model.parameters(), + decay=args.ema_max_decay, + use_ema_warmup=True, inv_gamma=args.ema_inv_gamma, power=args.ema_power, - max_value=args.ema_max_decay, ) model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 1e742f37ec7b..c1fc91b22786 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,6 +1,7 @@ import copy import os import random +from typing import Iterable, Union import numpy as np import torch @@ -41,6 +42,7 @@ def set_seed(seed: int): # ^^ safe to call this function even if cuda is not available +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ Exponential Moving Average of models weights @@ -48,45 +50,39 @@ class EMAModel: def __init__( self, - model, - update_after_step=0, - inv_gamma=1.0, - power=2 / 3, - min_value=0.0, - max_value=0.9999, - device=None, + parameters: Iterable[torch.nn.Parameter], + decay: float = 0.9999, + update_after_step: int = 0, + use_ema_warmup: bool = False, + inv_gamma: Union[float, int] = 1.0, + power: Union[float, int] = 2 / 3, ): """ + Args: + parameters (Iterable[torch.nn.Parameter]): The parameters to track. + decay (float): The decay factor for the exponential moving average. + update_after_step (int): The number of steps to wait before starting to update the EMA weights. + use_ema_warmup (bool): Whether to use EMA warmup. + inv_gamma (float): + Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. + power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). - - Args: - inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. - power (float): Exponential factor of EMA warmup. Default: 2/3. - min_value (float): The minimum EMA decay rate. Default: 0. """ - deprecation_message = ( - "`diffusers.training_utils.EMAModel` is deprecated in favor of \n" - "`EMAModel` in `examples/unconditional_image_generation/train_unconditional.py` \n" - "and will be removed in version v1.0.0" - ) - deprecate("EMAModel", "1.0.0", deprecation_message, standard_warn=False) - self.averaged_model = copy.deepcopy(model).eval() - self.averaged_model.requires_grad_(False) + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + + self.collected_params = None + self.decay = decay self.update_after_step = update_after_step + self.use_ema_warmup = use_ema_warmup self.inv_gamma = inv_gamma self.power = power - self.min_value = min_value - self.max_value = max_value - - if device is not None: - self.averaged_model = self.averaged_model.to(device=device) - - self.decay = 0.0 self.optimization_step = 0 def get_decay(self, optimization_step): @@ -94,40 +90,125 @@ def get_decay(self, optimization_step): Compute the decay factor for the exponential moving average. """ step = max(0, optimization_step - self.update_after_step - 1) - value = 1 - (1 + step / self.inv_gamma) ** -self.power + + if self.use_ema_warmup: + cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power + else: + cur_decay_value = (1 + step) / (10 + step) if step <= 0: return 0.0 - return max(self.min_value, min(value, self.max_value)) + return min(cur_decay_value, self.decay) @torch.no_grad() - def step(self, new_model): - ema_state_dict = {} - ema_params = self.averaged_model.state_dict() - - self.decay = self.get_decay(self.optimization_step) - - for key, param in new_model.named_parameters(): - if isinstance(param, dict): - continue - try: - ema_param = ema_params[key] - except KeyError: - ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) - ema_params[key] = ema_param - - if not param.requires_grad: - ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) - ema_param = ema_params[key] + def step(self, parameters): + parameters = list(parameters) + + self.optimization_step += 1 + + # Compute the decay factor for the exponential moving average. + decay = self.get_decay(self.optimization_step) + one_minus_decay = 1 - decay + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) else: - ema_param.mul_(self.decay) - ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) + s_param.copy_(param) - ema_state_dict[key] = ema_param + torch.cuda.empty_cache() - for key, param in new_model.named_buffers(): - ema_state_dict[key] = param + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. - self.averaged_model.load_state_dict(ema_state_dict, strict=False) - self.optimization_step += 1 + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] + + def state_dict(self) -> dict: + r""" + Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during + checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "optimization_step": self.optimization_step, + "update_after_step": self.update_after_step, + "use_ema_warmup": self.use_ema_warmup, + "inv_gamma": self.inv_gamma, + "power": self.power, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params, + } + + def load_state_dict(self, state_dict: dict) -> None: + r""" + Args: + Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the + ema state dict. + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.optimization_step = state_dict["optimization_step"] + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.update_after_step = state_dict["update_after_step"] + if not isinstance(self.update_after_step, int): + raise ValueError("Invalid update_after_step") + + self.use_ema_warmup = state_dict["use_ema_warmup"] + if not isinstance(self.use_ema_warmup, bool): + raise ValueError("Invalid use_ema_warmup") + + self.inv_gamma = state_dict["inv_gamma"] + if not isinstance(self.inv_gamma, (float, int)): + raise ValueError("Invalid inv_gamma") + + self.power = state_dict["power"] + if not isinstance(self.power, (float, int)): + raise ValueError("Invalid power") + + self.shadow_params = state_dict["shadow_params"] + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + if not isinstance(self.collected_params, list): + raise ValueError("collected_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.collected_params): + raise ValueError("collected_params must all be Tensors") + if len(self.collected_params) != len(self.shadow_params): + raise ValueError("collected_params and shadow_params must have the same length") From 3b274815f062ab23db76a71d5b0b15d72ebdd810 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 17 Jan 2023 13:49:05 +0100 Subject: [PATCH 04/20] quality --- examples/text_to_image/train_text_to_image.py | 3 +-- examples/unconditional_image_generation/train_unconditional.py | 2 +- src/diffusers/dependency_versions_table.py | 2 +- src/diffusers/training_utils.py | 2 -- 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 589e1883580b..7d38e0e47ebc 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1,11 +1,10 @@ import argparse -import copy import logging import math import os import random from pathlib import Path -from typing import Iterable, Optional +from typing import Optional import numpy as np import torch diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 41466441d72b..356c40602c73 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -4,7 +4,7 @@ import math import os from pathlib import Path -from typing import Iterable, Optional +from typing import Optional import torch import torch.nn.functional as F diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 1ef1edc14629..7fc779fc543e 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -4,7 +4,7 @@ deps = { "Pillow": "Pillow", "accelerate": "accelerate>=0.11.0", - "black": "black==22.8", + "black": "black==22.12", "datasets": "datasets", "filelock": "filelock", "flake8": "flake8>=3.8.3", diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index c1fc91b22786..e2001fde5a96 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -6,8 +6,6 @@ import numpy as np import torch -from diffusers.utils import deprecate - def enable_full_determinism(seed: int): """ From e9db8cd33771c9154b1321e38dc66d07b399296f Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 17 Jan 2023 14:10:18 +0100 Subject: [PATCH 05/20] fix tests --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 42f683f887fe..2f4139d87102 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -205,7 +205,7 @@ def test_ema_training(self): model = self.model_class(**init_dict) model.to(torch_device) model.train() - ema_model = EMAModel(model, device=torch_device) + ema_model = EMAModel(model.parameters()) output = model(**inputs_dict) From 43f0fe3cf8f418d547365e417c9be490b0608d7e Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 17 Jan 2023 14:54:30 +0100 Subject: [PATCH 06/20] fix test --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2f4139d87102..26acdd419200 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -215,7 +215,7 @@ def test_ema_training(self): noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) loss.backward() - ema_model.step(model) + ema_model.step(model.parameters()) def test_outputs_equivalence(self): def set_nan_tensor_to_zero(t): From db2d359df34837bdbba0f46716f87eb2263dfcd6 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 17 Jan 2023 15:45:49 +0100 Subject: [PATCH 07/20] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- examples/unconditional_image_generation/train_unconditional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 356c40602c73..3eb5cf23bc33 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -355,7 +355,7 @@ def transforms(examples): model, optimizer, train_dataloader, lr_scheduler ) if args.use_ema: - accelerator.register_for_checkpointing(ema_model, lr_scheduler) + accelerator.register_for_checkpointing(ema_model) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. From e7ac781f2c1c33ed801ce1b1ac9c07769322c881 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 17 Jan 2023 16:25:04 +0100 Subject: [PATCH 08/20] re organise the unconditional script --- .../train_unconditional.py | 138 +++++++++++------- 1 file changed, 86 insertions(+), 52 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 3eb5cf23bc33..f714df742063 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -1,6 +1,7 @@ import argparse import copy import inspect +import logging import math import os from pathlib import Path @@ -9,6 +10,8 @@ import torch import torch.nn.functional as F +import datasets +import diffusers from accelerate import Accelerator from accelerate.logging import get_logger from datasets import load_dataset @@ -33,7 +36,7 @@ check_min_version("0.12.0.dev0") -logger = get_logger(__name__) +logger = get_logger(__name__, log_level="INFO") def _extract_into_tensor(arr, timesteps, broadcast_shape): @@ -255,6 +258,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: def main(args): logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -262,6 +266,38 @@ def main(args): logging_dir=logging_dir, ) + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Initialize the model model = UNet2DModel( sample_size=args.resolution, in_channels=3, @@ -285,8 +321,19 @@ def main(args): "UpBlock2D", ), ) - accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) + # Create EMA for the model. + if args.use_ema: + ema_model = EMAModel( + model.parameters(), + decay=args.ema_max_decay, + use_ema_warmup=True, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + ) + + # Initialize the scheduler + accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) if accepts_prediction_type: noise_scheduler = DDPMScheduler( num_train_timesteps=args.ddpm_num_steps, @@ -296,6 +343,7 @@ def main(args): else: noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) + # Initialize the optimizer optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, @@ -304,16 +352,11 @@ def main(args): eps=args.adam_epsilon, ) - augmentations = Compose( - [ - Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), - CenterCrop(args.resolution), - RandomHorizontalFlip(), - ToTensor(), - Normalize([0.5], [0.5]), - ] - ) + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. if args.dataset_name is not None: dataset = load_dataset( args.dataset_name, @@ -323,6 +366,19 @@ def main(args): ) else: dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train") + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets and DataLoaders creation. + augmentations = Compose( + [ + Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), + CenterCrop(args.resolution), + RandomHorizontalFlip(), + ToTensor(), + Normalize([0.5], [0.5]), + ] + ) def transforms(examples): images = [augmentations(image.convert("RGB")) for image in examples["image"]] @@ -335,6 +391,7 @@ def transforms(examples): dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers ) + # Initialize the learning rate scheduler lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, @@ -342,60 +399,37 @@ def transforms(examples): num_training_steps=(len(train_dataloader) * args.num_epochs), ) - if args.use_ema: - ema_model = EMAModel( - model.parameters(), - decay=args.ema_max_decay, - use_ema_warmup=True, - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - ) - + # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) - if args.use_ema: - accelerator.register_for_checkpointing(ema_model) - - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - # Move text_encode and vae to gpu and cast to weight_dtype - model.to(accelerator.device, dtype=weight_dtype) if args.use_ema: + accelerator.register_for_checkpointing(ema_model) ema_model.to(accelerator.device) - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - - # Handle the repository creation - if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - repo = Repository(args.output_dir, clone_from=repo_name) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. if accelerator.is_main_process: run = os.path.split(__file__)[-1].split(".")[0] accelerator.init_trackers(run) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(dataset)}") + logger.info(f" Num Epochs = {args.num_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_train_steps}") + global_step = 0 first_epoch = 0 + # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) @@ -413,6 +447,7 @@ def transforms(examples): first_epoch = resume_global_step // num_update_steps_per_epoch resume_step = resume_global_step % num_update_steps_per_epoch + # Train! for epoch in range(first_epoch, args.num_epochs): model.train() progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) @@ -517,7 +552,6 @@ def transforms(examples): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False) - accelerator.wait_for_everyone() accelerator.end_training() From 7092f3a86af56870b246a4c2adcaa1311915d9fa Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 18 Jan 2023 14:17:24 +0100 Subject: [PATCH 09/20] backwards compatibility --- src/diffusers/training_utils.py | 68 +++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index e2001fde5a96..9053e2f818cc 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,11 +1,13 @@ import copy import os import random -from typing import Iterable, Union +from typing import Iterable, Optional, Union import numpy as np import torch +from .utils import deprecate + def enable_full_determinism(seed: int): """ @@ -50,15 +52,19 @@ def __init__( self, parameters: Iterable[torch.nn.Parameter], decay: float = 0.9999, + min_decay: float = 0.0, update_after_step: int = 0, use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, + device: Optional[Union[str, torch.device]] = None, + **kwargs, ): """ Args: parameters (Iterable[torch.nn.Parameter]): The parameters to track. decay (float): The decay factor for the exponential moving average. + min_decay (float): The minimum decay factor for the exponential moving average. update_after_step (int): The number of steps to wait before starting to update the EMA weights. use_ema_warmup (bool): Whether to use EMA warmup. inv_gamma (float): @@ -71,12 +77,45 @@ def __init__( gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). """ + + if issubclass(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + + # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility + use_ema_warmup = True + + if kwargs.get("max_value", None) is not None: + deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." + deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) + decay = kwargs["max_value"] + + if kwargs.get("min_value", None) is not None: + deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." + deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) + min_decay = kwargs["min_value"] + parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] + if device is not None: + self.shadow_params = [ + p.to(device=device) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params + ] + self.collected_params = None self.decay = decay + self.min_decay = min_decay self.update_after_step = update_after_step self.use_ema_warmup = use_ema_warmup self.inv_gamma = inv_gamma @@ -89,18 +128,35 @@ def get_decay(self, optimization_step): """ step = max(0, optimization_step - self.update_after_step - 1) + if step <= 0: + return 0.0 + if self.use_ema_warmup: cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power else: cur_decay_value = (1 + step) / (10 + step) - if step <= 0: - return 0.0 + cur_decay_value = min(cur_decay_value, self.decay) - return min(cur_decay_value, self.decay) + # make sure decay is not smaller than min_decay + cur_decay_value = max(cur_decay_value, self.min_decay) + return cur_decay_value @torch.no_grad() def step(self, parameters): + if issubclass(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + parameters = list(parameters) self.optimization_step += 1 @@ -176,6 +232,10 @@ def load_state_dict(self, state_dict: dict) -> None: if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") + self.min_decay = state_dict["min_decay"] + if not isinstance(self.min_decay, float): + raise ValueError("Invalid min_decay") + self.optimization_step = state_dict["optimization_step"] if not isinstance(self.optimization_step, int): raise ValueError("Invalid optimization_step") From f2b5d40684f948abe9a9191227012c60f6cd8835 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 18 Jan 2023 14:21:52 +0100 Subject: [PATCH 10/20] default to init values for some args --- src/diffusers/training_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 9053e2f818cc..ff6c904abe2d 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -228,31 +228,31 @@ def load_state_dict(self, state_dict: dict) -> None: # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) - self.decay = state_dict["decay"] + self.decay = state_dict.get("decay", self.decay) if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") - self.min_decay = state_dict["min_decay"] + self.min_decay = state_dict.get("min_decay", self.min_decay) if not isinstance(self.min_decay, float): raise ValueError("Invalid min_decay") - self.optimization_step = state_dict["optimization_step"] + self.optimization_step = state_dict.get("optimization_step", self.optimization_step) if not isinstance(self.optimization_step, int): raise ValueError("Invalid optimization_step") - self.update_after_step = state_dict["update_after_step"] + self.update_after_step = state_dict.get("update_after_step", self.update_after_step) if not isinstance(self.update_after_step, int): raise ValueError("Invalid update_after_step") - self.use_ema_warmup = state_dict["use_ema_warmup"] + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) if not isinstance(self.use_ema_warmup, bool): raise ValueError("Invalid use_ema_warmup") - self.inv_gamma = state_dict["inv_gamma"] + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) if not isinstance(self.inv_gamma, (float, int)): raise ValueError("Invalid inv_gamma") - self.power = state_dict["power"] + self.power = state_dict["power"].get("power", self.power) if not isinstance(self.power, (float, int)): raise ValueError("Invalid power") From eeed3af050c0f26369dfe6971761c1a5aa8ae350 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 18 Jan 2023 14:30:35 +0100 Subject: [PATCH 11/20] fix ort script --- .../unconditional_image_generation/README.md | 1 + .../train_unconditional_ort.py | 26 +++++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index 01aa32e746e7..63616b8aac47 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -152,6 +152,7 @@ accelerate launch train_unconditional_ort.py \ --dataset_name="huggan/flowers-102-categories" \ --resolution=64 \ --output_dir="ddpm-ema-flowers-64" \ + --use_ema \ --train_batch_size=16 \ --num_epochs=1 \ --gradient_accumulation_steps=1 \ diff --git a/examples/unconditional_image_generation/train_unconditional_ort.py b/examples/unconditional_image_generation/train_unconditional_ort.py index df0a463565b1..44ea5a2d0b1b 100644 --- a/examples/unconditional_image_generation/train_unconditional_ort.py +++ b/examples/unconditional_image_generation/train_unconditional_ort.py @@ -157,7 +157,6 @@ def parse_args(): parser.add_argument( "--use_ema", action="store_true", - default=True, help="Whether to use Exponential Moving Average for the final model weights.", ) parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") @@ -287,8 +286,17 @@ def main(args): "UpBlock2D", ), ) - accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) + if args.use_ema: + ema_model = EMAModel( + model.parameters(), + decay=args.ema_max_decay, + use_ema_warmup=True, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + ) + + accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) if accepts_prediction_type: noise_scheduler = DDPMScheduler( num_train_timesteps=args.ddpm_num_steps, @@ -347,16 +355,12 @@ def transforms(examples): model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) - accelerator.register_for_checkpointing(lr_scheduler) - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.use_ema: + accelerator.register_for_checkpointing(ema_model) + ema_model.to(accelerator.device) - ema_model = EMAModel( - accelerator.unwrap_model(model), - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) model = ORTModule(model) @@ -448,7 +452,7 @@ def transforms(examples): optimizer.step() lr_scheduler.step() if args.use_ema: - ema_model.step(model) + ema_model.step(model.parameters()) optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes From 7f84f1be63b891e2085d07792a6bd70d27c06d93 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 18 Jan 2023 14:34:36 +0100 Subject: [PATCH 12/20] issubclass => isinstance --- src/diffusers/training_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index ff6c904abe2d..60984b10f1f5 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -78,7 +78,7 @@ def __init__( at 215.4k steps). """ - if issubclass(parameters, torch.nn.Module): + if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " "Please pass the parameters of the module instead." @@ -144,7 +144,7 @@ def get_decay(self, optimization_step): @torch.no_grad() def step(self, parameters): - if issubclass(parameters, torch.nn.Module): + if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " "Please pass the parameters of the module instead." From be0212575520e528f95715e61b9aec65f8d87a6f Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 18 Jan 2023 14:36:02 +0100 Subject: [PATCH 13/20] update state_dict --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 60984b10f1f5..536e802c0cda 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -137,7 +137,6 @@ def get_decay(self, optimization_step): cur_decay_value = (1 + step) / (10 + step) cur_decay_value = min(cur_decay_value, self.decay) - # make sure decay is not smaller than min_decay cur_decay_value = max(cur_decay_value, self.min_decay) return cur_decay_value @@ -208,6 +207,7 @@ def state_dict(self) -> dict: # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, + "min_decay": self.decay, "optimization_step": self.optimization_step, "update_after_step": self.update_after_step, "use_ema_warmup": self.use_ema_warmup, From f61474a36a470a6a5e386a09562bd72634c454a5 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 18 Jan 2023 14:37:55 +0100 Subject: [PATCH 14/20] docstr --- src/diffusers/training_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 536e802c0cda..b69d17ee2255 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -70,6 +70,7 @@ def __init__( inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the weights @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan From df2d9e0efaaef443b70b8391acdd666a19c2bfdc Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 18 Jan 2023 14:58:37 +0100 Subject: [PATCH 15/20] doc --- src/diffusers/training_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index b69d17ee2255..0099345021a0 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -123,7 +123,7 @@ def __init__( self.power = power self.optimization_step = 0 - def get_decay(self, optimization_step): + def get_decay(self, optimization_step: int) -> float: """ Compute the decay factor for the exponential moving average. """ @@ -143,7 +143,7 @@ def get_decay(self, optimization_step): return cur_decay_value @torch.no_grad() - def step(self, parameters): + def step(self, parameters: Iterable[torch.nn.Parameter]): if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " From a94b53b417bc2205147590a92dcea26e977c2b60 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 18 Jan 2023 17:10:09 +0100 Subject: [PATCH 16/20] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- src/diffusers/training_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 0099345021a0..31ecf9e16736 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -70,7 +70,8 @@ def __init__( inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. - device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the weights + device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA + weights will be stored on CPU. @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan From c4409e48cb538168a6170477594f119af8758ae5 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 18 Jan 2023 17:11:20 +0100 Subject: [PATCH 17/20] use .to if device is passed --- src/diffusers/training_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 31ecf9e16736..fa45b770aed3 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -110,9 +110,7 @@ def __init__( self.shadow_params = [p.clone().detach() for p in parameters] if device is not None: - self.shadow_params = [ - p.to(device=device) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params - ] + self.to(device=device) self.collected_params = None From bd9f142c43aa791b42fdeacb0c06c5201107c598 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 18 Jan 2023 17:19:46 +0100 Subject: [PATCH 18/20] deprecate device --- src/diffusers/training_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index fa45b770aed3..655fab60903b 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -57,7 +57,6 @@ def __init__( use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, - device: Optional[Union[str, torch.device]] = None, **kwargs, ): """ @@ -109,8 +108,10 @@ def __init__( parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] - if device is not None: - self.to(device=device) + if kwargs.get("device", None) is not None: + deprecation_message = "The `device` argument is deprecated. Please use `to` instead." + deprecate("device", "1.0.0", deprecation_message, standard_warn=False) + self.to(device=kwargs["device"]) self.collected_params = None From 0bafc112efa9b45afe0a1932f6665a4273a2256b Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 18 Jan 2023 17:44:31 +0100 Subject: [PATCH 19/20] make flake happy --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 655fab60903b..7f43f553f6e2 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,7 +1,7 @@ import copy import os import random -from typing import Iterable, Optional, Union +from typing import Iterable, Union import numpy as np import torch From dc1935f3882eeca323b7a6e857b5ceed0603e589 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 19 Jan 2023 11:31:21 +0100 Subject: [PATCH 20/20] fix typo --- examples/unconditional_image_generation/train_unconditional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index f714df742063..4b92028a1176 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -416,7 +416,7 @@ def transforms(examples): total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + max_train_steps = args.num_epochs * num_update_steps_per_epoch logger.info("***** Running training *****") logger.info(f" Num examples = {len(dataset)}")