11import argparse
2- import itertools
32import math
43import os
54import random
@@ -147,6 +146,11 @@ def parse_args():
147146 default = 1 ,
148147 help = "Number of updates steps to accumulate before performing a backward/update pass." ,
149148 )
149+ parser .add_argument (
150+ "--gradient_checkpointing" ,
151+ action = "store_true" ,
152+ help = "Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass." ,
153+ )
150154 parser .add_argument (
151155 "--learning_rate" ,
152156 type = float ,
@@ -383,11 +387,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
383387 return f"{ organization } /{ model_id } "
384388
385389
386- def freeze_params (params ):
387- for param in params :
388- param .requires_grad = False
389-
390-
391390def main ():
392391 args = parse_args ()
393392 logging_dir = os .path .join (args .output_dir , args .logging_dir )
@@ -460,6 +459,10 @@ def main():
460459 revision = args .revision ,
461460 )
462461
462+ if args .gradient_checkpointing :
463+ text_encoder .gradient_checkpointing_enable ()
464+ unet .enable_gradient_checkpointing ()
465+
463466 if args .enable_xformers_memory_efficient_attention :
464467 if is_xformers_available ():
465468 unet .enable_xformers_memory_efficient_attention ()
@@ -474,15 +477,12 @@ def main():
474477 token_embeds [placeholder_token_id ] = token_embeds [initializer_token_id ]
475478
476479 # Freeze vae and unet
477- freeze_params ( vae .parameters () )
478- freeze_params ( unet .parameters () )
480+ vae .requires_grad_ ( False )
481+ unet .requires_grad_ ( False )
479482 # Freeze all parameters except for the token embeddings in text encoder
480- params_to_freeze = itertools .chain (
481- text_encoder .text_model .encoder .parameters (),
482- text_encoder .text_model .final_layer_norm .parameters (),
483- text_encoder .text_model .embeddings .position_embedding .parameters (),
484- )
485- freeze_params (params_to_freeze )
483+ text_encoder .text_model .encoder .requires_grad_ (False )
484+ text_encoder .text_model .final_layer_norm .requires_grad_ (False )
485+ text_encoder .text_model .embeddings .position_embedding .requires_grad_ (False )
486486
487487 if args .scale_lr :
488488 args .learning_rate = (
@@ -541,9 +541,10 @@ def main():
541541 unet .to (accelerator .device , dtype = weight_dtype )
542542 vae .to (accelerator .device , dtype = weight_dtype )
543543
544- # Keep vae and unet in eval model as we don't train these
545- vae .eval ()
546- unet .eval ()
544+ # Keep unet in train mode if we are using gradient checkpointing to save memory.
545+ # The dropout is 0 so it doesn't matter if we are in eval or train mode.
546+ if args .gradient_checkpointing :
547+ unet .train ()
547548
548549 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
549550 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -609,12 +610,11 @@ def main():
609610 latents = latents * 0.18215
610611
611612 # Sample noise that we'll add to the latents
612- noise = torch .randn (latents . shape ). to ( latents . device ). to ( dtype = weight_dtype )
613+ noise = torch .randn_like (latents )
613614 bsz = latents .shape [0 ]
614615 # Sample a random timestep for each image
615- timesteps = torch .randint (
616- 0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device
617- ).long ()
616+ timesteps = torch .randint (0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device )
617+ timesteps = timesteps .long ()
618618
619619 # Add noise to the latents according to the noise magnitude at each timestep
620620 # (this is the forward diffusion process)
@@ -669,8 +669,7 @@ def main():
669669 if global_step >= args .max_train_steps :
670670 break
671671
672- accelerator .wait_for_everyone ()
673-
672+ accelerator .wait_for_everyone ()
674673 # Create the pipeline using using the trained modules and save it.
675674 if accelerator .is_main_process :
676675 if args .push_to_hub and args .only_save_embeds :
0 commit comments