Skip to content

Commit 6c1eaa7

Browse files
committed
Fix edge case when tile is bigger than latent
1 parent 33bbf7f commit 6c1eaa7

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

ggml_extend.hpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -812,18 +812,25 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
812812
int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y);
813813
int non_tile_overlap_y = tile_size - tile_overlap_y;
814814

815-
int input_tile_size = tile_size;
816-
int output_tile_size = tile_size;
815+
int tile_size_x = tile_size < small_width ? tile_size : small_width;
816+
int tile_size_y = tile_size < small_height ? tile_size : small_height;
817+
818+
int input_tile_size_x = tile_size_x;
819+
int input_tile_size_y = tile_size_y;
820+
int output_tile_size_x = tile_size_x;
821+
int output_tile_size_y = tile_size_y;
817822

818823
if (big_out) {
819-
output_tile_size *= scale;
824+
output_tile_size_x *= scale;
825+
output_tile_size_y *= scale;
820826
} else {
821-
input_tile_size *= scale;
827+
input_tile_size_x *= scale;
828+
input_tile_size_y *= scale;
822829
}
823830

824831
struct ggml_init_params params = {};
825-
params.mem_size += input_tile_size * input_tile_size * input->ne[2] * sizeof(float); // input chunk
826-
params.mem_size += output_tile_size * output_tile_size * output->ne[2] * sizeof(float); // output chunk
832+
params.mem_size += input_tile_size_x * input_tile_size_y * input->ne[2] * sizeof(float); // input chunk
833+
params.mem_size += output_tile_size_x * output_tile_size_y * output->ne[2] * sizeof(float); // output chunk
827834
params.mem_size += 3 * ggml_tensor_overhead();
828835
params.mem_buffer = NULL;
829836
params.no_alloc = false;
@@ -838,19 +845,19 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
838845
}
839846

840847
// tiling
841-
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size, input->ne[2], 1);
842-
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size, output->ne[2], 1);
843-
int num_tiles = num_tiles_x * num_tiles_y;
848+
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], 1);
849+
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], 1);
850+
int num_tiles = num_tiles_x * num_tiles_y;
844851
LOG_INFO("processing %i tiles", num_tiles);
845852
pretty_progress(0, num_tiles, 0.0f);
846853
int tile_count = 1;
847854
bool last_y = false, last_x = false;
848855
float last_time = 0.0f;
849856
for (int y = 0; y < small_height && !last_y; y += non_tile_overlap_y) {
850857
int dy = 0;
851-
if (y + tile_size >= small_height) {
858+
if (y + tile_size_y >= small_height) {
852859
int _y = y;
853-
y = small_height - tile_size;
860+
y = small_height - tile_size_y;
854861
dy = _y - y;
855862
if (big_out) {
856863
dy *= scale;
@@ -859,9 +866,9 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
859866
}
860867
for (int x = 0; x < small_width && !last_x; x += non_tile_overlap_x) {
861868
int dx = 0;
862-
if (x + tile_size >= small_width) {
869+
if (x + tile_size_x >= small_width) {
863870
int _x = x;
864-
x = small_width - tile_size;
871+
x = small_width - tile_size_x;
865872
dx = _x - x;
866873
if (big_out) {
867874
dx *= scale;

0 commit comments

Comments
 (0)