@@ -1301,41 +1301,113 @@ class StableDiffusionGGML {
13011301 return latent;
13021302 }
13031303
1304- ggml_tensor* encode_first_stage (ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false ) {
1305- int64_t t0 = ggml_time_ms ();
1306- ggml_tensor* result = NULL ;
1307- int tile_size = 32 ;
1308- // TODO: arg instead of env?
1304+ void get_vae_tile_overlap (float & tile_overlap) {
1305+ const char * SD_TILE_OVERLAP = getenv (" SD_TILE_OVERLAP" );
1306+ if (SD_TILE_OVERLAP != nullptr ) {
1307+ std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
1308+ try {
1309+ tile_overlap = std::stof (sd_tile_overlap_str);
1310+ if (tile_overlap < 0.0 ) {
1311+ LOG_WARN (" SD_TILE_OVERLAP too low, setting it to 0.0" );
1312+ tile_overlap = 0.0 ;
1313+ } else if (tile_overlap > 0.5 ) {
1314+ LOG_WARN (" SD_TILE_OVERLAP too high, setting it to 0.5" );
1315+ tile_overlap = 0.5 ;
1316+ }
1317+ } catch (const std::invalid_argument&) {
1318+ LOG_WARN (" SD_TILE_OVERLAP is invalid, keeping the default" );
1319+ } catch (const std::out_of_range&) {
1320+ LOG_WARN (" SD_TILE_OVERLAP is out of range, keeping the default" );
1321+ }
1322+ }
1323+ if (SD_TILE_OVERLAP != nullptr ) {
1324+ LOG_INFO (" VAE Tile overlap: %.2f" , tile_overlap);
1325+ }
1326+ }
1327+
1328+ void get_vae_tile_sizes (int & tile_size_x, int & tile_size_y, float tile_overlap, int latent_x, int latent_y) {
13091329 const char * SD_TILE_SIZE = getenv (" SD_TILE_SIZE" );
13101330 if (SD_TILE_SIZE != nullptr ) {
1331+ // format is AxB, or just A (equivalent to AxA)
1332+ // A and B can be integers (tile size) or floating point
1333+ // floating point <= 1 means simple fraction of the latent dimension
1334+ // floating point > 1 means number of tiles across that dimension
1335+ // a single number gets applied to both
1336+ auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
1337+ float factor = std::stof (factor_str);
1338+ if (factor > 1.0 )
1339+ factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1340+ return factor;
1341+ };
1342+ const int min_tile_dimension = 4 ;
13111343 std::string sd_tile_size_str = SD_TILE_SIZE;
1344+ size_t x_pos = sd_tile_size_str.find (' x' );
13121345 try {
1313- tile_size = std::stoi (sd_tile_size_str);
1346+ int tmp_x = tile_size_x, tmp_y = tile_size_y;
1347+ if (x_pos != std::string::npos) {
1348+ std::string tile_x_str = sd_tile_size_str.substr (0 , x_pos);
1349+ std::string tile_y_str = sd_tile_size_str.substr (x_pos + 1 );
1350+ if (tile_x_str.find (' .' ) != std::string::npos) {
1351+ tmp_x = std::round (latent_x * get_tile_factor (tile_x_str));
1352+ } else {
1353+ tmp_x = std::stoi (tile_x_str);
1354+ }
1355+ if (tile_y_str.find (' .' ) != std::string::npos) {
1356+ tmp_y = std::round (latent_y * get_tile_factor (tile_y_str));
1357+ } else {
1358+ tmp_y = std::stoi (tile_y_str);
1359+ }
1360+ } else {
1361+ if (sd_tile_size_str.find (' .' ) != std::string::npos) {
1362+ float tile_factor = get_tile_factor (sd_tile_size_str);
1363+ tmp_x = std::round (latent_x * tile_factor);
1364+ tmp_y = std::round (latent_y * tile_factor);
1365+ } else {
1366+ tmp_x = tmp_y = std::stoi (sd_tile_size_str);
1367+ }
1368+ }
1369+ tile_size_x = std::max (std::min (tmp_x, latent_x), min_tile_dimension);
1370+ tile_size_y = std::max (std::min (tmp_y, latent_y), min_tile_dimension);
13141371 } catch (const std::invalid_argument&) {
1315- LOG_WARN (" Invalid " );
1372+ LOG_WARN (" SD_TILE_SIZE is invalid, keeping the default " );
13161373 } catch (const std::out_of_range&) {
1317- LOG_WARN (" OOR " );
1374+ LOG_WARN (" SD_TILE_SIZE is out of range, keeping the default " );
13181375 }
13191376 }
1320- if (!decode){
1321- // TODO: also use and arg for this one?
1322- // to keep the compute buffer size consistent
1323- tile_size*=1.30539 ;
1377+ if (SD_TILE_SIZE != nullptr ) {
1378+ LOG_INFO (" VAE Tile size: %dx%d" , tile_size_x, tile_size_y);
13241379 }
1380+ }
1381+
1382+ ggml_tensor* encode_first_stage (ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false ) {
1383+ int64_t t0 = ggml_time_ms ();
1384+ ggml_tensor* result = NULL ;
1385+ // TODO: args instead of env for tile size / overlap?
13251386 if (!use_tiny_autoencoder) {
1387+ float tile_overlap = 0 .5f ;
1388+ int tile_size_x = 32 ;
1389+ int tile_size_y = 32 ;
1390+
1391+ get_vae_tile_overlap (tile_overlap);
1392+ get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, x->ne [0 ] / 8 , x->ne [1 ] / 8 );
1393+
1394+ // TODO: also use an arg for this one?
1395+ // multiply tile size for encode to keep the compute buffer size consistent
1396+ tile_size_x *= 1.30539 ;
1397+ tile_size_y *= 1.30539 ;
1398+
13261399 process_vae_input_tensor (x);
13271400 if (vae_tiling && !decode_video) {
1328- // split latent in 32x32 tiles and compute in several steps
13291401 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13301402 first_stage_model->compute (n_threads, in, true , &out, NULL );
13311403 };
1332- sd_tiling (x, result, 8 , tile_size, 0 . 5f , on_tiling);
1404+ sd_tiling_non_square (x, result, 8 , tile_size_x, tile_size_y, tile_overlap , on_tiling);
13331405 } else {
13341406 first_stage_model->compute (n_threads, x, false , &result, work_ctx);
13351407 }
13361408 first_stage_model->free_compute_buffer ();
13371409 } else {
1338- if (vae_tiling && !decode_video) {
1410+ if (vae_tiling && !decode_video) {
13391411 // split latent in 32x32 tiles and compute in several steps
13401412 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13411413 tae_first_stage->compute (n_threads, in, true , &out, NULL );
@@ -1460,29 +1532,23 @@ class StableDiffusionGGML {
14601532 C,
14611533 x->ne [3 ]);
14621534 }
1463- int tile_size = 32 ;
1464- // TODO: arg instead of env?
1465- const char * SD_TILE_SIZE = getenv (" SD_TILE_SIZE" );
1466- if (SD_TILE_SIZE != nullptr ) {
1467- std::string sd_tile_size_str = SD_TILE_SIZE;
1468- try {
1469- tile_size = std::stoi (sd_tile_size_str);
1470- } catch (const std::invalid_argument&) {
1471- LOG_WARN (" Invalid" );
1472- } catch (const std::out_of_range&) {
1473- LOG_WARN (" OOR" );
1474- }
1475- }
14761535 int64_t t0 = ggml_time_ms ();
14771536 if (!use_tiny_autoencoder) {
1537+ float tile_overlap = 0 .5f ;
1538+ int tile_size_x = 32 ;
1539+ int tile_size_y = 32 ;
1540+
1541+ get_vae_tile_overlap (tile_overlap);
1542+ get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, x->ne [0 ] / 8 , x->ne [1 ] / 8 );
1543+
14781544 process_latent_out (x);
14791545 // x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
14801546 if (vae_tiling && !decode_video) {
14811547 // split latent in 32x32 tiles and compute in several steps
14821548 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14831549 first_stage_model->compute (n_threads, in, true , &out, NULL );
14841550 };
1485- sd_tiling (x, result, 8 , tile_size, 0 . 5f , on_tiling);
1551+ sd_tiling_non_square (x, result, 8 , tile_size_x, tile_size_y, tile_overlap , on_tiling);
14861552 } else {
14871553 first_stage_model->compute (n_threads, x, true , &result, work_ctx);
14881554 }
0 commit comments