@@ -1084,7 +1084,7 @@ def lora_state_dict(
10841084 # Map SDXL blocks correctly.
10851085 if unet_config is not None :
10861086 # use unet config to remap block numbers
1087- state_dict = cls ._map_sgm_blocks_to_diffusers (state_dict , unet_config )
1087+ state_dict = cls ._maybe_map_sgm_blocks_to_diffusers (state_dict , unet_config )
10881088 state_dict , network_alphas = cls ._convert_kohya_lora_to_diffusers (state_dict )
10891089
10901090 return state_dict , network_alphas
@@ -1121,24 +1121,41 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext
11211121 return weight_name
11221122
11231123 @classmethod
1124- def _map_sgm_blocks_to_diffusers (cls , state_dict , unet_config , delimiter = "_" , block_slice_pos = 5 ):
1125- is_all_unet = all (k .startswith ("lora_unet" ) for k in state_dict )
1124+ def _maybe_map_sgm_blocks_to_diffusers (cls , state_dict , unet_config , delimiter = "_" , block_slice_pos = 5 ):
1125+ # 1. get all state_dict_keys
1126+ all_keys = state_dict .keys ()
1127+ sgm_patterns = ["input_blocks" , "middle_block" , "output_blocks" ]
1128+
1129+ # 2. check if needs remapping, if not return original dict
1130+ is_in_sgm_format = False
1131+ for key in all_keys :
1132+ if any (p in key for p in sgm_patterns ):
1133+ is_in_sgm_format = True
1134+ break
1135+
1136+ if not is_in_sgm_format :
1137+ return state_dict
1138+
1139+ # 3. Else remap from SGM patterns
11261140 new_state_dict = {}
11271141 inner_block_map = ["resnets" , "attentions" , "upsamplers" ]
11281142
11291143 # Retrieves # of down, mid and up blocks
11301144 input_block_ids , middle_block_ids , output_block_ids = set (), set (), set ()
1131- for layer in state_dict :
1132- if "text" not in layer :
1145+
1146+ for layer in all_keys :
1147+ if "text" in layer :
1148+ new_state_dict [layer ] = state_dict .pop (layer )
1149+ else :
11331150 layer_id = int (layer .split (delimiter )[:block_slice_pos ][- 1 ])
1134- if "input_blocks" in layer :
1151+ if sgm_patterns [ 0 ] in layer :
11351152 input_block_ids .add (layer_id )
1136- elif "middle_block" in layer :
1153+ elif sgm_patterns [ 1 ] in layer :
11371154 middle_block_ids .add (layer_id )
1138- elif "output_blocks" in layer :
1155+ elif sgm_patterns [ 2 ] in layer :
11391156 output_block_ids .add (layer_id )
11401157 else :
1141- raise ValueError ("Checkpoint not supported" )
1158+ raise ValueError (f "Checkpoint not supported because layer { layer } not supported. " )
11421159
11431160 input_blocks = {
11441161 layer_id : [key for key in state_dict if f"input_blocks{ delimiter } { layer_id } " in key ]
@@ -1201,12 +1218,8 @@ def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", bl
12011218 )
12021219 new_state_dict [new_key ] = state_dict .pop (key )
12031220
1204- if is_all_unet and len (state_dict ) > 0 :
1221+ if len (state_dict ) > 0 :
12051222 raise ValueError ("At this point all state dict entries have to be converted." )
1206- else :
1207- # Remaining is the text encoder state dict.
1208- for k , v in state_dict .items ():
1209- new_state_dict .update ({k : v })
12101223
12111224 return new_state_dict
12121225
0 commit comments