Skip to content

Commit 31bce95

Browse files
wbrunastduhpf
authored andcommitted
non-square VAE tiling (#3)
* refactor tile number calculation * support non-square tiles * add env var to change tile overlap * add safeguards and better error messages for SD_TILE_OVERLAP * add safeguards and include overlapping factor for SD_TILE_SIZE * avoid rounding issues when specifying SD_TILE_SIZE as a factor * lower SD_TILE_OVERLAP limit * zero-init empty output buffer
1 parent 532aacb commit 31bce95

File tree

2 files changed

+143
-77
lines changed

2 files changed

+143
-77
lines changed

ggml_extend.hpp

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -748,8 +748,38 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_contex
748748

749749
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
750750

751+
__STATIC_INLINE__ void
752+
sd_tiling_calc_tiles(int &num_tiles_dim, float& tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) {
753+
754+
int tile_overlap = (tile_size * tile_overlap_factor);
755+
int non_tile_overlap = tile_size - tile_overlap;
756+
757+
num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
758+
int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim;
759+
760+
if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) {
761+
// if tiles don't fit perfectly using the desired overlap
762+
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
763+
num_tiles_dim++;
764+
}
765+
766+
tile_overlap_factor_dim = (float)(tile_size * num_tiles_dim - small_dim) / (float)(tile_size * (num_tiles_dim - 1));
767+
if (num_tiles_dim <= 2) {
768+
if (small_dim <= tile_size) {
769+
num_tiles_dim = 1;
770+
tile_overlap_factor_dim = 0;
771+
} else {
772+
num_tiles_dim = 2;
773+
tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float)tile_size;
774+
}
775+
}
776+
}
777+
751778
// Tiling
752-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
779+
__STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* output, const int scale,
780+
const int p_tile_size_x, const int p_tile_size_y,
781+
const float tile_overlap_factor, on_tile_process on_processing) {
782+
753783
output = ggml_set_f32(output, 0);
754784

755785
int input_width = (int)input->ne[0];
@@ -770,62 +800,27 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
770800
small_height = input_height;
771801
}
772802

773-
int tile_overlap = (tile_size * tile_overlap_factor);
774-
int non_tile_overlap = tile_size - tile_overlap;
775-
776-
int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
777-
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % small_width;
778-
779-
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
780-
// if tiles don't fit perfectly using the desired overlap
781-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
782-
num_tiles_x++;
783-
}
784-
785-
float tile_overlap_factor_x = (float)(tile_size * num_tiles_x - small_width) / (float)(tile_size * (num_tiles_x - 1));
786-
if (num_tiles_x <= 2) {
787-
if (small_width <= tile_size) {
788-
num_tiles_x = 1;
789-
tile_overlap_factor_x = 0;
790-
} else {
791-
num_tiles_x = 2;
792-
tile_overlap_factor_x = (2 * tile_size - small_width) / (float)tile_size;
793-
}
794-
}
795-
796-
int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
797-
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % small_height;
803+
int num_tiles_x;
804+
float tile_overlap_factor_x;
805+
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);
798806

799-
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
800-
// if tiles don't fit perfectly using the desired overlap
801-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
802-
num_tiles_y++;
803-
}
804-
805-
float tile_overlap_factor_y = (float)(tile_size * num_tiles_y - small_height) / (float)(tile_size * (num_tiles_y - 1));
806-
if (num_tiles_y <= 2) {
807-
if (small_height <= tile_size) {
808-
num_tiles_y = 1;
809-
tile_overlap_factor_y = 0;
810-
} else {
811-
num_tiles_y = 2;
812-
tile_overlap_factor_y = (2 * tile_size - small_height) / (float)tile_size;
813-
}
814-
}
807+
int num_tiles_y;
808+
float tile_overlap_factor_y;
809+
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);
815810

816811
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
817812
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
818813

819814
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
820815

821-
int tile_overlap_x = (int32_t)(tile_size * tile_overlap_factor_x);
822-
int non_tile_overlap_x = tile_size - tile_overlap_x;
816+
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
817+
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
823818

