@@ -101,7 +101,6 @@ struct SDParams {
101101 rng_type_t rng_type = CUDA_RNG;
102102 int64_t seed = 42 ;
103103 bool verbose = false ;
104- bool vae_tiling = false ;
105104 bool offload_params_to_cpu = false ;
106105 bool control_net_cpu = false ;
107106 bool normalize_input = false ;
@@ -119,6 +118,8 @@ struct SDParams {
119118 int chroma_t5_mask_pad = 1 ;
120119 float flow_shift = INFINITY;
121120
121+ sd_tiling_params_t vae_tiling_params = {false , 32 , 32 , 0 .5f , false , 0 .0f , 0 .0f };
122+
122123 SDParams () {
123124 sd_sample_params_init (&sample_params);
124125 sd_sample_params_init (&high_noise_sample_params);
@@ -180,7 +181,7 @@ void print_params(SDParams params) {
180181 printf (" rng: %s\n " , sd_rng_type_name (params.rng_type ));
181182 printf (" seed: %ld\n " , params.seed );
182183 printf (" batch_count: %d\n " , params.batch_count );
183- printf (" vae_tiling: %s\n " , params.vae_tiling ? " true" : " false" );
184+ printf (" vae_tiling: %s\n " , params.vae_tiling_params . enabled ? " true" : " false" );
184185 printf (" upscale_repeats: %d\n " , params.upscale_repeats );
185186 printf (" chroma_use_dit_mask: %s\n " , params.chroma_use_dit_mask ? " true" : " false" );
186187 printf (" chroma_use_t5_mask: %s\n " , params.chroma_use_t5_mask ? " true" : " false" );
@@ -268,6 +269,9 @@ void print_usage(int argc, const char* argv[]) {
268269 printf (" --clip-skip N ignore last_dot_pos layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n " );
269270 printf (" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n " );
270271 printf (" --vae-tiling process vae in tiles to reduce memory usage\n " );
272+ printf (" --vae-tile-size [X]x[Y] tile size for vae tiling (default: 32x32)\n " );
273+ printf (" --vae-relative-tile-size [X]x[Y] relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)\n " );
274+ printf (" --vae-tile-overlap OVERLAP tile overlap for vae tiling, in fraction of tile size (default: 0.5)\n " );
271275 printf (" --vae-on-cpu keep vae in cpu (for low vram)\n " );
272276 printf (" --clip-on-cpu keep clip in cpu (for low vram)\n " );
273277 printf (" --diffusion-fa use flash attention in the diffusion model (for low vram)\n " );
@@ -485,7 +489,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
485489 {" -o" , " --output" , " " , ¶ms.output_path },
486490 {" -p" , " --prompt" , " " , ¶ms.prompt },
487491 {" -n" , " --negative-prompt" , " " , ¶ms.negative_prompt },
488-
489492 {" " , " --upscale-model" , " " , ¶ms.esrgan_path },
490493 };
491494
@@ -526,7 +529,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
526529 };
527530
528531 options.bool_options = {
529- {" " , " --vae-tiling" , " " , true , ¶ms.vae_tiling },
532+ {" " , " --vae-tiling" , " " , true , ¶ms.vae_tiling_params . enabled },
530533 {" " , " --offload-to-cpu" , " " , true , ¶ms.offload_params_to_cpu },
531534 {" " , " --control-net-cpu" , " " , true , ¶ms.control_net_cpu },
532535 {" " , " --normalize-input" , " " , true , ¶ms.normalize_input },
@@ -726,6 +729,62 @@ void parse_args(int argc, const char** argv, SDParams& params) {
726729 return 1 ;
727730 };
728731
732+ auto on_tile_size_arg = [&](int argc, const char ** argv, int index) {
733+ if (++index >= argc) {
734+ return -1 ;
735+ }
736+ std::string tile_size_str = argv[index];
737+ size_t x_pos = tile_size_str.find (' x' );
738+ try {
739+ if (x_pos != std::string::npos) {
740+ std::string tile_x_str = tile_size_str.substr (0 , x_pos);
741+ std::string tile_y_str = tile_size_str.substr (x_pos + 1 );
742+ params.vae_tiling_params .tile_size_x = std::stoi (tile_x_str);
743+ params.vae_tiling_params .tile_size_y = std::stoi (tile_y_str);
744+ } else {
745+ params.vae_tiling_params .tile_size_x = params.vae_tiling_params .tile_size_y = std::stoi (tile_size_str);
746+ }
747+ } catch (const std::invalid_argument& e) {
748+ return -1 ;
749+ } catch (const std::out_of_range& e) {
750+ return -1 ;
751+ }
752+ params.vae_tiling_params .relative = false ;
753+ return 1 ;
754+ };
755+
756+ auto on_relative_tile_size_arg = [&](int argc, const char ** argv, int index) {
757+ if (++index >= argc) {
758+ return -1 ;
759+ }
760+ std::string rel_size_str = argv[index];
761+ size_t x_pos = rel_size_str.find (' x' );
762+ try {
763+ if (x_pos != std::string::npos) {
764+ std::string rel_x_str = rel_size_str.substr (0 , x_pos);
765+ std::string rel_y_str = rel_size_str.substr (x_pos + 1 );
766+ params.vae_tiling_params .rel_size_x = std::stof (rel_x_str);
767+ params.vae_tiling_params .rel_size_y = std::stof (rel_y_str);
768+ } else {
769+ params.vae_tiling_params .rel_size_x = params.vae_tiling_params .rel_size_y = std::stof (rel_size_str);
770+ }
771+ } catch (const std::invalid_argument& e) {
772+ return -1 ;
773+ } catch (const std::out_of_range& e) {
774+ return -1 ;
775+ }
776+ params.vae_tiling_params .relative = true ;
777+ return 1 ;
778+ };
779+
780+ auto on_tile_overlap_arg = [&](int argc, const char ** argv, int index) {
781+ if (++index >= argc) {
782+ return -1 ;
783+ }
784+ params.vae_tiling_params .target_overlap = std::stof (argv[index]);
785+ return 1 ;
786+ };
787+
729788 options.manual_options = {
730789 {" -M" , " --mode" , " " , on_mode_arg},
731790 {" " , " --type" , " " , on_type_arg},
@@ -739,6 +798,9 @@ void parse_args(int argc, const char** argv, SDParams& params) {
739798 {" " , " --high-noise-skip-layers" , " " , on_high_noise_skip_layers_arg},
740799 {" -r" , " --ref-image" , " " , on_ref_image_arg},
741800 {" -h" , " --help" , " " , on_help_arg},
801+ {" " , " --vae-tile-size" , " " , on_tile_size_arg},
802+ {" " , " --vae-relative-tile-size" , " " , on_relative_tile_size_arg},
803+ {" " , " --vae-tile-overlap" , " " , on_tile_overlap_arg},
742804 };
743805
744806 if (!parse_options (argc, argv, options)) {
@@ -1176,7 +1238,6 @@ int main(int argc, const char* argv[]) {
11761238 params.embedding_dir .c_str (),
11771239 params.stacked_id_embed_dir .c_str (),
11781240 vae_decode_only,
1179- params.vae_tiling ,
11801241 true ,
11811242 params.n_threads ,
11821243 params.wtype ,
@@ -1225,6 +1286,7 @@ int main(int argc, const char* argv[]) {
12251286 params.style_ratio ,
12261287 params.normalize_input ,
12271288 params.input_id_images_path .c_str (),
1289+ params.vae_tiling_params ,
12281290 };
12291291
12301292 results = generate_image (sd_ctx, &img_gen_params);
0 commit comments