From 22af79f8be9b25391cca9dafd086eda5e61e0f2f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 1 Dec 2022 18:34:37 +0000 Subject: [PATCH 1/2] Conversion SD 2 --- ..._original_stable_diffusion_to_diffusers.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 2d354df93818..3c50097f06c0 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -666,16 +666,28 @@ def convert_ldm_clip_checkpoint(checkpoint): args = parser.parse_args() + checkpoint = torch.load(args.checkpoint_path) + checkpoint = checkpoint["state_dict"] + + prediction_type = "epsilon" if args.original_config_file is None: - os.system( - "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - ) - args.original_config_file = "./v1-inference.yaml" + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - original_config = OmegaConf.load(args.original_config_file) + if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: + # model_type = "v2" + os.system( + "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + ) + args.original_config_file = "./v2-inference-v.yaml" + prediction_type + else: + # model_type = "v2" + os.system( + "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) + args.original_config_file = "./v1-inference.yaml" - checkpoint = torch.load(args.checkpoint_path) - checkpoint = checkpoint["state_dict"] + original_config = OmegaConf.load(args.original_config_file) num_train_timesteps = original_config.model.params.timesteps beta_start = original_config.model.params.linear_start From 18bf71375bfbe13301fe8d7b859ba1bc45c35008 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 2 Dec 2022 11:19:56 +0000 Subject: [PATCH 2/2] finish --- ..._original_stable_diffusion_to_diffusers.py | 132 +++++++++++++----- 1 file changed, 100 insertions(+), 32 deletions(-) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 1d0e74f7d6d6..ef3c76bfc65a 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -33,6 +33,7 @@ DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + HeunDiscreteScheduler, LDMTextToImagePipeline, LMSDiscreteScheduler, PNDMScheduler, @@ -232,6 +233,15 @@ def create_unet_diffusers_config(original_config, image_size: int): vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim = [5, 10, 20, 20] + config = dict( sample_size=image_size // vae_scale_factor, in_channels=unet_params.in_channels, @@ -241,7 +251,8 @@ def create_unet_diffusers_config(original_config, image_size: int): block_out_channels=tuple(block_out_channels), layers_per_block=unet_params.num_res_blocks, cross_attention_dim=unet_params.context_dim, - attention_head_dim=unet_params.num_heads, + attention_head_dim=head_dim, + use_linear_projection=use_linear_projection, ) return config @@ -636,6 +647,22 @@ def convert_ldm_clip_checkpoint(checkpoint): return text_model +def convert_open_clip_checkpoint(checkpoint): + text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + + # SKIP for now - need openclip -> HF conversion script here + # keys = list(checkpoint.keys()) + # + # text_model_dict = {} + # for key in keys: + # if key.startswith("cond_stage_model.model.transformer"): + # text_model_dict[key[len("cond_stage_model.model.transformer.") :]] = checkpoint[key] + # + # text_model.load_state_dict(text_model_dict) + + return text_model + + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -657,13 +684,22 @@ def convert_ldm_clip_checkpoint(checkpoint): ) parser.add_argument( "--image_size", - default=512, + default=None, type=int, help=( "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" " Base. Use 768 for Stable Diffusion v2." ), ) + parser.add_argument( + "--prediction_type", + default=None, + type=int, + help=( + "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable" + " Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2." + ), + ) parser.add_argument( "--extract_ema", action="store_true", @@ -674,13 +710,15 @@ def convert_ldm_clip_checkpoint(checkpoint): ), ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") - args = parser.parse_args() + image_size = args.image_size + prediction_type = args.prediction_type + checkpoint = torch.load(args.checkpoint_path) + global_step = checkpoint["global_step"] checkpoint = checkpoint["state_dict"] - prediction_type = "epsilon" if args.original_config_file is None: key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" @@ -690,9 +728,8 @@ def convert_ldm_clip_checkpoint(checkpoint): "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" ) args.original_config_file = "./v2-inference-v.yaml" - prediction_type else: - # model_type = "v2" + # model_type = "v1" os.system( "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) @@ -700,51 +737,69 @@ def convert_ldm_clip_checkpoint(checkpoint): original_config = OmegaConf.load(args.original_config_file) + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + num_train_timesteps = original_config.model.params.timesteps beta_start = original_config.model.params.linear_start beta_end = original_config.model.params.linear_end + + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) if args.scheduler_type == "pndm": - scheduler = PNDMScheduler( - beta_end=beta_end, - beta_schedule="scaled_linear", - beta_start=beta_start, - num_train_timesteps=num_train_timesteps, - skip_prk_steps=True, - ) + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) elif args.scheduler_type == "lms": - scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif args.scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) elif args.scheduler_type == "euler": - scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) elif args.scheduler_type == "euler-ancestral": - scheduler = EulerAncestralDiscreteScheduler( - beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear" - ) + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) elif args.scheduler_type == "dpm": - scheduler = DPMSolverMultistepScheduler( - beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear" - ) + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) elif args.scheduler_type == "ddim": - scheduler = DDIMScheduler( - beta_start=beta_start, - beta_end=beta_end, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - ) + scheduler = scheduler else: raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(original_config, image_size=args.image_size) + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet = UNet2DConditionModel(**unet_config) + converted_unet_checkpoint = convert_ldm_unet_checkpoint( checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema ) - unet = UNet2DConditionModel(**unet_config) unet.load_state_dict(converted_unet_checkpoint) # Convert the VAE model. - vae_config = create_vae_diffusers_config(original_config, image_size=args.image_size) + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) vae = AutoencoderKL(**vae_config) @@ -752,7 +807,20 @@ def convert_ldm_clip_checkpoint(checkpoint): # Convert the text model. text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] - if text_model_type == "FrozenCLIPEmbedder": + if text_model_type == "FrozenOpenCLIPEmbedder": + text_model = convert_open_clip_checkpoint(checkpoint) + tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + elif text_model_type == "FrozenCLIPEmbedder": text_model = convert_ldm_clip_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")