Skip to content

Commit 2e9242e

Browse files
authored
feat: add Qwen Image Edit support (#877)
* add ref latent support for qwen image * optimize clip_preprocess and fix get_first_stage_encoding * add qwen2vl vit support * add qwen image edit support * fix qwen image edit pipeline * add mmproj file support * support dynamic number of Qwen image transformer blocks * set prompt_template_encode_start_idx every time * to_add_out precision fix * to_out.0 precision fix * update docs
1 parent c64994d commit 2e9242e

18 files changed

+1339
-365
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ API and command-line option may change frequently.***
2424
- [Qwen Image](./docs/qwen_image.md)
2525
- Image Edit Models
2626
- [FLUX.1-Kontext-dev](./docs/kontext.md)
27+
- [Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
2728
- Video Models
2829
- [Wan2.1/Wan2.2](./docs/wan.md)
2930
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support.
@@ -298,6 +299,7 @@ arguments:
298299
--clip_vision path to the clip-vision encoder
299300
--t5xxl path to the t5xxl text encoder
300301
--qwen2vl path to the qwen2vl text encoder
302+
--qwen2vl_vision path to the qwen2vl vit
301303
--vae [VAE] path to vae
302304
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
303305
--control-net [CONTROL_PATH] path to control net model

assets/qwen/qwen_image_edit.png

457 KB
Loading
415 KB
Loading

conditioner.hpp

Lines changed: 159 additions & 86 deletions
Large diffs are not rendered by default.

diffusion_model.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ struct QwenImageModel : public DiffusionModel {
313313
diffusion_params.x,
314314
diffusion_params.timesteps,
315315
diffusion_params.context,
316+
diffusion_params.ref_latents,
317+
true, // increase_ref_index
316318
output,
317319
output_ctx);
318320
}

docs/qwen_image_edit.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# How to Use
2+
3+
## Download weights
4+
5+
- Download Qwen Image
6+
- Qwen Image Edit
7+
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/tree/main/split_files/diffusion_models
8+
- gguf: https://huggingface.co/QuantStack/Qwen-Image-Edit-GGUF/tree/main
9+
- Qwen Image Edit 2509
10+
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/tree/main/split_files/diffusion_models
11+
- gguf: https://huggingface.co/QuantStack/Qwen-Image-Edit-2509-GGUF/tree/main
12+
- Download vae
13+
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/vae
14+
- Download qwen_2.5_vl 7b
15+
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/text_encoders
16+
- gguf: https://huggingface.co/mradermacher/Qwen2.5-VL-7B-Instruct-GGUF/tree/main
17+
18+
## Examples
19+
20+
### Qwen Image Edit
21+
22+
```
23+
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen_Image_Edit-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --seed 1118877715456453
24+
```
25+
26+
<img alt="qwen_image_edit" src="../assets/qwen/qwen_image_edit.png" />
27+
28+
29+
### Qwen Image Edit 2509
30+
31+
```
32+
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen-Image-Edit-2509-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf --qwen2vl_vision ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct.mmproj-Q8_0.gguf --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'Qwen Image Edit 2509'"
33+
```
34+
35+
<img alt="qwen_image_edit_2509" src="../assets/qwen/qwen_image_edit_2509.png" />

examples/cli/main.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ struct SDParams {
6262
std::string clip_vision_path;
6363
std::string t5xxl_path;
6464
std::string qwen2vl_path;
65+
std::string qwen2vl_vision_path;
6566
std::string diffusion_model_path;
6667
std::string high_noise_diffusion_model_path;
6768
std::string vae_path;
@@ -148,6 +149,7 @@ void print_params(SDParams params) {
148149
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
149150
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
150151
printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str());
152+
printf(" qwen2vl_vision_path: %s\n", params.qwen2vl_vision_path.c_str());
151153
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
152154
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
153155
printf(" vae_path: %s\n", params.vae_path.c_str());
@@ -220,6 +222,7 @@ void print_usage(int argc, const char* argv[]) {
220222
printf(" --clip_vision path to the clip-vision encoder\n");
221223
printf(" --t5xxl path to the t5xxl text encoder\n");
222224
printf(" --qwen2vl path to the qwen2vl text encoder\n");
225+
printf(" --qwen2vl_vision path to the qwen2vl vit\n");
223226
printf(" --vae [VAE] path to vae\n");
224227
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
225228
printf(" --control-net [CONTROL_PATH] path to control net model\n");
@@ -490,6 +493,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
490493
{"", "--clip_vision", "", &params.clip_vision_path},
491494
{"", "--t5xxl", "", &params.t5xxl_path},
492495
{"", "--qwen2vl", "", &params.qwen2vl_path},
496+
{"", "--qwen2vl_vision", "", &params.qwen2vl_vision_path},
493497
{"", "--diffusion-model", "", &params.diffusion_model_path},
494498
{"", "--high-noise-diffusion-model", "", &params.high_noise_diffusion_model_path},
495499
{"", "--vae", "", &params.vae_path},
@@ -952,7 +956,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
952956
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
953957
}
954958
parameter_string += ", ";
955-
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path}) {
959+
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) {
956960
if (!te.empty()) {
957961
parameter_string += "TE: " + sd_basename(te) + ", ";
958962
}
@@ -1336,6 +1340,7 @@ int main(int argc, const char* argv[]) {
13361340
params.clip_vision_path.c_str(),
13371341
params.t5xxl_path.c_str(),
13381342
params.qwen2vl_path.c_str(),
1343+
params.qwen2vl_vision_path.c_str(),
13391344
params.diffusion_model_path.c_str(),
13401345
params.high_noise_diffusion_model_path.c_str(),
13411346
params.vae_path.c_str(),

flux.hpp

Lines changed: 6 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -81,57 +81,6 @@ namespace Flux {
8181
}
8282
};
8383

84-
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,
85-
struct ggml_tensor* x,
86-
struct ggml_tensor* pe) {
87-
// x: [N, L, n_head, d_head]
88-
// pe: [L, d_head/2, 2, 2]
89-
int64_t d_head = x->ne[0];
90-
int64_t n_head = x->ne[1];
91-
int64_t L = x->ne[2];
92-
int64_t N = x->ne[3];
93-
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head]
94-
x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2]
95-
x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2]
96-
97-
int64_t offset = x->nb[2] * x->ne[2];
98-
auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2]
99-
auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2]
100-
x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1]
101-
x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1]
102-
auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]);
103-
x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2]
104-
x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2]
105-
106-
pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2]
107-
offset = pe->nb[2] * pe->ne[2];
108-
auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2]
109-
auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2]
110-
111-
auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2]
112-
x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head]
113-
return x_out;
114-
}
115-
116-
__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx,
117-
ggml_backend_t backend,
118-
struct ggml_tensor* q,
119-
struct ggml_tensor* k,
120-
struct ggml_tensor* v,
121-
struct ggml_tensor* pe,
122-
struct ggml_tensor* mask,
123-
bool flash_attn,
124-
float kv_scale = 1.0f) {
125-
// q,k,v: [N, L, n_head, d_head]
126-
// pe: [L, d_head/2, 2, 2]
127-
// return: [N, L, n_head*d_head]
128-
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
129-
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
130-
131-
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
132-
return x;
133-
}
134-
13584
struct SelfAttention : public GGMLBlock {
13685
public:
13786
int64_t num_heads;
@@ -179,9 +128,9 @@ namespace Flux {
179128
// x: [N, n_token, dim]
180129
// pe: [n_token, d_head/2, 2, 2]
181130
// return [N, n_token, dim]
182-
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
183-
x = attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
184-
x = post_attention(ctx, x); // [N, n_token, dim]
131+
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
132+
x = Rope::attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
133+
x = post_attention(ctx, x); // [N, n_token, dim]
185134
return x;
186135
}
187136
};
@@ -369,8 +318,8 @@ namespace Flux {
369318
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
370319
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
371320

372-
auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
373-
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
321+
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
322+
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
374323
auto txt_attn_out = ggml_view_3d(ctx,
375324
attn,
376325
attn->ne[0],
@@ -504,7 +453,7 @@ namespace Flux {
504453
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
505454
q = norm->query_norm(ctx, q);
506455
k = norm->key_norm(ctx, k);
507-
auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]
456+
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]
508457

509458
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
510459
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]

