@@ -285,15 +285,34 @@ def create_ldm_bert_config(original_config):
285285 return config
286286
287287
288- def convert_ldm_unet_checkpoint (checkpoint , config ):
288+ def convert_ldm_unet_checkpoint (checkpoint , config , path = None , extract_ema = False ):
289289 """
290290 Takes a state dict and a config, and returns a converted checkpoint.
291291 """
292292
293293 # extract state_dict for UNet
294294 unet_state_dict = {}
295- unet_key = "model.diffusion_model."
296295 keys = list (checkpoint .keys ())
296+
297+ unet_key = "model.diffusion_model."
298+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
299+ if sum (k .startswith ("model_ema" ) for k in keys ) > 100 :
300+ print (f"Checkpoint { path } has both EMA and non-EMA weights." )
301+ if extract_ema :
302+ print (
303+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
304+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
305+ )
306+ for key in keys :
307+ if key .startswith ("model.diffusion_model" ):
308+ flat_ema_key = "model_ema." + "" .join (key .split ("." )[1 :])
309+ unet_state_dict [key .replace (unet_key , "" )] = checkpoint .pop (flat_ema_key )
310+ else :
311+ print (
312+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
313+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
314+ )
315+
297316 for key in keys :
298317 if key .startswith (unet_key ):
299318 unet_state_dict [key .replace (unet_key , "" )] = checkpoint .pop (key )
@@ -630,6 +649,15 @@ def convert_ldm_clip_checkpoint(checkpoint):
630649 type = str ,
631650 help = "Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']" ,
632651 )
652+ parser .add_argument (
653+ "--extract_ema" ,
654+ action = "store_true" ,
655+ help = (
656+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
657+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
658+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
659+ ),
660+ )
633661 parser .add_argument ("--dump_path" , default = None , type = str , required = True , help = "Path to the output model." )
634662
635663 args = parser .parse_args ()
@@ -641,7 +669,9 @@ def convert_ldm_clip_checkpoint(checkpoint):
641669 args .original_config_file = "./v1-inference.yaml"
642670
643671 original_config = OmegaConf .load (args .original_config_file )
644- checkpoint = torch .load (args .checkpoint_path )["state_dict" ]
672+
673+ checkpoint = torch .load (args .checkpoint_path )
674+ checkpoint = checkpoint ["state_dict" ]
645675
646676 num_train_timesteps = original_config .model .params .timesteps
647677 beta_start = original_config .model .params .linear_start
@@ -669,7 +699,9 @@ def convert_ldm_clip_checkpoint(checkpoint):
669699
670700 # Convert the UNet2DConditionModel model.
671701 unet_config = create_unet_diffusers_config (original_config )
672- converted_unet_checkpoint = convert_ldm_unet_checkpoint (checkpoint , unet_config )
702+ converted_unet_checkpoint = convert_ldm_unet_checkpoint (
703+ checkpoint , unet_config , path = args .checkpoint_path , extract_ema = args .extract_ema
704+ )
673705
674706 unet = UNet2DConditionModel (** unet_config )
675707 unet .load_state_dict (converted_unet_checkpoint )
0 commit comments