88from diffusers import UNet2DConditionModel
99from diffusers .models .prior_transformer import PriorTransformer
1010from diffusers .models .vq_model import VQModel
11- from diffusers .pipelines .kandinsky .text_proj import KandinskyTextProjModel
1211
1312
1413"""
@@ -225,37 +224,55 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix
225224
226225UNET_CONFIG = {
227226 "act_fn" : "silu" ,
227+ "addition_embed_type" : "text_image" ,
228+ "addition_embed_type_num_heads" : 64 ,
228229 "attention_head_dim" : 64 ,
229- "block_out_channels" : ( 384 , 768 , 1152 , 1536 ) ,
230+ "block_out_channels" : [ 384 , 768 , 1152 , 1536 ] ,
230231 "center_input_sample" : False ,
231- "class_embed_type" : "identity" ,
232+ "class_embed_type" : None ,
233+ "class_embeddings_concat" : False ,
234+ "conv_in_kernel" : 3 ,
235+ "conv_out_kernel" : 3 ,
232236 "cross_attention_dim" : 768 ,
233- "down_block_types" : (
237+ "cross_attention_norm" : None ,
238+ "down_block_types" : [
234239 "ResnetDownsampleBlock2D" ,
235240 "SimpleCrossAttnDownBlock2D" ,
236241 "SimpleCrossAttnDownBlock2D" ,
237242 "SimpleCrossAttnDownBlock2D" ,
238- ) ,
243+ ] ,
239244 "downsample_padding" : 1 ,
240245 "dual_cross_attention" : False ,
246+ "encoder_hid_dim" : 1024 ,
247+ "encoder_hid_dim_type" : "text_image_proj" ,
241248 "flip_sin_to_cos" : True ,
242249 "freq_shift" : 0 ,
243250 "in_channels" : 4 ,
244251 "layers_per_block" : 3 ,
252+ "mid_block_only_cross_attention" : None ,
245253 "mid_block_scale_factor" : 1 ,
246254 "mid_block_type" : "UNetMidBlock2DSimpleCrossAttn" ,
247255 "norm_eps" : 1e-05 ,
248256 "norm_num_groups" : 32 ,
257+ "num_class_embeds" : None ,
249258 "only_cross_attention" : False ,
250259 "out_channels" : 8 ,
260+ "projection_class_embeddings_input_dim" : None ,
261+ "resnet_out_scale_factor" : 1.0 ,
262+ "resnet_skip_time_act" : False ,
251263 "resnet_time_scale_shift" : "scale_shift" ,
252264 "sample_size" : 64 ,
253- "up_block_types" : (
265+ "time_cond_proj_dim" : None ,
266+ "time_embedding_act_fn" : None ,
267+ "time_embedding_dim" : None ,
268+ "time_embedding_type" : "positional" ,
269+ "timestep_post_act" : None ,
270+ "up_block_types" : [
254271 "SimpleCrossAttnUpBlock2D" ,
255272 "SimpleCrossAttnUpBlock2D" ,
256273 "SimpleCrossAttnUpBlock2D" ,
257274 "ResnetUpsampleBlock2D" ,
258- ) ,
275+ ] ,
259276 "upcast_attention" : False ,
260277 "use_linear_projection" : False ,
261278}
@@ -274,6 +291,8 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
274291
275292 diffusers_checkpoint .update (unet_time_embeddings (checkpoint ))
276293 diffusers_checkpoint .update (unet_conv_in (checkpoint ))
294+ diffusers_checkpoint .update (unet_add_embedding (checkpoint ))
295+ diffusers_checkpoint .update (unet_encoder_hid_proj (checkpoint ))
277296
278297 # <original>.input_blocks -> <diffusers>.down_blocks
279298
@@ -336,37 +355,55 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
336355
337356INPAINT_UNET_CONFIG = {
338357 "act_fn" : "silu" ,
358+ "addition_embed_type" : "text_image" ,
359+ "addition_embed_type_num_heads" : 64 ,
339360 "attention_head_dim" : 64 ,
340- "block_out_channels" : ( 384 , 768 , 1152 , 1536 ) ,
361+ "block_out_channels" : [ 384 , 768 , 1152 , 1536 ] ,
341362 "center_input_sample" : False ,
342- "class_embed_type" : "identity" ,
363+ "class_embed_type" : None ,
364+ "class_embeddings_concat" : None ,
365+ "conv_in_kernel" : 3 ,
366+ "conv_out_kernel" : 3 ,
343367 "cross_attention_dim" : 768 ,
344- "down_block_types" : (
368+ "cross_attention_norm" : None ,
369+ "down_block_types" : [
345370 "ResnetDownsampleBlock2D" ,
346371 "SimpleCrossAttnDownBlock2D" ,
347372 "SimpleCrossAttnDownBlock2D" ,
348373 "SimpleCrossAttnDownBlock2D" ,
349- ) ,
374+ ] ,
350375 "downsample_padding" : 1 ,
351376 "dual_cross_attention" : False ,
377+ "encoder_hid_dim" : 1024 ,
378+ "encoder_hid_dim_type" : "text_image_proj" ,
352379 "flip_sin_to_cos" : True ,
353380 "freq_shift" : 0 ,
354381 "in_channels" : 9 ,
355382 "layers_per_block" : 3 ,
383+ "mid_block_only_cross_attention" : None ,
356384 "mid_block_scale_factor" : 1 ,
357385 "mid_block_type" : "UNetMidBlock2DSimpleCrossAttn" ,
358386 "norm_eps" : 1e-05 ,
359387 "norm_num_groups" : 32 ,
388+ "num_class_embeds" : None ,
360389 "only_cross_attention" : False ,
361390 "out_channels" : 8 ,
391+ "projection_class_embeddings_input_dim" : None ,
392+ "resnet_out_scale_factor" : 1.0 ,
393+ "resnet_skip_time_act" : False ,
362394 "resnet_time_scale_shift" : "scale_shift" ,
363395 "sample_size" : 64 ,
364- "up_block_types" : (
396+ "time_cond_proj_dim" : None ,
397+ "time_embedding_act_fn" : None ,
398+ "time_embedding_dim" : None ,
399+ "time_embedding_type" : "positional" ,
400+ "timestep_post_act" : None ,
401+ "up_block_types" : [
365402 "SimpleCrossAttnUpBlock2D" ,
366403 "SimpleCrossAttnUpBlock2D" ,
367404 "SimpleCrossAttnUpBlock2D" ,
368405 "ResnetUpsampleBlock2D" ,
369- ) ,
406+ ] ,
370407 "upcast_attention" : False ,
371408 "use_linear_projection" : False ,
372409}
@@ -381,10 +418,12 @@ def inpaint_unet_model_from_original_config():
381418def inpaint_unet_original_checkpoint_to_diffusers_checkpoint (model , checkpoint ):
382419 diffusers_checkpoint = {}
383420
384- num_head_channels = UNET_CONFIG ["attention_head_dim" ]
421+ num_head_channels = INPAINT_UNET_CONFIG ["attention_head_dim" ]
385422
386423 diffusers_checkpoint .update (unet_time_embeddings (checkpoint ))
387424 diffusers_checkpoint .update (unet_conv_in (checkpoint ))
425+ diffusers_checkpoint .update (unet_add_embedding (checkpoint ))
426+ diffusers_checkpoint .update (unet_encoder_hid_proj (checkpoint ))
388427
389428 # <original>.input_blocks -> <diffusers>.down_blocks
390429
@@ -440,38 +479,6 @@ def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
440479
441480# done inpaint unet
442481
443- # text proj
444-
445- TEXT_PROJ_CONFIG = {}
446-
447-
448- def text_proj_from_original_config ():
449- model = KandinskyTextProjModel (** TEXT_PROJ_CONFIG )
450- return model
451-
452-
453- # Note that the input checkpoint is the original text2img model checkpoint
454- def text_proj_original_checkpoint_to_diffusers_checkpoint (checkpoint ):
455- diffusers_checkpoint = {
456- # <original>.text_seq_proj.0 -> <diffusers>.encoder_hidden_states_proj
457- "encoder_hidden_states_proj.weight" : checkpoint ["to_model_dim_n.weight" ],
458- "encoder_hidden_states_proj.bias" : checkpoint ["to_model_dim_n.bias" ],
459- # <original>.clip_tok_proj -> <diffusers>.clip_extra_context_tokens_proj
460- "clip_extra_context_tokens_proj.weight" : checkpoint ["clip_to_seq.weight" ],
461- "clip_extra_context_tokens_proj.bias" : checkpoint ["clip_to_seq.bias" ],
462- # <original>.proj_n -> <diffusers>.embedding_proj
463- "embedding_proj.weight" : checkpoint ["proj_n.weight" ],
464- "embedding_proj.bias" : checkpoint ["proj_n.bias" ],
465- # <original>.ln_model_n -> <diffusers>.embedding_norm
466- "embedding_norm.weight" : checkpoint ["ln_model_n.weight" ],
467- "embedding_norm.bias" : checkpoint ["ln_model_n.bias" ],
468- # <original>.clip_emb -> <diffusers>.clip_image_embeddings_project_to_time_embeddings
469- "clip_image_embeddings_project_to_time_embeddings.weight" : checkpoint ["img_layer.weight" ],
470- "clip_image_embeddings_project_to_time_embeddings.bias" : checkpoint ["img_layer.bias" ],
471- }
472-
473- return diffusers_checkpoint
474-
475482
476483# unet utils
477484
@@ -506,6 +513,38 @@ def unet_conv_in(checkpoint):
506513 return diffusers_checkpoint
507514
508515
516+ def unet_add_embedding (checkpoint ):
517+ diffusers_checkpoint = {}
518+
519+ diffusers_checkpoint .update (
520+ {
521+ "add_embedding.text_norm.weight" : checkpoint ["ln_model_n.weight" ],
522+ "add_embedding.text_norm.bias" : checkpoint ["ln_model_n.bias" ],
523+ "add_embedding.text_proj.weight" : checkpoint ["proj_n.weight" ],
524+ "add_embedding.text_proj.bias" : checkpoint ["proj_n.bias" ],
525+ "add_embedding.image_proj.weight" : checkpoint ["img_layer.weight" ],
526+ "add_embedding.image_proj.bias" : checkpoint ["img_layer.bias" ],
527+ }
528+ )
529+
530+ return diffusers_checkpoint
531+
532+
533+ def unet_encoder_hid_proj (checkpoint ):
534+ diffusers_checkpoint = {}
535+
536+ diffusers_checkpoint .update (
537+ {
538+ "encoder_hid_proj.image_embeds.weight" : checkpoint ["clip_to_seq.weight" ],
539+ "encoder_hid_proj.image_embeds.bias" : checkpoint ["clip_to_seq.bias" ],
540+ "encoder_hid_proj.text_proj.weight" : checkpoint ["to_model_dim_n.weight" ],
541+ "encoder_hid_proj.text_proj.bias" : checkpoint ["to_model_dim_n.bias" ],
542+ }
543+ )
544+
545+ return diffusers_checkpoint
546+
547+
509548# <original>.out.0 -> <diffusers>.conv_norm_out
510549def unet_conv_norm_out (checkpoint ):
511550 diffusers_checkpoint = {}
@@ -857,25 +896,13 @@ def text2img(*, args, checkpoint_map_location):
857896
858897 unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint (unet_model , text2img_checkpoint )
859898
860- # text proj interlude
861-
862- # The original decoder implementation includes a set of parameters that are used
863- # for creating the `encoder_hidden_states` which are what the U-net is conditioned
864- # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
865- # the parameters into the KandinskyTextProjModel class
866- text_proj_model = text_proj_from_original_config ()
867-
868- text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint (text2img_checkpoint )
869-
870- load_checkpoint_to_model (text_proj_checkpoint , text_proj_model , strict = True )
871-
872899 del text2img_checkpoint
873900
874901 load_checkpoint_to_model (unet_diffusers_checkpoint , unet_model , strict = True )
875902
876903 print ("done loading text2img" )
877904
878- return unet_model , text_proj_model
905+ return unet_model
879906
880907
881908def inpaint_text2img (* , args , checkpoint_map_location ):
@@ -891,25 +918,13 @@ def inpaint_text2img(*, args, checkpoint_map_location):
891918 inpaint_unet_model , inpaint_text2img_checkpoint
892919 )
893920
894- # text proj interlude
895-
896- # The original decoder implementation includes a set of parameters that are used
897- # for creating the `encoder_hidden_states` which are what the U-net is conditioned
898- # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
899- # the parameters into the KandinskyTextProjModel class
900- text_proj_model = text_proj_from_original_config ()
901-
902- text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint (inpaint_text2img_checkpoint )
903-
904- load_checkpoint_to_model (text_proj_checkpoint , text_proj_model , strict = True )
905-
906921 del inpaint_text2img_checkpoint
907922
908923 load_checkpoint_to_model (inpaint_unet_diffusers_checkpoint , inpaint_unet_model , strict = True )
909924
910925 print ("done loading inpaint text2img" )
911926
912- return inpaint_unet_model , text_proj_model
927+ return inpaint_unet_model
913928
914929
915930# movq
@@ -1384,15 +1399,11 @@ def load_checkpoint_to_model(checkpoint, model, strict=False):
13841399 prior_model = prior (args = args , checkpoint_map_location = checkpoint_map_location )
13851400 prior_model .save_pretrained (args .dump_path )
13861401 elif args .debug == "text2img" :
1387- unet_model , text_proj_model = text2img (args = args , checkpoint_map_location = checkpoint_map_location )
1402+ unet_model = text2img (args = args , checkpoint_map_location = checkpoint_map_location )
13881403 unet_model .save_pretrained (f"{ args .dump_path } /unet" )
1389- text_proj_model .save_pretrained (f"{ args .dump_path } /text_proj" )
13901404 elif args .debug == "inpaint_text2img" :
1391- inpaint_unet_model , inpaint_text_proj_model = inpaint_text2img (
1392- args = args , checkpoint_map_location = checkpoint_map_location
1393- )
1405+ inpaint_unet_model = inpaint_text2img (args = args , checkpoint_map_location = checkpoint_map_location )
13941406 inpaint_unet_model .save_pretrained (f"{ args .dump_path } /inpaint_unet" )
1395- inpaint_text_proj_model .save_pretrained (f"{ args .dump_path } /inpaint_text_proj" )
13961407 elif args .debug == "decoder" :
13971408 decoder = movq (args = args , checkpoint_map_location = checkpoint_map_location )
13981409 decoder .save_pretrained (f"{ args .dump_path } /decoder" )
0 commit comments