Skip to content

Commit 21ce9fe

Browse files
stduhpfleejet
andauthored
feat: add support for timestep boundary based automatic expert routing in Wan MoE (#779)
* Wan MoE: Automatic expert routing based on timestep boundary * unify code style and fix some issues --------- Co-authored-by: leejet <[email protected]>
1 parent cb1d975 commit 21ce9fe

File tree

3 files changed

+39
-12
lines changed

3 files changed

+39
-12
lines changed

examples/cli/main.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ struct SDParams {
8989
std::vector<int> high_noise_skip_layers = {7, 8, 9};
9090
sd_sample_params_t high_noise_sample_params;
9191

92+
float moe_boundary = 0.875f;
93+
9294
int video_frames = 1;
9395
int fps = 16;
9496

@@ -117,6 +119,7 @@ struct SDParams {
117119
SDParams() {
118120
sd_sample_params_init(&sample_params);
119121
sd_sample_params_init(&high_noise_sample_params);
122+
high_noise_sample_params.sample_steps = -1;
120123
}
121124
};
122125

@@ -167,6 +170,7 @@ void print_params(SDParams params) {
167170
printf(" height: %d\n", params.height);
168171
printf(" sample_params: %s\n", SAFE_STR(sample_params_str));
169172
printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str));
173+
printf(" moe_boundary: %.3f\n", params.moe_boundary);
170174
printf(" strength(img2img): %.2f\n", params.strength);
171175
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
172176
printf(" seed: %ld\n", params.seed);
@@ -243,7 +247,7 @@ void print_usage(int argc, const char* argv[]) {
243247
printf(" --high-noise-scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)\n");
244248
printf(" --high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
245249
printf(" (high noise) sampling method (default: \"euler_a\")\n");
246-
printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: 20)\n");
250+
printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: -1 = auto)\n");
247251
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
248252
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
249253
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20)\n");
@@ -274,6 +278,8 @@ void print_usage(int argc, const char* argv[]) {
274278
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
275279
printf(" --video-frames video frames (default: 1)\n");
276280
printf(" --fps fps (default: 24)\n");
281+
printf(" --moe-boundary BOUNDARY Timestep boundary for Wan2.2 MoE model. (default: 0.875)\n");
282+
printf(" Only enabled if `--high-noise-steps` is set to -1\n");
277283
printf(" -v, --verbose print extra info\n");
278284
}
279285

@@ -362,7 +368,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) {
362368
std::string arg;
363369
for (int i = 1; i < argc; i++) {
364370
bool found_arg = false;
365-
arg = argv[i];
371+
arg = argv[i];
366372

367373
for (auto& option : options.string_options) {
368374
if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) {
@@ -423,7 +429,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) {
423429
for (auto& option : options.manual_options) {
424430
if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) {
425431
found_arg = true;
426-
int ret = option.cb(argc, argv, i);
432+
int ret = option.cb(argc, argv, i);
427433
if (ret < 0) {
428434
invalid_arg = true;
429435
break;
@@ -435,7 +441,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) {
435441
break;
436442
}
437443
if (!found_arg) {
438-
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
444+
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
439445
return false;
440446
}
441447
}
@@ -507,6 +513,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
507513
{"", "--strength", "", &params.strength},
508514
{"", "--style-ratio", "", &params.style_ratio},
509515
{"", "--control-strength", "", &params.control_strength},
516+
{"", "--moe-boundary", "", &params.moe_boundary},
510517
};
511518

512519
options.bool_options = {
@@ -767,8 +774,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
767774
}
768775

769776
if (params.high_noise_sample_params.sample_steps <= 0) {
770-
fprintf(stderr, "error: the high_noise_sample_steps must be greater than 0\n");
771-
exit(1);
777+
params.high_noise_sample_params.sample_steps = -1;
772778
}
773779

774780
if (params.strength < 0.f || params.strength > 1.f) {
@@ -1222,6 +1228,7 @@ int main(int argc, const char* argv[]) {
12221228
params.height,
12231229
params.sample_params,
12241230
params.high_noise_sample_params,
1231+
params.moe_boundary,
12251232
params.strength,
12261233
params.seed,
12271234
params.video_frames,

stable-diffusion.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,11 +1727,13 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
17271727
memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t));
17281728
sd_sample_params_init(&sd_vid_gen_params->sample_params);
17291729
sd_sample_params_init(&sd_vid_gen_params->high_noise_sample_params);
1730-
sd_vid_gen_params->width = 512;
1731-
sd_vid_gen_params->height = 512;
1732-
sd_vid_gen_params->strength = 0.75f;
1733-
sd_vid_gen_params->seed = -1;
1734-
sd_vid_gen_params->video_frames = 6;
1730+
sd_vid_gen_params->high_noise_sample_params.sample_steps = -1;
1731+
sd_vid_gen_params->width = 512;
1732+
sd_vid_gen_params->height = 512;
1733+
sd_vid_gen_params->strength = 0.75f;
1734+
sd_vid_gen_params->seed = -1;
1735+
sd_vid_gen_params->video_frames = 6;
1736+
sd_vid_gen_params->moe_boundary = 0.875f;
17351737
}
17361738

17371739
struct sd_ctx_t {
@@ -2381,7 +2383,24 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
23812383
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
23822384
}
23832385

2384-
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps + high_noise_sample_steps);
2386+
int total_steps = sample_steps;
2387+
2388+
if (high_noise_sample_steps > 0) {
2389+
total_steps += high_noise_sample_steps;
2390+
}
2391+
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps);
2392+
2393+
if (high_noise_sample_steps < 0) {
2394+
// timesteps ∝ sigmas for Flow models (like wan2.2 a14b)
2395+
for (size_t i = 0; i < sigmas.size(); ++i) {
2396+
if (sigmas[i] < sd_vid_gen_params->moe_boundary) {
2397+
high_noise_sample_steps = i;
2398+
break;
2399+
}
2400+
}
2401+
LOG_DEBUG("switching from high noise model at step %d", high_noise_sample_steps);
2402+
sample_steps = total_steps - high_noise_sample_steps;
2403+
}
23852404

23862405
struct ggml_init_params params;
23872406
params.mem_size = static_cast<size_t>(200 * 1024) * 1024; // 200 MB

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ typedef struct {
205205
int height;
206206
sd_sample_params_t sample_params;
207207
sd_sample_params_t high_noise_sample_params;
208+
float moe_boundary;
208209
float strength;
209210
int64_t seed;
210211
int video_frames;

0 commit comments

Comments
 (0)