From 3aa835bfe677e7f42eab678be4b0e3615ee983a4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 Oct 2025 11:28:45 +0200 Subject: [PATCH 1/8] clip : use FA --- ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal.metal | 8 ++++++++ tests/test-backend-ops.cpp | 4 ++-- tools/mtmd/clip.cpp | 24 ++++++------------------ 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 360fbe19f0fb6..0cadd19a30fe9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -707,6 +707,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te if (op->src[0]->ne[0] != 32 && op->src[0]->ne[0] != 40 && op->src[0]->ne[0] != 64 && + op->src[0]->ne[0] != 72 && op->src[0]->ne[0] != 80 && op->src[0]->ne[0] != 96 && op->src[0]->ne[0] != 112 && diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index fa839a1df6e30..424c400f24b9b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5362,6 +5362,7 @@ typedef decltype(kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5374,6 +5375,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5387,6 +5389,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5400,6 +5403,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5412,6 +5416,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5424,6 +5429,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5436,6 +5442,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5448,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 92361d6f0f4d7..1eaf844353578 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7168,8 +7168,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v)); } - for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) { - for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) { + for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) { + for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) { if (hsk != 192 && hsk != 576 && hsk != hsv) continue; if (hsk == 192 && (hsv != 128 && hsv != 192)) continue; if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index b312fda637f3b..5c34ae37cb111 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2244,28 +2244,16 @@ struct clip_graph { ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); //cb(k, "k", il); - ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); - v = ggml_cont(ctx0, v); + ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); //cb(k, "v", il); - ggml_tensor * cur; + k = ggml_cast(ctx0, k, GGML_TYPE_F16); + v = ggml_cast(ctx0, v, GGML_TYPE_F16); - // TODO @ngxson : support flash attention - { - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; - // const auto n_kv = k->ne[1]; // for flash attention - - ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - // F32 may not needed for vision encoders? - // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + ggml_tensor * cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); - kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); - - ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); - } + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); cb(cur, "kqv_out", il); From a4b54f2697739ed60e42245fa80b38b8f480b748 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Oct 2025 18:15:34 +0200 Subject: [PATCH 2/8] cont : add warning about unsupported ops --- tools/mtmd/clip.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 5c34ae37cb111..eed93ba05d130 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -28,6 +28,7 @@ #include #include +// TODO: allow to pass callback from user code struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; enum ffn_op_type { @@ -3188,6 +3189,11 @@ struct clip_model_loader { size / 1024.0 / 1024.0); } } + + const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get()); + const int n_nodes = ggml_graph_n_nodes(gf); + + LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes); } void get_bool(const std::string & key, bool & output, bool required = true) { @@ -4373,6 +4379,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return false; // only support batch size of 1 } + if (ggml_backend_sched_get_n_splits(ctx->sched.get()) > 1) { + LOG_WRN("%s: *****************************************************************\n", __func__); + LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__); + LOG_WRN("%s: the performance will be suboptimal \n", __func__); + LOG_WRN("%s: \n", __func__); + LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__); + LOG_WRN("%s: *****************************************************************\n", __func__); + } + // build the inference graph ctx->debug_print_tensors.clear(); ggml_backend_sched_reset(ctx->sched.get()); From b4955f0ae6f001898f21c757ca4a98b5ed7be421 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 1 Nov 2025 23:52:40 +0100 Subject: [PATCH 3/8] implement "auto" mode for clip flash attn --- tools/mtmd/clip.cpp | 77 +++++++++++++++++++++++++++++++++++------ tools/mtmd/clip.h | 2 ++ tools/mtmd/mtmd-cli.cpp | 1 + tools/mtmd/mtmd.cpp | 2 ++ tools/mtmd/mtmd.h | 1 + tools/server/server.cpp | 1 + 6 files changed, 74 insertions(+), 10 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 70b93e3425668..74ba6c27054c8 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -4,6 +4,7 @@ // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch #include "clip.h" #include "clip-impl.h" +#include "mtmd.h" #include "ggml.h" #include "ggml-cpp.h" #include "ggml-cpu.h" @@ -427,12 +428,14 @@ struct clip_ctx { int max_nodes = 8192; ggml_backend_sched_ptr sched; + llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // for debugging bool debug_graph = false; std::vector debug_print_tensors; clip_ctx(clip_context_params & ctx_params) { + flash_attn_type = ctx_params.flash_attn_type; debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr; backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); if (!backend_cpu) { @@ -2261,16 +2264,36 @@ struct clip_graph { ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); //cb(k, "k", il); - ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); - //cb(k, "v", il); + ggml_tensor * cur; - k = ggml_cast(ctx0, k, GGML_TYPE_F16); - v = ggml_cast(ctx0, v, GGML_TYPE_F16); + if (ctx->flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) { + ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); - ggml_tensor * cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + k = ggml_cast(ctx0, k, GGML_TYPE_F16); + v = ggml_cast(ctx0, v, GGML_TYPE_F16); - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); + cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); + + } else { + ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); + v = ggml_cont(ctx0, v); + + const auto n_tokens = q->ne[1]; + const auto n_head = q->ne[2]; + + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + // F32 may not needed for vision encoders? + // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + } cb(cur, "kqv_out", il); @@ -3181,7 +3204,30 @@ struct clip_model_loader { } } - void alloc_compute_meta(clip_ctx & ctx_clip) { + void warmup(clip_ctx & ctx_clip) { + if (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { + // try to enable flash attention to see if it's supported + ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; + bool supported = alloc_compute_meta(ctx_clip); + if (!supported) { + LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__); + // TODO: maybe log more details about why flash attention is not supported + ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + alloc_compute_meta(ctx_clip); + } + } else { + bool supported = alloc_compute_meta(ctx_clip); + if (!supported) { + LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__); + } + } + + LOG_INF("%s: flash attention is %s\n", __func__, + (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); + } + + // return false if flash attention is not supported + bool alloc_compute_meta(clip_ctx & ctx_clip) { const auto & hparams = ctx_clip.model.hparams; ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); @@ -3217,6 +3263,17 @@ struct clip_model_loader { const int n_nodes = ggml_graph_n_nodes(gf); LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes); + + // check flash attention support + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * node = ggml_graph_node(gf, i); + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + if (!ggml_backend_supports_op(ctx_clip.backend, node)) { + return false; + } + } + } + return true; } void get_bool(const std::string & key, bool & output, bool required = true) { @@ -3306,14 +3363,14 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_vision = new clip_ctx(ctx_params); loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION); loader.load_tensors(*ctx_vision); - loader.alloc_compute_meta(*ctx_vision); + loader.warmup(*ctx_vision); } if (loader.has_audio) { ctx_audio = new clip_ctx(ctx_params); loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO); loader.load_tensors(*ctx_audio); - loader.alloc_compute_meta(*ctx_audio); + loader.warmup(*ctx_audio); } } catch (const std::exception & e) { diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 3387cdbd36955..afef15205e20e 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -1,6 +1,7 @@ #pragma once #include "ggml.h" +#include "mtmd.h" #include #include @@ -25,6 +26,7 @@ enum clip_modality { struct clip_context_params { bool use_gpu; enum ggml_log_level verbosity; + llama_flash_attn_type flash_attn_type; }; struct clip_init_result { diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index fd1fb6581b163..17aea1472b3c6 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -136,6 +136,7 @@ struct mtmd_cli_context { mparams.print_timings = true; mparams.n_threads = params.cpuparams.n_threads; mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mparams.flash_attn_type = params.flash_attn_type; ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) { LOG_ERR("Failed to load vision model from %s\n", clip_path); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 196641dd95ef4..88930bbb5e2f4 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -100,6 +100,7 @@ mtmd_context_params mtmd_context_params_default() { params.verbosity = GGML_LOG_LEVEL_INFO; params.image_marker = MTMD_DEFAULT_IMAGE_MARKER; params.media_marker = mtmd_default_marker(); + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; return params; } @@ -164,6 +165,7 @@ struct mtmd_context { clip_context_params ctx_clip_params; ctx_clip_params.use_gpu = ctx_params.use_gpu; ctx_clip_params.verbosity = ctx_params.verbosity; + ctx_clip_params.flash_attn_type = ctx_params.flash_attn_type; auto res = clip_init(mmproj_fname, ctx_clip_params); ctx_v = res.ctx_v; ctx_a = res.ctx_a; diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 0b5d2ba0c7634..91f4f0e4d705d 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -82,6 +82,7 @@ struct mtmd_context_params { enum ggml_log_level verbosity; const char * image_marker; // deprecated, use media_marker instead const char * media_marker; + llama_flash_attn_type flash_attn_type; }; MTMD_API const char * mtmd_default_marker(void); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 92d30664e41f4..fa45531eae254 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2456,6 +2456,7 @@ struct server_context { mparams.print_timings = false; mparams.n_threads = params_base.cpuparams.n_threads; mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mparams.flash_attn_type = params_base.flash_attn_type; mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); if (mctx == nullptr) { SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); From bdb43f6e9c0a8efb0ed7c6be8354ed9172099e25 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 2 Nov 2025 10:13:48 +0200 Subject: [PATCH 4/8] clip : print more detailed op support info during warmup --- tools/mtmd/clip.cpp | 134 +++++++++++++++++++++++++++++--------------- tools/mtmd/clip.h | 10 +++- tools/mtmd/mtmd.cpp | 16 ++++-- tools/mtmd/mtmd.h | 2 +- 4 files changed, 108 insertions(+), 54 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 74ba6c27054c8..524b370efaaf7 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -4,10 +4,8 @@ // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch #include "clip.h" #include "clip-impl.h" -#include "mtmd.h" #include "ggml.h" #include "ggml-cpp.h" -#include "ggml-cpu.h" #include "ggml-alloc.h" #include "ggml-backend.h" #include "gguf.h" @@ -18,15 +16,12 @@ #include #include #include -#include #include #include #include -#include #include #include #include -#include #include // TODO: allow to pass callback from user code @@ -428,7 +423,7 @@ struct clip_ctx { int max_nodes = 8192; ggml_backend_sched_ptr sched; - llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; + clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO; // for debugging bool debug_graph = false; @@ -2266,7 +2261,7 @@ struct clip_graph { ggml_tensor * cur; - if (ctx->flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) { + if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); k = ggml_cast(ctx0, k, GGML_TYPE_F16); @@ -3204,30 +3199,58 @@ struct clip_model_loader { } } - void warmup(clip_ctx & ctx_clip) { - if (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { + struct support_info_op { + ggml_tensor * op; + + // true if the op runs on the accelerated ctx_clip.backend + bool is_accel = true; + }; + + struct support_info_graph { + // whether the clip_ctx.backend supports flash attention + bool fattn = true; + + std::vector ops; + }; + + static void warmup(clip_ctx & ctx_clip) { + support_info_graph info; + + if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) { // try to enable flash attention to see if it's supported - ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; - bool supported = alloc_compute_meta(ctx_clip); - if (!supported) { + ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED; + info = alloc_compute_meta(ctx_clip); + if (!info.fattn) { LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__); // TODO: maybe log more details about why flash attention is not supported - ctx_clip.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED; alloc_compute_meta(ctx_clip); } } else { - bool supported = alloc_compute_meta(ctx_clip); - if (!supported) { + info = alloc_compute_meta(ctx_clip); + if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__); } } LOG_INF("%s: flash attention is %s\n", __func__, - (ctx_clip.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); + (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); + + // print ops that are not supported by the GPU backend (if there is one) + if (ctx_clip.backend && ctx_clip.backend != ctx_clip.backend_cpu) { + for (const auto & op : info.ops) { + if (!op.is_accel) { + LOG_WRN("%s: op %16s is not supported by the CLIP backend: type = %s, ne = [%d %d %d %d]\n", __func__, + ggml_op_name(op.op->op), + ggml_type_name(op.op->type), + op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]); + } + } + } } - // return false if flash attention is not supported - bool alloc_compute_meta(clip_ctx & ctx_clip) { + // return false if at least one op is not supported by the backend + static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) { const auto & hparams = ctx_clip.model.hparams; ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); @@ -3264,67 +3287,87 @@ struct clip_model_loader { LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes); - // check flash attention support + support_info_graph res { + /*.fattn = */ true, + /*.ops = */ {}, + }; + + // check op support for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { ggml_tensor * node = ggml_graph_node(gf, i); - if (node->op == GGML_OP_FLASH_ATTN_EXT) { - if (!ggml_backend_supports_op(ctx_clip.backend, node)) { - return false; + res.ops.push_back({node, true}); + if (!ggml_backend_supports_op(ctx_clip.backend, node)) { + res.ops.back().is_accel = false; + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + res.fattn = false; } } } - return true; + + return res; } - void get_bool(const std::string & key, bool & output, bool required = true) { + void get_bool(const std::string & key, bool & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); + if (required) { + throw std::runtime_error("Key not found: " + key); + } return; } output = gguf_get_val_bool(ctx_gguf.get(), i); } - void get_i32(const std::string & key, int & output, bool required = true) { + void get_i32(const std::string & key, int & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); + if (required) { + throw std::runtime_error("Key not found: " + key); + } return; } output = gguf_get_val_i32(ctx_gguf.get(), i); } - void get_u32(const std::string & key, int & output, bool required = true) { + void get_u32(const std::string & key, int & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); + if (required) { + throw std::runtime_error("Key not found: " + key); + } return; } output = gguf_get_val_u32(ctx_gguf.get(), i); } - void get_f32(const std::string & key, float & output, bool required = true) { + void get_f32(const std::string & key, float & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); + if (required) { + throw std::runtime_error("Key not found: " + key); + } return; } output = gguf_get_val_f32(ctx_gguf.get(), i); } - void get_string(const std::string & key, std::string & output, bool required = true) { + void get_string(const std::string & key, std::string & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); + if (required) { + throw std::runtime_error("Key not found: " + key); + } return; } output = std::string(gguf_get_val_str(ctx_gguf.get(), i)); } - void get_arr_int(const std::string & key, std::vector & output, bool required = true) { + void get_arr_int(const std::string & key, std::vector & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); + if (required) { + throw std::runtime_error("Key not found: " + key); + } return; } int n = gguf_get_arr_n(ctx_gguf.get(), i); @@ -3335,7 +3378,7 @@ struct clip_model_loader { } } - void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) { + static void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) { auto & hparams = model.hparams; for (int x = 1; x <= max_patches_per_side; x++) { for (int y = 1; y <= max_patches_per_side; y++) { @@ -3375,12 +3418,10 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params } catch (const std::exception & e) { LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what()); - if (ctx_vision) { - delete ctx_vision; - } - if (ctx_audio) { - delete ctx_audio; - } + + delete ctx_vision; + delete ctx_audio; + return {nullptr, nullptr}; } @@ -3418,10 +3459,10 @@ void clip_image_size_free(struct clip_image_size * load_image_size) { } delete load_image_size; } -void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; } -void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; } -void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; } -void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; } +void clip_image_u8_free(struct clip_image_u8 * img) { delete img; } +void clip_image_f32_free(struct clip_image_f32 * img) { delete img; } +void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { delete batch; } +void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { delete batch; } size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) { return batch->entries.size(); @@ -4539,6 +4580,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima if (ggml_backend_sched_get_n_splits(ctx->sched.get()) > 1) { LOG_WRN("%s: *****************************************************************\n", __func__); LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__); + LOG_WRN("%s: use GGML_SCHED_DEBUG=2 to determine which ops \n", __func__); LOG_WRN("%s: the performance will be suboptimal \n", __func__); LOG_WRN("%s: \n", __func__); LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__); diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index afef15205e20e..6384e2adaf775 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -1,7 +1,7 @@ #pragma once #include "ggml.h" -#include "mtmd.h" + #include #include @@ -23,10 +23,16 @@ enum clip_modality { CLIP_MODALITY_AUDIO, }; +enum clip_flash_attn_type { + CLIP_FLASH_ATTN_TYPE_AUTO = -1, + CLIP_FLASH_ATTN_TYPE_DISABLED = 0, + CLIP_FLASH_ATTN_TYPE_ENABLED = 1, +}; + struct clip_context_params { bool use_gpu; enum ggml_log_level verbosity; - llama_flash_attn_type flash_attn_type; + enum clip_flash_attn_type flash_attn_type; }; struct clip_init_result { diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 88930bbb5e2f4..297eef437ab91 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include // represents raw image data, layout is RGBRGBRGB... @@ -92,6 +91,15 @@ const char * mtmd_default_marker() { return "<__media__>"; } +static clip_flash_attn_type mtmd_get_clip_flash_attn_type(enum llama_flash_attn_type flash_attn_type) { + switch (flash_attn_type) { + case LLAMA_FLASH_ATTN_TYPE_AUTO: return CLIP_FLASH_ATTN_TYPE_AUTO; + case LLAMA_FLASH_ATTN_TYPE_DISABLED: return CLIP_FLASH_ATTN_TYPE_DISABLED; + case LLAMA_FLASH_ATTN_TYPE_ENABLED: return CLIP_FLASH_ATTN_TYPE_ENABLED; + } + return CLIP_FLASH_ATTN_TYPE_AUTO; +} + mtmd_context_params mtmd_context_params_default() { mtmd_context_params params; params.use_gpu = true; @@ -165,7 +173,7 @@ struct mtmd_context { clip_context_params ctx_clip_params; ctx_clip_params.use_gpu = ctx_params.use_gpu; ctx_clip_params.verbosity = ctx_params.verbosity; - ctx_clip_params.flash_attn_type = ctx_params.flash_attn_type; + ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type); auto res = clip_init(mmproj_fname, ctx_clip_params); ctx_v = res.ctx_v; ctx_a = res.ctx_a; @@ -380,9 +388,7 @@ mtmd_context * mtmd_init_from_file(const char * mmproj_fname, } void mtmd_free(mtmd_context * ctx) { - if (ctx) { - delete ctx; - } + delete ctx; } struct mtmd_tokenizer { diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 91f4f0e4d705d..4ae1925bcdfb6 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -82,7 +82,7 @@ struct mtmd_context_params { enum ggml_log_level verbosity; const char * image_marker; // deprecated, use media_marker instead const char * media_marker; - llama_flash_attn_type flash_attn_type; + enum llama_flash_attn_type flash_attn_type; }; MTMD_API const char * mtmd_default_marker(void); From 29330dcb5583b55f59f5f4867cb474fba388b86e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 2 Nov 2025 10:16:17 +0200 Subject: [PATCH 5/8] cont : remove obsolete comment [no ci] --- tools/mtmd/clip.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 524b370efaaf7..5116d3c094435 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -3249,7 +3249,6 @@ struct clip_model_loader { } } - // return false if at least one op is not supported by the backend static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) { const auto & hparams = ctx_clip.model.hparams; ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); From b67a168f10d412a887da6a8d874dea8d194819b8 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 2 Nov 2025 12:08:49 +0100 Subject: [PATCH 6/8] improve debugging message --- ggml/src/ggml-metal/ggml-metal-device.m | 1 + tools/mtmd/clip.cpp | 56 +++++++++++++++++-------- 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 0cadd19a30fe9..23e0f90ab537b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -703,6 +703,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: + return false; // for new head sizes, add checks here if (op->src[0]->ne[0] != 32 && op->src[0]->ne[0] != 40 && diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 5116d3c094435..1aa5a9dedfef7 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -3209,6 +3209,7 @@ struct clip_model_loader { struct support_info_graph { // whether the clip_ctx.backend supports flash attention bool fattn = true; + ggml_tensor * fattn_op = nullptr; // for debugging std::vector ops; }; @@ -3220,9 +3221,23 @@ struct clip_model_loader { // try to enable flash attention to see if it's supported ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED; info = alloc_compute_meta(ctx_clip); - if (!info.fattn) { - LOG_WRN("%s: flash attention not supported, memory usage will increase\n", __func__); - // TODO: maybe log more details about why flash attention is not supported + if (!info.fattn && info.fattn_op) { + auto op = info.fattn_op; + LOG_WRN("%s: *****************************************************************\n", __func__); + LOG_WRN("%s: WARNING: flash attention not supported by %s, memory usage will increase\n", __func__, ggml_backend_name(ctx_clip.backend)); + LOG_WRN("%s: op params: \n", __func__); + static auto print_shape = [](const char * fn, const char * name, ggml_tensor * t) { + LOG_WRN("%s: %s: type = %s, ne = [%d %d %d %d], nb = [%d %d %d %d]\n", fn, + name, ggml_type_name(t->type), + t->ne[0], t->ne[1], t->ne[2], t->ne[3], + t->nb[0], t->nb[1], t->nb[2], t->nb[3]); + }; + print_shape(__func__, " dst", op); + print_shape(__func__, "src0", op->src[0]); + print_shape(__func__, "src1", op->src[1]); + print_shape(__func__, "src2", op->src[2]); + LOG_WRN("%s: please report this on github as an issue\n", __func__); + LOG_WRN("%s: *****************************************************************\n", __func__); ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED; alloc_compute_meta(ctx_clip); } @@ -3238,13 +3253,28 @@ struct clip_model_loader { // print ops that are not supported by the GPU backend (if there is one) if (ctx_clip.backend && ctx_clip.backend != ctx_clip.backend_cpu) { + std::vector unsupported_ops; for (const auto & op : info.ops) { if (!op.is_accel) { - LOG_WRN("%s: op %16s is not supported by the CLIP backend: type = %s, ne = [%d %d %d %d]\n", __func__, + unsupported_ops.push_back(op); + } + } + if (!unsupported_ops.empty()) { + LOG_WRN("%s: *****************************************************************\n", __func__); + LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__); + LOG_WRN("%s: the performance will be suboptimal \n", __func__); + LOG_WRN("%s: list of unsupported ops (backend=%s):\n", __func__, ggml_backend_name(ctx_clip.backend)); + for (const auto & op : unsupported_ops) { + LOG_WRN("%s: %16s: type = %s, ne = [%d %d %d %d]\n", __func__, ggml_op_name(op.op->op), ggml_type_name(op.op->type), op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]); } + LOG_WRN("%s: flash attention is %s\n", __func__, + (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); + LOG_WRN("%s: please report this on github as an issue\n", __func__); + LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__); + LOG_WRN("%s: *****************************************************************\n", __func__); } } } @@ -3287,8 +3317,9 @@ struct clip_model_loader { LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes); support_info_graph res { - /*.fattn = */ true, - /*.ops = */ {}, + /*.fattn = */ true, + /*.fattn_op = */ nullptr, + /*.ops = */ {}, }; // check op support @@ -3298,7 +3329,8 @@ struct clip_model_loader { if (!ggml_backend_supports_op(ctx_clip.backend, node)) { res.ops.back().is_accel = false; if (node->op == GGML_OP_FLASH_ATTN_EXT) { - res.fattn = false; + res.fattn = false; + res.fattn_op = node; } } } @@ -4576,16 +4608,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return false; // only support batch size of 1 } - if (ggml_backend_sched_get_n_splits(ctx->sched.get()) > 1) { - LOG_WRN("%s: *****************************************************************\n", __func__); - LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__); - LOG_WRN("%s: use GGML_SCHED_DEBUG=2 to determine which ops \n", __func__); - LOG_WRN("%s: the performance will be suboptimal \n", __func__); - LOG_WRN("%s: \n", __func__); - LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__); - LOG_WRN("%s: *****************************************************************\n", __func__); - } - // build the inference graph ctx->debug_print_tensors.clear(); ggml_backend_sched_reset(ctx->sched.get()); From cdb3deae7674e239a93d589e780de4a03452835a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 2 Nov 2025 12:12:09 +0100 Subject: [PATCH 7/8] trailing space --- tools/mtmd/clip.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 1aa5a9dedfef7..a7e1799e93d45 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -3270,7 +3270,7 @@ struct clip_model_loader { ggml_type_name(op.op->type), op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]); } - LOG_WRN("%s: flash attention is %s\n", __func__, + LOG_WRN("%s: flash attention is %s\n", __func__, (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); LOG_WRN("%s: please report this on github as an issue\n", __func__); LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__); From d441c31b194503640b03fdd239d0e81a562121b0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 2 Nov 2025 18:24:00 +0200 Subject: [PATCH 8/8] metal : remove stray return --- ggml/src/ggml-metal/ggml-metal-device.m | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 23e0f90ab537b..0cadd19a30fe9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -703,7 +703,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: - return false; // for new head sizes, add checks here if (op->src[0]->ne[0] != 32 && op->src[0]->ne[0] != 40 &&