11import argparse
2+ import itertools
23import math
34import os
45from pathlib import Path
@@ -100,6 +101,7 @@ def parse_args():
100101 parser .add_argument (
101102 "--center_crop" , action = "store_true" , help = "Whether to center crop images before resizing to resolution"
102103 )
104+ parser .add_argument ("--train_text_encoder" , action = "store_true" , help = "Whether to train the text encoder" )
103105 parser .add_argument (
104106 "--train_batch_size" , type = int , default = 4 , help = "Batch size (per device) for the training dataloader."
105107 )
@@ -320,6 +322,15 @@ def main():
320322 logging_dir = logging_dir ,
321323 )
322324
325+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
326+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
327+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
328+ if args .train_text_encoder and args .gradient_accumulation_steps > 1 and accelerator .num_processes > 1 :
329+ raise ValueError (
330+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
331+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
332+ )
333+
323334 if args .seed is not None :
324335 set_seed (args .seed )
325336
@@ -385,8 +396,14 @@ def main():
385396 vae = AutoencoderKL .from_pretrained (args .pretrained_model_name_or_path , subfolder = "vae" )
386397 unet = UNet2DConditionModel .from_pretrained (args .pretrained_model_name_or_path , subfolder = "unet" )
387398
399+ vae .requires_grad_ (False )
400+ if not args .train_text_encoder :
401+ text_encoder .requires_grad_ (False )
402+
388403 if args .gradient_checkpointing :
389404 unet .enable_gradient_checkpointing ()
405+ if args .train_text_encoder :
406+ text_encoder .gradient_checkpointing_enable ()
390407
391408 if args .scale_lr :
392409 args .learning_rate = (
@@ -406,8 +423,11 @@ def main():
406423 else :
407424 optimizer_class = torch .optim .AdamW
408425
426+ params_to_optimize = (
427+ itertools .chain (unet .parameters (), text_encoder .parameters ()) if args .train_text_encoder else unet .parameters ()
428+ )
409429 optimizer = optimizer_class (
410- unet . parameters (), # only optimize unet
430+ params_to_optimize ,
411431 lr = args .learning_rate ,
412432 betas = (args .adam_beta1 , args .adam_beta2 ),
413433 weight_decay = args .adam_weight_decay ,
@@ -467,9 +487,14 @@ def collate_fn(examples):
467487 num_training_steps = args .max_train_steps * args .gradient_accumulation_steps ,
468488 )
469489
470- unet , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
471- unet , optimizer , train_dataloader , lr_scheduler
472- )
490+ if args .train_text_encoder :
491+ unet , text_encoder , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
492+ unet , text_encoder , optimizer , train_dataloader , lr_scheduler
493+ )
494+ else :
495+ unet , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
496+ unet , optimizer , train_dataloader , lr_scheduler
497+ )
473498
474499 weight_dtype = torch .float32
475500 if args .mixed_precision == "fp16" :
@@ -480,8 +505,9 @@ def collate_fn(examples):
480505 # Move text_encode and vae to gpu.
481506 # For mixed precision training we cast the text_encoder and vae weights to half-precision
482507 # as these models are only used for inference, keeping weights in full precision is not required.
483- text_encoder .to (accelerator .device , dtype = weight_dtype )
484508 vae .to (accelerator .device , dtype = weight_dtype )
509+ if not args .train_text_encoder :
510+ text_encoder .to (accelerator .device , dtype = weight_dtype )
485511
486512 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
487513 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -516,9 +542,8 @@ def collate_fn(examples):
516542 for step , batch in enumerate (train_dataloader ):
517543 with accelerator .accumulate (unet ):
518544 # Convert images to latent space
519- with torch .no_grad ():
520- latents = vae .encode (batch ["pixel_values" ].to (dtype = weight_dtype )).latent_dist .sample ()
521- latents = latents * 0.18215
545+ latents = vae .encode (batch ["pixel_values" ].to (dtype = weight_dtype )).latent_dist .sample ()
546+ latents = latents * 0.18215
522547
523548 # Sample noise that we'll add to the latents
524549 noise = torch .randn_like (latents )
@@ -532,8 +557,7 @@ def collate_fn(examples):
532557 noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
533558
534559 # Get the text embedding for conditioning
535- with torch .no_grad ():
536- encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
560+ encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
537561
538562 # Predict the noise residual
539563 noise_pred = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
@@ -556,7 +580,12 @@ def collate_fn(examples):
556580
557581 accelerator .backward (loss )
558582 if accelerator .sync_gradients :
559- accelerator .clip_grad_norm_ (unet .parameters (), args .max_grad_norm )
583+ params_to_clip = (
584+ itertools .chain (unet .parameters (), text_encoder .parameters ())
585+ if args .train_text_encoder
586+ else unet .parameters ()
587+ )
588+ accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
560589 optimizer .step ()
561590 lr_scheduler .step ()
562591 optimizer .zero_grad ()
@@ -578,7 +607,9 @@ def collate_fn(examples):
578607 # Create the pipeline using using the trained modules and save it.
579608 if accelerator .is_main_process :
580609 pipeline = StableDiffusionPipeline .from_pretrained (
581- args .pretrained_model_name_or_path , unet = accelerator .unwrap_model (unet )
610+ args .pretrained_model_name_or_path ,
611+ unet = accelerator .unwrap_model (unet ),
612+ text_encoder = accelerator .unwrap_model (text_encoder ),
582613 )
583614 pipeline .save_pretrained (args .output_dir )
584615
0 commit comments