Skip to content

Commit e78d630

Browse files
wbrunastduhpf
authored andcommitted
non-square VAE tiling (#3)
* refactor tile number calculation * support non-square tiles * add env var to change tile overlap * add safeguards and better error messages for SD_TILE_OVERLAP * add safeguards and include overlapping factor for SD_TILE_SIZE * avoid rounding issues when specifying SD_TILE_SIZE as a factor * lower SD_TILE_OVERLAP limit * zero-init empty output buffer
1 parent 6c1eaa7 commit e78d630

File tree

2 files changed

+143
-77
lines changed

2 files changed

+143
-77
lines changed

ggml_extend.hpp

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,38 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_contex
736736

737737
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
738738

739+
__STATIC_INLINE__ void
740+
sd_tiling_calc_tiles(int &num_tiles_dim, float& tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) {
741+
742+
int tile_overlap = (tile_size * tile_overlap_factor);
743+
int non_tile_overlap = tile_size - tile_overlap;
744+
745+
num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
746+
int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim;
747+
748+
if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) {
749+
// if tiles don't fit perfectly using the desired overlap
750+
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
751+
num_tiles_dim++;
752+
}
753+
754+
tile_overlap_factor_dim = (float)(tile_size * num_tiles_dim - small_dim) / (float)(tile_size * (num_tiles_dim - 1));
755+
if (num_tiles_dim <= 2) {
756+
if (small_dim <= tile_size) {
757+
num_tiles_dim = 1;
758+
tile_overlap_factor_dim = 0;
759+
} else {
760+
num_tiles_dim = 2;
761+
tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float)tile_size;
762+
}
763+
}
764+
}
765+
739766
// Tiling
740-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
767+
__STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* output, const int scale,
768+
const int p_tile_size_x, const int p_tile_size_y,
769+
const float tile_overlap_factor, on_tile_process on_processing) {
770+
741771
output = ggml_set_f32(output, 0);
742772

743773
int input_width = (int)input->ne[0];
@@ -758,62 +788,27 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
758788
small_height = input_height;
759789
}
760790

761-
int tile_overlap = (tile_size * tile_overlap_factor);
762-
int non_tile_overlap = tile_size - tile_overlap;
763-
764-
int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
765-
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % small_width;
766-
767-
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
768-
// if tiles don't fit perfectly using the desired overlap
769-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
770-
num_tiles_x++;
771-
}
772-
773-
float tile_overlap_factor_x = (float)(tile_size * num_tiles_x - small_width) / (float)(tile_size * (num_tiles_x - 1));
774-
if (num_tiles_x <= 2) {
775-
if (small_width <= tile_size) {
776-
num_tiles_x = 1;
777-
tile_overlap_factor_x = 0;
778-
} else {
779-
num_tiles_x = 2;
780-
tile_overlap_factor_x = (2 * tile_size - small_width) / (float)tile_size;
781-
}
782-
}
783-
784-
int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
785-
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % small_height;
791+
int num_tiles_x;
792+
float tile_overlap_factor_x;
793+
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);
786794

787-
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
788-
// if tiles don't fit perfectly using the desired overlap
789-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
790-
num_tiles_y++;
791-
}
792-
793-
float tile_overlap_factor_y = (float)(tile_size * num_tiles_y - small_height) / (float)(tile_size * (num_tiles_y - 1));
794-
if (num_tiles_y <= 2) {
795-
if (small_height <= tile_size) {
796-
num_tiles_y = 1;
797-
tile_overlap_factor_y = 0;
798-
} else {
799-
num_tiles_y = 2;
800-
tile_overlap_factor_y = (2 * tile_size - small_height) / (float)tile_size;
801-
}
802-
}
795+
int num_tiles_y;
796+
float tile_overlap_factor_y;
797+
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);
803798

804799
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
805800
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
806801

807802
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
808803

809-
int tile_overlap_x = (int32_t)(tile_size * tile_overlap_factor_x);
810-
int non_tile_overlap_x = tile_size - tile_overlap_x;
804+
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
805+
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
811806

