diff --git a/clip.hpp b/clip.hpp index dc891c77..eb37638c 100644 --- a/clip.hpp +++ b/clip.hpp @@ -476,11 +476,12 @@ struct CLIPLayer : public GGMLBlock { public: CLIPLayer(int64_t d_model, int64_t n_head, - int64_t intermediate_size) + int64_t intermediate_size, + bool proj_in = false) : d_model(d_model), n_head(n_head), intermediate_size(intermediate_size) { - blocks["self_attn"] = std::shared_ptr(new MultiheadAttention(d_model, n_head, true, true)); + blocks["self_attn"] = std::shared_ptr(new MultiheadAttention(d_model, n_head, true, true, proj_in)); blocks["layer_norm1"] = std::shared_ptr(new LayerNorm(d_model)); blocks["layer_norm2"] = std::shared_ptr(new LayerNorm(d_model)); @@ -509,11 +510,12 @@ struct CLIPEncoder : public GGMLBlock { CLIPEncoder(int64_t n_layer, int64_t d_model, int64_t n_head, - int64_t intermediate_size) + int64_t intermediate_size, + bool proj_in = false) : n_layer(n_layer) { for (int i = 0; i < n_layer; i++) { std::string name = "layers." + std::to_string(i); - blocks[name] = std::shared_ptr(new CLIPLayer(d_model, n_head, intermediate_size)); + blocks[name] = std::shared_ptr(new CLIPLayer(d_model, n_head, intermediate_size, proj_in)); } } @@ -549,10 +551,10 @@ class CLIPEmbeddings : public GGMLBlock { int64_t num_positions; bool force_clip_f32; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { enum ggml_type token_wtype = GGML_TYPE_F32; if (!force_clip_f32) { - token_wtype = get_type(prefix + "token_embedding.weight", tensor_types, GGML_TYPE_F32); + token_wtype = get_type(prefix + "token_embedding.weight", tensor_storage_map, GGML_TYPE_F32); if (!support_get_rows(token_wtype)) { token_wtype = GGML_TYPE_F32; } @@ -605,7 +607,8 @@ class CLIPVisionEmbeddings : public GGMLBlock { int64_t image_size; int64_t num_patches; int64_t num_positions; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { + + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { enum ggml_type patch_wtype = GGML_TYPE_F16; enum ggml_type class_wtype = GGML_TYPE_F32; enum ggml_type position_wtype = GGML_TYPE_F32; @@ -668,7 +671,7 @@ enum CLIPVersion { class CLIPTextModel : public GGMLBlock { protected: - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { if (version == OPEN_CLIP_VIT_BIGG_14) { enum ggml_type wtype = GGML_TYPE_F32; params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size); @@ -689,7 +692,8 @@ class CLIPTextModel : public GGMLBlock { CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14, bool with_final_ln = true, - bool force_clip_f32 = false) + bool force_clip_f32 = false, + bool proj_in = false) : version(version), with_final_ln(with_final_ln) { if (version == OPEN_CLIP_VIT_H_14) { hidden_size = 1024; @@ -704,7 +708,7 @@ class CLIPTextModel : public GGMLBlock { } blocks["embeddings"] = std::shared_ptr(new CLIPEmbeddings(hidden_size, vocab_size, n_token, force_clip_f32)); - blocks["encoder"] = std::shared_ptr(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size)); + blocks["encoder"] = std::shared_ptr(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size, proj_in)); blocks["final_layer_norm"] = std::shared_ptr(new LayerNorm(hidden_size)); } @@ -758,7 +762,7 @@ class CLIPVisionModel : public GGMLBlock { int32_t n_layer = 24; public: - CLIPVisionModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14) { + CLIPVisionModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14, bool proj_in = false) { if (version == OPEN_CLIP_VIT_H_14) { hidden_size = 1280; intermediate_size = 5120; @@ -773,7 +777,7 @@ class CLIPVisionModel : public GGMLBlock { blocks["embeddings"] = std::shared_ptr(new CLIPVisionEmbeddings(hidden_size, num_channels, patch_size, image_size)); blocks["pre_layernorm"] = std::shared_ptr(new LayerNorm(hidden_size)); - blocks["encoder"] = std::shared_ptr(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size)); + blocks["encoder"] = std::shared_ptr(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size, proj_in)); blocks["post_layernorm"] = std::shared_ptr(new LayerNorm(hidden_size)); } @@ -811,8 +815,8 @@ class CLIPProjection : public UnaryBlock { int64_t out_features; bool transpose_weight; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { - enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); if (transpose_weight) { params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features); } else { @@ -845,7 +849,8 @@ class CLIPVisionModelProjection : public GGMLBlock { public: CLIPVisionModelProjection(CLIPVersion version = OPENAI_CLIP_VIT_L_14, - bool transpose_proj_w = false) { + bool transpose_proj_w = false, + bool proj_in = false) { if (version == OPEN_CLIP_VIT_H_14) { hidden_size = 1280; projection_dim = 1024; @@ -853,7 +858,7 @@ class CLIPVisionModelProjection : public GGMLBlock { hidden_size = 1664; } - blocks["vision_model"] = std::shared_ptr(new CLIPVisionModel(version)); + blocks["vision_model"] = std::shared_ptr(new CLIPVisionModel(version, proj_in)); blocks["visual_projection"] = std::shared_ptr(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w)); } @@ -881,13 +886,24 @@ struct CLIPTextModelRunner : public GGMLRunner { CLIPTextModelRunner(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types, + const String2TensorStorage& tensor_storage_map, const std::string prefix, CLIPVersion version = OPENAI_CLIP_VIT_L_14, bool with_final_ln = true, bool force_clip_f32 = false) - : GGMLRunner(backend, offload_params_to_cpu), model(version, with_final_ln, force_clip_f32) { - model.init(params_ctx, tensor_types, prefix); + : GGMLRunner(backend, offload_params_to_cpu) { + bool proj_in = false; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (contains(name, "self_attn.in_proj")) { + proj_in = true; + break; + } + } + model = CLIPTextModel(version, with_final_ln, force_clip_f32, proj_in); + model.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { diff --git a/common.hpp b/common.hpp index 03c931bd..59540752 100644 --- a/common.hpp +++ b/common.hpp @@ -182,8 +182,8 @@ class GEGLU : public UnaryBlock { int64_t dim_in; int64_t dim_out; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override { - enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_types, GGML_TYPE_F32); + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { + enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_storage_map, GGML_TYPE_F32); enum ggml_type bias_wtype = GGML_TYPE_F32; params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2); params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2); @@ -408,30 +408,40 @@ class SpatialTransformer : public GGMLBlock { int64_t d_head; int64_t depth = 1; // 1 int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2 + bool use_linear = false; public: SpatialTransformer(int64_t in_channels, int64_t n_head, int64_t d_head, int64_t depth, - int64_t context_dim) + int64_t context_dim, + bool use_linear) : in_channels(in_channels), n_head(n_head), d_head(d_head), depth(depth), - context_dim(context_dim) { - // We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False + context_dim(context_dim), + use_linear(use_linear) { // disable_self_attn is always False int64_t inner_dim = n_head * d_head; // in_channels blocks["norm"] = std::shared_ptr(new GroupNorm32(in_channels)); - blocks["proj_in"] = std::shared_ptr(new Conv2d(in_channels, inner_dim, {1, 1})); + if (use_linear) { + blocks["proj_in"] = std::shared_ptr(new Linear(in_channels, inner_dim)); + } else { + blocks["proj_in"] = std::shared_ptr(new Conv2d(in_channels, inner_dim, {1, 1})); + } for (int i = 0; i < depth; i++) { std::string name = "transformer_blocks." + std::to_string(i); blocks[name] = std::shared_ptr(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false)); } - blocks["proj_out"] = std::shared_ptr(new Conv2d(inner_dim, in_channels, {1, 1})); + if (use_linear) { + blocks["proj_out"] = std::shared_ptr(new Linear(inner_dim, in_channels)); + } else { + blocks["proj_out"] = std::shared_ptr(new Conv2d(inner_dim, in_channels, {1, 1})); + } } virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, @@ -440,8 +450,8 @@ class SpatialTransformer : public GGMLBlock { // x: [N, in_channels, h, w] // context: [N, max_position(aka n_token), hidden_size(aka context_dim)] auto norm = std::dynamic_pointer_cast(blocks["norm"]); - auto proj_in = std::dynamic_pointer_cast(blocks["proj_in"]); - auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + auto proj_in = std::dynamic_pointer_cast(blocks["proj_in"]); + auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); auto x_in = x; int64_t n = x->ne[3]; @@ -450,10 +460,15 @@ class SpatialTransformer : public GGMLBlock { int64_t inner_dim = n_head * d_head; x = norm->forward(ctx, x); - x = proj_in->forward(ctx, x); // [N, inner_dim, h, w] - - x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim] - x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim] + if (use_linear) { + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim] + x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim] + x = proj_in->forward(ctx, x); // [N, inner_dim, h, w] + } else { + x = proj_in->forward(ctx, x); // [N, inner_dim, h, w] + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim] + x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim] + } for (int i = 0; i < depth; i++) { std::string name = "transformer_blocks." + std::to_string(i); @@ -462,11 +477,19 @@ class SpatialTransformer : public GGMLBlock { x = transformer_block->forward(ctx, x, context); } - x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w] - x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w] + if (use_linear) { + // proj_out + x = proj_out->forward(ctx, x); // [N, in_channels, h, w] - // proj_out - x = proj_out->forward(ctx, x); // [N, in_channels, h, w] + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w] + x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w] + } else { + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w] + x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w] + + // proj_out + x = proj_out->forward(ctx, x); // [N, in_channels, h, w] + } x = ggml_add(ctx->ggml_ctx, x, x_in); return x; @@ -475,7 +498,7 @@ class SpatialTransformer : public GGMLBlock { class AlphaBlender : public GGMLBlock { protected: - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { // Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix enum ggml_type wtype = GGML_TYPE_F32; params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); diff --git a/conditioner.hpp b/conditioner.hpp index 86cdfb87..b7d80595 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -63,19 +63,19 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types, + const String2TensorStorage& tensor_storage_map, const std::string& embd_dir, SDVersion version = VERSION_SD1, PMVersion pv = PM_VERSION_1) : version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) { bool force_clip_f32 = embd_dir.size() > 0; if (sd_version_is_sd1(version)) { - text_model = std::make_shared(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32); + text_model = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32); } else if (sd_version_is_sd2(version)) { - text_model = std::make_shared(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, true, force_clip_f32); + text_model = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, true, force_clip_f32); } else if (sd_version_is_sdxl(version)) { - text_model = std::make_shared(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false, force_clip_f32); - text_model2 = std::make_shared(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false, force_clip_f32); + text_model = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false, force_clip_f32); + text_model2 = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false, force_clip_f32); } } @@ -623,9 +623,21 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner { FrozenCLIPVisionEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}) - : vision_model(OPEN_CLIP_VIT_H_14), GGMLRunner(backend, offload_params_to_cpu) { - vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer"); + const String2TensorStorage& tensor_storage_map = {}) + : GGMLRunner(backend, offload_params_to_cpu) { + std::string prefix = "cond_stage_model.transformer"; + bool proj_in = false; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (contains(name, "self_attn.in_proj")) { + proj_in = true; + break; + } + } + vision_model = CLIPVisionModelProjection(OPEN_CLIP_VIT_H_14, false, proj_in); + vision_model.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { @@ -673,12 +685,12 @@ struct SD3CLIPEmbedder : public Conditioner { SD3CLIPEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}) + const String2TensorStorage& tensor_storage_map = {}) : clip_g_tokenizer(0) { bool use_clip_l = false; bool use_clip_g = false; bool use_t5 = false; - for (auto pair : tensor_types) { + for (auto pair : tensor_storage_map) { if (pair.first.find("text_encoders.clip_l") != std::string::npos) { use_clip_l = true; } else if (pair.first.find("text_encoders.clip_g") != std::string::npos) { @@ -692,13 +704,13 @@ struct SD3CLIPEmbedder : public Conditioner { return; } if (use_clip_l) { - clip_l = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false); + clip_l = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false); } if (use_clip_g) { - clip_g = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false); + clip_g = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false); } if (use_t5) { - t5 = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer"); + t5 = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.t5xxl.transformer"); } } @@ -1082,10 +1094,10 @@ struct FluxCLIPEmbedder : public Conditioner { FluxCLIPEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}) { + const String2TensorStorage& tensor_storage_map = {}) { bool use_clip_l = false; bool use_t5 = false; - for (auto pair : tensor_types) { + for (auto pair : tensor_storage_map) { if (pair.first.find("text_encoders.clip_l") != std::string::npos) { use_clip_l = true; } else if (pair.first.find("text_encoders.t5xxl") != std::string::npos) { @@ -1099,12 +1111,12 @@ struct FluxCLIPEmbedder : public Conditioner { } if (use_clip_l) { - clip_l = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true); + clip_l = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true); } else { LOG_WARN("clip_l text encoder not found! Prompt adherence might be degraded."); } if (use_t5) { - t5 = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer"); + t5 = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.t5xxl.transformer"); } else { LOG_WARN("t5xxl text encoder not found! Prompt adherence might be degraded."); } @@ -1342,13 +1354,13 @@ struct T5CLIPEmbedder : public Conditioner { T5CLIPEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - bool use_mask = false, - int mask_pad = 1, - bool is_umt5 = false) + const String2TensorStorage& tensor_storage_map = {}, + bool use_mask = false, + int mask_pad = 1, + bool is_umt5 = false) : use_mask(use_mask), mask_pad(mask_pad), t5_tokenizer(is_umt5) { bool use_t5 = false; - for (auto pair : tensor_types) { + for (auto pair : tensor_storage_map) { if (pair.first.find("text_encoders.t5xxl") != std::string::npos) { use_t5 = true; } @@ -1358,7 +1370,7 @@ struct T5CLIPEmbedder : public Conditioner { LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!"); return; } else { - t5 = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer", is_umt5); + t5 = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.t5xxl.transformer", is_umt5); } } @@ -1549,12 +1561,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - const std::string prefix = "", - bool enable_vision = false) { + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "", + bool enable_vision = false) { qwenvl = std::make_shared(backend, offload_params_to_cpu, - tensor_types, + tensor_storage_map, "text_encoders.qwen2vl", enable_vision); } diff --git a/control.hpp b/control.hpp index 72886dd0..856bde81 100644 --- a/control.hpp +++ b/control.hpp @@ -27,6 +27,7 @@ class ControlNetBlock : public GGMLBlock { int num_heads = 8; int num_head_channels = -1; // channels // num_heads int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL + bool use_linear_projection = false; public: int model_channels = 320; @@ -82,7 +83,7 @@ class ControlNetBlock : public GGMLBlock { int64_t d_head, int64_t depth, int64_t context_dim) -> SpatialTransformer* { - return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim); + return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection); }; auto make_zero_conv = [&](int64_t channels) { @@ -318,10 +319,10 @@ struct ControlNet : public GGMLRunner { ControlNet(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - SDVersion version = VERSION_SD1) + const String2TensorStorage& tensor_storage_map = {}, + SDVersion version = VERSION_SD1) : GGMLRunner(backend, offload_params_to_cpu), control_net(version) { - control_net.init(params_ctx, tensor_types, ""); + control_net.init(params_ctx, tensor_storage_map, ""); } ~ControlNet() override { diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 651f7a45..30704981 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -44,9 +44,9 @@ struct UNetModel : public DiffusionModel { UNetModel(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - SDVersion version = VERSION_SD1) - : unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version) { + const String2TensorStorage& tensor_storage_map = {}, + SDVersion version = VERSION_SD1) + : unet(backend, offload_params_to_cpu, tensor_storage_map, "model.diffusion_model", version) { } std::string get_desc() override { @@ -102,8 +102,8 @@ struct MMDiTModel : public DiffusionModel { MMDiTModel(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}) - : mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") { + const String2TensorStorage& tensor_storage_map = {}) + : mmdit(backend, offload_params_to_cpu, tensor_storage_map, "model.diffusion_model") { } std::string get_desc() override { @@ -158,10 +158,10 @@ struct FluxModel : public DiffusionModel { FluxModel(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - SDVersion version = VERSION_FLUX, - bool use_mask = false) - : flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, use_mask) { + const String2TensorStorage& tensor_storage_map = {}, + SDVersion version = VERSION_FLUX, + bool use_mask = false) + : flux(backend, offload_params_to_cpu, tensor_storage_map, "model.diffusion_model", version, use_mask) { } std::string get_desc() override { @@ -221,10 +221,10 @@ struct WanModel : public DiffusionModel { WanModel(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - const std::string prefix = "model.diffusion_model", - SDVersion version = VERSION_WAN2) - : prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version) { + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model", + SDVersion version = VERSION_WAN2) + : prefix(prefix), wan(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) { } std::string get_desc() override { @@ -283,10 +283,10 @@ struct QwenImageModel : public DiffusionModel { QwenImageModel(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - const std::string prefix = "model.diffusion_model", - SDVersion version = VERSION_QWEN_IMAGE) - : prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version) { + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model", + SDVersion version = VERSION_QWEN_IMAGE) + : prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) { } std::string get_desc() override { diff --git a/esrgan.hpp b/esrgan.hpp index 5a24436c..dd112439 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -156,7 +156,7 @@ struct ESRGAN : public GGMLRunner { ESRGAN(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}) + const String2TensorStorage& tensor_storage_map = {}) : GGMLRunner(backend, offload_params_to_cpu) { // rrdb_net will be created in load_from_file } diff --git a/flux.hpp b/flux.hpp index 9dd2c9f7..95927f8b 100644 --- a/flux.hpp +++ b/flux.hpp @@ -37,7 +37,7 @@ namespace Flux { int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { ggml_type wtype = GGML_TYPE_F32; params["scale"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } @@ -1115,10 +1115,10 @@ namespace Flux { FluxRunner(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - const std::string prefix = "", - SDVersion version = VERSION_FLUX, - bool use_mask = false) + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "", + SDVersion version = VERSION_FLUX, + bool use_mask = false) : GGMLRunner(backend, offload_params_to_cpu), version(version), use_mask(use_mask) { flux_params.version = version; flux_params.guidance_embed = false; @@ -1134,7 +1134,7 @@ namespace Flux { flux_params.in_channels = 3; flux_params.patch_size = 16; } - for (auto pair : tensor_types) { + for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; if (!starts_with(tensor_name, prefix)) continue; @@ -1172,7 +1172,7 @@ namespace Flux { } flux = Flux(flux_params); - flux.init(params_ctx, tensor_types, prefix); + flux.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { @@ -1403,17 +1403,16 @@ namespace Flux { return; } - auto tensor_types = model_loader.tensor_storages_types; - for (auto& item : tensor_types) { - // LOG_DEBUG("%s %u", item.first.c_str(), item.second); - if (ends_with(item.first, "weight")) { - // item.second = model_data_type; + auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + for (auto& [name, tensor_storage] : tensor_storage_map) { + if (ends_with(name, "weight")) { + tensor_storage.expected_type = model_data_type; } } std::shared_ptr flux = std::make_shared(backend, false, - tensor_types, + tensor_storage_map, "model.diffusion_model", VERSION_CHROMA_RADIANCE, false); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 41d59e48..d11e07a1 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1460,8 +1460,6 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) { #define MAX_PARAMS_TENSOR_NUM 32768 #define MAX_GRAPH_SIZE 327680 -typedef std::map String2GGMLType; - struct GGMLRunnerContext { ggml_backend_t backend = nullptr; ggml_context* ggml_ctx = nullptr; @@ -1900,30 +1898,36 @@ class GGMLBlock { GGMLBlockMap blocks; ParameterMap params; - ggml_type get_type(const std::string& name, const String2GGMLType& tensor_types, ggml_type default_type) { - auto iter = tensor_types.find(name); - if (iter != tensor_types.end()) { - return iter->second; + ggml_type get_type(const std::string& name, const String2TensorStorage& tensor_storage_map, ggml_type default_type) { + ggml_type wtype = default_type; + auto iter = tensor_storage_map.find(name); + if (iter != tensor_storage_map.end()) { + const TensorStorage& tensor_storage = iter->second; + if (tensor_storage.expected_type != GGML_TYPE_COUNT) { + wtype = tensor_storage.expected_type; + } else { + wtype = tensor_storage.type; + } } - return default_type; + return wtype; } - void init_blocks(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + void init_blocks(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") { for (auto& pair : blocks) { auto& block = pair.second; - block->init(ctx, tensor_types, prefix + pair.first); + block->init(ctx, tensor_storage_map, prefix + pair.first); } } - virtual void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {} + virtual void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {} public: - void init(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") { + void init(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") { if (prefix.size() > 0) { prefix = prefix + "."; } - init_blocks(ctx, tensor_types, prefix); - init_params(ctx, tensor_types, prefix); + init_blocks(ctx, tensor_storage_map, prefix); + init_params(ctx, tensor_storage_map, prefix); } size_t get_params_num() { @@ -2001,8 +2005,8 @@ class Linear : public UnaryBlock { bool force_prec_f32; float scale; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { - enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); if (in_features % ggml_blck_size(wtype) != 0 || force_f32) { wtype = GGML_TYPE_F32; } @@ -2049,8 +2053,8 @@ class Embedding : public UnaryBlock { protected: int64_t embedding_dim; int64_t num_embeddings; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override { - enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override { + enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); if (!support_get_rows(wtype)) { wtype = GGML_TYPE_F32; } @@ -2093,7 +2097,7 @@ class Conv2d : public UnaryBlock { bool bias; float scale = 1.f; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F16; params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels); if (bias) { @@ -2157,7 +2161,7 @@ class Conv3dnx1x1 : public UnaryBlock { int64_t dilation; bool bias; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F16; params["weight"] = ggml_new_tensor_4d(ctx, wtype, 1, kernel_size, in_channels, out_channels); // 5d => 4d if (bias) { @@ -2204,7 +2208,7 @@ class Conv3d : public UnaryBlock { std::tuple dilation; bool bias; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F16; params["weight"] = ggml_new_tensor_4d(ctx, wtype, @@ -2253,7 +2257,7 @@ class LayerNorm : public UnaryBlock { bool elementwise_affine; bool bias; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { if (elementwise_affine) { enum ggml_type wtype = GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape); @@ -2295,7 +2299,7 @@ class GroupNorm : public GGMLBlock { float eps; bool affine; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { if (affine) { enum ggml_type wtype = GGML_TYPE_F32; enum ggml_type bias_wtype = GGML_TYPE_F32; @@ -2336,7 +2340,7 @@ class RMSNorm : public UnaryBlock { int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } @@ -2359,9 +2363,11 @@ class MultiheadAttention : public GGMLBlock { protected: int64_t embed_dim; int64_t n_head; + bool proj_in; std::string q_proj_name; std::string k_proj_name; std::string v_proj_name; + std::string in_proj_name; std::string out_proj_name; public: @@ -2369,19 +2375,27 @@ class MultiheadAttention : public GGMLBlock { int64_t n_head, bool qkv_proj_bias = true, bool out_proj_bias = true, + bool proj_in = false, std::string q_proj_name = "q_proj", std::string k_proj_name = "k_proj", std::string v_proj_name = "v_proj", + std::string in_proj_name = "in_proj", std::string out_proj_name = "out_proj") : embed_dim(embed_dim), n_head(n_head), + proj_in(proj_in), q_proj_name(q_proj_name), k_proj_name(k_proj_name), v_proj_name(v_proj_name), + in_proj_name(in_proj_name), out_proj_name(out_proj_name) { - blocks[q_proj_name] = std::shared_ptr(new Linear(embed_dim, embed_dim, qkv_proj_bias)); - blocks[k_proj_name] = std::shared_ptr(new Linear(embed_dim, embed_dim, qkv_proj_bias)); - blocks[v_proj_name] = std::shared_ptr(new Linear(embed_dim, embed_dim, qkv_proj_bias)); + if (proj_in) { + blocks[in_proj_name] = std::shared_ptr(new Linear(embed_dim, embed_dim * 3, qkv_proj_bias)); + } else { + blocks[q_proj_name] = std::shared_ptr(new Linear(embed_dim, embed_dim, qkv_proj_bias)); + blocks[k_proj_name] = std::shared_ptr(new Linear(embed_dim, embed_dim, qkv_proj_bias)); + blocks[v_proj_name] = std::shared_ptr(new Linear(embed_dim, embed_dim, qkv_proj_bias)); + } blocks[out_proj_name] = std::shared_ptr(new Linear(embed_dim, embed_dim, out_proj_bias)); } @@ -2389,14 +2403,27 @@ class MultiheadAttention : public GGMLBlock { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = false) { - auto q_proj = std::dynamic_pointer_cast(blocks[q_proj_name]); - auto k_proj = std::dynamic_pointer_cast(blocks[k_proj_name]); - auto v_proj = std::dynamic_pointer_cast(blocks[v_proj_name]); auto out_proj = std::dynamic_pointer_cast(blocks[out_proj_name]); - struct ggml_tensor* q = q_proj->forward(ctx, x); - struct ggml_tensor* k = k_proj->forward(ctx, x); - struct ggml_tensor* v = v_proj->forward(ctx, x); + ggml_tensor* q; + ggml_tensor* k; + ggml_tensor* v; + if (proj_in) { + auto in_proj = std::dynamic_pointer_cast(blocks[in_proj_name]); + auto qkv = in_proj->forward(ctx, x); + auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); + q = qkv_vec[0]; + k = qkv_vec[1]; + v = qkv_vec[2]; + } else { + auto q_proj = std::dynamic_pointer_cast(blocks[q_proj_name]); + auto k_proj = std::dynamic_pointer_cast(blocks[k_proj_name]); + auto v_proj = std::dynamic_pointer_cast(blocks[v_proj_name]); + + q = q_proj->forward(ctx, x); + k = k_proj->forward(ctx, x); + v = v_proj->forward(ctx, x); + } x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim] diff --git a/mmdit.hpp b/mmdit.hpp index 6189783c..7249a13e 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -633,13 +633,13 @@ struct MMDiT : public GGMLBlock { int64_t hidden_size; std::string qk_norm; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F32; params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1); } public: - MMDiT(const String2GGMLType& tensor_types = {}) { + MMDiT(const String2TensorStorage& tensor_storage_map = {}) { // input_size is always None // learn_sigma is always False // register_length is alwalys 0 @@ -652,8 +652,7 @@ struct MMDiT : public GGMLBlock { // pos_embed_offset is not used // context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}} - // read tensors from tensor_types - for (auto pair : tensor_types) { + for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; if (tensor_name.find("model.diffusion_model.") == std::string::npos) continue; @@ -852,10 +851,10 @@ struct MMDiTRunner : public GGMLRunner { MMDiTRunner(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - const std::string prefix = "") - : GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) { - mmdit.init(params_ctx, tensor_types, prefix); + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") + : GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_storage_map) { + mmdit.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { diff --git a/model.cpp b/model.cpp index da77afed..cec69663 100644 --- a/model.cpp +++ b/model.cpp @@ -140,7 +140,9 @@ std::unordered_map open_clip_to_hf_clip_model = { {"model.visual.proj", "transformer.visual_projection.weight"}, }; -std::unordered_map open_clip_to_hk_clip_resblock = { +std::unordered_map open_clip_to_hf_clip_resblock = { + {"attn.in_proj_bias", "self_attn.in_proj.bias"}, + {"attn.in_proj_weight", "self_attn.in_proj.weight"}, {"attn.out_proj.bias", "self_attn.out_proj.bias"}, {"attn.out_proj.weight", "self_attn.out_proj.weight"}, {"ln_1.bias", "layer_norm1.bias"}, @@ -351,10 +353,8 @@ std::string convert_cond_model_name(const std::string& name) { std::string idx = remain.substr(0, remain.find(".")); std::string suffix = remain.substr(idx.length() + 1); - if (suffix == "attn.in_proj_weight" || suffix == "attn.in_proj_bias") { - new_name = hf_clip_resblock_prefix + idx + "." + suffix; - } else if (open_clip_to_hk_clip_resblock.find(suffix) != open_clip_to_hk_clip_resblock.end()) { - std::string new_suffix = open_clip_to_hk_clip_resblock[suffix]; + if (open_clip_to_hf_clip_resblock.find(suffix) != open_clip_to_hf_clip_resblock.end()) { + std::string new_suffix = open_clip_to_hf_clip_resblock[suffix]; new_name = hf_clip_resblock_prefix + idx + "." + new_suffix; } } @@ -740,80 +740,6 @@ std::string convert_tensor_name(std::string name) { return new_name; } -void add_preprocess_tensor_storage_types(String2GGMLType& tensor_storages_types, std::string name, enum ggml_type type) { - std::string new_name = convert_tensor_name(name); - - if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) { - size_t prefix_size = new_name.find("attn.in_proj_weight"); - std::string prefix = new_name.substr(0, prefix_size); - tensor_storages_types[prefix + "self_attn.q_proj.weight"] = type; - tensor_storages_types[prefix + "self_attn.k_proj.weight"] = type; - tensor_storages_types[prefix + "self_attn.v_proj.weight"] = type; - } else if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_bias")) { - size_t prefix_size = new_name.find("attn.in_proj_bias"); - std::string prefix = new_name.substr(0, prefix_size); - tensor_storages_types[prefix + "self_attn.q_proj.bias"] = type; - tensor_storages_types[prefix + "self_attn.k_proj.bias"] = type; - tensor_storages_types[prefix + "self_attn.v_proj.bias"] = type; - } else { - tensor_storages_types[new_name] = type; - } -} - -void preprocess_tensor(TensorStorage tensor_storage, - std::vector& processed_tensor_storages) { - std::vector result; - std::string new_name = convert_tensor_name(tensor_storage.name); - - // convert unet transformer linear to conv2d 1x1 - if (starts_with(new_name, "model.diffusion_model.") && - !starts_with(new_name, "model.diffusion_model.proj_out.") && - (ends_with(new_name, "proj_in.weight") || ends_with(new_name, "proj_out.weight"))) { - tensor_storage.unsqueeze(); - } - - // convert vae attn block linear to conv2d 1x1 - if (starts_with(new_name, "first_stage_model.") && new_name.find("attn_1") != std::string::npos) { - tensor_storage.unsqueeze(); - } - - // wan vae - if (ends_with(new_name, "gamma")) { - tensor_storage.reverse_ne(); - tensor_storage.n_dims = 1; - tensor_storage.reverse_ne(); - } - - tensor_storage.name = new_name; - - if (new_name.find("cond_stage_model") != std::string::npos && - ends_with(new_name, "attn.in_proj_weight")) { - size_t prefix_size = new_name.find("attn.in_proj_weight"); - std::string prefix = new_name.substr(0, prefix_size); - - std::vector chunks = tensor_storage.chunk(3); - chunks[0].name = prefix + "self_attn.q_proj.weight"; - chunks[1].name = prefix + "self_attn.k_proj.weight"; - chunks[2].name = prefix + "self_attn.v_proj.weight"; - - processed_tensor_storages.insert(processed_tensor_storages.end(), chunks.begin(), chunks.end()); - - } else if (new_name.find("cond_stage_model") != std::string::npos && - ends_with(new_name, "attn.in_proj_bias")) { - size_t prefix_size = new_name.find("attn.in_proj_bias"); - std::string prefix = new_name.substr(0, prefix_size); - - std::vector chunks = tensor_storage.chunk(3); - chunks[0].name = prefix + "self_attn.q_proj.bias"; - chunks[1].name = prefix + "self_attn.k_proj.bias"; - chunks[2].name = prefix + "self_attn.v_proj.bias"; - - processed_tensor_storages.insert(processed_tensor_storages.end(), chunks.begin(), chunks.end()); - } else { - processed_tensor_storages.push_back(tensor_storage); - } -} - float bf16_to_f32(uint16_t bfloat16) { uint32_t val_bits = (static_cast(bfloat16) << 16); return *reinterpret_cast(&val_bits); @@ -989,44 +915,10 @@ void convert_tensor(void* src, /*================================================= ModelLoader ==================================================*/ -// ported from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py#L16 -std::map unicode_to_byte() { - std::map byte_to_unicode; - - // List of utf-8 byte ranges - for (int b = static_cast('!'); b <= static_cast('~'); ++b) { - byte_to_unicode[b] = static_cast(b); - } - - for (int b = 49825; b <= 49836; ++b) { - byte_to_unicode[b] = static_cast(b); - } - - for (int b = 49838; b <= 50111; ++b) { - byte_to_unicode[b] = static_cast(b); - } - // printf("%d %d %d %d\n", static_cast('¡'), static_cast('¬'), static_cast('®'), static_cast('ÿ')); - // exit(1); - - int n = 0; - for (int b = 0; b < 256; ++b) { - if (byte_to_unicode.find(b) == byte_to_unicode.end()) { - byte_to_unicode[b] = static_cast(256 + n); - n++; - } - } - - // byte_encoder = bytes_to_unicode() - // byte_decoder = {v: k for k, v in byte_encoder.items()} - std::map byte_decoder; - - for (const auto& entry : byte_to_unicode) { - byte_decoder[entry.second] = entry.first; - } - - byte_to_unicode.clear(); - - return byte_decoder; +void ModelLoader::add_tensor_storage(const TensorStorage& tensor_storage) { + TensorStorage copy = tensor_storage; + copy.name = convert_tensor_name(copy.name); + tensor_storage_map[copy.name] = std::move(copy); } bool is_zip_file(const std::string& file_path) { @@ -1156,8 +1048,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s // LOG_DEBUG("%s %s", name.c_str(), tensor_storage.to_string().c_str()); - tensor_storages.push_back(tensor_storage); - add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type); + add_tensor_storage(tensor_storage); } return true; @@ -1182,8 +1073,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes()); - tensor_storages.push_back(tensor_storage); - add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type); + add_tensor_storage(tensor_storage); } gguf_free(ctx_gguf_); @@ -1350,8 +1240,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size); } - tensor_storages.push_back(tensor_storage); - add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type); + add_tensor_storage(tensor_storage); // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str()); } @@ -1370,11 +1259,13 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s if (!init_from_safetensors_file(unet_path, "unet.")) { return false; } - for (auto ts : tensor_storages) { - if (ts.name.find("add_embedding") != std::string::npos || ts.name.find("label_emb") != std::string::npos) { + for (auto& [name, tensor_storage] : tensor_storage_map) { + if (name.find("add_embedding") != std::string::npos || name.find("label_emb") != std::string::npos) { // probably SDXL LOG_DEBUG("Fixing name for SDXL output blocks.2.2"); - for (auto& tensor_storage : tensor_storages) { + String2TensorStorage new_tensor_storage_map; + + for (auto& [name, tensor_storage] : tensor_storage_map) { int len = 34; auto pos = tensor_storage.name.find("unet.up_blocks.0.upsamplers.0.conv"); if (pos == std::string::npos) { @@ -1382,11 +1273,15 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv"); } if (pos != std::string::npos) { - tensor_storage.name = "model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name.substr(len); - LOG_DEBUG("NEW NAME: %s", tensor_storage.name.c_str()); - add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type); + std::string new_name = "model.diffusion_model.output_blocks.2.2.conv" + name.substr(len); + LOG_DEBUG("NEW NAME: %s", new_name.c_str()); + tensor_storage.name = new_name; + new_tensor_storage_map[new_name] = tensor_storage; + } else { + new_tensor_storage_map[name] = tensor_storage; } } + tensor_storage_map = new_tensor_storage_map; break; } } @@ -1712,8 +1607,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer, name = prefix + name; } reader.tensor_storage.name = name; - tensor_storages.push_back(reader.tensor_storage); - add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type); + add_tensor_storage(reader.tensor_storage); // LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); // reset @@ -1767,15 +1661,6 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s return true; } -bool ModelLoader::model_is_unet() { - for (auto& tensor_storage : tensor_storages) { - if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) { - return true; - } - } - return false; -} - SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight, input_block_weight; @@ -1789,7 +1674,7 @@ SDVersion ModelLoader::get_sd_version() { bool has_img_emb = false; bool has_middle_block_1 = false; - for (auto& tensor_storage : tensor_storages) { + for (auto& [name, tensor_storage] : tensor_storage_map) { if (!(is_xl)) { if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { is_flux = true; @@ -1910,7 +1795,7 @@ SDVersion ModelLoader::get_sd_version() { std::map ModelLoader::get_wtype_stat() { std::map wtype_stat; - for (auto& tensor_storage : tensor_storages) { + for (auto& [name, tensor_storage] : tensor_storage_map) { if (is_unused_tensor(tensor_storage.name)) { continue; } @@ -1927,7 +1812,7 @@ std::map ModelLoader::get_wtype_stat() { std::map ModelLoader::get_conditioner_wtype_stat() { std::map wtype_stat; - for (auto& tensor_storage : tensor_storages) { + for (auto& [name, tensor_storage] : tensor_storage_map) { if (is_unused_tensor(tensor_storage.name)) { continue; } @@ -1951,7 +1836,7 @@ std::map ModelLoader::get_conditioner_wtype_stat() { std::map ModelLoader::get_diffusion_model_wtype_stat() { std::map wtype_stat; - for (auto& tensor_storage : tensor_storages) { + for (auto& [name, tensor_storage] : tensor_storage_map) { if (is_unused_tensor(tensor_storage.name)) { continue; } @@ -1972,7 +1857,7 @@ std::map ModelLoader::get_diffusion_model_wtype_stat() { std::map ModelLoader::get_vae_wtype_stat() { std::map wtype_stat; - for (auto& tensor_storage : tensor_storages) { + for (auto& [name, tensor_storage] : tensor_storage_map) { if (is_unused_tensor(tensor_storage.name)) { continue; } @@ -1993,26 +1878,14 @@ std::map ModelLoader::get_vae_wtype_stat() { } void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) { - for (auto& pair : tensor_storages_types) { - if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) { - bool found = false; - for (auto& tensor_storage : tensor_storages) { - std::map temp; - add_preprocess_tensor_storage_types(temp, tensor_storage.name, tensor_storage.type); - for (auto& preprocessed_name : temp) { - if (preprocessed_name.first == pair.first) { - if (tensor_should_be_converted(tensor_storage, wtype)) { - pair.second = wtype; - } - found = true; - break; - } - } - if (found) { - break; - } - } + for (auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; } + if (!tensor_should_be_converted(tensor_storage, wtype)) { + continue; + } + tensor_storage.expected_type = wtype; } } @@ -2047,74 +1920,13 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread LOG_DEBUG("using %d threads for model loading", num_threads_to_use); int64_t start_time = ggml_time_ms(); - std::vector processed_tensor_storages; - - { - struct IndexedStorage { - size_t index; - TensorStorage ts; - }; - - std::mutex vec_mutex; - std::vector all_results; - - int n_threads = std::min(num_threads_to_use, (int)tensor_storages.size()); - if (n_threads < 1) { - n_threads = 1; - } - std::vector workers; - - for (int i = 0; i < n_threads; ++i) { - workers.emplace_back([&, thread_id = i]() { - std::vector local_results; - std::vector temp_storages; - - for (size_t j = thread_id; j < tensor_storages.size(); j += n_threads) { - const auto& tensor_storage = tensor_storages[j]; - if (is_unused_tensor(tensor_storage.name)) { - continue; - } - - temp_storages.clear(); - preprocess_tensor(tensor_storage, temp_storages); - - for (const auto& ts : temp_storages) { - local_results.push_back({j, ts}); - } - } - - if (!local_results.empty()) { - std::lock_guard lock(vec_mutex); - all_results.insert(all_results.end(), - local_results.begin(), local_results.end()); - } - }); - } - for (auto& w : workers) { - w.join(); - } - std::vector deduplicated; - deduplicated.reserve(all_results.size()); - std::unordered_map name_to_pos; - for (auto& entry : all_results) { - auto it = name_to_pos.find(entry.ts.name); - if (it == name_to_pos.end()) { - name_to_pos.emplace(entry.ts.name, deduplicated.size()); - deduplicated.push_back(entry); - } else if (deduplicated[it->second].index < entry.index) { - deduplicated[it->second] = entry; - } - } - - std::sort(deduplicated.begin(), deduplicated.end(), [](const IndexedStorage& a, const IndexedStorage& b) { - return a.index < b.index; - }); - - processed_tensor_storages.reserve(deduplicated.size()); - for (auto& entry : deduplicated) { - processed_tensor_storages.push_back(entry.ts); + std::vector processed_tensor_storages; + for (auto& [name, tensor_storage] : tensor_storage_map) { + if (is_unused_tensor(tensor_storage.name)) { + continue; } + processed_tensor_storages.push_back(tensor_storage); } process_time_ms = ggml_time_ms() - start_time; @@ -2231,106 +2043,71 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread } }; + char* read_buf = nullptr; + char* target_buf = nullptr; + char* convert_buf = nullptr; if (dst_tensor->buffer == nullptr || ggml_backend_buffer_is_host(dst_tensor->buffer)) { if (tensor_storage.type == dst_tensor->type) { GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes()); if (tensor_storage.is_f64 || tensor_storage.is_i64) { read_buffer.resize(tensor_storage.nbytes_to_read()); - read_data((char*)read_buffer.data(), nbytes_to_read); + read_buf = (char*)read_buffer.data(); } else { - read_data((char*)dst_tensor->data, nbytes_to_read); - } - t1 = ggml_time_ms(); - read_time_ms.fetch_add(t1 - t0); - - t0 = ggml_time_ms(); - if (tensor_storage.is_bf16) { - // inplace op - bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements()); - } else if (tensor_storage.is_f8_e4m3) { - // inplace op - f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); - } else if (tensor_storage.is_f8_e5m2) { - // inplace op - f8_e5m2_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); - } else if (tensor_storage.is_f64) { - f64_to_f32_vec((double*)read_buffer.data(), (float*)dst_tensor->data, tensor_storage.nelements()); - } else if (tensor_storage.is_i64) { - i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)dst_tensor->data, tensor_storage.nelements()); + read_buf = (char*)dst_tensor->data; } - t1 = ggml_time_ms(); - convert_time_ms.fetch_add(t1 - t0); + target_buf = (char*)dst_tensor->data; } else { read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read())); - read_data((char*)read_buffer.data(), nbytes_to_read); - t1 = ggml_time_ms(); - read_time_ms.fetch_add(t1 - t0); - - t0 = ggml_time_ms(); - if (tensor_storage.is_bf16) { - // inplace op - bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_f8_e4m3) { - // inplace op - f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_f8_e5m2) { - // inplace op - f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_f64) { - // inplace op - f64_to_f32_vec((double*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_i64) { - // inplace op - i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)read_buffer.data(), tensor_storage.nelements()); - } - convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data, dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]); - t1 = ggml_time_ms(); - convert_time_ms.fetch_add(t1 - t0); + read_buf = (char*)read_buffer.data(); + target_buf = read_buf; + convert_buf = (char*)dst_tensor->data; } } else { read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read())); - read_data((char*)read_buffer.data(), nbytes_to_read); - t1 = ggml_time_ms(); - read_time_ms.fetch_add(t1 - t0); + read_buf = (char*)read_buffer.data(); + target_buf = read_buf; - t0 = ggml_time_ms(); - if (tensor_storage.is_bf16) { - // inplace op - bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_f8_e4m3) { - // inplace op - f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_f8_e5m2) { - // inplace op - f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_f64) { - // inplace op - f64_to_f32_vec((double*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_i64) { - // inplace op - i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)read_buffer.data(), tensor_storage.nelements()); + if (tensor_storage.type != dst_tensor->type) { + convert_buffer.resize(ggml_nbytes(dst_tensor)); + convert_buf = (char*)convert_buffer.data(); } + } - if (tensor_storage.type == dst_tensor->type) { - // copy to device memory - t1 = ggml_time_ms(); - convert_time_ms.fetch_add(t1 - t0); - t0 = ggml_time_ms(); - ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor)); - t1 = ggml_time_ms(); - copy_to_backend_time_ms.fetch_add(t1 - t0); - } else { - // convert first, then copy to device memory + t0 = ggml_time_ms(); + read_data(read_buf, nbytes_to_read); + t1 = ggml_time_ms(); + read_time_ms.fetch_add(t1 - t0); - convert_buffer.resize(ggml_nbytes(dst_tensor)); - convert_tensor((void*)read_buffer.data(), tensor_storage.type, (void*)convert_buffer.data(), dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]); - t1 = ggml_time_ms(); - convert_time_ms.fetch_add(t1 - t0); - t0 = ggml_time_ms(); - ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor)); - t1 = ggml_time_ms(); - copy_to_backend_time_ms.fetch_add(t1 - t0); - } + t0 = ggml_time_ms(); + if (tensor_storage.is_bf16) { + bf16_to_f32_vec((uint16_t*)read_buf, (float*)target_buf, tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e4m3) { + f8_e4m3_to_f16_vec((uint8_t*)read_buf, (uint16_t*)target_buf, tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e5m2) { + f8_e5m2_to_f16_vec((uint8_t*)read_buf, (uint16_t*)target_buf, tensor_storage.nelements()); + } else if (tensor_storage.is_f64) { + f64_to_f32_vec((double*)read_buf, (float*)target_buf, tensor_storage.nelements()); + } else if (tensor_storage.is_i64) { + i64_to_i32_vec((int64_t*)read_buf, (int32_t*)target_buf, tensor_storage.nelements()); + } + if (tensor_storage.type != dst_tensor->type) { + convert_tensor((void*)target_buf, + tensor_storage.type, + convert_buf, + dst_tensor->type, + (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], + (int)tensor_storage.ne[0]); + } else { + convert_buf = read_buf; + } + t1 = ggml_time_ms(); + convert_time_ms.fetch_add(t1 - t0); + + if (dst_tensor->buffer != nullptr && !ggml_backend_buffer_is_host(dst_tensor->buffer)) { + t0 = ggml_time_ms(); + ggml_backend_tensor_set(dst_tensor, convert_buf, 0, ggml_nbytes(dst_tensor)); + t1 = ggml_time_ms(); + copy_to_backend_time_ms.fetch_add(t1 - t0); } } if (zip != nullptr) { @@ -2520,7 +2297,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) { auto backend = ggml_backend_cpu_init(); size_t mem_size = 1 * 1024 * 1024; // for padding - mem_size += tensor_storages.size() * ggml_tensor_overhead(); + mem_size += tensor_storage_map.size() * ggml_tensor_overhead(); mem_size += get_params_mem_size(backend, type); LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f); ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false}); @@ -2587,14 +2364,10 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) } int64_t mem_size = 0; std::vector processed_tensor_storages; - for (auto& tensor_storage : tensor_storages) { + for (auto [name, tensor_storage] : tensor_storage_map) { if (is_unused_tensor(tensor_storage.name)) { continue; } - preprocess_tensor(tensor_storage, processed_tensor_storages); - } - - for (auto& tensor_storage : processed_tensor_storages) { if (tensor_should_be_converted(tensor_storage, type)) { tensor_storage.type = type; } diff --git a/model.h b/model.h index f1711e67..a29160cf 100644 --- a/model.h +++ b/model.h @@ -65,6 +65,15 @@ static inline bool sd_version_is_sdxl(SDVersion version) { return false; } +static inline bool sd_version_is_unet(SDVersion version) { + if (sd_version_is_sd1(version) || + sd_version_is_sd2(version) || + sd_version_is_sdxl(version)) { + return true; + } + return false; +} + static inline bool sd_version_is_sd3(SDVersion version) { if (version == VERSION_SD3) { return true; @@ -134,6 +143,7 @@ enum PMVersion { struct TensorStorage { std::string name; ggml_type type = GGML_TYPE_F32; + ggml_type expected_type = GGML_TYPE_COUNT; bool is_bf16 = false; bool is_f8_e4m3 = false; bool is_f8_e5m2 = false; @@ -242,12 +252,14 @@ struct TensorStorage { typedef std::function on_new_tensor_cb_t; -typedef std::map String2GGMLType; +typedef std::map String2TensorStorage; class ModelLoader { protected: std::vector file_paths_; - std::vector tensor_storages; + String2TensorStorage tensor_storage_map; + + void add_tensor_storage(const TensorStorage& tensor_storage); bool parse_data_pkl(uint8_t* buffer, size_t buffer_size, @@ -262,15 +274,13 @@ class ModelLoader { bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); public: - String2GGMLType tensor_storages_types; - bool init_from_file(const std::string& file_path, const std::string& prefix = ""); - bool model_is_unet(); SDVersion get_sd_version(); std::map get_wtype_stat(); std::map get_conditioner_wtype_stat(); std::map get_diffusion_model_wtype_stat(); std::map get_vae_wtype_stat(); + String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; } void set_wtype_override(ggml_type wtype, std::string prefix = ""); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0); bool load_tensors(std::map& tensors, @@ -279,8 +289,8 @@ class ModelLoader { std::vector get_tensor_names() const { std::vector names; - for (const auto& ts : tensor_storages) { - names.push_back(ts.name); + for (const auto& [name, tensor_storage] : tensor_storage_map) { + names.push_back(name); } return names; } diff --git a/pmid.hpp b/pmid.hpp index ea7c3989..51e8fb76 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -412,7 +412,7 @@ struct PhotoMakerIDEncoder : public GGMLRunner { public: PhotoMakerIDEncoder(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types, + const String2TensorStorage& tensor_storage_map, const std::string prefix, SDVersion version = VERSION_SDXL, PMVersion pm_v = PM_VERSION_1, @@ -422,9 +422,9 @@ struct PhotoMakerIDEncoder : public GGMLRunner { pm_version(pm_v), style_strength(sty) { if (pm_version == PM_VERSION_1) { - id_encoder.init(params_ctx, tensor_types, prefix); + id_encoder.init(params_ctx, tensor_storage_map, prefix); } else if (pm_version == PM_VERSION_2) { - id_encoder2.init(params_ctx, tensor_types, prefix); + id_encoder2.init(params_ctx, tensor_storage_map, prefix); } } diff --git a/qwen_image.hpp b/qwen_image.hpp index 6288d9aa..ca3c84ac 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -502,12 +502,12 @@ namespace Qwen { QwenImageRunner(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - const std::string prefix = "", - SDVersion version = VERSION_QWEN_IMAGE) + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "", + SDVersion version = VERSION_QWEN_IMAGE) : GGMLRunner(backend, offload_params_to_cpu) { qwen_image_params.num_layers = 0; - for (auto pair : tensor_types) { + for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; if (tensor_name.find(prefix) == std::string::npos) continue; @@ -526,7 +526,7 @@ namespace Qwen { } LOG_INFO("qwen_image_params.num_layers: %ld", qwen_image_params.num_layers); qwen_image = QwenImageModel(qwen_image_params); - qwen_image.init(params_ctx, tensor_types, prefix); + qwen_image.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { @@ -649,17 +649,16 @@ namespace Qwen { return; } - auto tensor_types = model_loader.tensor_storages_types; - for (auto& item : tensor_types) { - // LOG_DEBUG("%s %u", item.first.c_str(), item.second); - if (ends_with(item.first, "weight")) { - item.second = model_data_type; + auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + for (auto& [name, tensor_storage] : tensor_storage_map) { + if (ends_with(name, "weight")) { + tensor_storage.expected_type = model_data_type; } } std::shared_ptr qwen_image = std::make_shared(backend, false, - tensor_types, + tensor_storage_map, "model.diffusion_model", VERSION_QWEN_IMAGE); diff --git a/qwenvl.hpp b/qwenvl.hpp index 8918978d..26d18623 100644 --- a/qwenvl.hpp +++ b/qwenvl.hpp @@ -910,13 +910,13 @@ namespace Qwen { Qwen2_5_VLRunner(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types, + const String2TensorStorage& tensor_storage_map, const std::string prefix, bool enable_vision_ = false) : GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) { bool have_vision_weight = false; bool llama_cpp_style = false; - for (auto pair : tensor_types) { + for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; if (tensor_name.find(prefix) == std::string::npos) continue; @@ -940,7 +940,7 @@ namespace Qwen { } } model = Qwen2_5_VL(params, enable_vision, llama_cpp_style); - model.init(params_ctx, tensor_types, prefix); + model.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { @@ -1188,10 +1188,10 @@ namespace Qwen { Qwen2_5_VLEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - const std::string prefix = "", - bool enable_vision = false) - : model(backend, offload_params_to_cpu, tensor_types, prefix, enable_vision) { + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "", + bool enable_vision = false) + : model(backend, offload_params_to_cpu, tensor_storage_map, prefix, enable_vision) { } void get_param_tensors(std::map& tensors, const std::string prefix) { @@ -1347,17 +1347,16 @@ namespace Qwen { return; } - auto tensor_types = model_loader.tensor_storages_types; - for (auto& item : tensor_types) { - // LOG_DEBUG("%s %u", item.first.c_str(), item.second); - if (ends_with(item.first, "weight")) { - item.second = model_data_type; + auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + for (auto& [name, tensor_storage] : tensor_storage_map) { + if (ends_with(name, "weight")) { + tensor_storage.expected_type = model_data_type; } } std::shared_ptr qwenvl = std::make_shared(backend, false, - tensor_types, + tensor_storage_map, "qwen2vl", true); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 49f6530f..9faba955 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -213,7 +213,7 @@ class StableDiffusionGGML { } } - bool is_unet = model_loader.model_is_unet(); + bool is_unet = sd_version_is_unet(model_loader.get_sd_version()); if (strlen(SAFE_STR(sd_ctx_params->clip_l_path)) > 0) { LOG_INFO("loading clip_l from '%s'", sd_ctx_params->clip_l_path); @@ -273,12 +273,12 @@ class StableDiffusionGGML { return false; } - auto& tensor_types = model_loader.tensor_storages_types; - for (auto& item : tensor_types) { - // LOG_DEBUG("%s %u", item.first.c_str(), item.second); - if (contains(item.first, "qwen2vl") && ends_with(item.first, "weight") && (item.second == GGML_TYPE_F32 || item.second == GGML_TYPE_BF16)) { - item.second = GGML_TYPE_F16; - // LOG_DEBUG(" change %s %u", item.first.c_str(), item.second); + auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + for (auto& [name, tensor_storage] : tensor_storage_map) { + if (contains(name, "qwen2vl") && + ends_with(name, "weight") && + (tensor_storage.type == GGML_TYPE_F32 || tensor_storage.type == GGML_TYPE_BF16)) { + tensor_storage.expected_type = GGML_TYPE_F16; } } @@ -344,13 +344,13 @@ class StableDiffusionGGML { if (sd_version_is_sd3(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, - model_loader.tensor_storages_types); + tensor_storage_map); diffusion_model = std::make_shared(backend, offload_params_to_cpu, - model_loader.tensor_storages_types); + tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; - for (auto pair : model_loader.tensor_storages_types) { + for (auto pair : tensor_storage_map) { if (pair.first.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { is_chroma = true; break; @@ -368,42 +368,42 @@ class StableDiffusionGGML { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, sd_ctx_params->chroma_use_t5_mask, sd_ctx_params->chroma_t5_mask_pad); } else { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, - model_loader.tensor_storages_types); + tensor_storage_map); } diffusion_model = std::make_shared(backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, version, sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_wan(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, true, 1, true); diffusion_model = std::make_shared(backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, "model.diffusion_model", version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, "model.high_noise_diffusion_model", version); } if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") { clip_vision = std::make_shared(backend, offload_params_to_cpu, - model_loader.tensor_storages_types); + tensor_storage_map); clip_vision->alloc_params_buffer(); clip_vision->get_param_tensors(tensors); } @@ -414,32 +414,32 @@ class StableDiffusionGGML { } cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, "", enable_vision); diffusion_model = std::make_shared(backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, "model.diffusion_model", version); } else { // SD1.x SD2.x SDXL if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, SAFE_STR(sd_ctx_params->embedding_dir), version, PM_VERSION_2); } else { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, SAFE_STR(sd_ctx_params->embedding_dir), version); } diffusion_model = std::make_shared(backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, version); if (sd_ctx_params->diffusion_conv_direct) { LOG_INFO("Using Conv2d direct in the diffusion model"); @@ -477,7 +477,7 @@ class StableDiffusionGGML { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { first_stage_model = std::make_shared(vae_backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, "first_stage_model", vae_decode_only, version); @@ -489,7 +489,7 @@ class StableDiffusionGGML { } else if (!use_tiny_autoencoder) { first_stage_model = std::make_shared(vae_backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, "first_stage_model", vae_decode_only, false, @@ -512,7 +512,7 @@ class StableDiffusionGGML { } else { tae_first_stage = std::make_shared(vae_backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, "decoder.layers", vae_decode_only, version); @@ -533,7 +533,7 @@ class StableDiffusionGGML { } control_net = std::make_shared(controlnet_backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, version); if (sd_ctx_params->diffusion_conv_direct) { LOG_INFO("Using Conv2d direct in the control net"); @@ -544,7 +544,7 @@ class StableDiffusionGGML { if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { pmid_model = std::make_shared(backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, "pmid", version, PM_VERSION_2); @@ -552,7 +552,7 @@ class StableDiffusionGGML { } else { pmid_model = std::make_shared(backend, offload_params_to_cpu, - model_loader.tensor_storages_types, + tensor_storage_map, "pmid", version); } @@ -733,12 +733,12 @@ class StableDiffusionGGML { is_using_v_parameterization = true; } } else if (sd_version_is_sdxl(version)) { - if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) { + if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) { // CosXL models // TODO: get sigma_min and sigma_max values from file is_using_edm_v_parameterization = true; } - if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) { + if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) { is_using_v_parameterization = true; } } else if (version == VERSION_SVD) { @@ -758,10 +758,9 @@ class StableDiffusionGGML { float shift = sd_ctx_params->flow_shift; if (shift == INFINITY) { shift = 1.0f; // TODO: validate - for (auto pair : model_loader.tensor_storages_types) { - if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) { + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) { shift = 1.15f; - break; } } } diff --git a/t5.hpp b/t5.hpp index 1f6341f8..89a60665 100644 --- a/t5.hpp +++ b/t5.hpp @@ -461,7 +461,7 @@ class T5LayerNorm : public UnaryBlock { int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } @@ -759,7 +759,7 @@ struct T5Runner : public GGMLRunner { T5Runner(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types, + const String2TensorStorage& tensor_storage_map, const std::string prefix, bool is_umt5 = false) : GGMLRunner(backend, offload_params_to_cpu) { @@ -768,7 +768,7 @@ struct T5Runner : public GGMLRunner { params.relative_attention = false; } model = T5(params); - model.init(params_ctx, tensor_types, prefix); + model.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { @@ -905,10 +905,10 @@ struct T5Embedder { T5Embedder(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - const std::string prefix = "", - bool is_umt5 = false) - : model(backend, offload_params_to_cpu, tensor_types, prefix, is_umt5), tokenizer(is_umt5) { + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "", + bool is_umt5 = false) + : model(backend, offload_params_to_cpu, tensor_storage_map, prefix, is_umt5), tokenizer(is_umt5) { } void get_param_tensors(std::map& tensors, const std::string prefix) { @@ -1009,15 +1009,14 @@ struct T5Embedder { return; } - auto tensor_types = model_loader.tensor_storages_types; - for (auto& item : tensor_types) { - // LOG_DEBUG("%s %u", item.first.c_str(), item.second); - if (ends_with(item.first, "weight")) { - item.second = model_data_type; + auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + for (auto& [name, tensor_storage] : tensor_storage_map) { + if (ends_with(name, "weight")) { + tensor_storage.expected_type = model_data_type; } } - std::shared_ptr t5 = std::make_shared(backend, false, tensor_types, "", true); + std::shared_ptr t5 = std::make_shared(backend, false, tensor_storage_map, "", true); t5->alloc_params_buffer(); std::map tensors; diff --git a/tae.hpp b/tae.hpp index 21617b3f..14cdb578 100644 --- a/tae.hpp +++ b/tae.hpp @@ -197,14 +197,14 @@ struct TinyAutoEncoder : public GGMLRunner { TinyAutoEncoder(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types, + const String2TensorStorage& tensor_storage_map, const std::string prefix, bool decoder_only = true, SDVersion version = VERSION_SD1) : decode_only(decoder_only), taesd(decoder_only, version), GGMLRunner(backend, offload_params_to_cpu) { - taesd.init(params_ctx, tensor_types, prefix); + taesd.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { diff --git a/unet.hpp b/unet.hpp index 91af9f7c..0e0d049b 100644 --- a/unet.hpp +++ b/unet.hpp @@ -20,9 +20,10 @@ class SpatialVideoTransformer : public SpatialTransformer { int64_t d_head, int64_t depth, int64_t context_dim, + bool use_linear, int64_t time_depth = 1, int64_t max_time_embed_period = 10000) - : SpatialTransformer(in_channels, n_head, d_head, depth, context_dim), + : SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear), max_time_embed_period(max_time_embed_period) { // We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False // use_spatial_context is always True @@ -178,17 +179,19 @@ class UnetModelBlock : public GGMLBlock { int num_heads = 8; int num_head_channels = -1; // channels // num_heads int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL + bool use_linear_projection = false; public: int model_channels = 320; int adm_in_channels = 2816; // only for VERSION_SDXL/SVD - UnetModelBlock(SDVersion version = VERSION_SD1, const String2GGMLType& tensor_types = {}) + UnetModelBlock(SDVersion version = VERSION_SD1, const String2TensorStorage& tensor_storage_map = {}) : version(version) { if (sd_version_is_sd2(version)) { - context_dim = 1024; - num_head_channels = 64; - num_heads = -1; + context_dim = 1024; + num_head_channels = 64; + num_heads = -1; + use_linear_projection = true; } else if (sd_version_is_sdxl(version)) { context_dim = 2048; attention_resolutions = {4, 2}; @@ -196,13 +199,15 @@ class UnetModelBlock : public GGMLBlock { transformer_depth = {1, 2, 10}; num_head_channels = 64; num_heads = -1; + use_linear_projection = true; } else if (version == VERSION_SVD) { - in_channels = 8; - out_channels = 4; - context_dim = 1024; - adm_in_channels = 768; - num_head_channels = 64; - num_heads = -1; + in_channels = 8; + out_channels = 4; + context_dim = 1024; + adm_in_channels = 768; + num_head_channels = 64; + num_heads = -1; + use_linear_projection = true; } else if (version == VERSION_SD1_TINY_UNET) { num_res_blocks = 1; channel_mult = {1, 2, 4}; @@ -249,9 +254,9 @@ class UnetModelBlock : public GGMLBlock { int64_t depth, int64_t context_dim) -> SpatialTransformer* { if (version == VERSION_SVD) { - return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim); + return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection); } else { - return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim); + return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection); } }; @@ -581,11 +586,11 @@ struct UNetModelRunner : public GGMLRunner { UNetModelRunner(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types, + const String2TensorStorage& tensor_storage_map, const std::string prefix, SDVersion version = VERSION_SD1) - : GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_types) { - unet.init(params_ctx, tensor_types, prefix); + : GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_storage_map) { + unet.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { diff --git a/upscaler.cpp b/upscaler.cpp index a9e5f6a1..74048a1a 100644 --- a/upscaler.cpp +++ b/upscaler.cpp @@ -51,7 +51,7 @@ struct UpscalerGGML { backend = ggml_backend_cpu_init(); } LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type)); - esrgan_upscaler = std::make_shared(backend, offload_params_to_cpu, model_loader.tensor_storages_types); + esrgan_upscaler = std::make_shared(backend, offload_params_to_cpu, model_loader.get_tensor_storage_map()); if (direct) { esrgan_upscaler->set_conv2d_direct_enabled(true); } diff --git a/vae.hpp b/vae.hpp index 8c82d2f8..ddf970c9 100644 --- a/vae.hpp +++ b/vae.hpp @@ -64,25 +64,32 @@ class ResnetBlock : public UnaryBlock { class AttnBlock : public UnaryBlock { protected: int64_t in_channels; + bool use_linear; public: - AttnBlock(int64_t in_channels) - : in_channels(in_channels) { + AttnBlock(int64_t in_channels, bool use_linear) + : in_channels(in_channels), use_linear(use_linear) { blocks["norm"] = std::shared_ptr(new GroupNorm32(in_channels)); - blocks["q"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); - blocks["k"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); - blocks["v"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); - - blocks["proj_out"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); + if (use_linear) { + blocks["q"] = std::shared_ptr(new Linear(in_channels, in_channels)); + blocks["k"] = std::shared_ptr(new Linear(in_channels, in_channels)); + blocks["v"] = std::shared_ptr(new Linear(in_channels, in_channels)); + blocks["proj_out"] = std::shared_ptr(new Linear(in_channels, in_channels)); + } else { + blocks["q"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); + blocks["k"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); + blocks["v"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); + blocks["proj_out"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); + } } struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [N, in_channels, h, w] auto norm = std::dynamic_pointer_cast(blocks["norm"]); - auto q_proj = std::dynamic_pointer_cast(blocks["q"]); - auto k_proj = std::dynamic_pointer_cast(blocks["k"]); - auto v_proj = std::dynamic_pointer_cast(blocks["v"]); - auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + auto q_proj = std::dynamic_pointer_cast(blocks["q"]); + auto k_proj = std::dynamic_pointer_cast(blocks["k"]); + auto v_proj = std::dynamic_pointer_cast(blocks["v"]); + auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); auto h_ = norm->forward(ctx, x); @@ -91,23 +98,44 @@ class AttnBlock : public UnaryBlock { const int64_t h = h_->ne[1]; const int64_t w = h_->ne[0]; - auto q = q_proj->forward(ctx, h_); // [N, in_channels, h, w] - q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels] - q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels] + ggml_tensor* q; + ggml_tensor* k; + ggml_tensor* v; + if (use_linear) { + h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 2, 0, 3)); // [N, h, w, in_channels] + h_ = ggml_reshape_3d(ctx->ggml_ctx, h_, c, h * w, n); // [N, h * w, in_channels] - auto k = k_proj->forward(ctx, h_); // [N, in_channels, h, w] - k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels] - k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels] + q = q_proj->forward(ctx, h_); // [N, h * w, in_channels] + k = k_proj->forward(ctx, h_); // [N, h * w, in_channels] + v = v_proj->forward(ctx, h_); // [N, h * w, in_channels] - auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] - v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w] + v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [N, in_channels, h * w] + } else { + q = q_proj->forward(ctx, h_); // [N, in_channels, h, w] + q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels] + q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels] + + k = k_proj->forward(ctx, h_); // [N, in_channels, h, w] + k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels] + k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels] + + v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] + v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w] + } h_ = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [N, h * w, in_channels] - h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] - h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w] + if (use_linear) { + h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels] - h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w] + h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] + h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w] + } else { + h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] + h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w] + + h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w] + } h_ = ggml_add(ctx->ggml_ctx, h_, x); return h_; @@ -163,8 +191,8 @@ class AE3DConv : public Conv2d { class VideoResnetBlock : public ResnetBlock { protected: - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { - enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_types, GGML_TYPE_F32); + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_storage_map, GGML_TYPE_F32); params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); } @@ -233,7 +261,8 @@ class Encoder : public GGMLBlock { int num_res_blocks, int in_channels, int z_channels, - bool double_z = true) + bool double_z = true, + bool use_linear_projection = false) : ch(ch), ch_mult(ch_mult), num_res_blocks(num_res_blocks), @@ -264,7 +293,7 @@ class Encoder : public GGMLBlock { } blocks["mid.block_1"] = std::shared_ptr(new ResnetBlock(block_in, block_in)); - blocks["mid.attn_1"] = std::shared_ptr(new AttnBlock(block_in)); + blocks["mid.attn_1"] = std::shared_ptr(new AttnBlock(block_in, use_linear_projection)); blocks["mid.block_2"] = std::shared_ptr(new ResnetBlock(block_in, block_in)); blocks["norm_out"] = std::shared_ptr(new GroupNorm32(block_in)); @@ -351,8 +380,9 @@ class Decoder : public GGMLBlock { std::vector ch_mult, int num_res_blocks, int z_channels, - bool video_decoder = false, - int video_kernel_size = 3) + bool use_linear_projection = false, + bool video_decoder = false, + int video_kernel_size = 3) : ch(ch), out_ch(out_ch), ch_mult(ch_mult), @@ -366,7 +396,7 @@ class Decoder : public GGMLBlock { blocks["conv_in"] = std::shared_ptr(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1})); blocks["mid.block_1"] = get_resnet_block(block_in, block_in); - blocks["mid.attn_1"] = std::shared_ptr(new AttnBlock(block_in)); + blocks["mid.attn_1"] = std::shared_ptr(new AttnBlock(block_in, use_linear_projection)); blocks["mid.block_2"] = get_resnet_block(block_in, block_in); for (int i = num_resolutions - 1; i >= 0; i--) { @@ -454,9 +484,10 @@ class AutoencodingEngine : public GGMLBlock { } dd_config; public: - AutoencodingEngine(bool decode_only = true, - bool use_video_decoder = false, - SDVersion version = VERSION_SD1) + AutoencodingEngine(SDVersion version = VERSION_SD1, + bool decode_only = true, + bool use_linear_projection = false, + bool use_video_decoder = false) : decode_only(decode_only), use_video_decoder(use_video_decoder) { if (sd_version_is_dit(version)) { dd_config.z_channels = 16; @@ -470,6 +501,7 @@ class AutoencodingEngine : public GGMLBlock { dd_config.ch_mult, dd_config.num_res_blocks, dd_config.z_channels, + use_linear_projection, use_video_decoder)); if (use_quant) { blocks["post_quant_conv"] = std::shared_ptr(new Conv2d(dd_config.z_channels, @@ -482,7 +514,8 @@ class AutoencodingEngine : public GGMLBlock { dd_config.num_res_blocks, dd_config.in_channels, dd_config.z_channels, - dd_config.double_z)); + dd_config.double_z, + use_linear_projection)); if (use_quant) { int factor = dd_config.double_z ? 2 : 1; @@ -562,13 +595,26 @@ struct AutoEncoderKL : public VAE { AutoEncoderKL(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types, + const String2TensorStorage& tensor_storage_map, const std::string prefix, bool decode_only = false, bool use_video_decoder = false, SDVersion version = VERSION_SD1) - : decode_only(decode_only), ae(decode_only, use_video_decoder, version), VAE(backend, offload_params_to_cpu) { - ae.init(params_ctx, tensor_types, prefix); + : decode_only(decode_only), VAE(backend, offload_params_to_cpu) { + bool use_linear_projection = false; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (ends_with(name, "attn_1.proj_out.weight")) { + if (tensor_storage.n_dims == 2) { + use_linear_projection = true; + } + break; + } + } + ae = AutoencodingEngine(version, decode_only, use_linear_projection, use_video_decoder); + ae.init(params_ctx, tensor_storage_map, prefix); } void set_conv2d_scale(float scale) override { diff --git a/wan.hpp b/wan.hpp index db4f5aaa..9720cc63 100644 --- a/wan.hpp +++ b/wan.hpp @@ -26,7 +26,7 @@ namespace WAN { std::tuple dilation; bool bias; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { params["weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, std::get<2>(kernel_size), @@ -87,9 +87,14 @@ namespace WAN { protected: int64_t dim; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { ggml_type wtype = GGML_TYPE_F32; - params["gamma"] = ggml_new_tensor_1d(ctx, wtype, dim); + auto iter = tensor_storage_map.find(prefix + "gamma"); + if (iter != tensor_storage_map.end()) { + params["gamma"] = ggml_new_tensor(ctx, wtype, iter->second.n_dims, &iter->second.ne[0]); + } else { + params["gamma"] = ggml_new_tensor_1d(ctx, wtype, dim); + } } public: @@ -101,6 +106,7 @@ namespace WAN { // assert N == 1 struct ggml_tensor* w = params["gamma"]; + w = ggml_reshape_1d(ctx->ggml_ctx, w, ggml_nelements(w)); auto h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC] h = ggml_rms_norm(ctx->ggml_ctx, h, 1e-12); h = ggml_mul(ctx->ggml_ctx, h, w); @@ -1110,12 +1116,12 @@ namespace WAN { WanVAERunner(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - const std::string prefix = "", - bool decode_only = false, - SDVersion version = VERSION_WAN2) + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "", + bool decode_only = false, + SDVersion version = VERSION_WAN2) : decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) { - ae.init(params_ctx, tensor_types, prefix); + ae.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { @@ -1256,7 +1262,7 @@ namespace WAN { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F16; - std::shared_ptr vae = std::make_shared(backend, false, String2GGMLType{}, "", false, VERSION_WAN2_2_TI2V); + std::shared_ptr vae = std::make_shared(backend, false, String2TensorStorage{}, "", false, VERSION_WAN2_2_TI2V); { LOG_INFO("loading from '%s'", file_path.c_str()); @@ -1494,8 +1500,8 @@ namespace WAN { protected: int dim; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { - enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 6, 1); } @@ -1582,8 +1588,8 @@ namespace WAN { class VaceWanAttentionBlock : public WanAttentionBlock { protected: int block_id; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { - enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 6, 1); } @@ -1634,8 +1640,8 @@ namespace WAN { protected: int dim; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { - enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 2, 1); } @@ -1681,7 +1687,7 @@ namespace WAN { int in_dim; int flf_pos_embed_token_number; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { if (flf_pos_embed_token_number > 0) { params["emb_pos"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, in_dim, flf_pos_embed_token_number, 1); } @@ -2015,12 +2021,12 @@ namespace WAN { WanRunner(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - const std::string prefix = "", - SDVersion version = VERSION_WAN2) + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "", + SDVersion version = VERSION_WAN2) : GGMLRunner(backend, offload_params_to_cpu) { wan_params.num_layers = 0; - for (auto pair : tensor_types) { + for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; if (tensor_name.find(prefix) == std::string::npos) continue; @@ -2117,7 +2123,7 @@ namespace WAN { LOG_INFO("%s", desc.c_str()); wan = Wan(wan_params); - wan.init(params_ctx, tensor_types, prefix); + wan.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { @@ -2254,17 +2260,16 @@ namespace WAN { return; } - auto tensor_types = model_loader.tensor_storages_types; - for (auto& item : tensor_types) { - // LOG_DEBUG("%s %u", item.first.c_str(), item.second); - if (ends_with(item.first, "weight")) { - item.second = model_data_type; + auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + for (auto& [name, tensor_storage] : tensor_storage_map) { + if (ends_with(name, "weight")) { + tensor_storage.expected_type = model_data_type; } } std::shared_ptr wan = std::make_shared(backend, false, - tensor_types, + tensor_storage_map, "model.diffusion_model", VERSION_WAN2_2_TI2V);