11import argparse
2+ import copy
23import logging
34import math
45import os
1112import torch .nn .functional as F
1213import torch .utils .checkpoint
1314
15+ import datasets
16+ import diffusers
17+ import transformers
1418from accelerate import Accelerator
1519from accelerate .logging import get_logger
1620from accelerate .utils import set_seed
2832# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
2933check_min_version ("0.10.0.dev0" )
3034
31- logger = get_logger (__name__ )
35+ logger = get_logger (__name__ , log_level = "INFO" )
3236
3337
3438def parse_args ():
@@ -171,7 +175,25 @@ def parse_args():
171175 parser .add_argument (
172176 "--use_8bit_adam" , action = "store_true" , help = "Whether or not to use 8-bit Adam from bitsandbytes."
173177 )
178+ parser .add_argument (
179+ "--allow_tf32" ,
180+ action = "store_true" ,
181+ help = (
182+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
183+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
184+ ),
185+ )
174186 parser .add_argument ("--use_ema" , action = "store_true" , help = "Whether to use EMA model." )
187+ parser .add_argument (
188+ "--non_ema_revision" ,
189+ type = str ,
190+ default = None ,
191+ required = False ,
192+ help = (
193+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
194+ " remote repository specified with --pretrained_model_name_or_path."
195+ ),
196+ )
175197 parser .add_argument ("--adam_beta1" , type = float , default = 0.9 , help = "The beta1 parameter for the Adam optimizer." )
176198 parser .add_argument ("--adam_beta2" , type = float , default = 0.999 , help = "The beta2 parameter for the Adam optimizer." )
177199 parser .add_argument ("--adam_weight_decay" , type = float , default = 1e-2 , help = "Weight decay to use." )
@@ -247,6 +269,10 @@ def parse_args():
247269 if args .dataset_name is None and args .train_data_dir is None :
248270 raise ValueError ("Need either a dataset name or a training folder." )
249271
272+ # default to using the same revision for the non-ema model if not specified
273+ if args .non_ema_revision is None :
274+ args .non_ema_revision = args .revision
275+
250276 return args
251277
252278
@@ -275,6 +301,8 @@ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
275301 parameters = list (parameters )
276302 self .shadow_params = [p .clone ().detach () for p in parameters ]
277303
304+ self .collected_params = None
305+
278306 self .decay = decay
279307 self .optimization_step = 0
280308
@@ -322,6 +350,55 @@ def to(self, device=None, dtype=None) -> None:
322350 for p in self .shadow_params
323351 ]
324352
353+ def state_dict (self ) -> dict :
354+ r"""
355+ Returns the state of the ExponentialMovingAverage as a dict.
356+ This method is used by accelerate during checkpointing to save the ema state dict.
357+ """
358+ # Following PyTorch conventions, references to tensors are returned:
359+ # "returns a reference to the state and not its copy!" -
360+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
361+ return {
362+ "decay" : self .decay ,
363+ "optimization_step" : self .optimization_step ,
364+ "shadow_params" : self .shadow_params ,
365+ "collected_params" : self .collected_params ,
366+ }
367+
368+ def load_state_dict (self , state_dict : dict ) -> None :
369+ r"""
370+ Loads the ExponentialMovingAverage state.
371+ This method is used by accelerate during checkpointing to save the ema state dict.
372+ Args:
373+ state_dict (dict): EMA state. Should be an object returned
374+ from a call to :meth:`state_dict`.
375+ """
376+ # deepcopy, to be consistent with module API
377+ state_dict = copy .deepcopy (state_dict )
378+
379+ self .decay = state_dict ["decay" ]
380+ if self .decay < 0.0 or self .decay > 1.0 :
381+ raise ValueError ("Decay must be between 0 and 1" )
382+
383+ self .optimization_step = state_dict ["optimization_step" ]
384+ if not isinstance (self .optimization_step , int ):
385+ raise ValueError ("Invalid optimization_step" )
386+
387+ self .shadow_params = state_dict ["shadow_params" ]
388+ if not isinstance (self .shadow_params , list ):
389+ raise ValueError ("shadow_params must be a list" )
390+ if not all (isinstance (p , torch .Tensor ) for p in self .shadow_params ):
391+ raise ValueError ("shadow_params must all be Tensors" )
392+
393+ self .collected_params = state_dict ["collected_params" ]
394+ if self .collected_params is not None :
395+ if not isinstance (self .collected_params , list ):
396+ raise ValueError ("collected_params must be a list" )
397+ if not all (isinstance (p , torch .Tensor ) for p in self .collected_params ):
398+ raise ValueError ("collected_params must all be Tensors" )
399+ if len (self .collected_params ) != len (self .shadow_params ):
400+ raise ValueError ("collected_params and shadow_params must have the same length" )
401+
325402
326403def main ():
327404 args = parse_args ()
@@ -339,6 +416,15 @@ def main():
339416 datefmt = "%m/%d/%Y %H:%M:%S" ,
340417 level = logging .INFO ,
341418 )
419+ logger .info (accelerator .state , main_process_only = False )
420+ if accelerator .is_local_main_process :
421+ datasets .utils .logging .set_verbosity_warning ()
422+ transformers .utils .logging .set_verbosity_info ()
423+ diffusers .utils .logging .set_verbosity_info ()
424+ else :
425+ datasets .utils .logging .set_verbosity_error ()
426+ transformers .utils .logging .set_verbosity_error ()
427+ diffusers .utils .logging .set_verbosity_error ()
342428
343429 # If passed along, set the training seed now.
344430 if args .seed is not None :
@@ -361,39 +447,44 @@ def main():
361447 elif args .output_dir is not None :
362448 os .makedirs (args .output_dir , exist_ok = True )
363449
364- # Load models and create wrapper for stable diffusion
450+ # Load scheduler, tokenizer and models.
451+ noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
365452 tokenizer = CLIPTokenizer .from_pretrained (
366453 args .pretrained_model_name_or_path , subfolder = "tokenizer" , revision = args .revision
367454 )
368455 text_encoder = CLIPTextModel .from_pretrained (
369- args .pretrained_model_name_or_path ,
370- subfolder = "text_encoder" ,
371- revision = args .revision ,
372- )
373- vae = AutoencoderKL .from_pretrained (
374- args .pretrained_model_name_or_path ,
375- subfolder = "vae" ,
376- revision = args .revision ,
456+ args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision
377457 )
458+ vae = AutoencoderKL .from_pretrained (args .pretrained_model_name_or_path , subfolder = "vae" , revision = args .revision )
378459 unet = UNet2DConditionModel .from_pretrained (
379- args .pretrained_model_name_or_path ,
380- subfolder = "unet" ,
381- revision = args .revision ,
460+ args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .non_ema_revision
382461 )
383462
463+ # Freeze vae and text_encoder
464+ vae .requires_grad_ (False )
465+ text_encoder .requires_grad_ (False )
466+
467+ # Create EMA for the unet.
468+ if args .use_ema :
469+ ema_unet = UNet2DConditionModel .from_pretrained (
470+ args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision
471+ )
472+ ema_unet = EMAModel (ema_unet .parameters ())
473+
384474 if args .enable_xformers_memory_efficient_attention :
385475 if is_xformers_available ():
386476 unet .enable_xformers_memory_efficient_attention ()
387477 else :
388478 raise ValueError ("xformers is not available. Make sure it is installed correctly" )
389479
390- # Freeze vae and text_encoder
391- vae .requires_grad_ (False )
392- text_encoder .requires_grad_ (False )
393-
394480 if args .gradient_checkpointing :
395481 unet .enable_gradient_checkpointing ()
396482
483+ # Enable TF32 for faster training on Ampere GPUs,
484+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
485+ if args .allow_tf32 :
486+ torch .backends .cuda .matmul .allow_tf32 = True
487+
397488 if args .scale_lr :
398489 args .learning_rate = (
399490 args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
@@ -419,7 +510,6 @@ def main():
419510 weight_decay = args .adam_weight_decay ,
420511 eps = args .adam_epsilon ,
421512 )
422- noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
423513
424514 # Get the datasets: you can either provide your own training and evaluation files (see below)
425515 # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
@@ -482,9 +572,10 @@ def tokenize_captions(examples, is_train=True):
482572 raise ValueError (
483573 f"Caption column `{ caption_column } ` should contain either strings or lists of strings."
484574 )
485- inputs = tokenizer (captions , max_length = tokenizer .model_max_length , padding = "do_not_pad" , truncation = True )
486- input_ids = inputs .input_ids
487- return input_ids
575+ inputs = tokenizer (
576+ captions , max_length = tokenizer .model_max_length , padding = "max_length" , truncation = True , return_tensors = "pt"
577+ )
578+ return inputs .input_ids
488579
489580 train_transforms = transforms .Compose (
490581 [
@@ -500,7 +591,6 @@ def preprocess_train(examples):
500591 images = [image .convert ("RGB" ) for image in examples [image_column ]]
501592 examples ["pixel_values" ] = [train_transforms (image ) for image in images ]
502593 examples ["input_ids" ] = tokenize_captions (examples )
503-
504594 return examples
505595
506596 with accelerator .main_process_first ():
@@ -512,13 +602,8 @@ def preprocess_train(examples):
512602 def collate_fn (examples ):
513603 pixel_values = torch .stack ([example ["pixel_values" ] for example in examples ])
514604 pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
515- input_ids = [example ["input_ids" ] for example in examples ]
516- padded_tokens = tokenizer .pad ({"input_ids" : input_ids }, padding = True , return_tensors = "pt" )
517- return {
518- "pixel_values" : pixel_values ,
519- "input_ids" : padded_tokens .input_ids ,
520- "attention_mask" : padded_tokens .attention_mask ,
521- }
605+ input_ids = torch .stack ([example ["input_ids" ] for example in examples ])
606+ return {"pixel_values" : pixel_values , "input_ids" : input_ids }
522607
523608 train_dataloader = torch .utils .data .DataLoader (
524609 train_dataset , shuffle = True , collate_fn = collate_fn , batch_size = args .train_batch_size
@@ -541,23 +626,22 @@ def collate_fn(examples):
541626 unet , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
542627 unet , optimizer , train_dataloader , lr_scheduler
543628 )
544- accelerator .register_for_checkpointing (lr_scheduler )
629+ if args .use_ema :
630+ accelerator .register_for_checkpointing (ema_unet )
545631
632+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
633+ # as these models are only used for inference, keeping weights in full precision is not required.
546634 weight_dtype = torch .float32
547635 if accelerator .mixed_precision == "fp16" :
548636 weight_dtype = torch .float16
549637 elif accelerator .mixed_precision == "bf16" :
550638 weight_dtype = torch .bfloat16
551639
552- # Move text_encode and vae to gpu.
553- # For mixed precision training we cast the text_encoder and vae weights to half-precision
554- # as these models are only used for inference, keeping weights in full precision is not required.
640+ # Move text_encode and vae to gpu and cast to weight_dtype
555641 text_encoder .to (accelerator .device , dtype = weight_dtype )
556642 vae .to (accelerator .device , dtype = weight_dtype )
557-
558- # Create EMA for the unet.
559643 if args .use_ema :
560- ema_unet = EMAModel ( unet . parameters () )
644+ ema_unet . to ( accelerator . device )
561645
562646 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
563647 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
0 commit comments