From 32c7cdeec8a37fea8f1bf962316fc5e40b8eb873 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Dec 2022 16:26:39 +0100 Subject: [PATCH 01/13] allow using non-ema weights for training --- examples/text_to_image/train_text_to_image.py | 146 ++++++++++++++++-- 1 file changed, 135 insertions(+), 11 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 6c45ee0b1b65..5554783c2d29 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1,4 +1,5 @@ import argparse +import copy import logging import math import os @@ -172,6 +173,13 @@ def parse_args(): "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help="Revision of pretrained non-ema model identifier from huggingface.co/models.", + ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -244,6 +252,10 @@ def parse_args(): if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Need either a dataset name or a training folder.") + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + return args @@ -272,6 +284,8 @@ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] + self.collected_params = None + self.decay = decay self.optimization_step = 0 @@ -324,6 +338,101 @@ def to(self, device=None, dtype=None) -> None: for p in self.shadow_params ] + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = list(parameters) + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + if self.collected_params is None: + raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") + parameters = list(parameters) + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) + + self.collected_params = None + torch.cuda.empty_cache() + + def state_dict(self) -> dict: + r"""Returns the state of the ExponentialMovingAverage as a dict.""" + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "optimization_step": self.optimization_step, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params, + } + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the ExponentialMovingAverage state. + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + self.optimization_step = state_dict["optimization_step"] + assert isinstance(self.optimization_step, int), "Invalid optimization_step" + + self.shadow_params = state_dict["shadow_params"] + assert isinstance(self.shadow_params, list), "shadow_params must be a list" + assert all(isinstance(p, torch.Tensor) for p in self.shadow_params), "shadow_params must all be Tensors" + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + assert isinstance(self.collected_params, list), "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len( + self.shadow_params + ), "collected_params and shadow_params had different lengths" + + # if len(self.shadow_params) == len(self._params_refs): + # # Consistant with torch.optim.Optimizer, cast things to consistant + # # device and dtype with the parameters + # params = [p() for p in self._params_refs] + # # If parameters have been garbage collected, just load the state + # # we were given without change. + # if not any(p is None for p in params): + # # ^ parameter references are still good + # for i, p in enumerate(params): + # self.shadow_params[i] = self.shadow_params[i].to( + # device=p.device, dtype=p.dtype + # ) + # if self.collected_params is not None: + # self.collected_params[i] = self.collected_params[i].to( + # device=p.device, dtype=p.dtype + # ) + # else: + # raise ValueError( + # "Tried to `load_state_dict()` with the wrong number of " + # "parameters in the saved state." + # ) + def main(): args = parse_args() @@ -378,11 +487,16 @@ def main(): revision=args.revision, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, + args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) + # Create EMA for the unet. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + ema_unet = EMAModel(ema_unet.parameters()) + if is_xformers_available(): try: unet.enable_xformers_memory_efficient_attention() @@ -543,10 +657,23 @@ def collate_fn(examples): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) - accelerator.register_for_checkpointing(lr_scheduler) + if args.use_ema: + unet, ema_unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, + ema_unet, + optimizer, + train_dataloader, + lr_scheduler, + ) + accelerator.register_for_checkpointing(lr_scheduler, ema_unet) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, + optimizer, + train_dataloader, + lr_scheduler, + ) + accelerator.register_for_checkpointing(lr_scheduler) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": @@ -559,10 +686,8 @@ def collate_fn(examples): # as these models are only used for inference, keeping weights in full precision is not required. text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - - # Create EMA for the unet. if args.use_ema: - ema_unet = EMAModel(unet.parameters()) + ema_unet.to(accelerator.device) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -599,7 +724,6 @@ def collate_fn(examples): dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) resume_global_step = global_step * args.gradient_accumulation_steps From a706c96e6edf3a108520bbe114992878182b7a25 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 27 Dec 2022 15:15:07 +0100 Subject: [PATCH 02/13] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- examples/text_to_image/train_text_to_image.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 5554783c2d29..6244451ddf06 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -178,7 +178,7 @@ def parse_args(): type=str, default=None, required=False, - help="Revision of pretrained non-ema model identifier from huggingface.co/models.", + help="Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or remote repository specified with --pretrained_model_name_or_path.", ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") @@ -343,8 +343,7 @@ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. If `None`, the parameters of with which this - `ExponentialMovingAverage` was initialized will be used. + temporarily stored. """ parameters = list(parameters) self.collected_params = [param.clone() for param in parameters] @@ -358,9 +357,7 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + updated with the stored parameters. """ if self.collected_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") @@ -409,7 +406,7 @@ def load_state_dict(self, state_dict: dict) -> None: ), "collected_params must all be Tensors" assert len(self.collected_params) == len( self.shadow_params - ), "collected_params and shadow_params had different lengths" + ), "collected_params and shadow_params must have the same length" # if len(self.shadow_params) == len(self._params_refs): # # Consistant with torch.optim.Optimizer, cast things to consistant From 5918f669bf4fb7f9bad0d3005282e1e42096f6a5 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Dec 2022 15:25:13 +0100 Subject: [PATCH 03/13] address more review comment --- examples/text_to_image/train_text_to_image.py | 45 +++++-------------- 1 file changed, 12 insertions(+), 33 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 6244451ddf06..e3a0e5a1f6f1 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -178,7 +178,10 @@ def parse_args(): type=str, default=None, required=False, - help="Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or remote repository specified with --pretrained_model_name_or_path.", + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") @@ -338,38 +341,11 @@ def to(self, device=None, dtype=None) -> None: for p in self.shadow_params ] - def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: - """ - Save the current parameters for restoring later. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - parameters = list(parameters) - self.collected_params = [param.clone() for param in parameters] - - def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. - """ - if self.collected_params is None: - raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") - parameters = list(parameters) - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) - - self.collected_params = None - torch.cuda.empty_cache() - def state_dict(self) -> dict: - r"""Returns the state of the ExponentialMovingAverage as a dict.""" + r""" + Returns the state of the ExponentialMovingAverage as a dict. + This method is used by accelerate during checkpointing to save the ema state dict. + """ # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict @@ -381,7 +357,9 @@ def state_dict(self) -> dict: } def load_state_dict(self, state_dict: dict) -> None: - r"""Loads the ExponentialMovingAverage state. + r""" + Loads the ExponentialMovingAverage state. + This method is used by accelerate during checkpointing to save the ema state dict. Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. @@ -721,6 +699,7 @@ def collate_fn(examples): dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) resume_global_step = global_step * args.gradient_accumulation_steps From 103eacab6bb147b0c953ecd53a7380bcb42b04c7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Dec 2022 15:40:35 +0100 Subject: [PATCH 04/13] reorganise a few lines --- examples/text_to_image/train_text_to_image.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index e3a0e5a1f6f1..d9bf28aaf936 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -447,7 +447,8 @@ def main(): elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - # Load models and create wrapper for stable diffusion + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) @@ -465,6 +466,10 @@ def main(): args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + # Create EMA for the unet. if args.use_ema: ema_unet = UNet2DConditionModel.from_pretrained( @@ -481,10 +486,6 @@ def main(): f" correctly and a GPU is available: {e}" ) - # Freeze vae and text_encoder - vae.requires_grad_(False) - text_encoder.requires_grad_(False) - if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -513,7 +514,6 @@ def main(): weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). From 5c8f0699140bd7662621fbcdfdeed1b538bde3af Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Dec 2022 16:03:42 +0100 Subject: [PATCH 05/13] always pad text to max_length to match original training --- examples/text_to_image/train_text_to_image.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 43d540dd41c0..0c3ff2352cf8 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -576,9 +576,10 @@ def tokenize_captions(examples, is_train=True): raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) - inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True) - input_ids = inputs.input_ids - return input_ids + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids train_transforms = transforms.Compose( [ @@ -606,13 +607,8 @@ def preprocess_train(examples): def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = [example["input_ids"] for example in examples] - padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt") - return { - "pixel_values": pixel_values, - "input_ids": padded_tokens.input_ids, - "attention_mask": padded_tokens.attention_mask, - } + input_ids = torch.cat([example["input_ids"] for example in examples], dim=0) + return {"pixel_values": pixel_values, "input_ids": input_ids} train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size From 008014e402ae930f6ebe0079f9d68e14635f2492 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Dec 2022 16:50:23 +0100 Subject: [PATCH 06/13] ifx collate_fn --- examples/text_to_image/train_text_to_image.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 0c3ff2352cf8..e4d104673ee9 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -595,7 +595,6 @@ def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [train_transforms(image) for image in images] examples["input_ids"] = tokenize_captions(examples) - return examples with accelerator.main_process_first(): @@ -607,7 +606,7 @@ def preprocess_train(examples): def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = torch.cat([example["input_ids"] for example in examples], dim=0) + input_ids = torch.stack([example["input_ids"] for example in examples]) return {"pixel_values": pixel_values, "input_ids": input_ids} train_dataloader = torch.utils.data.DataLoader( From fe51c7c56111a048b179e7a81a119a233295a46c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 30 Dec 2022 12:50:51 +0100 Subject: [PATCH 07/13] remove unused code --- examples/text_to_image/train_text_to_image.py | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 17065427a1cb..12d7f91ca111 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -384,28 +384,6 @@ def load_state_dict(self, state_dict: dict) -> None: self.shadow_params ), "collected_params and shadow_params must have the same length" - # if len(self.shadow_params) == len(self._params_refs): - # # Consistant with torch.optim.Optimizer, cast things to consistant - # # device and dtype with the parameters - # params = [p() for p in self._params_refs] - # # If parameters have been garbage collected, just load the state - # # we were given without change. - # if not any(p is None for p in params): - # # ^ parameter references are still good - # for i, p in enumerate(params): - # self.shadow_params[i] = self.shadow_params[i].to( - # device=p.device, dtype=p.dtype - # ) - # if self.collected_params is not None: - # self.collected_params[i] = self.collected_params[i].to( - # device=p.device, dtype=p.dtype - # ) - # else: - # raise ValueError( - # "Tried to `load_state_dict()` with the wrong number of " - # "parameters in the saved state." - # ) - def main(): args = parse_args() From 50782ebf8d108075ef090d239703520567e806fb Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 30 Dec 2022 12:53:37 +0100 Subject: [PATCH 08/13] don't prepare ema_unet, don't register lr scheduler --- examples/text_to_image/train_text_to_image.py | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 12d7f91ca111..62dad75d7b39 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -600,23 +600,14 @@ def collate_fn(examples): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, + optimizer, + train_dataloader, + lr_scheduler, + ) if args.use_ema: - unet, ema_unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, - ema_unet, - optimizer, - train_dataloader, - lr_scheduler, - ) - accelerator.register_for_checkpointing(lr_scheduler, ema_unet) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, - optimizer, - train_dataloader, - lr_scheduler, - ) - accelerator.register_for_checkpointing(lr_scheduler) + accelerator.register_for_checkpointing(ema_unet) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": From 4079f5b489e50b618cad7dfdfa86c6ede5df0d9f Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 30 Dec 2022 12:57:27 +0100 Subject: [PATCH 09/13] style --- examples/text_to_image/train_text_to_image.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 62dad75d7b39..2a52e9c6adc4 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -429,15 +429,9 @@ def main(): args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="vae", - revision=args.revision, + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) @@ -601,10 +595,7 @@ def collate_fn(examples): ) unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, - optimizer, - train_dataloader, - lr_scheduler, + unet, optimizer, train_dataloader, lr_scheduler ) if args.use_ema: accelerator.register_for_checkpointing(ema_unet) From 114969e465e9e5f9c16ea5be22a36ec7aa41223e Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 30 Dec 2022 13:02:51 +0100 Subject: [PATCH 10/13] assert => ValueError --- examples/text_to_image/train_text_to_image.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 2a52e9c6adc4..9ed25cae4932 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -364,25 +364,29 @@ def load_state_dict(self, state_dict: dict) -> None: """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) + self.decay = state_dict["decay"] if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") + self.optimization_step = state_dict["optimization_step"] - assert isinstance(self.optimization_step, int), "Invalid optimization_step" + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") self.shadow_params = state_dict["shadow_params"] - assert isinstance(self.shadow_params, list), "shadow_params must be a list" - assert all(isinstance(p, torch.Tensor) for p in self.shadow_params), "shadow_params must all be Tensors" + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") self.collected_params = state_dict["collected_params"] if self.collected_params is not None: - assert isinstance(self.collected_params, list), "collected_params must be a list" - assert all( - isinstance(p, torch.Tensor) for p in self.collected_params - ), "collected_params must all be Tensors" - assert len(self.collected_params) == len( - self.shadow_params - ), "collected_params and shadow_params must have the same length" + if not isinstance(self.collected_params, list): + raise ValueError("collected_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.collected_params): + raise ValueError("collected_params must all be Tensors") + if len(self.collected_params) != len(self.shadow_params): + raise ValueError("collected_params and shadow_params must have the same length") def main(): From a54a56bda4790c403a5728919c8dc48a946af9bc Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 30 Dec 2022 14:43:10 +0100 Subject: [PATCH 11/13] add allow_tf32 --- examples/text_to_image/train_text_to_image.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 9ed25cae4932..723e47bfe57b 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -172,6 +172,14 @@ def parse_args(): parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument( "--non_ema_revision", @@ -460,6 +468,11 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes From ad7ebe3eb37a7f45f02c0932cac9039145e2ed7c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 30 Dec 2022 15:10:20 +0100 Subject: [PATCH 12/13] set log level --- examples/text_to_image/train_text_to_image.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 723e47bfe57b..ffa98fd6e32a 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -12,6 +12,9 @@ import torch.nn.functional as F import torch.utils.checkpoint +import datasets +import diffusers +import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed @@ -29,7 +32,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.10.0.dev0") -logger = get_logger(__name__) +logger = get_logger(__name__, log_level="INFO") def parse_args(): @@ -413,6 +416,15 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: From b29ed553f0bb83b1ad64f899d5fe910b5341b4d7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 30 Dec 2022 21:47:49 +0100 Subject: [PATCH 13/13] fix comment --- examples/text_to_image/train_text_to_image.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index ffa98fd6e32a..986a57b75779 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -629,15 +629,15 @@ def collate_fn(examples): if args.use_ema: accelerator.register_for_checkpointing(ema_unet) + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move text_encode and vae to gpu. - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. + # Move text_encode and vae to gpu and cast to weight_dtype text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) if args.use_ema: