3333 DPMSolverMultistepScheduler ,
3434 EulerAncestralDiscreteScheduler ,
3535 EulerDiscreteScheduler ,
36+ HeunDiscreteScheduler ,
3637 LDMTextToImagePipeline ,
3738 LMSDiscreteScheduler ,
3839 PNDMScheduler ,
@@ -232,6 +233,15 @@ def create_unet_diffusers_config(original_config, image_size: int):
232233
233234 vae_scale_factor = 2 ** (len (vae_params .ch_mult ) - 1 )
234235
236+ head_dim = unet_params .num_heads if "num_heads" in unet_params else None
237+ use_linear_projection = (
238+ unet_params .use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
239+ )
240+ if use_linear_projection :
241+ # stable diffusion 2-base-512 and 2-768
242+ if head_dim is None :
243+ head_dim = [5 , 10 , 20 , 20 ]
244+
235245 config = dict (
236246 sample_size = image_size // vae_scale_factor ,
237247 in_channels = unet_params .in_channels ,
@@ -241,7 +251,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
241251 block_out_channels = tuple (block_out_channels ),
242252 layers_per_block = unet_params .num_res_blocks ,
243253 cross_attention_dim = unet_params .context_dim ,
244- attention_head_dim = unet_params .num_heads ,
254+ attention_head_dim = head_dim ,
255+ use_linear_projection = use_linear_projection ,
245256 )
246257
247258 return config
@@ -636,6 +647,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
636647 return text_model
637648
638649
650+ def convert_open_clip_checkpoint (checkpoint ):
651+ text_model = CLIPTextModel .from_pretrained ("stabilityai/stable-diffusion-2" , subfolder = "text_encoder" )
652+
653+ # SKIP for now - need openclip -> HF conversion script here
654+ # keys = list(checkpoint.keys())
655+ #
656+ # text_model_dict = {}
657+ # for key in keys:
658+ # if key.startswith("cond_stage_model.model.transformer"):
659+ # text_model_dict[key[len("cond_stage_model.model.transformer.") :]] = checkpoint[key]
660+ #
661+ # text_model.load_state_dict(text_model_dict)
662+
663+ return text_model
664+
665+
639666if __name__ == "__main__" :
640667 parser = argparse .ArgumentParser ()
641668
@@ -657,13 +684,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
657684 )
658685 parser .add_argument (
659686 "--image_size" ,
660- default = 512 ,
687+ default = None ,
661688 type = int ,
662689 help = (
663690 "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
664691 " Base. Use 768 for Stable Diffusion v2."
665692 ),
666693 )
694+ parser .add_argument (
695+ "--prediction_type" ,
696+ default = None ,
697+ type = int ,
698+ help = (
699+ "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
700+ " Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2."
701+ ),
702+ )
667703 parser .add_argument (
668704 "--extract_ema" ,
669705 action = "store_true" ,
@@ -674,73 +710,117 @@ def convert_ldm_clip_checkpoint(checkpoint):
674710 ),
675711 )
676712 parser .add_argument ("--dump_path" , default = None , type = str , required = True , help = "Path to the output model." )
677-
678713 args = parser .parse_args ()
679714
715+ image_size = args .image_size
716+ prediction_type = args .prediction_type
717+
718+ checkpoint = torch .load (args .checkpoint_path )
719+ global_step = checkpoint ["global_step" ]
720+ checkpoint = checkpoint ["state_dict" ]
721+
680722 if args .original_config_file is None :
681- os .system (
682- "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
683- )
684- args .original_config_file = "./v1-inference.yaml"
723+ key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
724+
725+ if key_name in checkpoint and checkpoint [key_name ].shape [- 1 ] == 1024 :
726+ # model_type = "v2"
727+ os .system (
728+ "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
729+ )
730+ args .original_config_file = "./v2-inference-v.yaml"
731+ else :
732+ # model_type = "v1"
733+ os .system (
734+ "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
735+ )
736+ args .original_config_file = "./v1-inference.yaml"
685737
686738 original_config = OmegaConf .load (args .original_config_file )
687739
688- checkpoint = torch .load (args .checkpoint_path )
689- checkpoint = checkpoint ["state_dict" ]
740+ if (
741+ "parameterization" in original_config ["model" ]["params" ]
742+ and original_config ["model" ]["params" ]["parameterization" ] == "v"
743+ ):
744+ if prediction_type is None :
745+ # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
746+ # as it relies on a brittle global step parameter here
747+ prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
748+ if image_size is None :
749+ # NOTE: For stable diffusion 2 base one has to pass `image_size==512`
750+ # as it relies on a brittle global step parameter here
751+ image_size = 512 if global_step == 875000 else 768
752+ else :
753+ if prediction_type is None :
754+ prediction_type = "epsilon"
755+ if image_size is None :
756+ image_size = 512
690757
691758 num_train_timesteps = original_config .model .params .timesteps
692759 beta_start = original_config .model .params .linear_start
693760 beta_end = original_config .model .params .linear_end
761+
762+ scheduler = DDIMScheduler (
763+ beta_end = beta_end ,
764+ beta_schedule = "scaled_linear" ,
765+ beta_start = beta_start ,
766+ num_train_timesteps = num_train_timesteps ,
767+ steps_offset = 1 ,
768+ clip_sample = False ,
769+ set_alpha_to_one = False ,
770+ prediction_type = prediction_type ,
771+ )
694772 if args .scheduler_type == "pndm" :
695- scheduler = PNDMScheduler (
696- beta_end = beta_end ,
697- beta_schedule = "scaled_linear" ,
698- beta_start = beta_start ,
699- num_train_timesteps = num_train_timesteps ,
700- skip_prk_steps = True ,
701- )
773+ config = dict (scheduler .config )
774+ config ["skip_prk_steps" ] = True
775+ scheduler = PNDMScheduler .from_config (config )
702776 elif args .scheduler_type == "lms" :
703- scheduler = LMSDiscreteScheduler (beta_start = beta_start , beta_end = beta_end , beta_schedule = "scaled_linear" )
777+ scheduler = LMSDiscreteScheduler .from_config (scheduler .config )
778+ elif args .scheduler_type == "heun" :
779+ scheduler = HeunDiscreteScheduler .from_config (scheduler .config )
704780 elif args .scheduler_type == "euler" :
705- scheduler = EulerDiscreteScheduler ( beta_start = beta_start , beta_end = beta_end , beta_schedule = "scaled_linear" )
781+ scheduler = EulerDiscreteScheduler . from_config ( scheduler . config )
706782 elif args .scheduler_type == "euler-ancestral" :
707- scheduler = EulerAncestralDiscreteScheduler (
708- beta_start = beta_start , beta_end = beta_end , beta_schedule = "scaled_linear"
709- )
783+ scheduler = EulerAncestralDiscreteScheduler .from_config (scheduler .config )
710784 elif args .scheduler_type == "dpm" :
711- scheduler = DPMSolverMultistepScheduler (
712- beta_start = beta_start , beta_end = beta_end , beta_schedule = "scaled_linear"
713- )
785+ scheduler = DPMSolverMultistepScheduler .from_config (scheduler .config )
714786 elif args .scheduler_type == "ddim" :
715- scheduler = DDIMScheduler (
716- beta_start = beta_start ,
717- beta_end = beta_end ,
718- beta_schedule = "scaled_linear" ,
719- clip_sample = False ,
720- set_alpha_to_one = False ,
721- )
787+ scheduler = scheduler
722788 else :
723789 raise ValueError (f"Scheduler of type { args .scheduler_type } doesn't exist!" )
724790
725791 # Convert the UNet2DConditionModel model.
726- unet_config = create_unet_diffusers_config (original_config , image_size = args .image_size )
792+ unet_config = create_unet_diffusers_config (original_config , image_size = image_size )
793+ unet = UNet2DConditionModel (** unet_config )
794+
727795 converted_unet_checkpoint = convert_ldm_unet_checkpoint (
728796 checkpoint , unet_config , path = args .checkpoint_path , extract_ema = args .extract_ema
729797 )
730798
731- unet = UNet2DConditionModel (** unet_config )
732799 unet .load_state_dict (converted_unet_checkpoint )
733800
734801 # Convert the VAE model.
735- vae_config = create_vae_diffusers_config (original_config , image_size = args . image_size )
802+ vae_config = create_vae_diffusers_config (original_config , image_size = image_size )
736803 converted_vae_checkpoint = convert_ldm_vae_checkpoint (checkpoint , vae_config )
737804
738805 vae = AutoencoderKL (** vae_config )
739806 vae .load_state_dict (converted_vae_checkpoint )
740807
741808 # Convert the text model.
742809 text_model_type = original_config .model .params .cond_stage_config .target .split ("." )[- 1 ]
743- if text_model_type == "FrozenCLIPEmbedder" :
810+ if text_model_type == "FrozenOpenCLIPEmbedder" :
811+ text_model = convert_open_clip_checkpoint (checkpoint )
812+ tokenizer = CLIPTokenizer .from_pretrained ("stabilityai/stable-diffusion-2" , subfolder = "tokenizer" )
813+ pipe = StableDiffusionPipeline (
814+ vae = vae ,
815+ text_encoder = text_model ,
816+ tokenizer = tokenizer ,
817+ unet = unet ,
818+ scheduler = scheduler ,
819+ safety_checker = None ,
820+ feature_extractor = None ,
821+ requires_safety_checker = False ,
822+ )
823+ elif text_model_type == "FrozenCLIPEmbedder" :
744824 text_model = convert_ldm_clip_checkpoint (checkpoint )
745825 tokenizer = CLIPTokenizer .from_pretrained ("openai/clip-vit-large-patch14" )
746826 safety_checker = StableDiffusionSafetyChecker .from_pretrained ("CompVis/stable-diffusion-safety-checker" )
0 commit comments