ggml_extend.hpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,11 @@ __STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int i
197197
return value;
198198
}
199199

200-
__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic) {
200+
__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic, bool scale = true) {
201201
float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic);
202+
if (scale) {
203+
value /= 255.f;
204+
}
202205
return value;
203206
}
204207

@@ -458,24 +461,18 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
458461
}
459462
}
460463

461-
__STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data,
462-
struct ggml_tensor* output,
464+
__STATIC_INLINE__ void sd_image_f32_to_tensor(sd_image_f32_t image,
465+
ggml_tensor* tensor,
463466
bool scale = true) {
464-
int64_t width = output->ne[0];
465-
int64_t height = output->ne[1];
466-
int64_t channels = output->ne[2];
467-
GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32);
468-
for (int iy = 0; iy < height; iy++) {
469-
for (int ix = 0; ix < width; ix++) {
470-
for (int k = 0; k < channels; k++) {
471-
int value = *(image_data + iy * width * channels + ix * channels + k);
472-
if (scale) {
473-
value /= 255.f;
474-
}
475-
ggml_tensor_set_f32(output, value, ix, iy, k);
476-
}
477-
}
478-
}
467+
GGML_ASSERT(image.width == tensor->ne[0]);
468+
GGML_ASSERT(image.height == tensor->ne[1]);
469+
GGML_ASSERT(image.channel == tensor->ne[2]);
470+
GGML_ASSERT(1 == tensor->ne[3]);
471+
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
472+
ggml_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
473+
float value = sd_image_get_f32(image, i0, i1, i2, scale);
474+
ggml_tensor_set_f32(tensor, value, i0, i1, i2, i3);
475+
});
479476
}
480477