812-
int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y);
813-
int non_tile_overlap_y = tile_size - tile_overlap_y;
807+
int tile_overlap_y = (int32_t)(p_tile_size_y * tile_overlap_factor_y);
808+
int non_tile_overlap_y = p_tile_size_y - tile_overlap_y;
814809

815-
int tile_size_x = tile_size < small_width ? tile_size : small_width;
816-
int tile_size_y = tile_size < small_height ? tile_size : small_height;
810+
int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width;
811+
int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height;
817812

818813
int input_tile_size_x = tile_size_x;
819814
int input_tile_size_y = tile_size_y;
@@ -902,6 +897,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
902897
ggml_free(tiles_ctx);
903898
}
904899

900+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale,
901+
const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
902+
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
903+
}
904+
905905
__STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
906906
struct ggml_tensor* a) {
907907
const float eps = 1e-6f; // default eps parameter

stable-diffusion.cpp

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,41 +1281,113 @@ class StableDiffusionGGML {
12811281
return latent;
12821282
}
12831283

1284-
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
1285-
int64_t t0 = ggml_time_ms();
1286-
ggml_tensor* result = NULL;
1287-
int tile_size = 32;
1288-
// TODO: arg instead of env?
1284+
void get_vae_tile_overlap(float& tile_overlap) {
1285+
const char* SD_TILE_OVERLAP = getenv("SD_TILE_OVERLAP");
1286+
if (SD_TILE_OVERLAP != nullptr) {
1287+
std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
1288+
try {
1289+
tile_overlap = std::stof(sd_tile_overlap_str);
1290+
if (tile_overlap < 0.0) {
1291+
LOG_WARN("SD_TILE_OVERLAP too low, setting it to 0.0");
1292+
tile_overlap = 0.0;
1293+
} else if (tile_overlap > 0.5) {
1294+
LOG_WARN("SD_TILE_OVERLAP too high, setting it to 0.5");
1295+
tile_overlap = 0.5;
1296+
}
1297+
} catch (const std::invalid_argument&) {
1298+
LOG_WARN("SD_TILE_OVERLAP is invalid, keeping the default");
1299+
} catch (const std::out_of_range&) {
1300+
LOG_WARN("SD_TILE_OVERLAP is out of range, keeping the default");
1301+
}
1302+
}
1303+
if (SD_TILE_OVERLAP != nullptr) {
1304+
LOG_INFO("VAE Tile overlap: %.2f", tile_overlap);
1305+
}
1306+
}
1307+
1308+
void get_vae_tile_sizes(int& tile_size_x, int& tile_size_y, float tile_overlap, int latent_x, int latent_y) {
12891309
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
12901310
if (SD_TILE_SIZE != nullptr) {
1311+
// format is AxB, or just A (equivalent to AxA)
1312+
// A and B can be integers (tile size) or floating point
1313+
// floating point <= 1 means simple fraction of the latent dimension
1314+
// floating point > 1 means number of tiles across that dimension
1315+
// a single number gets applied to both
1316+
auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
1317+
float factor = std::stof(factor_str);
1318+
if (factor > 1.0)
1319+
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1320+
return factor;
1321+
};
1322+
const int min_tile_dimension = 4;
12911323
std::string sd_tile_size_str = SD_TILE_SIZE;
1324+
size_t x_pos = sd_tile_size_str.find('x');
12921325
try {
1293-
tile_size = std::stoi(sd_tile_size_str);
1326+
int tmp_x = tile_size_x, tmp_y = tile_size_y;
1327+
if (x_pos != std::string::npos) {
1328+
std::string tile_x_str = sd_tile_size_str.substr(0, x_pos);
1329+
std::string tile_y_str = sd_tile_size_str.substr(x_pos + 1);
1330+
if (tile_x_str.find('.') != std::string::npos) {
1331+
tmp_x = std::round(latent_x * get_tile_factor(tile_x_str));
1332+
} else {
1333+
tmp_x = std::stoi(tile_x_str);
1334+
}
1335+
if (tile_y_str.find('.') != std::string::npos) {
1336+
tmp_y = std::round(latent_y * get_tile_factor(tile_y_str));
1337+
} else {
1338+
tmp_y = std::stoi(tile_y_str);
1339+
}
1340+
} else {
1341+
if (sd_tile_size_str.find('.') != std::string::npos) {
1342+
float tile_factor = get_tile_factor(sd_tile_size_str);
1343+
tmp_x = std::round(latent_x * tile_factor);
1344+
tmp_y = std::round(latent_y * tile_factor);
1345+
} else {
1346+
tmp_x = tmp_y = std::stoi(sd_tile_size_str);
1347+
}
1348+
}
1349+
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
1350+
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
12941351
} catch (const std::invalid_argument&) {
1295-
LOG_WARN("Invalid");
1352+
LOG_WARN("SD_TILE_SIZE is invalid, keeping the default");
12961353
} catch (const std::out_of_range&) {
1297-
LOG_WARN("OOR");
1354+
LOG_WARN("SD_TILE_SIZE is out of range, keeping the default");
12981355
}
12991356
}
1300-
if(!decode){
1301-
// TODO: also use and arg for this one?
1302-
// to keep the compute buffer size consistent
1303-
tile_size*=1.30539;
1357+
if (SD_TILE_SIZE != nullptr) {
1358+
LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
13041359
}
1360+
}
1361+
1362+
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
1363+
int64_t t0 = ggml_time_ms();
1364+
ggml_tensor* result = NULL;
1365+
// TODO: args instead of env for tile size / overlap?
13051366
if (!use_tiny_autoencoder) {
1367+
float tile_overlap = 0.5f;
1368+
int tile_size_x = 32;
1369+
int tile_size_y = 32;
1370+
1371+
get_vae_tile_overlap(tile_overlap);
1372+
get_vae_tile_sizes(tile_size_x, tile_size_y, tile_overlap, x->ne[0] / 8, x->ne[1] / 8);
1373+
1374+
// TODO: also use an arg for this one?
1375+
// multiply tile size for encode to keep the compute buffer size consistent
1376+
tile_size_x *= 1.30539;
1377+
tile_size_y *= 1.30539;
1378+
13061379
process_vae_input_tensor(x);
13071380
if (vae_tiling && !decode_video) {
1308-
// split latent in 32x32 tiles and compute in several steps
13091381
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13101382
first_stage_model->compute(n_threads, in, true, &out, NULL);
13111383
};
1312-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
1384+
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
13131385
} else {
13141386
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
13151387
}
13161388
first_stage_model->free_compute_buffer();
13171389
} else {
1318-
if (vae_tiling && !decode_video) {
1390+
if (vae_tiling && !decode_video) {
13191391
// split latent in 32x32 tiles and compute in several steps
13201392
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13211393
tae_first_stage->compute(n_threads, in, true, &out, NULL);
@@ -1440,29 +1512,23 @@ class StableDiffusionGGML {
14401512
C,
14411513
x->ne[3]);
14421514
}
1443-
int tile_size = 32;
1444-
// TODO: arg instead of env?
1445-
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
1446-
if (SD_TILE_SIZE != nullptr) {
1447-
std::string sd_tile_size_str = SD_TILE_SIZE;
1448-
try {
1449-
tile_size = std::stoi(sd_tile_size_str);
1450-
} catch (const std::invalid_argument&) {
1451-
LOG_WARN("Invalid");
1452-
} catch (const std::out_of_range&) {
1453-
LOG_WARN("OOR");
1454-
}
1455-
}
14561515
int64_t t0 = ggml_time_ms();
14571516
if (!use_tiny_autoencoder) {
1517+
float tile_overlap = 0.5f;
1518+
int tile_size_x = 32;
1519+
int tile_size_y = 32;
1520+
1521+
get_vae_tile_overlap(tile_overlap);
1522+
get_vae_tile_sizes(tile_size_x, tile_size_y, tile_overlap, x->ne[0] / 8, x->ne[1] / 8);
1523+
14581524
process_latent_out(x);
14591525
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
14601526
if (vae_tiling && !decode_video) {
14611527
// split latent in 32x32 tiles and compute in several steps
14621528
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14631529
first_stage_model->compute(n_threads, in, true, &out, NULL);
14641530
};
1465-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
1531+
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
14661532
} else {
14671533
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
14681534
}

0 commit comments

Comments
 (0)