Skip to content

Commit 11f436c

Browse files
authored
feat: add support for Flux Controls and Flex.2 (#692)
1 parent 35843c7 commit 11f436c

File tree

7 files changed

+156
-34
lines changed

7 files changed

+156
-34
lines changed

examples/cli/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1246,7 +1246,7 @@ int main(int argc, const char* argv[]) {
12461246
}
12471247
}
12481248

1249-
if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) {
1249+
if (params.control_image_path.size() > 0) {
12501250
int width = 0;
12511251
int height = 0;
12521252
control_image.data = load_image(params.control_image_path.c_str(), width, height, params.width, params.height);

flux.hpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ namespace Flux {
615615
bool guidance_embed = true;
616616
bool flash_attn = true;
617617
bool is_chroma = false;
618+
SDVersion version = VERSION_FLUX;
618619
};
619620

620621
struct Flux : public GGMLBlock {
@@ -720,6 +721,7 @@ namespace Flux {
720721
auto final_layer = std::dynamic_pointer_cast<LastLayer>(blocks["final_layer"]);
721722

722723
img = img_in->forward(ctx, img);
724+
723725
struct ggml_tensor* vec;
724726
struct ggml_tensor* txt_img_mask = NULL;
725727
if (params.is_chroma) {
@@ -849,14 +851,36 @@ namespace Flux {
849851
auto img = process_img(ctx, x);
850852
uint64_t img_tokens = img->ne[1];
851853

852-
if (c_concat != NULL) {
854+
if (params.version == VERSION_FLUX_FILL) {
855+
GGML_ASSERT(c_concat != NULL);
853856
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
854857
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
855858

856859
masked = process_img(ctx, masked);
857860
mask = process_img(ctx, mask);
858861

859862
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
863+
} else if (params.version == VERSION_FLEX_2) {
864+
GGML_ASSERT(c_concat != NULL);
865+
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
866+
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
867+
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
868+
869+
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
870+
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
871+
control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);
872+
873+
masked = patchify(ctx, masked, patch_size);
874+
mask = patchify(ctx, mask, patch_size);
875+
control = patchify(ctx, control, patch_size);
876+
877+
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
878+
} else if (params.version == VERSION_FLUX_CONTROLS) {
879+
GGML_ASSERT(c_concat != NULL);
880+
881+
ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
882+
control = patchify(ctx, control, patch_size);
883+
img = ggml_concat(ctx, img, control, 0);
860884
}
861885

862886
if (ref_latents.size() > 0) {
@@ -867,6 +891,7 @@ namespace Flux {
867891
}
868892

869893
auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
894+
870895
if (out->ne[1] > img_tokens) {
871896
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
872897
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
@@ -896,13 +921,18 @@ namespace Flux {
896921
SDVersion version = VERSION_FLUX,
897922
bool flash_attn = false,
898923
bool use_mask = false)
899-
: GGMLRunner(backend, offload_params_to_cpu), use_mask(use_mask) {
924+
: GGMLRunner(backend, offload_params_to_cpu), version(version), use_mask(use_mask) {
925+
flux_params.version = version;
900926
flux_params.flash_attn = flash_attn;
901927
flux_params.guidance_embed = false;
902928
flux_params.depth = 0;
903929
flux_params.depth_single_blocks = 0;
904930
if (version == VERSION_FLUX_FILL) {
905931
flux_params.in_channels = 384;
932+
} else if (version == VERSION_FLUX_CONTROLS) {
933+
flux_params.in_channels = 128;
934+
} else if (version == VERSION_FLEX_2) {
935+
flux_params.in_channels = 196;
906936
}
907937
for (auto pair : tensor_types) {
908938
std::string tensor_name = pair.first;

ggml_extend.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,18 +428,24 @@ __STATIC_INLINE__ void sd_image_to_tensor(sd_image_t image,
428428

429429
__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
430430
struct ggml_tensor* mask,
431-
struct ggml_tensor* output) {
431+
struct ggml_tensor* output,
432+
float masked_value = 0.5f) {
432433
int64_t width = output->ne[0];
433434
int64_t height = output->ne[1];
434435
int64_t channels = output->ne[2];
436+
float rescale_mx = mask->ne[0] / output->ne[0];
437+
float rescale_my = mask->ne[1] / output->ne[1];
435438
GGML_ASSERT(output->type == GGML_TYPE_F32);
436439
for (int ix = 0; ix < width; ix++) {
437440
for (int iy = 0; iy < height; iy++) {
438-
float m = ggml_tensor_get_f32(mask, ix, iy);
441+
int mx = (int)(ix * rescale_mx);
442+
int my = (int)(iy * rescale_my);
443+
float m = ggml_tensor_get_f32(mask, mx, my);
439444
m = round(m); // inpaint models need binary masks
440-
ggml_tensor_set_f32(mask, m, ix, iy);
445+
ggml_tensor_set_f32(mask, m, mx, my);
441446
for (int k = 0; k < channels; k++) {
442-
float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
447+
float value = ggml_tensor_get_f32(image_data, ix, iy, k);
448+
value = (1 - m) * (value - masked_value) + masked_value;
443449
ggml_tensor_set_f32(output, value, ix, iy, k);
444450
}
445451
}

model.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,10 +1803,15 @@ SDVersion ModelLoader::get_sd_version() {
18031803
}
18041804

18051805
if (is_flux) {
1806-
is_inpaint = input_block_weight.ne[0] == 384;
1807-
if (is_inpaint) {
1806+
if (input_block_weight.ne[0] == 384) {
18081807
return VERSION_FLUX_FILL;
18091808
}
1809+
if (input_block_weight.ne[0] == 128) {
1810+
return VERSION_FLUX_CONTROLS;
1811+
}
1812+
if (input_block_weight.ne[0] == 196) {
1813+
return VERSION_FLEX_2;
1814+
}
18101815
return VERSION_FLUX;
18111816
}
18121817

model.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ enum SDVersion {
3131
VERSION_SD3,
3232
VERSION_FLUX,
3333
VERSION_FLUX_FILL,
34+
VERSION_FLUX_CONTROLS,
35+
VERSION_FLEX_2,
3436
VERSION_WAN2,
3537
VERSION_WAN2_2_I2V,
3638
VERSION_WAN2_2_TI2V,
@@ -66,7 +68,7 @@ static inline bool sd_version_is_sd3(SDVersion version) {
6668
}
6769

6870
static inline bool sd_version_is_flux(SDVersion version) {
69-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
71+
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2) {
7072
return true;
7173
}
7274
return false;
@@ -80,7 +82,7 @@ static inline bool sd_version_is_wan(SDVersion version) {
8082
}
8183

8284
static inline bool sd_version_is_inpaint(SDVersion version) {
83-
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
85+
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
8486
return true;
8587
}
8688
return false;
@@ -97,8 +99,12 @@ static inline bool sd_version_is_unet_edit(SDVersion version) {
9799
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
98100
}
99101

102+
static inline bool sd_version_is_control(SDVersion version) {
103+
return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2;
104+
}
105+
100106
static bool sd_version_is_inpaint_or_unet_edit(SDVersion version) {
101-
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version);
107+
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version) || sd_version_is_control(version);
102108
}
103109

104110
enum PMVersion {

0 commit comments

Comments
 (0)