@@ -615,6 +615,7 @@ namespace Flux {
615615 bool guidance_embed = true ;
616616 bool flash_attn = true ;
617617 bool is_chroma = false ;
618+ SDVersion version = VERSION_FLUX;
618619 };
619620
620621 struct Flux : public GGMLBlock {
@@ -720,6 +721,7 @@ namespace Flux {
720721 auto final_layer = std::dynamic_pointer_cast<LastLayer>(blocks[" final_layer" ]);
721722
722723 img = img_in->forward (ctx, img);
724+
723725 struct ggml_tensor * vec;
724726 struct ggml_tensor * txt_img_mask = NULL ;
725727 if (params.is_chroma ) {
@@ -849,14 +851,36 @@ namespace Flux {
849851 auto img = process_img (ctx, x);
850852 uint64_t img_tokens = img->ne [1 ];
851853
852- if (c_concat != NULL ) {
854+ if (params.version == VERSION_FLUX_FILL) {
855+ GGML_ASSERT (c_concat != NULL );
853856 ggml_tensor* masked = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], C, 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], 0 );
854857 ggml_tensor* mask = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], 8 * 8 , 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], c_concat->nb [2 ] * C);
855858
856859 masked = process_img (ctx, masked);
857860 mask = process_img (ctx, mask);
858861
859862 img = ggml_concat (ctx, img, ggml_concat (ctx, masked, mask, 0 ), 0 );
863+ } else if (params.version == VERSION_FLEX_2) {
864+ GGML_ASSERT (c_concat != NULL );
865+ ggml_tensor* masked = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], C, 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], 0 );
866+ ggml_tensor* mask = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], 1 , 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], c_concat->nb [2 ] * C);
867+ ggml_tensor* control = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], C, 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], c_concat->nb [2 ] * (C + 1 ));
868+
869+ masked = ggml_pad (ctx, masked, pad_w, pad_h, 0 , 0 );
870+ mask = ggml_pad (ctx, mask, pad_w, pad_h, 0 , 0 );
871+ control = ggml_pad (ctx, control, pad_w, pad_h, 0 , 0 );
872+
873+ masked = patchify (ctx, masked, patch_size);
874+ mask = patchify (ctx, mask, patch_size);
875+ control = patchify (ctx, control, patch_size);
876+
877+ img = ggml_concat (ctx, img, ggml_concat (ctx, ggml_concat (ctx, masked, mask, 0 ), control, 0 ), 0 );
878+ } else if (params.version == VERSION_FLUX_CONTROLS) {
879+ GGML_ASSERT (c_concat != NULL );
880+
881+ ggml_tensor* control = ggml_pad (ctx, c_concat, pad_w, pad_h, 0 , 0 );
882+ control = patchify (ctx, control, patch_size);
883+ img = ggml_concat (ctx, img, control, 0 );
860884 }
861885
862886 if (ref_latents.size () > 0 ) {
@@ -867,6 +891,7 @@ namespace Flux {
867891 }
868892
869893 auto out = forward_orig (ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
894+
870895 if (out->ne [1 ] > img_tokens) {
871896 out = ggml_cont (ctx, ggml_permute (ctx, out, 0 , 2 , 1 , 3 )); // [num_tokens, N, C * patch_size * patch_size]
872897 out = ggml_view_3d (ctx, out, out->ne [0 ], out->ne [1 ], img_tokens, out->nb [1 ], out->nb [2 ], 0 );
@@ -896,13 +921,18 @@ namespace Flux {
896921 SDVersion version = VERSION_FLUX,
897922 bool flash_attn = false ,
898923 bool use_mask = false )
899- : GGMLRunner(backend, offload_params_to_cpu), use_mask(use_mask) {
924+ : GGMLRunner(backend, offload_params_to_cpu), version(version), use_mask(use_mask) {
925+ flux_params.version = version;
900926 flux_params.flash_attn = flash_attn;
901927 flux_params.guidance_embed = false ;
902928 flux_params.depth = 0 ;
903929 flux_params.depth_single_blocks = 0 ;
904930 if (version == VERSION_FLUX_FILL) {
905931 flux_params.in_channels = 384 ;
932+ } else if (version == VERSION_FLUX_CONTROLS) {
933+ flux_params.in_channels = 128 ;
934+ } else if (version == VERSION_FLEX_2) {
935+ flux_params.in_channels = 196 ;
906936 }
907937 for (auto pair : tensor_types) {
908938 std::string tensor_name = pair.first ;
0 commit comments