@@ -643,7 +643,7 @@ namespace Flux {
643643 Flux () {}
644644 Flux (FluxParams params)
645645 : params(params) {
646- int64_t pe_dim = params.hidden_size / params.num_heads ;
646+ int64_t pe_dim = params.hidden_size / params.num_heads ;
647647
648648 blocks[" img_in" ] = std::shared_ptr<GGMLBlock>(new Linear (params.in_channels , params.hidden_size , true ));
649649 blocks[" time_in" ] = std::shared_ptr<GGMLBlock>(new MLPEmbedder (256 , params.hidden_size ));
@@ -789,6 +789,7 @@ namespace Flux {
789789 struct ggml_tensor * x,
790790 struct ggml_tensor * timestep,
791791 struct ggml_tensor * context,
792+ struct ggml_tensor * c_concat,
792793 struct ggml_tensor * y,
793794 struct ggml_tensor * guidance,
794795 struct ggml_tensor * pe,
@@ -797,6 +798,7 @@ namespace Flux {
797798 // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
798799 // timestep: (N,) tensor of diffusion timesteps
799800 // context: (N, L, D)
801+ // c_concat: NULL, or for (N,C+M, H, W) for Fill
800802 // y: (N, adm_in_channels) tensor of class labels
801803 // guidance: (N,)
802804 // pe: (L, d_head/2, 2, 2)
@@ -806,6 +808,7 @@ namespace Flux {
806808
807809 int64_t W = x->ne [0 ];
808810 int64_t H = x->ne [1 ];
811+ int64_t C = x->ne [2 ];
809812 int64_t patch_size = 2 ;
810813 int pad_h = (patch_size - H % patch_size) % patch_size;
811814 int pad_w = (patch_size - W % patch_size) % patch_size;
@@ -814,6 +817,19 @@ namespace Flux {
814817 // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
815818 auto img = patchify (ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
816819
820+ if (c_concat != NULL ) {
821+ ggml_tensor* masked = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], C, 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], 0 );
822+ ggml_tensor* mask = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], 8 * 8 , 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], c_concat->nb [2 ] * C);
823+
824+ masked = ggml_pad (ctx, masked, pad_w, pad_h, 0 , 0 );
825+ mask = ggml_pad (ctx, mask, pad_w, pad_h, 0 , 0 );
826+
827+ masked = patchify (ctx, masked, patch_size);
828+ mask = patchify (ctx, mask, patch_size);
829+
830+ img = ggml_concat (ctx, img, ggml_cont (ctx, ggml_concat (ctx, masked, mask, 0 )), 0 );
831+ }
832+
817833 auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
818834
819835 // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
@@ -841,7 +857,7 @@ namespace Flux {
841857 flux_params.guidance_embed = false ;
842858 flux_params.depth = 0 ;
843859 flux_params.depth_single_blocks = 0 ;
844- if (version == VERSION_FLUX_INPAINT ) {
860+ if (version == VERSION_FLUX_FILL ) {
845861 flux_params.in_channels = 384 ;
846862 }
847863 for (auto pair : tensor_types) {
@@ -890,14 +906,18 @@ namespace Flux {
890906 struct ggml_cgraph * build_graph (struct ggml_tensor * x,
891907 struct ggml_tensor * timesteps,
892908 struct ggml_tensor * context,
909+ struct ggml_tensor * c_concat,
893910 struct ggml_tensor * y,
894911 struct ggml_tensor * guidance,
895912 std::vector<int > skip_layers = std::vector<int >()) {
896913 GGML_ASSERT (x->ne [3 ] == 1 );
897914 struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, FLUX_GRAPH_SIZE, false );
898915
899- x = to_backend (x);
900- context = to_backend (context);
916+ x = to_backend (x);
917+ context = to_backend (context);
918+ if (c_concat != NULL ) {
919+ c_concat = to_backend (c_concat);
920+ }
901921 y = to_backend (y);
902922 timesteps = to_backend (timesteps);
903923 if (flux_params.guidance_embed ) {
@@ -917,6 +937,7 @@ namespace Flux {
917937 x,
918938 timesteps,
919939 context,
940+ c_concat,
920941 y,
921942 guidance,
922943 pe,
@@ -931,6 +952,7 @@ namespace Flux {
931952 struct ggml_tensor * x,
932953 struct ggml_tensor * timesteps,
933954 struct ggml_tensor * context,
955+ struct ggml_tensor * c_concat,
934956 struct ggml_tensor * y,
935957 struct ggml_tensor * guidance,
936958 struct ggml_tensor ** output = NULL ,
@@ -942,7 +964,7 @@ namespace Flux {
942964 // y: [N, adm_in_channels] or [1, adm_in_channels]
943965 // guidance: [N, ]
944966 auto get_graph = [&]() -> struct ggml_cgraph * {
945- return build_graph(x, timesteps, context, y, guidance, skip_layers);
967+ return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
946968 };
947969
948970 GGMLRunner::compute (get_graph, n_threads, false , output, output_ctx);
@@ -982,7 +1004,7 @@ namespace Flux {
9821004 struct ggml_tensor * out = NULL ;
9831005
9841006 int t0 = ggml_time_ms ();
985- compute (8 , x, timesteps, context, y, guidance, &out, work_ctx);
1007+ compute (8 , x, timesteps, context, NULL , y, guidance, &out, work_ctx);
9861008 int t1 = ggml_time_ms ();
9871009
9881010 print_ggml_tensor (out);
0 commit comments