@@ -643,7 +643,7 @@ namespace Flux {
643643 Flux () {}
644644 Flux (FluxParams params)
645645 : params(params) {
646- int64_t pe_dim = params.hidden_size / params.num_heads ;
646+ int64_t pe_dim = params.hidden_size / params.num_heads ;
647647
648648 blocks[" img_in" ] = std::shared_ptr<GGMLBlock>(new Linear (params.in_channels , params.hidden_size , true ));
649649 blocks[" time_in" ] = std::shared_ptr<GGMLBlock>(new MLPEmbedder (256 , params.hidden_size ));
@@ -789,6 +789,7 @@ namespace Flux {
789789 struct ggml_tensor * x,
790790 struct ggml_tensor * timestep,
791791 struct ggml_tensor * context,
792+ struct ggml_tensor * c_concat,
792793 struct ggml_tensor * y,
793794 struct ggml_tensor * guidance,
794795 struct ggml_tensor * pe,
@@ -797,15 +798,18 @@ namespace Flux {
797798 // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
798799 // timestep: (N,) tensor of diffusion timesteps
799800 // context: (N, L, D)
801+ // c_concat: NULL, or for (N,C+M, H, W) for Fill
800802 // y: (N, adm_in_channels) tensor of class labels
801803 // guidance: (N,)
802804 // pe: (L, d_head/2, 2, 2)
803805 // return: (N, C, H, W)
804806
805807 GGML_ASSERT (x->ne [3 ] == 1 );
806808
809+
807810 int64_t W = x->ne [0 ];
808811 int64_t H = x->ne [1 ];
812+ int64_t C = x->ne [2 ];
809813 int64_t patch_size = 2 ;
810814 int pad_h = (patch_size - H % patch_size) % patch_size;
811815 int pad_w = (patch_size - W % patch_size) % patch_size;
@@ -814,6 +818,21 @@ namespace Flux {
814818 // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
815819 auto img = patchify (ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
816820
821+ if (c_concat != NULL ) {
822+ ggml_tensor* masked = ggml_cont (ctx,
823+ ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], C, 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], c_concat->nb [0 ] * 0 ));
824+ ggml_tensor* mask = ggml_cont (ctx,
825+ ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], 8 * 8 , 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], c_concat->nb [2 ] * C));
826+
827+ masked = ggml_pad (ctx, masked, pad_w, pad_h, 0 , 0 );
828+ mask = ggml_pad (ctx, mask, pad_w, pad_h, 0 , 0 );
829+
830+ masked = patchify (ctx, masked, patch_size);
831+ mask = patchify (ctx, mask, patch_size);
832+
833+ img = ggml_concat (ctx, img, ggml_cont (ctx, ggml_concat (ctx, masked, mask, 0 )), 0 );
834+ }
835+
817836 auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
818837
819838 // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
@@ -841,7 +860,7 @@ namespace Flux {
841860 flux_params.guidance_embed = false ;
842861 flux_params.depth = 0 ;
843862 flux_params.depth_single_blocks = 0 ;
844- if (version == VERSION_FLUX_INPAINT ) {
863+ if (version == VERSION_FLUX_FILL ) {
845864 flux_params.in_channels = 384 ;
846865 }
847866 for (auto pair : tensor_types) {
@@ -890,14 +909,18 @@ namespace Flux {
890909 struct ggml_cgraph * build_graph (struct ggml_tensor * x,
891910 struct ggml_tensor * timesteps,
892911 struct ggml_tensor * context,
912+ struct ggml_tensor * c_concat,
893913 struct ggml_tensor * y,
894914 struct ggml_tensor * guidance,
895915 std::vector<int > skip_layers = std::vector<int >()) {
896916 GGML_ASSERT (x->ne [3 ] == 1 );
897917 struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, FLUX_GRAPH_SIZE, false );
898918
899- x = to_backend (x);
900- context = to_backend (context);
919+ x = to_backend (x);
920+ context = to_backend (context);
921+ if (c_concat != NULL ) {
922+ c_concat = to_backend (c_concat);
923+ }
901924 y = to_backend (y);
902925 timesteps = to_backend (timesteps);
903926 if (flux_params.guidance_embed ) {
@@ -917,6 +940,7 @@ namespace Flux {
917940 x,
918941 timesteps,
919942 context,
943+ c_concat,
920944 y,
921945 guidance,
922946 pe,
@@ -931,6 +955,7 @@ namespace Flux {
931955 struct ggml_tensor * x,
932956 struct ggml_tensor * timesteps,
933957 struct ggml_tensor * context,
958+ struct ggml_tensor * c_concat,
934959 struct ggml_tensor * y,
935960 struct ggml_tensor * guidance,
936961 struct ggml_tensor ** output = NULL ,
@@ -942,7 +967,7 @@ namespace Flux {
942967 // y: [N, adm_in_channels] or [1, adm_in_channels]
943968 // guidance: [N, ]
944969 auto get_graph = [&]() -> struct ggml_cgraph * {
945- return build_graph(x, timesteps, context, y, guidance, skip_layers);
970+ return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
946971 };
947972
948973 GGMLRunner::compute (get_graph, n_threads, false , output, output_ctx);
@@ -982,7 +1007,7 @@ namespace Flux {
9821007 struct ggml_tensor * out = NULL ;
9831008
9841009 int t0 = ggml_time_ms ();
985- compute (8 , x, timesteps, context, y, guidance, &out, work_ctx);
1010+ compute (8 , x, timesteps, context, NULL , y, guidance, &out, work_ctx);
9861011 int t1 = ggml_time_ms ();
9871012
9881013 print_ggml_tensor (out);
0 commit comments