@@ -490,6 +490,7 @@ namespace Flux {
490490
491491 struct FluxParams {
492492 int64_t in_channels = 64 ;
493+ int64_t out_channels = 64 ;
493494 int64_t vec_in_dim = 768 ;
494495 int64_t context_in_dim = 4096 ;
495496 int64_t hidden_size = 3072 ;
@@ -642,8 +643,7 @@ namespace Flux {
642643 Flux () {}
643644 Flux (FluxParams params)
644645 : params(params) {
645- int64_t out_channels = params.in_channels ;
646- int64_t pe_dim = params.hidden_size / params.num_heads ;
646+ int64_t pe_dim = params.hidden_size / params.num_heads ;
647647
648648 blocks[" img_in" ] = std::shared_ptr<GGMLBlock>(new Linear (params.in_channels , params.hidden_size , true ));
649649 blocks[" time_in" ] = std::shared_ptr<GGMLBlock>(new MLPEmbedder (256 , params.hidden_size ));
@@ -669,7 +669,7 @@ namespace Flux {
669669 params.flash_attn ));
670670 }
671671
672- blocks[" final_layer" ] = std::shared_ptr<GGMLBlock>(new LastLayer (params.hidden_size , 1 , out_channels));
672+ blocks[" final_layer" ] = std::shared_ptr<GGMLBlock>(new LastLayer (params.hidden_size , 1 , params. out_channels ));
673673 }
674674
675675 struct ggml_tensor * patchify (struct ggml_context * ctx,
@@ -789,6 +789,7 @@ namespace Flux {
789789 struct ggml_tensor * x,
790790 struct ggml_tensor * timestep,
791791 struct ggml_tensor * context,
792+ struct ggml_tensor * c_concat,
792793 struct ggml_tensor * y,
793794 struct ggml_tensor * guidance,
794795 struct ggml_tensor * pe,
@@ -797,6 +798,7 @@ namespace Flux {
797798 // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
798799 // timestep: (N,) tensor of diffusion timesteps
799800 // context: (N, L, D)
801+ // c_concat: NULL, or for (N,C+M, H, W) for Fill
800802 // y: (N, adm_in_channels) tensor of class labels
801803 // guidance: (N,)
802804 // pe: (L, d_head/2, 2, 2)
@@ -806,6 +808,7 @@ namespace Flux {
806808
807809 int64_t W = x->ne [0 ];
808810 int64_t H = x->ne [1 ];
811+ int64_t C = x->ne [2 ];
809812 int64_t patch_size = 2 ;
810813 int pad_h = (patch_size - H % patch_size) % patch_size;
811814 int pad_w = (patch_size - W % patch_size) % patch_size;
@@ -814,6 +817,19 @@ namespace Flux {
814817 // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
815818 auto img = patchify (ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
816819
820+ if (c_concat != NULL ) {
821+ ggml_tensor* masked = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], C, 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], 0 );
822+ ggml_tensor* mask = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], 8 * 8 , 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], c_concat->nb [2 ] * C);
823+
824+ masked = ggml_pad (ctx, masked, pad_w, pad_h, 0 , 0 );
825+ mask = ggml_pad (ctx, mask, pad_w, pad_h, 0 , 0 );
826+
827+ masked = patchify (ctx, masked, patch_size);
828+ mask = patchify (ctx, mask, patch_size);
829+
830+ img = ggml_concat (ctx, img, ggml_concat (ctx, masked, mask, 0 ), 0 );
831+ }
832+
817833 auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
818834
819835 // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
@@ -834,12 +850,16 @@ namespace Flux {
834850 FluxRunner (ggml_backend_t backend,
835851 std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
836852 const std::string prefix = " " ,
853+ SDVersion version = VERSION_FLUX,
837854 bool flash_attn = false )
838855 : GGMLRunner(backend) {
839856 flux_params.flash_attn = flash_attn;
840857 flux_params.guidance_embed = false ;
841858 flux_params.depth = 0 ;
842859 flux_params.depth_single_blocks = 0 ;
860+ if (version == VERSION_FLUX_FILL) {
861+ flux_params.in_channels = 384 ;
862+ }
843863 for (auto pair : tensor_types) {
844864 std::string tensor_name = pair.first ;
845865 if (tensor_name.find (" model.diffusion_model." ) == std::string::npos)
@@ -886,14 +906,18 @@ namespace Flux {
886906 struct ggml_cgraph * build_graph (struct ggml_tensor * x,
887907 struct ggml_tensor * timesteps,
888908 struct ggml_tensor * context,
909+ struct ggml_tensor * c_concat,
889910 struct ggml_tensor * y,
890911 struct ggml_tensor * guidance,
891912 std::vector<int > skip_layers = std::vector<int >()) {
892913 GGML_ASSERT (x->ne [3 ] == 1 );
893914 struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, FLUX_GRAPH_SIZE, false );
894915
895- x = to_backend (x);
896- context = to_backend (context);
916+ x = to_backend (x);
917+ context = to_backend (context);
918+ if (c_concat != NULL ) {
919+ c_concat = to_backend (c_concat);
920+ }
897921 y = to_backend (y);
898922 timesteps = to_backend (timesteps);
899923 if (flux_params.guidance_embed ) {
@@ -913,6 +937,7 @@ namespace Flux {
913937 x,
914938 timesteps,
915939 context,
940+ c_concat,
916941 y,
917942 guidance,
918943 pe,
@@ -927,6 +952,7 @@ namespace Flux {
927952 struct ggml_tensor * x,
928953 struct ggml_tensor * timesteps,
929954 struct ggml_tensor * context,
955+ struct ggml_tensor * c_concat,
930956 struct ggml_tensor * y,
931957 struct ggml_tensor * guidance,
932958 struct ggml_tensor ** output = NULL ,
@@ -938,7 +964,7 @@ namespace Flux {
938964 // y: [N, adm_in_channels] or [1, adm_in_channels]
939965 // guidance: [N, ]
940966 auto get_graph = [&]() -> struct ggml_cgraph * {
941- return build_graph(x, timesteps, context, y, guidance, skip_layers);
967+ return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
942968 };
943969
944970 GGMLRunner::compute (get_graph, n_threads, false , output, output_ctx);
@@ -978,7 +1004,7 @@ namespace Flux {
9781004 struct ggml_tensor * out = NULL ;
9791005
9801006 int t0 = ggml_time_ms ();
981- compute (8 , x, timesteps, context, y, guidance, &out, work_ctx);
1007+ compute (8 , x, timesteps, context, NULL , y, guidance, &out, work_ctx);
9821008 int t1 = ggml_time_ms ();
9831009
9841010 print_ggml_tensor (out);
0 commit comments