@@ -2532,14 +2532,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
25322532        sd_image_to_ggml_tensor (sd_img_gen_params->mask_image , mask_img);
25332533        sd_image_to_ggml_tensor (sd_img_gen_params->init_image , init_img);
25342534
2535-         init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
2536- 
25372535        if  (sd_version_is_inpaint (sd_ctx->sd ->version )) {
25382536            int64_t  mask_channels = 1 ;
25392537            if  (sd_ctx->sd ->version  == VERSION_FLUX_FILL) {
2540-                 mask_channels = 8  * 8 ;  //  flatten the whole mask
2538+                 mask_channels = vae_scale_factor  * vae_scale_factor ;  //  flatten the whole mask
25412539            } else  if  (sd_ctx->sd ->version  == VERSION_FLEX_2) {
2542-                 mask_channels = 1  + init_latent-> ne [ 2 ] ;
2540+                 mask_channels = 1  + sd_ctx-> sd -> get_latent_channel () ;
25432541            }
25442542            ggml_tensor* masked_latent = nullptr ;
25452543
@@ -2548,8 +2546,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
25482546                ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width, height, 3 , 1 );
25492547                ggml_ext_tensor_apply_mask (init_img, mask_img, masked_img);
25502548                masked_latent = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
2549+                 init_latent   = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
25512550            } else  {
25522551                //  mask after vae
2552+                 init_latent   = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
25532553                masked_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], init_latent->ne [2 ], 1 );
25542554                ggml_ext_tensor_apply_mask (init_latent, mask_img, masked_latent, 0 .);
25552555            }
@@ -2590,9 +2590,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
25902590                        for  (int  k = 0 ; k < masked_latent->ne [2 ]; k++) {
25912591                            ggml_ext_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent->ne [2 ] + 1  + k);
25922592                        }
2593+                     } else  {
2594+                         float  m = ggml_ext_tensor_get_f32 (mask_img, mx, my);
2595+                         ggml_ext_tensor_set_f32 (concat_latent, m, ix, iy, 0 );
2596+                         for  (int  k = 0 ; k < masked_latent->ne [2 ]; k++) {
2597+                             float  v = ggml_ext_tensor_get_f32 (masked_latent, ix, iy, k);
2598+                             ggml_ext_tensor_set_f32 (concat_latent, v, ix, iy, k + mask_channels);
2599+                         }
25932600                    }
25942601                }
25952602            }
2603+         } else  {
2604+             init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
25962605        }
25972606
25982607        {
0 commit comments