Skip to content

Commit bb93b0d

Browse files
authored
Merge pull request #3 from huggingface/pipline-model
[testing] pipeline + model
2 parents 0bdf8cc + dd35f1b commit bb93b0d

File tree

9 files changed

+940
-24
lines changed

9 files changed

+940
-24
lines changed
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import torch
5+
from accelerate import init_empty_weights
6+
7+
from diffusers import FluxTransformer2DModel
8+
from diffusers.utils.import_utils import is_accelerate_available
9+
import safetensors.torch
10+
from huggingface_hub import hf_hub_download
11+
12+
"""
13+
python scripts/convert_flux_to_diffusers.py \
14+
--original_state_dict_repo_id "diffusers-internal-dev/dummy-model-2" \
15+
--output_path "flux"
16+
"""
17+
18+
CTX = init_empty_weights if is_accelerate_available else nullcontext
19+
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
22+
parser.add_argument("--checkpoint_path", default=None, type=str)
23+
parser.add_argument("--output_path", type=str)
24+
parser.add_argument("--dtype", type=str, default="bf16")
25+
26+
args = parser.parse_args()
27+
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
28+
29+
def load_original_checkpoint(args):
30+
if args.original_state_dict_repo_id is not None:
31+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="flux.safetensors")
32+
elif args.checkpoint_path is not None:
33+
ckpt_path = args.checkpoint_path
34+
else:
35+
raise ValueError(f" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
36+
37+
original_state_dict = safetensors.torch.load_file(ckpt_path)
38+
return original_state_dict
39+
40+
41+
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
42+
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
43+
def swap_scale_shift(weight):
44+
shift, scale = weight.chunk(2, dim=0)
45+
new_weight = torch.cat([scale, shift], dim=0)
46+
return new_weight
47+
48+
49+
def convert_flux_transformer_checkpoint_to_diffusers(
50+
original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0
51+
):
52+
converted_state_dict = {}
53+
54+
## time_text_embed.timestep_embedder <- time_in
55+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
56+
"time_in.in_layer.weight"
57+
)
58+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
59+
"time_in.in_layer.bias"
60+
)
61+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
62+
"time_in.out_layer.weight"
63+
)
64+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
65+
"time_in.out_layer.bias"
66+
)
67+
68+
## time_text_embed.text_embedder <- vector_in
69+
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
70+
"vector_in.in_layer.weight"
71+
)
72+
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
73+
"vector_in.in_layer.bias"
74+
)
75+
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
76+
"vector_in.out_layer.weight"
77+
)
78+
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
79+
"vector_in.out_layer.bias"
80+
)
81+
82+
# context_embedder
83+
converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight")
84+
converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias")
85+
86+
# x_embedder
87+
converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight")
88+
converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias")
89+
90+
# double transformer blocks
91+
for i in range(num_layers):
92+
block_prefix = f"transformer_blocks.{i}."
93+
# norms.
94+
## norm1
95+
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
96+
f"double_blocks.{i}.img_mod.lin.weight"
97+
)
98+
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
99+
f"double_blocks.{i}.img_mod.lin.bias"
100+
)
101+
## norm1_context
102+
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
103+
f"double_blocks.{i}.txt_mod.lin.weight"
104+
)
105+
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
106+
f"double_blocks.{i}.txt_mod.lin.bias"
107+
)
108+
# Q, K, V
109+
sample_q, sample_k, sample_v = torch.chunk(
110+
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
111+
)
112+
context_q, context_k, context_v = torch.chunk(
113+
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
114+
)
115+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
116+
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
117+
)
118+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
119+
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
120+
)
121+
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
122+
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
123+
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
124+
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
125+
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
126+
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
127+
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
128+
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
129+
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
130+
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
131+
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
132+
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
133+
# qk_norm
134+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
135+
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
136+
)
137+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
138+
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
139+
)
140+
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
141+
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
142+
)
143+
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
144+
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
145+
)
146+
# ff img_mlp
147+
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
148+
f"double_blocks.{i}.img_mlp.0.weight"
149+
)
150+
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
151+
f"double_blocks.{i}.img_mlp.0.bias"
152+
)
153+
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
154+
f"double_blocks.{i}.img_mlp.2.weight"
155+
)
156+
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
157+
f"double_blocks.{i}.img_mlp.2.bias"
158+
)
159+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
160+
f"double_blocks.{i}.txt_mlp.0.weight"
161+
)
162+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
163+
f"double_blocks.{i}.txt_mlp.0.bias"
164+
)
165+
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
166+
f"double_blocks.{i}.txt_mlp.2.weight"
167+
)
168+
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
169+
f"double_blocks.{i}.txt_mlp.2.bias"
170+
)
171+
# output projections.
172+
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
173+
f"double_blocks.{i}.img_attn.proj.weight"
174+
)
175+
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
176+
f"double_blocks.{i}.img_attn.proj.bias"
177+
)
178+
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
179+
f"double_blocks.{i}.txt_attn.proj.weight"
180+
)
181+
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
182+
f"double_blocks.{i}.txt_attn.proj.bias"
183+
)
184+
185+
# single transfomer blocks
186+
for i in range(num_single_layers):
187+
block_prefix = f"single_transformer_blocks.{i}."
188+
# norm.linear <- single_blocks.0.modulation.lin
189+
converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
190+
f"single_blocks.{i}.modulation.lin.weight"
191+
)
192+
converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
193+
f"single_blocks.{i}.modulation.lin.bias"
194+
)
195+
# Q, K, V, mlp
196+
mlp_hidden_dim = int(inner_dim * mlp_ratio)
197+
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
198+
q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
199+
q_bias, k_bias, v_bias, mlp_bias = torch.split(
200+
original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
201+
)
202+
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
203+
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
204+
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
205+
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
206+
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
207+
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
208+
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
209+
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
210+
# qk norm
211+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
212+
f"single_blocks.{i}.norm.query_norm.scale"
213+
)
214+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
215+
f"single_blocks.{i}.norm.key_norm.scale"
216+
)
217+
# output projections.
218+
converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
219+
f"single_blocks.{i}.linear2.weight"
220+
)
221+
converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
222+
f"single_blocks.{i}.linear2.bias"
223+
)
224+
225+
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
226+
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
227+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
228+
original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
229+
)
230+
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
231+
original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
232+
)
233+
234+
return converted_state_dict
235+
236+
237+
def main(args):
238+
original_ckpt = load_original_checkpoint(args)
239+
num_layers = 19
240+
num_single_layers = 38
241+
inner_dim = 3072
242+
mlp_ratio = 4.0
243+
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
244+
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
245+
)
246+
transformer = FluxTransformer2DModel()
247+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
248+
249+
print("Saving Flux Transformer in Diffusers format.")
250+
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
251+
252+
253+
if __name__ == "__main__":
254+
main(args)

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
"ControlNetModel",
8585
"ControlNetXSAdapter",
8686
"DiTTransformer2DModel",
87+
"FluxTransformer2DModel",
8788
"HunyuanDiT2DControlNetModel",
8889
"HunyuanDiT2DModel",
8990
"HunyuanDiT2DMultiControlNetModel",
@@ -521,6 +522,7 @@
521522
ControlNetModel,
522523
ControlNetXSAdapter,
523524
DiTTransformer2DModel,
525+
FluxTransformer2DModel,
524526
HunyuanDiT2DControlNetModel,
525527
HunyuanDiT2DModel,
526528
HunyuanDiT2DMultiControlNetModel,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
5050
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
5151
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
52+
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
5253
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
5354
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
5455
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -90,6 +91,7 @@
9091
AuraFlowTransformer2DModel,
9192
DiTTransformer2DModel,
9293
DualTransformer2DModel,
94+
FluxTransformer2DModel,
9395
HunyuanDiT2DModel,
9496
LatteTransformer3DModel,
9597
LuminaNextDiT2DModel,

0 commit comments

Comments
 (0)