Skip to content

Commit 987ced8

Browse files
committed
tile size params instead of env
1 parent 606976f commit 987ced8

File tree

3 files changed

+138
-116
lines changed

3 files changed

+138
-116
lines changed

examples/cli/main.cpp

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ struct SDParams {
101101
rng_type_t rng_type = CUDA_RNG;
102102
int64_t seed = 42;
103103
bool verbose = false;
104-
bool vae_tiling = false;
105104
bool offload_params_to_cpu = false;
106105
bool control_net_cpu = false;
107106
bool normalize_input = false;
@@ -119,6 +118,8 @@ struct SDParams {
119118
int chroma_t5_mask_pad = 1;
120119
float flow_shift = INFINITY;
121120

121+
sd_tiling_params_t vae_tiling_params = {false, 32, 32, 0.5f, false, 0.0f, 0.0f};
122+
122123
SDParams() {
123124
sd_sample_params_init(&sample_params);
124125
sd_sample_params_init(&high_noise_sample_params);
@@ -180,7 +181,7 @@ void print_params(SDParams params) {
180181
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
181182
printf(" seed: %ld\n", params.seed);
182183
printf(" batch_count: %d\n", params.batch_count);
183-
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
184+
printf(" vae_tiling: %s\n", params.vae_tiling_params.enabled ? "true" : "false");
184185
printf(" upscale_repeats: %d\n", params.upscale_repeats);
185186
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
186187
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
@@ -268,6 +269,9 @@ void print_usage(int argc, const char* argv[]) {
268269
printf(" --clip-skip N ignore last_dot_pos layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
269270
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
270271
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
272+
printf(" --vae-tile-size [X]x[Y] tile size for vae tiling (default: 32x32)\n");
273+
printf(" --vae-relative-tile-size [X]x[Y] relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)\n");
274+
printf(" --vae-tile-overlap OVERLAP tile overlap for vae tiling, in fraction of tile size (default: 0.5)\n");
271275
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
272276
printf(" --clip-on-cpu keep clip in cpu (for low vram)\n");
273277
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
@@ -485,7 +489,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
485489
{"-o", "--output", "", &params.output_path},
486490
{"-p", "--prompt", "", &params.prompt},
487491
{"-n", "--negative-prompt", "", &params.negative_prompt},
488-
489492
{"", "--upscale-model", "", &params.esrgan_path},
490493
};
491494

@@ -526,7 +529,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
526529
};
527530

528531
options.bool_options = {
529-
{"", "--vae-tiling", "", true, &params.vae_tiling},
532+
{"", "--vae-tiling", "", true, &params.vae_tiling_params.enabled},
530533
{"", "--offload-to-cpu", "", true, &params.offload_params_to_cpu},
531534
{"", "--control-net-cpu", "", true, &params.control_net_cpu},
532535
{"", "--normalize-input", "", true, &params.normalize_input},
@@ -726,6 +729,62 @@ void parse_args(int argc, const char** argv, SDParams& params) {
726729
return 1;
727730
};
728731

732+
auto on_tile_size_arg = [&](int argc, const char** argv, int index) {
733+
if (++index >= argc) {
734+
return -1;
735+
}
736+
std::string tile_size_str = argv[index];
737+
size_t x_pos = tile_size_str.find('x');
738+
try {
739+
if (x_pos != std::string::npos) {
740+
std::string tile_x_str = tile_size_str.substr(0, x_pos);
741+
std::string tile_y_str = tile_size_str.substr(x_pos + 1);
742+
params.vae_tiling_params.tile_size_x = std::stoi(tile_x_str);
743+
params.vae_tiling_params.tile_size_y = std::stoi(tile_y_str);
744+
} else {
745+
params.vae_tiling_params.tile_size_x = params.vae_tiling_params.tile_size_y = std::stoi(tile_size_str);
746+
}
747+
} catch (const std::invalid_argument& e) {
748+
return -1;
749+
} catch (const std::out_of_range& e) {
750+
return -1;
751+
}
752+
params.vae_tiling_params.relative = false;
753+
return 1;
754+
};
755+
756+
auto on_relative_tile_size_arg = [&](int argc, const char** argv, int index) {
757+
if (++index >= argc) {
758+
return -1;
759+
}
760+
std::string rel_size_str = argv[index];
761+
size_t x_pos = rel_size_str.find('x');
762+
try {
763+
if (x_pos != std::string::npos) {
764+
std::string rel_x_str = rel_size_str.substr(0, x_pos);
765+
std::string rel_y_str = rel_size_str.substr(x_pos + 1);
766+
params.vae_tiling_params.rel_size_x = std::stof(rel_x_str);
767+
params.vae_tiling_params.rel_size_y = std::stof(rel_y_str);
768+
} else {
769+
params.vae_tiling_params.rel_size_x = params.vae_tiling_params.rel_size_y = std::stof(rel_size_str);
770+
}
771+
} catch (const std::invalid_argument& e) {
772+
return -1;
773+
} catch (const std::out_of_range& e) {
774+
return -1;
775+
}
776+
params.vae_tiling_params.relative = true;
777+
return 1;
778+
};
779+
780+
auto on_tile_overlap_arg = [&](int argc, const char** argv, int index) {
781+
if (++index >= argc) {
782+
return -1;
783+
}
784+
params.vae_tiling_params.target_overlap = std::stof(argv[index]);
785+
return 1;
786+
};
787+
729788
options.manual_options = {
730789
{"-M", "--mode", "", on_mode_arg},
731790
{"", "--type", "", on_type_arg},
@@ -739,6 +798,9 @@ void parse_args(int argc, const char** argv, SDParams& params) {
739798
{"", "--high-noise-skip-layers", "", on_high_noise_skip_layers_arg},
740799
{"-r", "--ref-image", "", on_ref_image_arg},
741800
{"-h", "--help", "", on_help_arg},
801+
{"", "--vae-tile-size", "", on_tile_size_arg},
802+
{"", "--vae-relative-tile-size", "", on_relative_tile_size_arg},
803+
{"", "--vae-tile-overlap", "", on_tile_overlap_arg},
742804
};
743805

744806
if (!parse_options(argc, argv, options)) {
@@ -1176,7 +1238,6 @@ int main(int argc, const char* argv[]) {
11761238
params.embedding_dir.c_str(),
11771239
params.stacked_id_embed_dir.c_str(),
11781240
vae_decode_only,
1179-
params.vae_tiling,
11801241
true,
11811242
params.n_threads,
11821243
params.wtype,
@@ -1225,6 +1286,7 @@ int main(int argc, const char* argv[]) {
12251286
params.style_ratio,
12261287
params.normalize_input,
12271288
params.input_id_images_path.c_str(),
1289+
params.vae_tiling_params,
12281290
};
12291291

12301292
results = generate_image(sd_ctx, &img_gen_params);

0 commit comments

Comments
 (0)