From 01ab0ae3c42930c8d403318daa3e5158d9915156 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 Jun 2023 05:19:58 +0000 Subject: [PATCH 1/2] update kandinsky conversion script --- scripts/convert_kandinsky_to_diffusers.py | 184 ++++++++++++---------- 1 file changed, 102 insertions(+), 82 deletions(-) diff --git a/scripts/convert_kandinsky_to_diffusers.py b/scripts/convert_kandinsky_to_diffusers.py index de9879f7f03b..df6e3b74871e 100644 --- a/scripts/convert_kandinsky_to_diffusers.py +++ b/scripts/convert_kandinsky_to_diffusers.py @@ -8,8 +8,6 @@ from diffusers import UNet2DConditionModel from diffusers.models.prior_transformer import PriorTransformer from diffusers.models.vq_model import VQModel -from diffusers.pipelines.kandinsky.text_proj import KandinskyTextProjModel - """ Example - From the diffusers root directory: @@ -225,39 +223,57 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix UNET_CONFIG = { "act_fn": "silu", + "addition_embed_type": "text_image", + "addition_embed_type_num_heads": 64, "attention_head_dim": 64, - "block_out_channels": (384, 768, 1152, 1536), + "block_out_channels": [384,768,1152,1536], "center_input_sample": False, - "class_embed_type": "identity", + "class_embed_type": None, + "class_embeddings_concat": False, + "conv_in_kernel": 3, + "conv_out_kernel": 3, "cross_attention_dim": 768, - "down_block_types": ( + "cross_attention_norm": None, + "down_block_types": [ "ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D", - "SimpleCrossAttnDownBlock2D", - ), + "SimpleCrossAttnDownBlock2D" + ], "downsample_padding": 1, "dual_cross_attention": False, + "encoder_hid_dim": 1024, + "encoder_hid_dim_type": "text_image_proj", "flip_sin_to_cos": True, "freq_shift": 0, "in_channels": 4, "layers_per_block": 3, + "mid_block_only_cross_attention": None, "mid_block_scale_factor": 1, "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, + "num_class_embeds": None, "only_cross_attention": False, "out_channels": 8, + "projection_class_embeddings_input_dim": None, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": False, "resnet_time_scale_shift": "scale_shift", "sample_size": 64, - "up_block_types": ( + "time_cond_proj_dim": None, + "time_embedding_act_fn": None, + "time_embedding_dim": None, + "time_embedding_type": "positional", + "timestep_post_act": None, + "up_block_types": [ "SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D", - "ResnetUpsampleBlock2D", - ), + "ResnetUpsampleBlock2D" + ], "upcast_attention": False, - "use_linear_projection": False, + "use_linear_projection": False } @@ -274,6 +290,9 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): diffusers_checkpoint.update(unet_time_embeddings(checkpoint)) diffusers_checkpoint.update(unet_conv_in(checkpoint)) + diffusers_checkpoint.update(unet_add_embedding(checkpoint)) + diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint)) + # .input_blocks -> .down_blocks @@ -336,39 +355,62 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): INPAINT_UNET_CONFIG = { "act_fn": "silu", + "addition_embed_type": "text_image", + "addition_embed_type_num_heads": 64, "attention_head_dim": 64, - "block_out_channels": (384, 768, 1152, 1536), + "block_out_channels": [ + 384, + 768, + 1152, + 1536 + ], "center_input_sample": False, - "class_embed_type": "identity", + "class_embed_type": None, + "class_embeddings_concat": None, + "conv_in_kernel": 3, + "conv_out_kernel": 3, "cross_attention_dim": 768, - "down_block_types": ( + "cross_attention_norm": None, + "down_block_types": [ "ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D", - "SimpleCrossAttnDownBlock2D", - ), + "SimpleCrossAttnDownBlock2D" + ], "downsample_padding": 1, "dual_cross_attention": False, + "encoder_hid_dim": 1024, + "encoder_hid_dim_type": "text_image_proj", "flip_sin_to_cos": True, "freq_shift": 0, "in_channels": 9, "layers_per_block": 3, + "mid_block_only_cross_attention": None, "mid_block_scale_factor": 1, "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, + "num_class_embeds": None, "only_cross_attention": False, "out_channels": 8, + "projection_class_embeddings_input_dim": None, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": False, "resnet_time_scale_shift": "scale_shift", "sample_size": 64, - "up_block_types": ( + "time_cond_proj_dim": None, + "time_embedding_act_fn": None, + "time_embedding_dim": None, + "time_embedding_type": "positional", + "timestep_post_act": None, + "up_block_types": [ "SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D", - "ResnetUpsampleBlock2D", - ), + "ResnetUpsampleBlock2D" + ], "upcast_attention": False, - "use_linear_projection": False, + "use_linear_projection": False } @@ -381,10 +423,12 @@ def inpaint_unet_model_from_original_config(): def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): diffusers_checkpoint = {} - num_head_channels = UNET_CONFIG["attention_head_dim"] + num_head_channels = INPAINT_UNET_CONFIG["attention_head_dim"] diffusers_checkpoint.update(unet_time_embeddings(checkpoint)) diffusers_checkpoint.update(unet_conv_in(checkpoint)) + diffusers_checkpoint.update(unet_add_embedding(checkpoint)) + diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint)) # .input_blocks -> .down_blocks @@ -440,38 +484,6 @@ def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): # done inpaint unet -# text proj - -TEXT_PROJ_CONFIG = {} - - -def text_proj_from_original_config(): - model = KandinskyTextProjModel(**TEXT_PROJ_CONFIG) - return model - - -# Note that the input checkpoint is the original text2img model checkpoint -def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint): - diffusers_checkpoint = { - # .text_seq_proj.0 -> .encoder_hidden_states_proj - "encoder_hidden_states_proj.weight": checkpoint["to_model_dim_n.weight"], - "encoder_hidden_states_proj.bias": checkpoint["to_model_dim_n.bias"], - # .clip_tok_proj -> .clip_extra_context_tokens_proj - "clip_extra_context_tokens_proj.weight": checkpoint["clip_to_seq.weight"], - "clip_extra_context_tokens_proj.bias": checkpoint["clip_to_seq.bias"], - # .proj_n -> .embedding_proj - "embedding_proj.weight": checkpoint["proj_n.weight"], - "embedding_proj.bias": checkpoint["proj_n.bias"], - # .ln_model_n -> .embedding_norm - "embedding_norm.weight": checkpoint["ln_model_n.weight"], - "embedding_norm.bias": checkpoint["ln_model_n.bias"], - # .clip_emb -> .clip_image_embeddings_project_to_time_embeddings - "clip_image_embeddings_project_to_time_embeddings.weight": checkpoint["img_layer.weight"], - "clip_image_embeddings_project_to_time_embeddings.bias": checkpoint["img_layer.bias"], - } - - return diffusers_checkpoint - # unet utils @@ -505,6 +517,39 @@ def unet_conv_in(checkpoint): return diffusers_checkpoint +def unet_add_embedding(checkpoint): + diffusers_checkpoint = {} + + + diffusers_checkpoint.update( + { + "add_embedding.text_norm.weight": checkpoint["ln_model_n.weight"], + "add_embedding.text_norm.bias": checkpoint["ln_model_n.bias"], + "add_embedding.text_proj.weight": checkpoint["proj_n.weight"], + "add_embedding.text_proj.bias": checkpoint["proj_n.bias"], + "add_embedding.image_proj.weight": checkpoint["img_layer.weight"], + "add_embedding.image_proj.bias": checkpoint["img_layer.bias"] + } + ) + + return diffusers_checkpoint + +def unet_encoder_hid_proj(checkpoint): + diffusers_checkpoint = {} + + + diffusers_checkpoint.update( + { + "encoder_hid_proj.image_embeds.weight": checkpoint["clip_to_seq.weight"], + "encoder_hid_proj.image_embeds.bias": checkpoint["clip_to_seq.bias"], + "encoder_hid_proj.text_proj.weight": checkpoint["to_model_dim_n.weight"], + "encoder_hid_proj.text_proj.bias": checkpoint["to_model_dim_n.bias"], + } + ) + + return diffusers_checkpoint + + # .out.0 -> .conv_norm_out def unet_conv_norm_out(checkpoint): @@ -857,25 +902,13 @@ def text2img(*, args, checkpoint_map_location): unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint) - # text proj interlude - - # The original decoder implementation includes a set of parameters that are used - # for creating the `encoder_hidden_states` which are what the U-net is conditioned - # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull - # the parameters into the KandinskyTextProjModel class - text_proj_model = text_proj_from_original_config() - - text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(text2img_checkpoint) - - load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True) - del text2img_checkpoint load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True) print("done loading text2img") - return unet_model, text_proj_model + return unet_model def inpaint_text2img(*, args, checkpoint_map_location): @@ -891,17 +924,6 @@ def inpaint_text2img(*, args, checkpoint_map_location): inpaint_unet_model, inpaint_text2img_checkpoint ) - # text proj interlude - - # The original decoder implementation includes a set of parameters that are used - # for creating the `encoder_hidden_states` which are what the U-net is conditioned - # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull - # the parameters into the KandinskyTextProjModel class - text_proj_model = text_proj_from_original_config() - - text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(inpaint_text2img_checkpoint) - - load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True) del inpaint_text2img_checkpoint @@ -909,7 +931,7 @@ def inpaint_text2img(*, args, checkpoint_map_location): print("done loading inpaint text2img") - return inpaint_unet_model, text_proj_model + return inpaint_unet_model # movq @@ -1384,15 +1406,13 @@ def load_checkpoint_to_model(checkpoint, model, strict=False): prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) prior_model.save_pretrained(args.dump_path) elif args.debug == "text2img": - unet_model, text_proj_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location) + unet_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location) unet_model.save_pretrained(f"{args.dump_path}/unet") - text_proj_model.save_pretrained(f"{args.dump_path}/text_proj") elif args.debug == "inpaint_text2img": - inpaint_unet_model, inpaint_text_proj_model = inpaint_text2img( + inpaint_unet_model = inpaint_text2img( args=args, checkpoint_map_location=checkpoint_map_location ) inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet") - inpaint_text_proj_model.save_pretrained(f"{args.dump_path}/inpaint_text_proj") elif args.debug == "decoder": decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location) decoder.save_pretrained(f"{args.dump_path}/decoder") From 91c51b3aa97a92abaeccc49cf6af5f2f2bdbb728 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 Jun 2023 05:20:49 +0000 Subject: [PATCH 2/2] style --- scripts/convert_kandinsky_to_diffusers.py | 35 +++++++++-------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/scripts/convert_kandinsky_to_diffusers.py b/scripts/convert_kandinsky_to_diffusers.py index df6e3b74871e..1b5722f5d5f3 100644 --- a/scripts/convert_kandinsky_to_diffusers.py +++ b/scripts/convert_kandinsky_to_diffusers.py @@ -9,6 +9,7 @@ from diffusers.models.prior_transformer import PriorTransformer from diffusers.models.vq_model import VQModel + """ Example - From the diffusers root directory: @@ -226,7 +227,7 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix "addition_embed_type": "text_image", "addition_embed_type_num_heads": 64, "attention_head_dim": 64, - "block_out_channels": [384,768,1152,1536], + "block_out_channels": [384, 768, 1152, 1536], "center_input_sample": False, "class_embed_type": None, "class_embeddings_concat": False, @@ -238,7 +239,7 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix "ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D", - "SimpleCrossAttnDownBlock2D" + "SimpleCrossAttnDownBlock2D", ], "downsample_padding": 1, "dual_cross_attention": False, @@ -270,10 +271,10 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix "SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D", - "ResnetUpsampleBlock2D" + "ResnetUpsampleBlock2D", ], "upcast_attention": False, - "use_linear_projection": False + "use_linear_projection": False, } @@ -293,7 +294,6 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): diffusers_checkpoint.update(unet_add_embedding(checkpoint)) diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint)) - # .input_blocks -> .down_blocks original_down_block_idx = 1 @@ -358,12 +358,7 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): "addition_embed_type": "text_image", "addition_embed_type_num_heads": 64, "attention_head_dim": 64, - "block_out_channels": [ - 384, - 768, - 1152, - 1536 - ], + "block_out_channels": [384, 768, 1152, 1536], "center_input_sample": False, "class_embed_type": None, "class_embeddings_concat": None, @@ -375,7 +370,7 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): "ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D", - "SimpleCrossAttnDownBlock2D" + "SimpleCrossAttnDownBlock2D", ], "downsample_padding": 1, "dual_cross_attention": False, @@ -407,10 +402,10 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): "SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D", - "ResnetUpsampleBlock2D" + "ResnetUpsampleBlock2D", ], "upcast_attention": False, - "use_linear_projection": False + "use_linear_projection": False, } @@ -517,10 +512,10 @@ def unet_conv_in(checkpoint): return diffusers_checkpoint + def unet_add_embedding(checkpoint): diffusers_checkpoint = {} - diffusers_checkpoint.update( { "add_embedding.text_norm.weight": checkpoint["ln_model_n.weight"], @@ -528,16 +523,16 @@ def unet_add_embedding(checkpoint): "add_embedding.text_proj.weight": checkpoint["proj_n.weight"], "add_embedding.text_proj.bias": checkpoint["proj_n.bias"], "add_embedding.image_proj.weight": checkpoint["img_layer.weight"], - "add_embedding.image_proj.bias": checkpoint["img_layer.bias"] + "add_embedding.image_proj.bias": checkpoint["img_layer.bias"], } ) return diffusers_checkpoint + def unet_encoder_hid_proj(checkpoint): diffusers_checkpoint = {} - diffusers_checkpoint.update( { "encoder_hid_proj.image_embeds.weight": checkpoint["clip_to_seq.weight"], @@ -550,7 +545,6 @@ def unet_encoder_hid_proj(checkpoint): return diffusers_checkpoint - # .out.0 -> .conv_norm_out def unet_conv_norm_out(checkpoint): diffusers_checkpoint = {} @@ -924,7 +918,6 @@ def inpaint_text2img(*, args, checkpoint_map_location): inpaint_unet_model, inpaint_text2img_checkpoint ) - del inpaint_text2img_checkpoint load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True) @@ -1409,9 +1402,7 @@ def load_checkpoint_to_model(checkpoint, model, strict=False): unet_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location) unet_model.save_pretrained(f"{args.dump_path}/unet") elif args.debug == "inpaint_text2img": - inpaint_unet_model = inpaint_text2img( - args=args, checkpoint_map_location=checkpoint_map_location - ) + inpaint_unet_model = inpaint_text2img(args=args, checkpoint_map_location=checkpoint_map_location) inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet") elif args.debug == "decoder": decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location)