Skip to content

Commit 8b5581c

Browse files
committed
Wan MoE: Automatic expert routing based on timestep boundary
1 parent 797d2f9 commit 8b5581c

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

examples/cli/main.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,12 @@ struct SDParams {
112112
bool chroma_use_dit_mask = true;
113113
bool chroma_use_t5_mask = false;
114114
int chroma_t5_mask_pad = 1;
115+
float boundary = 0.875;
115116

116117
SDParams() {
117118
sd_sample_params_init(&sample_params);
118119
sd_sample_params_init(&high_noise_sample_params);
120+
high_noise_sample_params.sample_steps = -1;
119121
}
120122
};
121123

@@ -240,7 +242,7 @@ void print_usage(int argc, const char* argv[]) {
240242
printf(" --high-noise-scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)\n");
241243
printf(" --high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
242244
printf(" (high noise) sampling method (default: \"euler_a\")\n");
243-
printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: 20)\n");
245+
printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: -1 = auto)\n");
244246
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
245247
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
246248
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20)\n");
@@ -271,6 +273,8 @@ void print_usage(int argc, const char* argv[]) {
271273
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
272274
printf(" --video-frames video frames (default: 1)\n");
273275
printf(" --fps fps (default: 24)\n");
276+
printf(" --moe-boundary BOUNDARY Timestep boundary for Wan2.2 MoE model. (default: 0.875)");
277+
printf(" Only enabled if `--high-noise-steps` is set to -1");
274278
printf(" -v, --verbose print extra info\n");
275279
}
276280

@@ -493,6 +497,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
493497
{"", "--strength", "", &params.strength},
494498
{"", "--style-ratio", "", &params.style_ratio},
495499
{"", "--control-strength", "", &params.control_strength},
500+
{"", "--moe-boundary", "", &params.boundary},
496501
};
497502

498503
options.bool_options = {
@@ -753,8 +758,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
753758
}
754759

755760
if (params.high_noise_sample_params.sample_steps <= 0) {
756-
fprintf(stderr, "error: the high_noise_sample_steps must be greater than 0\n");
757-
exit(1);
761+
params.high_noise_sample_params.sample_steps = -1;
758762
}
759763

760764
if (params.strength < 0.f || params.strength > 1.f) {
@@ -1181,6 +1185,7 @@ int main(int argc, const char* argv[]) {
11811185
params.strength,
11821186
params.seed,
11831187
params.video_frames,
1188+
params.boundary
11841189
};
11851190

11861191
results = generate_video(sd_ctx, &vid_gen_params, &num_results);

stable-diffusion.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,11 +1727,13 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
17271727
memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t));
17281728
sd_sample_params_init(&sd_vid_gen_params->sample_params);
17291729
sd_sample_params_init(&sd_vid_gen_params->high_noise_sample_params);
1730-
sd_vid_gen_params->width = 512;
1731-
sd_vid_gen_params->height = 512;
1732-
sd_vid_gen_params->strength = 0.75f;
1733-
sd_vid_gen_params->seed = -1;
1734-
sd_vid_gen_params->video_frames = 6;
1730+
sd_vid_gen_params->high_noise_sample_params.sample_steps = -1;
1731+
sd_vid_gen_params->width = 512;
1732+
sd_vid_gen_params->height = 512;
1733+
sd_vid_gen_params->strength = 0.75f;
1734+
sd_vid_gen_params->seed = -1;
1735+
sd_vid_gen_params->video_frames = 6;
1736+
sd_vid_gen_params->boundary = 0.875f;
17351737
}
17361738

17371739
struct sd_ctx_t {
@@ -2381,7 +2383,17 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
23812383
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
23822384
}
23832385

2384-
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps + high_noise_sample_steps);
2386+
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps);
2387+
2388+
if(high_noise_sample_steps < 0) {
2389+
// timesteps∝sigmas for Flow models (like wan2.2 a14b)
2390+
for (size_t i = 0; i < sigmas.size(); ++i) {
2391+
if (sigmas[i] < sd_vid_gen_params->boundary) {
2392+
high_noise_sample_steps = i;
2393+
break;
2394+
}
2395+
}
2396+
LOG_DEBUG("Switching from high noise model at step %d", high_noise_sample_steps);
23852397

23862398
struct ggml_init_params params;
23872399
params.mem_size = static_cast<size_t>(100 * 1024) * 1024; // 100 MB

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ typedef struct {
207207
float strength;
208208
int64_t seed;
209209
int video_frames;
210+
float boundary;
210211
} sd_vid_gen_params_t;
211212

212213
typedef struct sd_ctx_t sd_ctx_t;

0 commit comments

Comments
 (0)