Skip to content

Commit 46698f2

Browse files
stduhpfleejet
authored andcommitted
Add flow shift parameter (for SD3 and Wan)
1 parent 21ce9fe commit 46698f2

File tree

4 files changed

+22
-5
lines changed

4 files changed

+22
-5
lines changed

denoiser.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ struct DiscreteFlowDenoiser : public Denoiser {
382382

383383
float sigma_data = 1.0f;
384384

385-
DiscreteFlowDenoiser() {
385+
DiscreteFlowDenoiser(float shift = 3.0f) : shift(shift) {
386386
set_parameters();
387387
}
388388

examples/cli/main.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ struct SDParams {
115115
bool chroma_use_dit_mask = true;
116116
bool chroma_use_t5_mask = false;
117117
int chroma_t5_mask_pad = 1;
118+
float flow_shift = INFINITY;
118119

119120
SDParams() {
120121
sd_sample_params_init(&sample_params);
@@ -278,8 +279,9 @@ void print_usage(int argc, const char* argv[]) {
278279
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
279280
printf(" --video-frames video frames (default: 1)\n");
280281
printf(" --fps fps (default: 24)\n");
281-
printf(" --moe-boundary BOUNDARY Timestep boundary for Wan2.2 MoE model. (default: 0.875)\n");
282-
printf(" Only enabled if `--high-noise-steps` is set to -1\n");
282+
printf(" --moe-boundary BOUNDARY timestep boundary for Wan2.2 MoE model. (default: 0.875)\n");
283+
printf(" only enabled if `--high-noise-steps` is set to -1\n");
284+
printf(" --flow-shift SHIFT shift value for Flow models like SD3.x or WAN (default: auto)\n");
283285
printf(" -v, --verbose print extra info\n");
284286
}
285287

@@ -514,6 +516,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
514516
{"", "--style-ratio", "", &params.style_ratio},
515517
{"", "--control-strength", "", &params.control_strength},
516518
{"", "--moe-boundary", "", &params.moe_boundary},
519+
{"", "--flow-shift", "", &params.flow_shift},
517520
};
518521

519522
options.bool_options = {
@@ -1181,6 +1184,7 @@ int main(int argc, const char* argv[]) {
11811184
params.chroma_use_dit_mask,
11821185
params.chroma_use_t5_mask,
11831186
params.chroma_t5_mask_pad,
1187+
params.flow_shift,
11841188
};
11851189

11861190
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);

stable-diffusion.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,11 @@ class StableDiffusionGGML {
681681

682682
if (sd_version_is_sd3(version)) {
683683
LOG_INFO("running in FLOW mode");
684-
denoiser = std::make_shared<DiscreteFlowDenoiser>();
684+
float shift = sd_ctx_params->flow_shift;
685+
if( shift == INFINITY){
686+
shift = 3.0;
687+
}
688+
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
685689
} else if (sd_version_is_flux(version)) {
686690
LOG_INFO("running in Flux FLOW mode");
687691
float shift = 1.0f; // TODO: validate
@@ -694,7 +698,14 @@ class StableDiffusionGGML {
694698
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
695699
} else if (sd_version_is_wan(version)) {
696700
LOG_INFO("running in FLOW mode");
697-
denoiser = std::make_shared<DiscreteFlowDenoiser>();
701+
float shift = sd_ctx_params->flow_shift;
702+
if(shift == INFINITY) {
703+
shift = 5.0;
704+
if (version == VERSION_WAN2){
705+
shift = 12.0;
706+
}
707+
}
708+
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
698709
} else if (is_using_v_parameterization) {
699710
LOG_INFO("running in v-prediction mode");
700711
denoiser = std::make_shared<CompVisVDenoiser>();
@@ -1553,6 +1564,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
15531564
sd_ctx_params->chroma_use_dit_mask = true;
15541565
sd_ctx_params->chroma_use_t5_mask = false;
15551566
sd_ctx_params->chroma_t5_mask_pad = 1;
1567+
sd_ctx_params->flow_shift = INFINITY;
15561568
}
15571569

15581570
char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ typedef struct {
142142
bool chroma_use_dit_mask;
143143
bool chroma_use_t5_mask;
144144
int chroma_t5_mask_pad;
145+
float flow_shift;
145146
} sd_ctx_params_t;
146147

147148
typedef struct {

0 commit comments

Comments
 (0)