From 5bf5c1b52e570b726bab32b266623f98012c8c3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 29 Aug 2025 22:33:37 +0200 Subject: [PATCH 1/2] Wan MoE: Automatic expert routing based on timestep boundary --- examples/cli/main.cpp | 11 ++++++++--- stable-diffusion.cpp | 31 +++++++++++++++++++++++++------ stable-diffusion.h | 1 + 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 0d08ccbbb..95676fc72 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -113,10 +113,12 @@ struct SDParams { bool chroma_use_dit_mask = true; bool chroma_use_t5_mask = false; int chroma_t5_mask_pad = 1; + float boundary = 0.875; SDParams() { sd_sample_params_init(&sample_params); sd_sample_params_init(&high_noise_sample_params); + high_noise_sample_params.sample_steps = -1; } }; @@ -243,7 +245,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --high-noise-scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)\n"); printf(" --high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n"); printf(" (high noise) sampling method (default: \"euler_a\")\n"); - printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: 20)\n"); + printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: -1 = auto)\n"); printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n"); printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n"); printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20)\n"); @@ -274,6 +276,8 @@ void print_usage(int argc, const char* argv[]) { printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n"); printf(" --video-frames video frames (default: 1)\n"); printf(" --fps fps (default: 24)\n"); + printf(" --moe-boundary BOUNDARY Timestep boundary for Wan2.2 MoE model. (default: 0.875)"); + printf(" Only enabled if `--high-noise-steps` is set to -1"); printf(" -v, --verbose print extra info\n"); } @@ -507,6 +511,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--strength", "", ¶ms.strength}, {"", "--style-ratio", "", ¶ms.style_ratio}, {"", "--control-strength", "", ¶ms.control_strength}, + {"", "--moe-boundary", "", ¶ms.boundary}, }; options.bool_options = { @@ -767,8 +772,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { } if (params.high_noise_sample_params.sample_steps <= 0) { - fprintf(stderr, "error: the high_noise_sample_steps must be greater than 0\n"); - exit(1); + params.high_noise_sample_params.sample_steps = -1; } if (params.strength < 0.f || params.strength > 1.f) { @@ -1225,6 +1229,7 @@ int main(int argc, const char* argv[]) { params.strength, params.seed, params.video_frames, + params.boundary }; results = generate_video(sd_ctx, &vid_gen_params, &num_results); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1f69e2be2..7c3ce8a7f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1727,11 +1727,13 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t)); sd_sample_params_init(&sd_vid_gen_params->sample_params); sd_sample_params_init(&sd_vid_gen_params->high_noise_sample_params); - sd_vid_gen_params->width = 512; - sd_vid_gen_params->height = 512; - sd_vid_gen_params->strength = 0.75f; - sd_vid_gen_params->seed = -1; - sd_vid_gen_params->video_frames = 6; + sd_vid_gen_params->high_noise_sample_params.sample_steps = -1; + sd_vid_gen_params->width = 512; + sd_vid_gen_params->height = 512; + sd_vid_gen_params->strength = 0.75f; + sd_vid_gen_params->seed = -1; + sd_vid_gen_params->video_frames = 6; + sd_vid_gen_params->boundary = 0.875f; } struct sd_ctx_t { @@ -2381,7 +2383,24 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps; } - std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps + high_noise_sample_steps); + int total_steps = sample_steps; + + if (high_noise_sample_steps > 0) { + total_steps += high_noise_sample_steps; + } + std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps); + + if(high_noise_sample_steps < 0) { + // timesteps∝sigmas for Flow models (like wan2.2 a14b) + for (size_t i = 0; i < sigmas.size(); ++i) { + if (sigmas[i] < sd_vid_gen_params->boundary) { + high_noise_sample_steps = i; + break; + } + } + LOG_DEBUG("Switching from high noise model at step %d", high_noise_sample_steps); + sample_steps = total_steps - high_noise_sample_steps; + } struct ggml_init_params params; params.mem_size = static_cast(200 * 1024) * 1024; // 200 MB diff --git a/stable-diffusion.h b/stable-diffusion.h index 52d4aa67d..31b3c06e1 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -208,6 +208,7 @@ typedef struct { float strength; int64_t seed; int video_frames; + float boundary; } sd_vid_gen_params_t; typedef struct sd_ctx_t sd_ctx_t; From 10933550830a4cc59a7cff3eeaf5f2c99fdb4279 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 7 Sep 2025 01:39:52 +0800 Subject: [PATCH 2/2] unify code style and fix some issues --- examples/cli/main.cpp | 18 ++++++++++-------- stable-diffusion.cpp | 10 +++++----- stable-diffusion.h | 2 +- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 95676fc72..91d74f173 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -89,6 +89,8 @@ struct SDParams { std::vector high_noise_skip_layers = {7, 8, 9}; sd_sample_params_t high_noise_sample_params; + float moe_boundary = 0.875f; + int video_frames = 1; int fps = 16; @@ -113,7 +115,6 @@ struct SDParams { bool chroma_use_dit_mask = true; bool chroma_use_t5_mask = false; int chroma_t5_mask_pad = 1; - float boundary = 0.875; SDParams() { sd_sample_params_init(&sample_params); @@ -169,6 +170,7 @@ void print_params(SDParams params) { printf(" height: %d\n", params.height); printf(" sample_params: %s\n", SAFE_STR(sample_params_str)); printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str)); + printf(" moe_boundary: %.3f\n", params.moe_boundary); printf(" strength(img2img): %.2f\n", params.strength); printf(" rng: %s\n", sd_rng_type_name(params.rng_type)); printf(" seed: %ld\n", params.seed); @@ -276,8 +278,8 @@ void print_usage(int argc, const char* argv[]) { printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n"); printf(" --video-frames video frames (default: 1)\n"); printf(" --fps fps (default: 24)\n"); - printf(" --moe-boundary BOUNDARY Timestep boundary for Wan2.2 MoE model. (default: 0.875)"); - printf(" Only enabled if `--high-noise-steps` is set to -1"); + printf(" --moe-boundary BOUNDARY Timestep boundary for Wan2.2 MoE model. (default: 0.875)\n"); + printf(" Only enabled if `--high-noise-steps` is set to -1\n"); printf(" -v, --verbose print extra info\n"); } @@ -366,7 +368,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) { std::string arg; for (int i = 1; i < argc; i++) { bool found_arg = false; - arg = argv[i]; + arg = argv[i]; for (auto& option : options.string_options) { if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) { @@ -427,7 +429,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) { for (auto& option : options.manual_options) { if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) { found_arg = true; - int ret = option.cb(argc, argv, i); + int ret = option.cb(argc, argv, i); if (ret < 0) { invalid_arg = true; break; @@ -439,7 +441,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) { break; } if (!found_arg) { - fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); return false; } } @@ -511,7 +513,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--strength", "", ¶ms.strength}, {"", "--style-ratio", "", ¶ms.style_ratio}, {"", "--control-strength", "", ¶ms.control_strength}, - {"", "--moe-boundary", "", ¶ms.boundary}, + {"", "--moe-boundary", "", ¶ms.moe_boundary}, }; options.bool_options = { @@ -1226,10 +1228,10 @@ int main(int argc, const char* argv[]) { params.height, params.sample_params, params.high_noise_sample_params, + params.moe_boundary, params.strength, params.seed, params.video_frames, - params.boundary }; results = generate_video(sd_ctx, &vid_gen_params, &num_results); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 7c3ce8a7f..db89cbb74 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1733,7 +1733,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { sd_vid_gen_params->strength = 0.75f; sd_vid_gen_params->seed = -1; sd_vid_gen_params->video_frames = 6; - sd_vid_gen_params->boundary = 0.875f; + sd_vid_gen_params->moe_boundary = 0.875f; } struct sd_ctx_t { @@ -2390,15 +2390,15 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps); - if(high_noise_sample_steps < 0) { - // timesteps∝sigmas for Flow models (like wan2.2 a14b) + if (high_noise_sample_steps < 0) { + // timesteps ∝ sigmas for Flow models (like wan2.2 a14b) for (size_t i = 0; i < sigmas.size(); ++i) { - if (sigmas[i] < sd_vid_gen_params->boundary) { + if (sigmas[i] < sd_vid_gen_params->moe_boundary) { high_noise_sample_steps = i; break; } } - LOG_DEBUG("Switching from high noise model at step %d", high_noise_sample_steps); + LOG_DEBUG("switching from high noise model at step %d", high_noise_sample_steps); sample_steps = total_steps - high_noise_sample_steps; } diff --git a/stable-diffusion.h b/stable-diffusion.h index 31b3c06e1..5ffe50618 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -205,10 +205,10 @@ typedef struct { int height; sd_sample_params_t sample_params; sd_sample_params_t high_noise_sample_params; + float moe_boundary; float strength; int64_t seed; int video_frames; - float boundary; } sd_vid_gen_params_t; typedef struct sd_ctx_t sd_ctx_t;