@@ -490,6 +490,7 @@ namespace Flux {
490490
491491 struct FluxParams {
492492 int64_t in_channels = 64 ;
493+ int64_t out_channels = 64 ;
493494 int64_t vec_in_dim = 768 ;
494495 int64_t context_in_dim = 4096 ;
495496 int64_t hidden_size = 3072 ;
@@ -642,7 +643,6 @@ namespace Flux {
642643 Flux () {}
643644 Flux (FluxParams params)
644645 : params(params) {
645- int64_t out_channels = params.in_channels ;
646646 int64_t pe_dim = params.hidden_size / params.num_heads ;
647647
648648 blocks[" img_in" ] = std::shared_ptr<GGMLBlock>(new Linear (params.in_channels , params.hidden_size , true ));
@@ -669,7 +669,7 @@ namespace Flux {
669669 params.flash_attn ));
670670 }
671671
672- blocks[" final_layer" ] = std::shared_ptr<GGMLBlock>(new LastLayer (params.hidden_size , 1 , out_channels));
672+ blocks[" final_layer" ] = std::shared_ptr<GGMLBlock>(new LastLayer (params.hidden_size , 1 , params. out_channels ));
673673 }
674674
675675 struct ggml_tensor * patchify (struct ggml_context * ctx,
@@ -834,12 +834,16 @@ namespace Flux {
834834 FluxRunner (ggml_backend_t backend,
835835 std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
836836 const std::string prefix = " " ,
837+ SDVersion version = VERSION_FLUX,
837838 bool flash_attn = false )
838839 : GGMLRunner(backend) {
839840 flux_params.flash_attn = flash_attn;
840841 flux_params.guidance_embed = false ;
841842 flux_params.depth = 0 ;
842843 flux_params.depth_single_blocks = 0 ;
844+ if (version == VERSION_FLUX_INPAINT) {
845+ flux_params.in_channels = 384 ;
846+ }
843847 for (auto pair : tensor_types) {
844848 std::string tensor_name = pair.first ;
845849 if (tensor_name.find (" model.diffusion_model." ) == std::string::npos)
0 commit comments