@@ -2137,9 +2137,18 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
21372137 converted_state_dict = {}
21382138 keys = list (checkpoint .keys ())
21392139
2140+ variant = "chroma" if "distilled_guidance_layer.in_proj.weight" in checkpoint else "flux"
2141+
21402142 for k in keys :
21412143 if "model.diffusion_model." in k :
21422144 checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
2145+ if variant == "chroma" and "distilled_guidance_layer." in k :
2146+ new_key = k
2147+ if k .startswith ("distilled_guidance_layer.norms" ):
2148+ new_key = k .replace (".scale" , ".weight" )
2149+ elif k .startswith ("distilled_guidance_layer.layer" ):
2150+ new_key = k .replace ("in_layer" , "linear_1" ).replace ("out_layer" , "linear_2" )
2151+ converted_state_dict [new_key ] = checkpoint .pop (k )
21432152
21442153 num_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in checkpoint if "double_blocks." in k ))[- 1 ] + 1 # noqa: C401
21452154 num_single_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in checkpoint if "single_blocks." in k ))[- 1 ] + 1 # noqa: C401
@@ -2153,40 +2162,49 @@ def swap_scale_shift(weight):
21532162 new_weight = torch .cat ([scale , shift ], dim = 0 )
21542163 return new_weight
21552164
2156- ## time_text_embed.timestep_embedder <- time_in
2157- converted_state_dict ["time_text_embed.timestep_embedder.linear_1.weight" ] = checkpoint .pop (
2158- "time_in.in_layer.weight"
2159- )
2160- converted_state_dict ["time_text_embed.timestep_embedder.linear_1.bias" ] = checkpoint .pop ("time_in.in_layer.bias" )
2161- converted_state_dict ["time_text_embed.timestep_embedder.linear_2.weight" ] = checkpoint .pop (
2162- "time_in.out_layer.weight"
2163- )
2164- converted_state_dict ["time_text_embed.timestep_embedder.linear_2.bias" ] = checkpoint .pop ("time_in.out_layer.bias" )
2165-
2166- ## time_text_embed.text_embedder <- vector_in
2167- converted_state_dict ["time_text_embed.text_embedder.linear_1.weight" ] = checkpoint .pop ("vector_in.in_layer.weight" )
2168- converted_state_dict ["time_text_embed.text_embedder.linear_1.bias" ] = checkpoint .pop ("vector_in.in_layer.bias" )
2169- converted_state_dict ["time_text_embed.text_embedder.linear_2.weight" ] = checkpoint .pop (
2170- "vector_in.out_layer.weight"
2171- )
2172- converted_state_dict ["time_text_embed.text_embedder.linear_2.bias" ] = checkpoint .pop ("vector_in.out_layer.bias" )
2173-
2174- # guidance
2175- has_guidance = any ("guidance" in k for k in checkpoint )
2176- if has_guidance :
2177- converted_state_dict ["time_text_embed.guidance_embedder.linear_1.weight" ] = checkpoint .pop (
2178- "guidance_in.in_layer.weight"
2165+ if variant == "flux" :
2166+ ## time_text_embed.timestep_embedder <- time_in
2167+ converted_state_dict ["time_text_embed.timestep_embedder.linear_1.weight" ] = checkpoint .pop (
2168+ "time_in.in_layer.weight"
21792169 )
2180- converted_state_dict ["time_text_embed.guidance_embedder .linear_1.bias" ] = checkpoint .pop (
2181- "guidance_in .in_layer.bias"
2170+ converted_state_dict ["time_text_embed.timestep_embedder .linear_1.bias" ] = checkpoint .pop (
2171+ "time_in .in_layer.bias"
21822172 )
2183- converted_state_dict ["time_text_embed.guidance_embedder .linear_2.weight" ] = checkpoint .pop (
2184- "guidance_in .out_layer.weight"
2173+ converted_state_dict ["time_text_embed.timestep_embedder .linear_2.weight" ] = checkpoint .pop (
2174+ "time_in .out_layer.weight"
21852175 )
2186- converted_state_dict ["time_text_embed.guidance_embedder .linear_2.bias" ] = checkpoint .pop (
2187- "guidance_in .out_layer.bias"
2176+ converted_state_dict ["time_text_embed.timestep_embedder .linear_2.bias" ] = checkpoint .pop (
2177+ "time_in .out_layer.bias"
21882178 )
21892179
2180+ ## time_text_embed.text_embedder <- vector_in
2181+ converted_state_dict ["time_text_embed.text_embedder.linear_1.weight" ] = checkpoint .pop (
2182+ "vector_in.in_layer.weight"
2183+ )
2184+ converted_state_dict ["time_text_embed.text_embedder.linear_1.bias" ] = checkpoint .pop ("vector_in.in_layer.bias" )
2185+ converted_state_dict ["time_text_embed.text_embedder.linear_2.weight" ] = checkpoint .pop (
2186+ "vector_in.out_layer.weight"
2187+ )
2188+ converted_state_dict ["time_text_embed.text_embedder.linear_2.bias" ] = checkpoint .pop (
2189+ "vector_in.out_layer.bias"
2190+ )
2191+
2192+ # guidance
2193+ has_guidance = any ("guidance" in k for k in checkpoint )
2194+ if has_guidance :
2195+ converted_state_dict ["time_text_embed.guidance_embedder.linear_1.weight" ] = checkpoint .pop (
2196+ "guidance_in.in_layer.weight"
2197+ )
2198+ converted_state_dict ["time_text_embed.guidance_embedder.linear_1.bias" ] = checkpoint .pop (
2199+ "guidance_in.in_layer.bias"
2200+ )
2201+ converted_state_dict ["time_text_embed.guidance_embedder.linear_2.weight" ] = checkpoint .pop (
2202+ "guidance_in.out_layer.weight"
2203+ )
2204+ converted_state_dict ["time_text_embed.guidance_embedder.linear_2.bias" ] = checkpoint .pop (
2205+ "guidance_in.out_layer.bias"
2206+ )
2207+
21902208 # context_embedder
21912209 converted_state_dict ["context_embedder.weight" ] = checkpoint .pop ("txt_in.weight" )
21922210 converted_state_dict ["context_embedder.bias" ] = checkpoint .pop ("txt_in.bias" )
@@ -2199,20 +2217,21 @@ def swap_scale_shift(weight):
21992217 for i in range (num_layers ):
22002218 block_prefix = f"transformer_blocks.{ i } ."
22012219 # norms.
2202- ## norm1
2203- converted_state_dict [f"{ block_prefix } norm1.linear.weight" ] = checkpoint .pop (
2204- f"double_blocks.{ i } .img_mod.lin.weight"
2205- )
2206- converted_state_dict [f"{ block_prefix } norm1.linear.bias" ] = checkpoint .pop (
2207- f"double_blocks.{ i } .img_mod.lin.bias"
2208- )
2209- ## norm1_context
2210- converted_state_dict [f"{ block_prefix } norm1_context.linear.weight" ] = checkpoint .pop (
2211- f"double_blocks.{ i } .txt_mod.lin.weight"
2212- )
2213- converted_state_dict [f"{ block_prefix } norm1_context.linear.bias" ] = checkpoint .pop (
2214- f"double_blocks.{ i } .txt_mod.lin.bias"
2215- )
2220+ if variant == "flux" :
2221+ ## norm1
2222+ converted_state_dict [f"{ block_prefix } norm1.linear.weight" ] = checkpoint .pop (
2223+ f"double_blocks.{ i } .img_mod.lin.weight"
2224+ )
2225+ converted_state_dict [f"{ block_prefix } norm1.linear.bias" ] = checkpoint .pop (
2226+ f"double_blocks.{ i } .img_mod.lin.bias"
2227+ )
2228+ ## norm1_context
2229+ converted_state_dict [f"{ block_prefix } norm1_context.linear.weight" ] = checkpoint .pop (
2230+ f"double_blocks.{ i } .txt_mod.lin.weight"
2231+ )
2232+ converted_state_dict [f"{ block_prefix } norm1_context.linear.bias" ] = checkpoint .pop (
2233+ f"double_blocks.{ i } .txt_mod.lin.bias"
2234+ )
22162235 # Q, K, V
22172236 sample_q , sample_k , sample_v = torch .chunk (checkpoint .pop (f"double_blocks.{ i } .img_attn.qkv.weight" ), 3 , dim = 0 )
22182237 context_q , context_k , context_v = torch .chunk (
@@ -2285,13 +2304,15 @@ def swap_scale_shift(weight):
22852304 # single transformer blocks
22862305 for i in range (num_single_layers ):
22872306 block_prefix = f"single_transformer_blocks.{ i } ."
2288- # norm.linear <- single_blocks.0.modulation.lin
2289- converted_state_dict [f"{ block_prefix } norm.linear.weight" ] = checkpoint .pop (
2290- f"single_blocks.{ i } .modulation.lin.weight"
2291- )
2292- converted_state_dict [f"{ block_prefix } norm.linear.bias" ] = checkpoint .pop (
2293- f"single_blocks.{ i } .modulation.lin.bias"
2294- )
2307+
2308+ if variant == "flux" :
2309+ # norm.linear <- single_blocks.0.modulation.lin
2310+ converted_state_dict [f"{ block_prefix } norm.linear.weight" ] = checkpoint .pop (
2311+ f"single_blocks.{ i } .modulation.lin.weight"
2312+ )
2313+ converted_state_dict [f"{ block_prefix } norm.linear.bias" ] = checkpoint .pop (
2314+ f"single_blocks.{ i } .modulation.lin.bias"
2315+ )
22952316 # Q, K, V, mlp
22962317 mlp_hidden_dim = int (inner_dim * mlp_ratio )
22972318 split_size = (inner_dim , inner_dim , inner_dim , mlp_hidden_dim )
@@ -2320,12 +2341,14 @@ def swap_scale_shift(weight):
23202341
23212342 converted_state_dict ["proj_out.weight" ] = checkpoint .pop ("final_layer.linear.weight" )
23222343 converted_state_dict ["proj_out.bias" ] = checkpoint .pop ("final_layer.linear.bias" )
2323- converted_state_dict ["norm_out.linear.weight" ] = swap_scale_shift (
2324- checkpoint .pop ("final_layer.adaLN_modulation.1.weight" )
2325- )
2326- converted_state_dict ["norm_out.linear.bias" ] = swap_scale_shift (
2327- checkpoint .pop ("final_layer.adaLN_modulation.1.bias" )
2328- )
2344+
2345+ if variant == "flux" :
2346+ converted_state_dict ["norm_out.linear.weight" ] = swap_scale_shift (
2347+ checkpoint .pop ("final_layer.adaLN_modulation.1.weight" )
2348+ )
2349+ converted_state_dict ["norm_out.linear.bias" ] = swap_scale_shift (
2350+ checkpoint .pop ("final_layer.adaLN_modulation.1.bias" )
2351+ )
23292352
23302353 return converted_state_dict
23312354
0 commit comments