Skip to content

Commit 91dca1a

Browse files
committed
fix VAE tiling for Qwen Image
1 parent 6ea2a75 commit 91dca1a

File tree

1 file changed

+30
-24
lines changed

1 file changed

+30
-24
lines changed

ggml_extend.hpp

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -479,12 +479,15 @@ __STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,
479479
int64_t width = output->ne[0];
480480
int64_t height = output->ne[1];
481481
int64_t channels = output->ne[2];
482+
int64_t ne3 = output->ne[3];
482483
GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32);
483484
for (int iy = 0; iy < height; iy++) {
484485
for (int ix = 0; ix < width; ix++) {
485486
for (int k = 0; k < channels; k++) {
486-
float value = ggml_tensor_get_f32(input, ix + x, iy + y, k);
487-
ggml_tensor_set_f32(output, value, ix, iy, k);
487+
for (int l = 0; l < ne3; l++) {
488+
float value = ggml_tensor_get_f32(input, ix + x, iy + y, k, l);
489+
ggml_tensor_set_f32(output, value, ix, iy, k, l);
490+
}
488491
}
489492
}
490493
}
@@ -507,6 +510,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
507510
int64_t width = input->ne[0];
508511
int64_t height = input->ne[1];
509512
int64_t channels = input->ne[2];
513+
int64_t ne3 = input->ne[3];
510514

511515
int64_t img_width = output->ne[0];
512516
int64_t img_height = output->ne[1];
@@ -515,24 +519,26 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
515519
for (int iy = y_skip; iy < height; iy++) {
516520
for (int ix = x_skip; ix < width; ix++) {
517521
for (int k = 0; k < channels; k++) {
518-
float new_value = ggml_tensor_get_f32(input, ix, iy, k);
519-
if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area
520-
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k);
521-
522-
const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1;
523-
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1;
524-
const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1;
525-
const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1;
526-
527-
const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
528-
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);
529-
530-
ggml_tensor_set_f32(
531-
output,
532-
old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f),
533-
x + ix, y + iy, k);
534-
} else {
535-
ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k);
522+
for (int l = 0; l < ne3; l++) {
523+
float new_value = ggml_tensor_get_f32(input, ix, iy, k, l);
524+
if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area
525+
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k, l);
526+
527+
const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1;
528+
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1;
529+
const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1;
530+
const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1;
531+
532+
const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
533+
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);
534+
535+
ggml_tensor_set_f32(
536+
output,
537+
old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f),
538+
x + ix, y + iy, k, l);
539+
} else {
540+
ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k, l);
541+
}
536542
}
537543
}
538544
}
@@ -848,8 +854,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
848854
}
849855

850856
struct ggml_init_params params = {};
851-
params.mem_size += input_tile_size_x * input_tile_size_y * input->ne[2] * sizeof(float); // input chunk
852-
params.mem_size += output_tile_size_x * output_tile_size_y * output->ne[2] * sizeof(float); // output chunk
857+
params.mem_size += input_tile_size_x * input_tile_size_y * input->ne[2] * input->ne[3] * sizeof(float); // input chunk
858+
params.mem_size += output_tile_size_x * output_tile_size_y * output->ne[2] * output->ne[3] * sizeof(float); // output chunk
853859
params.mem_size += 3 * ggml_tensor_overhead();
854860
params.mem_buffer = NULL;
855861
params.no_alloc = false;
@@ -864,8 +870,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
864870
}
865871

866872
// tiling
867-
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], 1);
868-
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], 1);
873+
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]);
874+
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]);
869875
int num_tiles = num_tiles_x * num_tiles_y;
870876
LOG_INFO("processing %i tiles", num_tiles);
871877
pretty_progress(0, num_tiles, 0.0f);

0 commit comments

Comments
 (0)