824-
int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y);
825-
int non_tile_overlap_y = tile_size - tile_overlap_y;
819+
int tile_overlap_y = (int32_t)(p_tile_size_y * tile_overlap_factor_y);
820+
int non_tile_overlap_y = p_tile_size_y - tile_overlap_y;
826821

827-
int tile_size_x = tile_size < small_width ? tile_size : small_width;
828-
int tile_size_y = tile_size < small_height ? tile_size : small_height;
822+
int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width;
823+
int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height;
829824

830825
int input_tile_size_x = tile_size_x;
831826
int input_tile_size_y = tile_size_y;
@@ -914,6 +909,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
914909
ggml_free(tiles_ctx);
915910
}
916911

912+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale,
913+
const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
914+
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
915+
}
916+
917917
__STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
918918
struct ggml_tensor* a) {
919919
const float eps = 1e-6f; // default eps parameter

stable-diffusion.cpp

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,41 +1301,113 @@ class StableDiffusionGGML {
13011301
return latent;
13021302
}
13031303

1304-
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
1305-
int64_t t0 = ggml_time_ms();
1306-
ggml_tensor* result = NULL;
1307-
int tile_size = 32;
1308-
// TODO: arg instead of env?
1304+
void get_vae_tile_overlap(float& tile_overlap) {
1305+
const char* SD_TILE_OVERLAP = getenv("SD_TILE_OVERLAP");
1306+
if (SD_TILE_OVERLAP != nullptr) {
1307+
std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
1308+
try {
1309+
tile_overlap = std::stof(sd_tile_overlap_str);
1310+
if (tile_overlap < 0.0) {
1311+
LOG_WARN("SD_TILE_OVERLAP too low, setting it to 0.0");
1312+
tile_overlap = 0.0;
1313+
} else if (tile_overlap > 0.5) {
1314+
LOG_WARN("SD_TILE_OVERLAP too high, setting it to 0.5");
1315+
tile_overlap = 0.5;
1316+
}
1317+
} catch (const std::invalid_argument&) {
1318+
LOG_WARN("SD_TILE_OVERLAP is invalid, keeping the default");
1319+
} catch (const std::out_of_range&) {
1320+
LOG_WARN("SD_TILE_OVERLAP is out of range, keeping the default");
1321+
}
1322+
}
1323+
if (SD_TILE_OVERLAP != nullptr) {
1324+
LOG_INFO("VAE Tile overlap: %.2f", tile_overlap);
1325+
}
1326+
}
1327+
1328+
void get_vae_tile_sizes(int& tile_size_x, int& tile_size_y, float tile_overlap, int latent_x, int latent_y) {
13091329
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
13101330
if (SD_TILE_SIZE != nullptr) {
1331+
// format is AxB, or just A (equivalent to AxA)
1332+
// A and B can be integers (tile size) or floating point
1333+
// floating point <= 1 means simple fraction of the latent dimension
1334+
// floating point > 1 means number of tiles across that dimension
1335+
// a single number gets applied to both
1336+
auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
1337+
float factor = std::stof(factor_str);
1338+
if (factor > 1.0)
1339+
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1340+
return factor;
1341+
};
1342+
const int min_tile_dimension = 4;
13111343
std::string sd_tile_size_str = SD_TILE_SIZE;
1344+
size_t x_pos = sd_tile_size_str.find('x');
13121345
try {
1313-
tile_size = std::stoi(sd_tile_size_str);
1346+
int tmp_x = tile_size_x, tmp_y = tile_size_y;
1347+
if (x_pos != std::string::npos) {
1348+
std::string tile_x_str = sd_tile_size_str.substr(0, x_pos);
1349+
std::string tile_y_str = sd_tile_size_str.substr(x_pos + 1);
1350+
if (tile_x_str.find('.') != std::string::npos) {
1351+
tmp_x = std::round(latent_x * get_tile_factor(tile_x_str));
1352+
} else {
1353+
tmp_x = std::stoi(tile_x_str);
1354+
}
1355+
if (tile_y_str.find('.') != std::string::npos) {
1356+
tmp_y = std::round(latent_y * get_tile_factor(tile_y_str));
1357+
} else {
1358+
tmp_y = std::stoi(tile_y_str);
1359+
}
1360+
} else {
1361+
if (sd_tile_size_str.find('.') != std::string::npos) {
1362+
float tile_factor = get_tile_factor(sd_tile_size_str);
1363+
tmp_x = std::round(latent_x * tile_factor);
1364+
tmp_y = std::round(latent_y * tile_factor);
1365+
} else {
1366+
tmp_x = tmp_y = std::stoi(sd_tile_size_str);
1367+
}
1368+
}
1369+
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
1370+
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
13141371
} catch (const std::invalid_argument&) {
1315-
LOG_WARN("Invalid");
1372+
LOG_WARN("SD_TILE_SIZE is invalid, keeping the default");
13161373
} catch (const std::out_of_range&) {
1317-
LOG_WARN("OOR");
1374+
LOG_WARN("SD_TILE_SIZE is out of range, keeping the default");
13181375
}
13191376
}
1320-
if(!decode){
1321-
// TODO: also use and arg for this one?
1322-
// to keep the compute buffer size consistent
1323-
tile_size*=1.30539;
1377+
if (SD_TILE_SIZE != nullptr) {
1378+
LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
13241379
}
1380+
}
1381+
1382+
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
1383+
int64_t t0 = ggml_time_ms();
1384+
ggml_tensor* result = NULL;
1385+
// TODO: args instead of env for tile size / overlap?
13251386
if (!use_tiny_autoencoder) {
1387+
float tile_overlap = 0.5f;
1388+
int tile_size_x = 32;
1389+
int tile_size_y = 32;
1390+
1391+
get_vae_tile_overlap(tile_overlap);
1392+
get_vae_tile_sizes(tile_size_x, tile_size_y, tile_overlap, x->ne[0] / 8, x->ne[1] / 8);
1393+
1394+
// TODO: also use an arg for this one?
1395+
// multiply tile size for encode to keep the compute buffer size consistent
1396+
tile_size_x *= 1.30539;
1397+
tile_size_y *= 1.30539;
1398+
13261399
process_vae_input_tensor(x);
13271400
if (vae_tiling && !decode_video) {
1328-
// split latent in 32x32 tiles and compute in several steps
13291401
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13301402
first_stage_model->compute(n_threads, in, true, &out, NULL);
13311403
};
1332-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
1404+
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
13331405
} else {
13341406
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
13351407
}
13361408
first_stage_model->free_compute_buffer();
13371409
} else {
1338-
if (vae_tiling && !decode_video) {
1410+
if (vae_tiling && !decode_video) {
13391411
// split latent in 32x32 tiles and compute in several steps
13401412
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13411413
tae_first_stage->compute(n_threads, in, true, &out, NULL);
@@ -1460,29 +1532,23 @@ class StableDiffusionGGML {
14601532
C,
14611533
x->ne[3]);
14621534
}
1463-
int tile_size = 32;
1464-
// TODO: arg instead of env?
1465-
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
1466-
if (SD_TILE_SIZE != nullptr) {
1467-
std::string sd_tile_size_str = SD_TILE_SIZE;
1468-
try {
1469-
tile_size = std::stoi(sd_tile_size_str);
1470-
} catch (const std::invalid_argument&) {
1471-
LOG_WARN("Invalid");
1472-
} catch (const std::out_of_range&) {
1473-
LOG_WARN("OOR");
1474-
}
1475-
}
14761535
int64_t t0 = ggml_time_ms();
14771536
if (!use_tiny_autoencoder) {
1537+
float tile_overlap = 0.5f;
1538+
int tile_size_x = 32;
1539+
int tile_size_y = 32;
1540+
1541+
get_vae_tile_overlap(tile_overlap);
1542+
get_vae_tile_sizes(tile_size_x, tile_size_y, tile_overlap, x->ne[0] / 8, x->ne[1] / 8);
1543+
14781544
process_latent_out(x);
14791545
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
14801546
if (vae_tiling && !decode_video) {
14811547
// split latent in 32x32 tiles and compute in several steps
14821548
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14831549
first_stage_model->compute(n_threads, in, true, &out, NULL);
14841550
};
1485-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
1551+
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
14861552
} else {
14871553
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
14881554
}

0 commit comments

Comments
 (0)