@@ -1281,41 +1281,113 @@ class StableDiffusionGGML {
12811281 return latent;
12821282 }
12831283
1284- ggml_tensor* encode_first_stage (ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false ) {
1285- int64_t t0 = ggml_time_ms ();
1286- ggml_tensor* result = NULL ;
1287- int tile_size = 32 ;
1288- // TODO: arg instead of env?
1284+ void get_vae_tile_overlap (float & tile_overlap) {
1285+ const char * SD_TILE_OVERLAP = getenv (" SD_TILE_OVERLAP" );
1286+ if (SD_TILE_OVERLAP != nullptr ) {
1287+ std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
1288+ try {
1289+ tile_overlap = std::stof (sd_tile_overlap_str);
1290+ if (tile_overlap < 0.0 ) {
1291+ LOG_WARN (" SD_TILE_OVERLAP too low, setting it to 0.0" );
1292+ tile_overlap = 0.0 ;
1293+ } else if (tile_overlap > 0.5 ) {
1294+ LOG_WARN (" SD_TILE_OVERLAP too high, setting it to 0.5" );
1295+ tile_overlap = 0.5 ;
1296+ }
1297+ } catch (const std::invalid_argument&) {
1298+ LOG_WARN (" SD_TILE_OVERLAP is invalid, keeping the default" );
1299+ } catch (const std::out_of_range&) {
1300+ LOG_WARN (" SD_TILE_OVERLAP is out of range, keeping the default" );
1301+ }
1302+ }
1303+ if (SD_TILE_OVERLAP != nullptr ) {
1304+ LOG_INFO (" VAE Tile overlap: %.2f" , tile_overlap);
1305+ }
1306+ }
1307+
1308+ void get_vae_tile_sizes (int & tile_size_x, int & tile_size_y, float tile_overlap, int latent_x, int latent_y) {
12891309 const char * SD_TILE_SIZE = getenv (" SD_TILE_SIZE" );
12901310 if (SD_TILE_SIZE != nullptr ) {
1311+ // format is AxB, or just A (equivalent to AxA)
1312+ // A and B can be integers (tile size) or floating point
1313+ // floating point <= 1 means simple fraction of the latent dimension
1314+ // floating point > 1 means number of tiles across that dimension
1315+ // a single number gets applied to both
1316+ auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
1317+ float factor = std::stof (factor_str);
1318+ if (factor > 1.0 )
1319+ factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1320+ return factor;
1321+ };
1322+ const int min_tile_dimension = 4 ;
12911323 std::string sd_tile_size_str = SD_TILE_SIZE;
1324+ size_t x_pos = sd_tile_size_str.find (' x' );
12921325 try {
1293- tile_size = std::stoi (sd_tile_size_str);
1326+ int tmp_x = tile_size_x, tmp_y = tile_size_y;
1327+ if (x_pos != std::string::npos) {
1328+ std::string tile_x_str = sd_tile_size_str.substr (0 , x_pos);
1329+ std::string tile_y_str = sd_tile_size_str.substr (x_pos + 1 );
1330+ if (tile_x_str.find (' .' ) != std::string::npos) {
1331+ tmp_x = std::round (latent_x * get_tile_factor (tile_x_str));
1332+ } else {
1333+ tmp_x = std::stoi (tile_x_str);
1334+ }
1335+ if (tile_y_str.find (' .' ) != std::string::npos) {
1336+ tmp_y = std::round (latent_y * get_tile_factor (tile_y_str));
1337+ } else {
1338+ tmp_y = std::stoi (tile_y_str);
1339+ }
1340+ } else {
1341+ if (sd_tile_size_str.find (' .' ) != std::string::npos) {
1342+ float tile_factor = get_tile_factor (sd_tile_size_str);
1343+ tmp_x = std::round (latent_x * tile_factor);
1344+ tmp_y = std::round (latent_y * tile_factor);
1345+ } else {
1346+ tmp_x = tmp_y = std::stoi (sd_tile_size_str);
1347+ }
1348+ }
1349+ tile_size_x = std::max (std::min (tmp_x, latent_x), min_tile_dimension);
1350+ tile_size_y = std::max (std::min (tmp_y, latent_y), min_tile_dimension);
12941351 } catch (const std::invalid_argument&) {
1295- LOG_WARN (" Invalid " );
1352+ LOG_WARN (" SD_TILE_SIZE is invalid, keeping the default " );
12961353 } catch (const std::out_of_range&) {
1297- LOG_WARN (" OOR " );
1354+ LOG_WARN (" SD_TILE_SIZE is out of range, keeping the default " );
12981355 }
12991356 }
1300- if (!decode){
1301- // TODO: also use and arg for this one?
1302- // to keep the compute buffer size consistent
1303- tile_size*=1.30539 ;
1357+ if (SD_TILE_SIZE != nullptr ) {
1358+ LOG_INFO (" VAE Tile size: %dx%d" , tile_size_x, tile_size_y);
13041359 }
1360+ }
1361+
1362+ ggml_tensor* encode_first_stage (ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false ) {
1363+ int64_t t0 = ggml_time_ms ();
1364+ ggml_tensor* result = NULL ;
1365+ // TODO: args instead of env for tile size / overlap?
13051366 if (!use_tiny_autoencoder) {
1367+ float tile_overlap = 0 .5f ;
1368+ int tile_size_x = 32 ;
1369+ int tile_size_y = 32 ;
1370+
1371+ get_vae_tile_overlap (tile_overlap);
1372+ get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, x->ne [0 ] / 8 , x->ne [1 ] / 8 );
1373+
1374+ // TODO: also use an arg for this one?
1375+ // multiply tile size for encode to keep the compute buffer size consistent
1376+ tile_size_x *= 1.30539 ;
1377+ tile_size_y *= 1.30539 ;
1378+
13061379 process_vae_input_tensor (x);
13071380 if (vae_tiling && !decode_video) {
1308- // split latent in 32x32 tiles and compute in several steps
13091381 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13101382 first_stage_model->compute (n_threads, in, true , &out, NULL );
13111383 };
1312- sd_tiling (x, result, 8 , tile_size, 0 . 5f , on_tiling);
1384+ sd_tiling_non_square (x, result, 8 , tile_size_x, tile_size_y, tile_overlap , on_tiling);
13131385 } else {
13141386 first_stage_model->compute (n_threads, x, false , &result, work_ctx);
13151387 }
13161388 first_stage_model->free_compute_buffer ();
13171389 } else {
1318- if (vae_tiling && !decode_video) {
1390+ if (vae_tiling && !decode_video) {
13191391 // split latent in 32x32 tiles and compute in several steps
13201392 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13211393 tae_first_stage->compute (n_threads, in, true , &out, NULL );
@@ -1440,29 +1512,23 @@ class StableDiffusionGGML {
14401512 C,
14411513 x->ne [3 ]);
14421514 }
1443- int tile_size = 32 ;
1444- // TODO: arg instead of env?
1445- const char * SD_TILE_SIZE = getenv (" SD_TILE_SIZE" );
1446- if (SD_TILE_SIZE != nullptr ) {
1447- std::string sd_tile_size_str = SD_TILE_SIZE;
1448- try {
1449- tile_size = std::stoi (sd_tile_size_str);
1450- } catch (const std::invalid_argument&) {
1451- LOG_WARN (" Invalid" );
1452- } catch (const std::out_of_range&) {
1453- LOG_WARN (" OOR" );
1454- }
1455- }
14561515 int64_t t0 = ggml_time_ms ();
14571516 if (!use_tiny_autoencoder) {
1517+ float tile_overlap = 0 .5f ;
1518+ int tile_size_x = 32 ;
1519+ int tile_size_y = 32 ;
1520+
1521+ get_vae_tile_overlap (tile_overlap);
1522+ get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, x->ne [0 ] / 8 , x->ne [1 ] / 8 );
1523+
14581524 process_latent_out (x);
14591525 // x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
14601526 if (vae_tiling && !decode_video) {
14611527 // split latent in 32x32 tiles and compute in several steps
14621528 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14631529 first_stage_model->compute (n_threads, in, true , &out, NULL );
14641530 };
1465- sd_tiling (x, result, 8 , tile_size, 0 . 5f , on_tiling);
1531+ sd_tiling_non_square (x, result, 8 , tile_size_x, tile_size_y, tile_overlap , on_tiling);
14661532 } else {
14671533 first_stage_model->compute (n_threads, x, true , &result, work_ctx);
14681534 }
0 commit comments