481478
__STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,

model.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ const char* unused_tensors[] = {
113113
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
114114
"text_encoders.qwen2vl.output.weight",
115115
"text_encoders.qwen2vl.lm_head.",
116-
"text_encoders.qwen2vl.visual.",
117116
};
118117

119118
bool is_unused_tensor(std::string name) {
@@ -212,6 +211,24 @@ std::unordered_map<std::string, std::string> qwenvl_name_map{
212211
{"output_norm.", "model.norm."},
213212
};
214213

214+
std::unordered_map<std::string, std::string> qwenvl_vision_name_map{
215+
{"mm.", "merger.mlp."},
216+
{"v.post_ln.", "merger.ln_q."},
217+
{"v.patch_embd.weight", "patch_embed.proj.0.weight"},
218+
{"patch_embed.proj.0.weight.1", "patch_embed.proj.1.weight"},
219+
{"v.patch_embd.weight.1", "patch_embed.proj.1.weight"},
220+
{"v.blk.", "blocks."},
221+
{"attn_q.", "attn.q_proj."},
222+
{"attn_k.", "attn.k_proj."},
223+
{"attn_v.", "attn.v_proj."},
224+
{"attn_out.", "attn.proj."},
225+
{"ffn_down.", "mlp.down_proj."},
226+
{"ffn_gate.", "mlp.gate_proj."},
227+
{"ffn_up.", "mlp.up_proj."},
228+
{"ln1.", "norm1."},
229+
{"ln2.", "norm2."},
230+
};
231+
215232
std::string convert_cond_model_name(const std::string& name) {
216233
std::string new_name = name;
217234
std::string prefix;
@@ -270,10 +287,19 @@ std::string convert_cond_model_name(const std::string& name) {
270287
new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias.");
271288
}
272289
} else if (contains(name, "qwen2vl")) {
273-
for (auto kv : qwenvl_name_map) {
274-
size_t pos = new_name.find(kv.first);
275-
if (pos != std::string::npos) {
276-
new_name.replace(pos, kv.first.size(), kv.second);
290+
if (contains(name, "qwen2vl.visual")) {
291+
for (auto kv : qwenvl_vision_name_map) {
292+
size_t pos = new_name.find(kv.first);
293+
if (pos != std::string::npos) {
294+
new_name.replace(pos, kv.first.size(), kv.second);
295+
}
296+
}
297+
} else {
298+
for (auto kv : qwenvl_name_map) {
299+
size_t pos = new_name.find(kv.first);
300+
if (pos != std::string::npos) {
301+
new_name.replace(pos, kv.first.size(), kv.second);
302+
}
277303
}
278304
}
279305
} else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") {

0 commit comments

Comments
 (0)