From c07bda55c66333efb8cbb5209fd20fb52ed60312 Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Mon, 26 May 2025 19:27:51 +0200 Subject: [PATCH 01/13] Update llama.cpp Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- llama-cpp-2/src/context/kv_cache.rs | 136 +++------------------------- llama-cpp-sys-2/llama.cpp | 2 +- 2 files changed, 13 insertions(+), 125 deletions(-) diff --git a/llama-cpp-2/src/context/kv_cache.rs b/llama-cpp-2/src/context/kv_cache.rs index d90a6b8a..7fd641bc 100644 --- a/llama-cpp-2/src/context/kv_cache.rs +++ b/llama-cpp-2/src/context/kv_cache.rs @@ -28,7 +28,7 @@ impl LlamaContext<'_> { /// * `dest` - The sequence id to copy the cache to. /// * `size` - The size of the cache to copy. pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) { - unsafe { llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, 0, size) } + unsafe { llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, 0, size) } } /// Copy the cache from one sequence to another. @@ -58,7 +58,7 @@ impl LlamaContext<'_> { .map_or(Ok(-1), i32::try_from) .map_err(KvCacheConversionError::P1TooLarge)?; unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1); + llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, p0, p1); } Ok(()) } @@ -92,18 +92,18 @@ impl LlamaContext<'_> { let p1 = p1 .map_or(Ok(-1), i32::try_from) .map_err(KvCacheConversionError::P1TooLarge)?; - Ok(unsafe { llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1) }) + Ok(unsafe { llama_cpp_sys_2::llama_kv_self_seq_rm(self.context.as_ptr(), src, p0, p1) }) } /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) #[must_use] pub fn get_kv_cache_used_cells(&self) -> i32 { - unsafe { llama_cpp_sys_2::llama_get_kv_cache_used_cells(self.context.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_kv_self_used_cells(self.context.as_ptr()) } } /// Clear the KV cache pub fn clear_kv_cache(&mut self) { - unsafe { llama_cpp_sys_2::llama_kv_cache_clear(self.context.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_kv_self_clear(self.context.as_ptr()) } } /// Removes all tokens that do not belong to the specified sequence @@ -112,7 +112,7 @@ impl LlamaContext<'_> { /// /// * `seq_id` - The sequence id to keep pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) { - unsafe { llama_cpp_sys_2::llama_kv_cache_seq_keep(self.context.as_ptr(), seq_id) } + unsafe { llama_cpp_sys_2::llama_kv_self_seq_keep(self.context.as_ptr(), seq_id) } } #[allow(clippy::doc_markdown)] @@ -147,7 +147,7 @@ impl LlamaContext<'_> { .map_or(Ok(-1), i32::try_from) .map_err(KvCacheConversionError::P1TooLarge)?; unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta); + llama_cpp_sys_2::llama_kv_self_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta); } Ok(()) } @@ -183,7 +183,7 @@ impl LlamaContext<'_> { .map_or(Ok(-1), i32::try_from) .map_err(KvCacheConversionError::P1TooLarge)?; let d = c_int::from(d.get()); - unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) } + unsafe { llama_cpp_sys_2::llama_kv_self_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) } Ok(()) } @@ -194,7 +194,7 @@ impl LlamaContext<'_> { /// * `seq_id` - The sequence id to get the max position for #[must_use] pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 { - unsafe { llama_cpp_sys_2::llama_kv_cache_seq_pos_max(self.context.as_ptr(), seq_id) } + unsafe { llama_cpp_sys_2::llama_kv_self_seq_pos_max(self.context.as_ptr(), seq_id) } } /// Defragment the KV cache @@ -202,130 +202,18 @@ impl LlamaContext<'_> { /// - lazily on next [`LlamaContext::decode`] /// - explicitly with [`Self::kv_cache_update`] pub fn kv_cache_defrag(&mut self) { - unsafe { llama_cpp_sys_2::llama_kv_cache_defrag(self.context.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_kv_self_defrag(self.context.as_ptr()) } } /// Apply the KV cache updates (such as K-shifts, defragmentation, etc.) pub fn kv_cache_update(&mut self) { - unsafe { llama_cpp_sys_2::llama_kv_cache_update(self.context.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_kv_self_update(self.context.as_ptr()) } } /// Returns the number of tokens in the KV cache (slow, use only for debug) /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times #[must_use] pub fn get_kv_cache_token_count(&self) -> i32 { - unsafe { llama_cpp_sys_2::llama_get_kv_cache_token_count(self.context.as_ptr()) } - } - - /// Create an empty KV cache view. (use only for debugging purposes) - /// - /// # Parameters - /// - /// * `n_max_seq` - Maximum number of sequences that can exist in a cell. It's not an error - /// if there are more sequences in a cell than this value, however they will - /// not be visible in the view `cells_sequences`. - #[must_use] - pub fn new_kv_cache_view(&self, n_max_seq: i32) -> KVCacheView { - let view = - unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) }; - KVCacheView { view, ctx: self } - } -} - -/// Information associated with an individual cell in the KV cache view. -#[derive(Debug)] -pub struct KVCacheViewCell { - /// The position for this cell. Takes KV cache shifts into account. - /// May be negative if the cell is not populated. - pub pos: llama_cpp_sys_2::llama_pos, -} - -/// An updateable view of the KV cache. (use only for debugging purposes) -#[derive(Debug)] -pub struct KVCacheView<'a> { - ctx: &'a LlamaContext<'a>, - view: llama_cpp_sys_2::llama_kv_cache_view, -} - -impl KVCacheView<'_> { - /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) - pub fn update(&mut self) { - unsafe { - llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view); - } - } - - /// Number of KV cache cells. This will be the same as the context size. - #[must_use] - pub fn n_cells(&self) -> i32 { - self.view.n_cells - } - - /// Number of tokens in the cache. For example, if there are two populated - /// cells, the first with 1 sequence id in it and the second with 2 sequence - /// ids then you'll have 3 tokens. - #[must_use] - pub fn token_count(&self) -> i32 { - self.view.token_count - } - - /// Number of populated cache cells. - #[must_use] - pub fn used_cells(&self) -> i32 { - self.view.used_cells - } - - /// Maximum contiguous empty slots in the cache. - #[must_use] - pub fn max_contiguous(&self) -> i32 { - self.view.max_contiguous - } - - /// Index to the start of the `max_contiguous` slot range. Can be negative - /// when cache is full. - #[must_use] - pub fn max_contiguous_idx(&self) -> i32 { - self.view.max_contiguous_idx - } - - /// Information for individual cells. - /// - /// # Panics - /// - /// - if `n_cells` does not fit into usize. - pub fn cells(&self) -> impl Iterator { - unsafe { - std::slice::from_raw_parts( - self.view.cells, - usize::try_from(self.view.n_cells).expect("failed to fit n_cells into usize"), - ) - } - .iter() - .map(|&cell| KVCacheViewCell { pos: cell.pos }) - } - - /// The sequences for each cell. There will be `n_max_seq` items per cell. - /// - /// # Panics - /// - /// - if `n_cells * n_max_seq` does not fit into usize. - /// - if `n_max_seq` does not fit into usize. - pub fn cells_sequences(&self) -> impl Iterator { - unsafe { - std::slice::from_raw_parts( - self.view.cells_sequences, - usize::try_from(self.view.n_cells * self.view.n_seq_max) - .expect("failed to fit n_cells * n_max_seq into usize"), - ) - } - .chunks(usize::try_from(self.view.n_seq_max).expect("failed to fit n_max_seq into usize")) - } -} - -impl Drop for KVCacheView<'_> { - fn drop(&mut self) { - unsafe { - llama_cpp_sys_2::llama_kv_cache_view_free(&mut self.view); - } + unsafe { llama_cpp_sys_2::llama_kv_self_n_tokens(self.context.as_ptr()) } } } diff --git a/llama-cpp-sys-2/llama.cpp b/llama-cpp-sys-2/llama.cpp index ceda28ef..79c137f7 160000 --- a/llama-cpp-sys-2/llama.cpp +++ b/llama-cpp-sys-2/llama.cpp @@ -1 +1 @@ -Subproject commit ceda28ef8e310a8dee60bf275077a3eedae8e36c +Subproject commit 79c137f77677b3c8ee3c60a7da033721b938399a From 945489787a3dec05d557ef4c2406428bfc110cd0 Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Mon, 21 Jul 2025 23:25:22 +0200 Subject: [PATCH 02/13] WIP Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- Cargo.lock | 8 + Cargo.toml | 1 + examples/mtmd/Cargo.toml | 24 ++ examples/mtmd/src/mtmd.rs | 298 ++++++++++++++++ examples/reranker/Cargo.toml | 2 +- llama-cpp-2/src/lib.rs | 1 + llama-cpp-2/src/model.rs | 9 +- llama-cpp-2/src/mtmd.rs | 530 ++++++++++++++++++++++++++++ llama-cpp-2/src/sampling.rs | 40 +-- llama-cpp-2/src/timing.rs | 2 + llama-cpp-2/src/token/logit_bias.rs | 17 +- llama-cpp-sys-2/Cargo.toml | 15 +- llama-cpp-sys-2/build.rs | 77 +++- llama-cpp-sys-2/llama.cpp | 2 +- llama-cpp-sys-2/wrapper.h | 4 +- 15 files changed, 965 insertions(+), 65 deletions(-) create mode 100644 examples/mtmd/Cargo.toml create mode 100644 examples/mtmd/src/mtmd.rs create mode 100644 llama-cpp-2/src/mtmd.rs diff --git a/Cargo.lock b/Cargo.lock index 3f7f3918..820bb162 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -711,6 +711,14 @@ dependencies = [ "adler", ] +[[package]] +name = "mtmd" +version = "0.1.86" +dependencies = [ + "clap", + "llama-cpp-2", +] + [[package]] name = "native-tls" version = "0.2.12" diff --git a/Cargo.toml b/Cargo.toml index f1b63014..bbd4bd8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "examples/embeddings", "examples/simple", "examples/reranker", + "examples/mtmd", ] [workspace.dependencies] diff --git a/examples/mtmd/Cargo.toml b/examples/mtmd/Cargo.toml new file mode 100644 index 00000000..24aed021 --- /dev/null +++ b/examples/mtmd/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "mtmd" +version = "0.1.86" +edition = "2021" + +[dependencies] +llama-cpp-2 = { path = "../../llama-cpp-2", version = "0.1.86" } +clap = { workspace = true, features = ["derive"] } +# hf-hub = { workspace = true } +# anyhow = { workspace = true } +# encoding_rs = { workspace = true } + +[features] +cuda = ["llama-cpp-2/cuda"] +metal = ["llama-cpp-2/metal"] +native = ["llama-cpp-2/native"] +vulkan = ["llama-cpp-2/vulkan"] + +[lints] +workspace = true + +[[example]] +name = "mtmd" +path = "src/mtmd.rs" diff --git a/examples/mtmd/src/mtmd.rs b/examples/mtmd/src/mtmd.rs new file mode 100644 index 00000000..ba4b30ae --- /dev/null +++ b/examples/mtmd/src/mtmd.rs @@ -0,0 +1,298 @@ +//! Based on the mtmd cli example from llama.cpp. + +use std::ffi::CString; +use std::io::{self, Write}; +use std::path::Path; + +use clap::Parser; + +use llama_cpp_2::context::params::LlamaContextParams; +use llama_cpp_2::context::LlamaContext; +use llama_cpp_2::llama_batch::LlamaBatch; +use llama_cpp_2::model::params::LlamaModelParams; +use llama_cpp_2::mtmd::*; + +use llama_cpp_2::llama_backend::LlamaBackend; +use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel, Special}; +use llama_cpp_2::sampling::LlamaSampler; + +/// Command line parameters for the MTMD CLI application +#[derive(clap::Parser, Debug)] +#[command(name = "mtmd-cli")] +#[command(about = "Experimental CLI for multimodal llama.cpp")] +pub struct MtmdCliParams { + /// Path to the model file + #[arg(short = 'm', long = "model", value_name = "PATH")] + pub model_path: String, + /// Path to the multimodal projection file + #[arg(long = "mmproj", value_name = "PATH")] + pub mmproj_path: String, + /// Path to image file(s) + #[arg(long = "image", value_name = "PATH")] + pub images: Vec, + /// Path to audio file(s) + #[arg(long = "audio", value_name = "PATH")] + pub audio: Vec, + /// Text prompt to use as input to the model. May include media markers - else they will be added automatically. + #[arg(short = 'p', long = "prompt", value_name = "TEXT")] + pub prompt: String, + /// Number of tokens to predict (-1 for unlimited) + #[arg( + short = 'n', + long = "n-predict", + value_name = "N", + default_value = "-1" + )] + pub n_predict: i32, + /// Number of threads + #[arg(short = 't', long = "threads", value_name = "N", default_value = "4")] + pub n_threads: i32, + /// Maximum number of tokens in context + #[arg(long = "n-tokens", value_name = "N", default_value = "2048")] + pub n_tokens: usize, + /// Chat template to use, default template if not provided + #[arg(long = "chat-template", value_name = "TEMPLATE")] + pub chat_template: Option, + /// Disable GPU acceleration + #[arg(long = "no-gpu")] + pub no_gpu: bool, + /// Disable GPU offload for multimodal projection + #[arg(long = "no-mmproj-offload")] + pub no_mmproj_offload: bool, + /// Media marker. If not provided, the default marker will be used. + #[arg(long = "marker", value_name = "TEXT")] + pub media_marker: Option, +} + +/// State of the MTMD CLI application. +#[allow(missing_debug_implementations)] +pub struct MtmdCliContext { + /// The MTMD context for multimodal processing. + pub mtmd_ctx: MtmdContext, + /// The batch used for processing tokens. + pub batch: LlamaBatch, + /// The list of loaded bitmaps (images/audio). + pub bitmaps: Vec, + /// The number of past tokens processed. + pub n_past: i32, + /// The chat template used for formatting messages. + pub chat_template: LlamaChatTemplate, + /// The current chat messages history. + pub chat: Vec, +} + +impl MtmdCliContext { + /// Creates a new MTMD CLI context + pub fn new( + params: &MtmdCliParams, + model: &LlamaModel, + ) -> Result> { + // Initialize MTMD context + let mtmd_params = MtmdContextParams { + use_gpu: !params.no_gpu && !params.no_mmproj_offload, + print_timings: true, + n_threads: params.n_threads, + media_marker: CString::new( + params + .media_marker + .as_ref() + .unwrap_or(&llama_cpp_2::mtmd::mtmd_default_marker().to_string()) + .clone(), + )?, + }; + + let mtmd_ctx = MtmdContext::init_from_file(¶ms.mmproj_path, model, mtmd_params)?; + + let chat_template = model + .chat_template(params.chat_template.as_deref()) + .map_err(|e| format!("Failed to get chat template: {}", e))?; + + let batch = LlamaBatch::new(params.n_tokens, 1); + + Ok(Self { + mtmd_ctx, + batch, + chat: Vec::new(), + bitmaps: Vec::new(), + n_past: 0, + chat_template, + }) + } + + /// Loads media (image or audio) from the specified file path + pub fn load_media(&mut self, path: &str) -> Result<(), MtmdBitmapError> { + let bitmap = MtmdBitmap::from_file(&self.mtmd_ctx, path)?; + self.bitmaps.push(bitmap); + Ok(()) + } + + /// Evaluates a chat message, tokenizing and processing it through the model + pub fn eval_message( + &mut self, + model: &LlamaModel, + context: &mut LlamaContext, + msg: LlamaChatMessage, + add_bos: bool, + ) -> Result<(), Box> { + self.chat.push(msg); + + // Format the message using chat template (simplified) + let formatted_prompt = model.apply_chat_template(&self.chat_template, &self.chat, true)?; + + let input_text = MtmdInputText { + text: formatted_prompt, + add_special: add_bos, + parse_special: true, + }; + + let bitmap_refs: Vec<&MtmdBitmap> = self.bitmaps.iter().collect(); + + if bitmap_refs.is_empty() { + println!("No bitmaps provided, only tokenizing text"); + } else { + println!("Tokenizing with {} bitmaps", bitmap_refs.len()); + } + + // Tokenize the input + let chunks = self.mtmd_ctx.tokenize(input_text, &bitmap_refs)?; + + println!("Tokenization complete, {} chunks created", chunks.len()); + + // Clear bitmaps after tokenization + self.bitmaps.clear(); + + self.n_past = chunks.eval_chunks(&self.mtmd_ctx, &context, 0, 0, 1, true)?; + Ok(()) + } + + /// Generates a response by sampling tokens from the model + pub fn generate_response( + &mut self, + model: &LlamaModel, + context: &mut LlamaContext, + sampler: &mut LlamaSampler, + n_predict: i32, + ) -> Result<(), Box> { + let mut generated_tokens = Vec::new(); + let max_predict = if n_predict < 0 { i32::MAX } else { n_predict }; + + for _i in 0..max_predict { + // Sample next token + let token = sampler.sample(context, 0); + generated_tokens.push(token); + sampler.accept(token); + + // Check for end of generation + if model.is_eog_token(token) { + println!(); + break; + } + + // Print token + let piece = model.token_to_str(token, Special::Tokenize)?; + print!("{}", piece); + io::stdout().flush()?; + + // Prepare next batch + self.batch.clear(); + self.batch.add(token, self.n_past, &[0], true)?; + self.n_past += 1; + + // Decode + context.decode(&mut self.batch)?; + } + + Ok(()) + } +} + +fn run_single_turn( + ctx: &mut MtmdCliContext, + model: &LlamaModel, + context: &mut LlamaContext, + sampler: &mut LlamaSampler, + params: &MtmdCliParams, +) -> Result<(), Box> { + // Add media marker if not present + let mut prompt = params.prompt.clone(); + let default_marker = llama_cpp_2::mtmd::mtmd_default_marker().to_string(); + let media_marker = params.media_marker.as_ref().unwrap_or(&default_marker); + if !prompt.contains(media_marker) { + prompt.push_str(media_marker); + } + + // Load media files + for image_path in ¶ms.images { + println!("Loading image: {}", image_path); + ctx.load_media(image_path)?; + } + for audio_path in ¶ms.audio { + ctx.load_media(audio_path)?; + } + + // Create user message + let msg = LlamaChatMessage::new("user".to_string(), prompt)?; + + println!("Evaluating message: {:?}", msg); + + // Evaluate the message (prefill) + ctx.eval_message(model, context, msg, true)?; + + // Generate response (decode) + ctx.generate_response(model, context, sampler, params.n_predict)?; + + Ok(()) +} + +fn main() -> Result<(), Box> { + let params = MtmdCliParams::parse(); + + // Validate required parameters + if !Path::new(¶ms.model_path).exists() { + eprintln!("Error: Model file not found: {}", params.model_path); + return Err("Model file not found".into()); + } + + if !Path::new(¶ms.mmproj_path).exists() { + eprintln!( + "Error: Multimodal projection file not found: {}", + params.mmproj_path + ); + return Err("Multimodal projection file not found".into()); + } + + println!("Loading model: {}", params.model_path); + + // Initialize backend + let backend = LlamaBackend::init()?; + + // Setup model parameters + let mut model_params = LlamaModelParams::default(); + if !params.no_gpu { + model_params = model_params.with_n_gpu_layers(1000000); // Use all layers on GPU + } + + // Load model + let model = LlamaModel::load_from_file(&backend, ¶ms.model_path, &model_params)?; + + // Create context + let context_params = LlamaContextParams::default() + .with_n_threads(params.n_threads) + .with_n_batch(1); + let mut context = model.new_context(&backend, context_params)?; + + // Create sampler + let mut sampler = LlamaSampler::chain_simple([LlamaSampler::greedy()]); + + println!("Model loaded successfully"); + println!("Loading mtmd projection: {}", params.mmproj_path); + + // Create the MTMD context + let mut ctx = MtmdCliContext::new(¶ms, &model)?; + + run_single_turn(&mut ctx, &model, &mut context, &mut sampler, ¶ms)?; + + println!("\n"); + + Ok(()) +} diff --git a/examples/reranker/Cargo.toml b/examples/reranker/Cargo.toml index fa32c2d3..753a7dd1 100644 --- a/examples/reranker/Cargo.toml +++ b/examples/reranker/Cargo.toml @@ -17,4 +17,4 @@ native = ["llama-cpp-2/native"] vulkan = ["llama-cpp-2/vulkan"] [lints] -workspace = true \ No newline at end of file +workspace = true diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index f2ac5313..2c79ba46 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -27,6 +27,7 @@ pub mod llama_backend; pub mod llama_batch; mod log; pub mod model; +pub mod mtmd; pub mod sampling; pub mod timing; pub mod token; diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index b8cd26bb..d2df9990 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -13,9 +13,9 @@ use crate::model::params::LlamaModelParams; use crate::token::LlamaToken; use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs}; use crate::{ - ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, - LlamaLoraAdapterInitError, LlamaModelLoadError, MetaValError, NewLlamaChatMessageError, - StringToTokenError, TokenToStringError, + ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError, + LlamaModelLoadError, MetaValError, NewLlamaChatMessageError, StringToTokenError, + TokenToStringError, }; pub mod params; @@ -488,7 +488,8 @@ impl LlamaModel { pub fn n_head_kv(&self) -> u32 { // It's never possible for this to panic because while the API interface is defined as an int32_t, // the field it's accessing is a uint32_t. - u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) }).unwrap() + u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) }) + .unwrap() } /// Get metadata value as a string by key name diff --git a/llama-cpp-2/src/mtmd.rs b/llama-cpp-2/src/mtmd.rs new file mode 100644 index 00000000..616b2d48 --- /dev/null +++ b/llama-cpp-2/src/mtmd.rs @@ -0,0 +1,530 @@ +use std::ffi::{CStr, CString}; +use std::ptr::NonNull; +use std::slice; + +use crate::context::LlamaContext; +use crate::model::LlamaModel; +use crate::token::LlamaToken; + +/// Input chunk types for multimodal data +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MtmdInputChunkType { + Text = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT as isize, + Image = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE as isize, + Audio = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO as isize, +} + +impl From for MtmdInputChunkType { + fn from(chunk_type: llama_cpp_sys_2::mtmd_input_chunk_type) -> Self { + match chunk_type { + llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT => MtmdInputChunkType::Text, + llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE => MtmdInputChunkType::Image, + llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO => MtmdInputChunkType::Audio, + _ => panic!("Unknown MTMD input chunk type"), + } + } +} + +/// Configuration parameters for MTMD context +#[derive(Debug, Clone)] +pub struct MtmdContextParams { + pub use_gpu: bool, + pub print_timings: bool, + pub n_threads: i32, + pub media_marker: CString, +} + +impl Default for MtmdContextParams { + fn default() -> Self { + Self { + use_gpu: false, + print_timings: true, + n_threads: 4, + media_marker: CString::new(mtmd_default_marker()).unwrap_or_default(), + } + } +} + +impl From<&MtmdContextParams> for llama_cpp_sys_2::mtmd_context_params { + fn from(params: &MtmdContextParams) -> Self { + let mut context = unsafe { llama_cpp_sys_2::mtmd_context_params_default() }; + + context.use_gpu = params.use_gpu; + context.print_timings = params.print_timings; + context.n_threads = params.n_threads; + context.media_marker = params.media_marker.as_ptr(); + + context + } +} + +/// Text input configuration +#[derive(Debug, Clone)] +pub struct MtmdInputText { + pub text: String, + pub add_special: bool, + pub parse_special: bool, +} + +/// Safe wrapper around `mtmd_context` +pub struct MtmdContext { + pub(crate) context: NonNull, +} + +impl MtmdContext { + /// Initialize MTMD context from a multimodal projection file + pub fn init_from_file( + mmproj_path: &str, + text_model: &LlamaModel, + params: MtmdContextParams, + ) -> Result { + let path_cstr = CString::new(mmproj_path)?; + let ctx_params = llama_cpp_sys_2::mtmd_context_params::from(¶ms); + + let context = unsafe { + llama_cpp_sys_2::mtmd_init_from_file( + path_cstr.as_ptr(), + text_model.model.as_ptr(), + ctx_params, + ) + }; + + if context.is_null() { + return Err(MtmdInitError::NullResult); + } + + let context = NonNull::new(context).ok_or(MtmdInitError::NullResult)?; + Ok(Self { context }) + } + + /// Check if non-causal mask is needed before llama_decode + pub fn decode_use_non_causal(&self) -> bool { + unsafe { llama_cpp_sys_2::mtmd_decode_use_non_causal(self.context.as_ptr()) } + } + + /// Check if the model uses M-RoPE for llama_decode + pub fn decode_use_mrope(&self) -> bool { + unsafe { llama_cpp_sys_2::mtmd_decode_use_mrope(self.context.as_ptr()) } + } + + /// Check if the model supports vision input + pub fn support_vision(&self) -> bool { + unsafe { llama_cpp_sys_2::mtmd_support_vision(self.context.as_ptr()) } + } + + /// Check if the model supports audio input + pub fn support_audio(&self) -> bool { + unsafe { llama_cpp_sys_2::mtmd_support_audio(self.context.as_ptr()) } + } + + /// Tokenize input text and bitmaps into chunks + pub fn tokenize( + &self, + text: MtmdInputText, + bitmaps: &[&MtmdBitmap], + ) -> Result { + let chunks = MtmdInputChunks::new(); + let text_cstring = CString::new(text.text).unwrap_or_default(); + let input_text = llama_cpp_sys_2::mtmd_input_text { + text: text_cstring.as_ptr(), + add_special: text.add_special, + parse_special: text.parse_special, + }; + + // Create bitmap pointers + let bitmap_ptrs: Vec<*const llama_cpp_sys_2::mtmd_bitmap> = bitmaps + .iter() + .map(|b| b.bitmap.as_ptr() as *const _) + .collect(); + + let result = unsafe { + llama_cpp_sys_2::mtmd_tokenize( + self.context.as_ptr(), + chunks.chunks.as_ptr(), + &input_text, + bitmap_ptrs.as_ptr() as *mut *const llama_cpp_sys_2::mtmd_bitmap, + bitmaps.len(), + ) + }; + + match result { + 0 => Ok(chunks), + 1 => Err(MtmdTokenizeError::BitmapCountMismatch), + 2 => Err(MtmdTokenizeError::ImagePreprocessingError), + _ => Err(MtmdTokenizeError::UnknownError(result)), + } + } + + /// Encode a chunk (for image/audio processing) + pub fn encode_chunk(&self, chunk: &MtmdInputChunk) -> Result<(), MtmdEncodeError> { + let result = unsafe { + llama_cpp_sys_2::mtmd_encode_chunk(self.context.as_ptr(), chunk.chunk.as_ptr()) + }; + + if result == 0 { + Ok(()) + } else { + Err(MtmdEncodeError::EncodeFailure(result)) + } + } + + /// Get output embeddings from the last encode pass + pub fn get_output_embeddings(&self) -> Option<&[f32]> { + let ptr = unsafe { llama_cpp_sys_2::mtmd_get_output_embd(self.context.as_ptr()) }; + if ptr.is_null() { + None + } else { + // Note: The size calculation would need context about the model and chunk + // For now, returning None when we can't determine size safely + None + } + } +} + +unsafe impl Send for MtmdContext {} +unsafe impl Sync for MtmdContext {} + +impl Drop for MtmdContext { + fn drop(&mut self) { + unsafe { llama_cpp_sys_2::mtmd_free(self.context.as_ptr()) } + } +} + +/// Safe wrapper around `mtmd_bitmap` +#[derive(Debug, Clone)] +pub struct MtmdBitmap { + pub(crate) bitmap: NonNull, +} + +impl MtmdBitmap { + /// Create a bitmap from image data (RGB format) + pub fn from_image_data(nx: u32, ny: u32, data: &[u8]) -> Result { + if data.len() != (nx * ny * 3) as usize { + return Err(MtmdBitmapError::InvalidDataSize); + } + + let bitmap = unsafe { llama_cpp_sys_2::mtmd_bitmap_init(nx, ny, data.as_ptr()) }; + + let bitmap = NonNull::new(bitmap).ok_or(MtmdBitmapError::NullResult)?; + Ok(Self { bitmap }) + } + + /// Create a bitmap from audio data (PCM F32 format) + pub fn from_audio_data(data: &[f32]) -> Result { + let bitmap = + unsafe { llama_cpp_sys_2::mtmd_bitmap_init_from_audio(data.len(), data.as_ptr()) }; + + let bitmap = NonNull::new(bitmap).ok_or(MtmdBitmapError::NullResult)?; + Ok(Self { bitmap }) + } + + /// Create a bitmap from a file + pub fn from_file(ctx: &MtmdContext, path: &str) -> Result { + let path_cstr = CString::new(path)?; + let bitmap = unsafe { + llama_cpp_sys_2::mtmd_helper_bitmap_init_from_file( + ctx.context.as_ptr(), + path_cstr.as_ptr(), + ) + }; + + let bitmap = NonNull::new(bitmap).ok_or(MtmdBitmapError::NullResult)?; + Ok(Self { bitmap }) + } + + /// Create a bitmap from a buffer containing file data + pub fn from_buffer(ctx: &MtmdContext, data: &[u8]) -> Result { + let bitmap = unsafe { + llama_cpp_sys_2::mtmd_helper_bitmap_init_from_buf( + ctx.context.as_ptr(), + data.as_ptr(), + data.len(), + ) + }; + + let bitmap = NonNull::new(bitmap).ok_or(MtmdBitmapError::NullResult)?; + Ok(Self { bitmap }) + } + + /// Get bitmap width + pub fn nx(&self) -> u32 { + unsafe { llama_cpp_sys_2::mtmd_bitmap_get_nx(self.bitmap.as_ptr()) } + } + + /// Get bitmap height + pub fn ny(&self) -> u32 { + unsafe { llama_cpp_sys_2::mtmd_bitmap_get_ny(self.bitmap.as_ptr()) } + } + + /// Get bitmap data as bytes + pub fn data(&self) -> &[u8] { + let ptr = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_data(self.bitmap.as_ptr()) }; + let len = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_n_bytes(self.bitmap.as_ptr()) }; + unsafe { slice::from_raw_parts(ptr, len) } + } + + /// Check if this is an audio bitmap + pub fn is_audio(&self) -> bool { + unsafe { llama_cpp_sys_2::mtmd_bitmap_is_audio(self.bitmap.as_ptr()) } + } + + /// Get bitmap ID (if set) + pub fn id(&self) -> Option { + let ptr = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_id(self.bitmap.as_ptr()) }; + if ptr.is_null() { + None + } else { + unsafe { CStr::from_ptr(ptr) } + .to_string_lossy() + .into_owned() + .into() + } + } + + /// Set bitmap ID + pub fn set_id(&self, id: &str) -> Result<(), std::ffi::NulError> { + let id_cstr = CString::new(id)?; + unsafe { + llama_cpp_sys_2::mtmd_bitmap_set_id(self.bitmap.as_ptr(), id_cstr.as_ptr()); + } + Ok(()) + } +} + +unsafe impl Send for MtmdBitmap {} +unsafe impl Sync for MtmdBitmap {} + +impl Drop for MtmdBitmap { + fn drop(&mut self) { + unsafe { llama_cpp_sys_2::mtmd_bitmap_free(self.bitmap.as_ptr()) } + } +} + +/// Safe wrapper around `mtmd_input_chunks` +pub struct MtmdInputChunks { + pub(crate) chunks: NonNull, +} + +impl MtmdInputChunks { + /// Create a new empty input chunks collection + pub fn new() -> Self { + let chunks = unsafe { llama_cpp_sys_2::mtmd_input_chunks_init() }; + let chunks = NonNull::new(chunks).unwrap(); + Self { chunks } + } + + /// Get the number of chunks + pub fn len(&self) -> usize { + unsafe { llama_cpp_sys_2::mtmd_input_chunks_size(self.chunks.as_ptr()) } + } + + /// Check if chunks collection is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Get a chunk by index + pub fn get(&self, index: usize) -> Option { + if index >= self.len() { + return None; + } + + let chunk_ptr = + unsafe { llama_cpp_sys_2::mtmd_input_chunks_get(self.chunks.as_ptr(), index) }; + + if chunk_ptr.is_null() { + None + } else { + // Note: We don't own this chunk, it's owned by the chunks collection + Some(MtmdInputChunk { + chunk: NonNull::new(chunk_ptr as *mut _).unwrap(), + owned: false, + }) + } + } + + /// Get total number of tokens across all chunks + pub fn total_tokens(&self) -> usize { + unsafe { llama_cpp_sys_2::mtmd_helper_get_n_tokens(self.chunks.as_ptr()) } + } + + /// Get total position count across all chunks + pub fn total_positions(&self) -> i32 { + unsafe { llama_cpp_sys_2::mtmd_helper_get_n_pos(self.chunks.as_ptr()) } + } + + /// Evaluate chunks using the multimodal context and LLAMA context + /// Returns the new n_past value on success + pub fn eval_chunks( + &self, + mtmd_ctx: &MtmdContext, + llama_ctx: &LlamaContext, + n_past: llama_cpp_sys_2::llama_pos, + seq_id: llama_cpp_sys_2::llama_seq_id, + n_batch: i32, + logits_last: bool, + ) -> Result { + let mut new_n_past: llama_cpp_sys_2::llama_pos = 0; + + let result = unsafe { + llama_cpp_sys_2::mtmd_helper_eval_chunks( + mtmd_ctx.context.as_ptr(), + llama_ctx.context.as_ptr(), + self.chunks.as_ptr(), + n_past, + seq_id, + n_batch, + logits_last, + &mut new_n_past, + ) + }; + + if result == 0 { + Ok(new_n_past) + } else { + Err(MtmdEvalError::EvalFailure(result)) + } + } +} + +impl Drop for MtmdInputChunks { + fn drop(&mut self) { + unsafe { llama_cpp_sys_2::mtmd_input_chunks_free(self.chunks.as_ptr()) } + } +} + +/// Safe wrapper around `mtmd_input_chunk` +pub struct MtmdInputChunk { + pub(crate) chunk: NonNull, + owned: bool, +} + +impl MtmdInputChunk { + /// Get the type of this chunk + pub fn chunk_type(&self) -> MtmdInputChunkType { + let chunk_type = unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_type(self.chunk.as_ptr()) }; + MtmdInputChunkType::from(chunk_type) + } + + /// Get text tokens (only valid for text chunks) + pub fn text_tokens(&self) -> Option<&[LlamaToken]> { + if self.chunk_type() != MtmdInputChunkType::Text { + return None; + } + + let mut n_tokens = 0usize; + let tokens_ptr = unsafe { + llama_cpp_sys_2::mtmd_input_chunk_get_tokens_text(self.chunk.as_ptr(), &mut n_tokens) + }; + + if tokens_ptr.is_null() || n_tokens == 0 { + None + } else { + unsafe { + Some(slice::from_raw_parts( + tokens_ptr as *const LlamaToken, + n_tokens, + )) + } + } + } + + /// Get the number of tokens in this chunk + pub fn n_tokens(&self) -> usize { + unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_n_tokens(self.chunk.as_ptr()) } + } + + /// Get the number of positions in this chunk + pub fn n_positions(&self) -> i32 { + unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_n_pos(self.chunk.as_ptr()) } + } + + /// Get chunk ID (if available) + pub fn id(&self) -> Option { + let ptr = unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_id(self.chunk.as_ptr()) }; + if ptr.is_null() { + None + } else { + unsafe { CStr::from_ptr(ptr) } + .to_string_lossy() + .into_owned() + .into() + } + } + + /// Create a copy of this chunk that you own + pub fn copy(&self) -> Result { + let chunk = unsafe { llama_cpp_sys_2::mtmd_input_chunk_copy(self.chunk.as_ptr()) }; + let chunk = NonNull::new(chunk).ok_or(MtmdInputChunkError::NullResult)?; + Ok(Self { chunk, owned: true }) + } +} + +impl Drop for MtmdInputChunk { + fn drop(&mut self) { + if self.owned { + unsafe { llama_cpp_sys_2::mtmd_input_chunk_free(self.chunk.as_ptr()) } + } + } +} + +/// Get the default media marker +pub fn mtmd_default_marker() -> &'static str { + unsafe { + let c_str = llama_cpp_sys_2::mtmd_default_marker(); + CStr::from_ptr(c_str).to_str().unwrap_or("<__media__>") + } +} + +// Error types +#[derive(thiserror::Error, Debug)] +pub enum MtmdInitError { + #[error("Failed to create CString: {0}")] + CStringError(#[from] std::ffi::NulError), + #[error("MTMD context initialization returned null")] + NullResult, +} + +#[derive(thiserror::Error, Debug)] +pub enum MtmdBitmapError { + #[error("Failed to create CString: {0}")] + CStringError(#[from] std::ffi::NulError), + #[error("Invalid data size for bitmap")] + InvalidDataSize, + #[error("Bitmap creation returned null")] + NullResult, +} + +#[derive(thiserror::Error, Debug)] +pub enum MtmdInputChunksError { + #[error("Input chunks creation returned null")] + NullResult, +} + +#[derive(thiserror::Error, Debug)] +pub enum MtmdInputChunkError { + #[error("Input chunk operation returned null")] + NullResult, +} + +#[derive(thiserror::Error, Debug)] +pub enum MtmdTokenizeError { + #[error("Number of bitmaps does not match number of markers")] + BitmapCountMismatch, + #[error("Image preprocessing error")] + ImagePreprocessingError, + #[error("Unknown error: {0}")] + UnknownError(i32), +} + +#[derive(thiserror::Error, Debug)] +pub enum MtmdEncodeError { + #[error("Encode failed with code: {0}")] + EncodeFailure(i32), +} + +#[derive(thiserror::Error, Debug)] +pub enum MtmdEvalError { + #[error("Eval failed with code: {0}")] + EvalFailure(i32), +} diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index 96feb402..f4ca88e1 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -63,7 +63,7 @@ impl LlamaSampler { } /// Resets the internal state of the sampler. - /// + /// /// This can be useful when you want to start fresh with a sampler without creating a new instance. pub fn reset(&mut self) { unsafe { @@ -72,7 +72,7 @@ impl LlamaSampler { } /// Gets the random seed used by this sampler. - /// + /// /// Returns: /// - For random samplers (dist, mirostat, mirostat_v2): returns their current seed /// - For sampler chains: returns the first non-default seed found in reverse order @@ -80,7 +80,7 @@ impl LlamaSampler { #[must_use] pub fn get_seed(&self) -> u32 { unsafe { llama_cpp_sys_2::llama_sampler_get_seed(self.sampler) } - } + } /// Combines a list of samplers into a single sampler that applies each component sampler one /// after another. @@ -213,11 +213,11 @@ impl LlamaSampler { Self { sampler } } - /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" + /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" /// /// /// This method filters logits by selecting only those within *n* standard deviations of the mean. - /// + /// /// # Parameters /// - `n`: Number of standard deviations from the mean to include in sampling /// @@ -232,7 +232,7 @@ impl LlamaSampler { /// /// let mut data_array = LlamaTokenDataArray::new(vec![ /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0), - /// LlamaTokenData::new(LlamaToken(1), 1.0, 0.0), + /// LlamaTokenData::new(LlamaToken(1), 1.0, 0.0), /// LlamaTokenData::new(LlamaToken(2), 2.0, 0.0), /// ], false); /// @@ -309,17 +309,15 @@ impl LlamaSampler { ) -> Self { let grammar_str = CString::new(grammar_str).unwrap(); let grammar_root = CString::new(grammar_root).unwrap(); - + let trigger_word_cstrings: Vec = trigger_words .into_iter() .map(|word| CString::new(word.as_ref()).unwrap()) .collect(); - - let mut trigger_word_ptrs: Vec<*const c_char> = trigger_word_cstrings - .iter() - .map(|cs| cs.as_ptr()) - .collect(); - + + let mut trigger_word_ptrs: Vec<*const c_char> = + trigger_word_cstrings.iter().map(|cs| cs.as_ptr()).collect(); + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_grammar_lazy( model.vocab_ptr(), @@ -331,9 +329,9 @@ impl LlamaSampler { trigger_tokens.len(), ) }; - + Self { sampler } - } + } /// DRY sampler, designed by p-e-w, as described in: /// , porting Koboldcpp @@ -495,20 +493,14 @@ impl LlamaSampler { /// ``` #[must_use] pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self { - let data = biases.as_ptr().cast::(); - + let sampler = unsafe { - llama_cpp_sys_2::llama_sampler_init_logit_bias( - n_vocab, - biases.len() as i32, - data, - ) + llama_cpp_sys_2::llama_sampler_init_logit_bias(n_vocab, biases.len() as i32, data) }; - + Self { sampler } } - } impl Drop for LlamaSampler { diff --git a/llama-cpp-2/src/timing.rs b/llama-cpp-2/src/timing.rs index b45d9318..aff5d498 100644 --- a/llama-cpp-2/src/timing.rs +++ b/llama-cpp-2/src/timing.rs @@ -26,6 +26,7 @@ impl LlamaTimings { t_eval_ms: f64, n_p_eval: i32, n_eval: i32, + n_reused: i32, ) -> Self { Self { timings: llama_cpp_sys_2::llama_perf_context_data { @@ -35,6 +36,7 @@ impl LlamaTimings { t_eval_ms, n_p_eval, n_eval, + n_reused, }, } } diff --git a/llama-cpp-2/src/token/logit_bias.rs b/llama-cpp-2/src/token/logit_bias.rs index 631c9395..f2a8384d 100644 --- a/llama-cpp-2/src/token/logit_bias.rs +++ b/llama-cpp-2/src/token/logit_bias.rs @@ -17,7 +17,7 @@ pub struct LlamaLogitBias { impl LlamaLogitBias { /// Creates a new logit bias for a specific token with the given bias value. - /// + /// /// # Examples /// ``` /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; @@ -27,15 +27,12 @@ impl LlamaLogitBias { #[must_use] pub fn new(LlamaToken(token): LlamaToken, bias: f32) -> Self { Self { - logit_bias: llama_cpp_sys_2::llama_logit_bias { - token, - bias, - }, + logit_bias: llama_cpp_sys_2::llama_logit_bias { token, bias }, } } /// Gets the token this bias applies to. - /// + /// /// # Examples /// ``` /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; @@ -49,7 +46,7 @@ impl LlamaLogitBias { } /// Gets the bias value. - /// + /// /// # Examples /// ``` /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; @@ -63,7 +60,7 @@ impl LlamaLogitBias { } /// Sets the token this bias applies to. - /// + /// /// # Examples /// ``` /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; @@ -78,7 +75,7 @@ impl LlamaLogitBias { } /// Sets the bias value. - /// + /// /// # Examples /// ``` /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; @@ -90,4 +87,4 @@ impl LlamaLogitBias { pub fn set_bias(&mut self, bias: f32) { self.logit_bias.bias = bias; } -} \ No newline at end of file +} diff --git a/llama-cpp-sys-2/Cargo.toml b/llama-cpp-sys-2/Cargo.toml index f998598d..ab994910 100644 --- a/llama-cpp-sys-2/Cargo.toml +++ b/llama-cpp-sys-2/Cargo.toml @@ -21,21 +21,10 @@ include = [ "/llama.cpp/ggml/src/*.cpp", "/llama.cpp/src/*.h", "/llama.cpp/src/*.cpp", + "/llama.cpp/tools/mtmd/*.h", + "/llama.cpp/tools/mtmd/*.cpp", "/llama.cpp/convert_hf_to_gguf.py", # Yes, it's required - - # Erroneously the llama.cpp code currently generates the build-info.cpp - # into the source directory of the build instead of into the target directory - # as it should. Will try submitting something upstream to clean this up as - # well but for now explictly exclude this from the build. Previously this was - # implicitly excluded because the llama.cpp code was copied wholesale into the - # target directory for building which is why this problem wasn't visible before - # (i.e. we'd package the llama.cpp source from the submodule & thus this build-info.cpp - # generated file would still be ignored because it would only exist in the separate - # copy within the target directory. An alternative, if we do want to capture build-info.cpp - # within the package would be to change the CI task to add `--allow-dirty` to the package - # command. - "!/llama.cpp/common/build-info.cpp", "/llama.cpp/common/build-info.cpp.in", "/llama.cpp/ggml/src/ggml-cuda.cu", diff --git a/llama-cpp-sys-2/build.rs b/llama-cpp-sys-2/build.rs index df654053..dbc2f445 100644 --- a/llama-cpp-sys-2/build.rs +++ b/llama-cpp-sys-2/build.rs @@ -244,6 +244,8 @@ fn main() { .allowlist_type("ggml_.*") .allowlist_function("llama_.*") .allowlist_type("llama_.*") + .allowlist_function("mtmd_.*") + .allowlist_type("mtmd_.*") .prepend_enum_name(false) .generate() .expect("Failed to generate bindings"); @@ -270,6 +272,10 @@ fn main() { config.define("LLAMA_BUILD_SERVER", "OFF"); config.define("LLAMA_CURL", "OFF"); + // Enable common and tools to build mtmd library for multimodal support + config.define("LLAMA_BUILD_COMMON", "ON"); + config.define("LLAMA_BUILD_TOOLS", "ON"); + config.define( "BUILD_SHARED_LIBS", if build_shared_libs { "ON" } else { "OFF" }, @@ -279,7 +285,11 @@ fn main() { config.define("GGML_BLAS", "OFF"); } - if (matches!(target_os, TargetOs::Windows(WindowsVariant::Msvc)) && matches!(profile.as_str(), "Release" | "RelWithDebInfo" | "MinSizeRel")) + if (matches!(target_os, TargetOs::Windows(WindowsVariant::Msvc)) + && matches!( + profile.as_str(), + "Release" | "RelWithDebInfo" | "MinSizeRel" + )) { // Debug Rust builds under MSVC turn off optimization even though we're ideally building the release profile of llama.cpp. // Looks like an upstream bug: @@ -373,18 +383,63 @@ fn main() { .always_configure(false); let build_dir = config.build(); - let build_info_src = llama_src.join("common/build-info.cpp"); - let build_info_target = build_dir.join("build-info.cpp"); - std::fs::rename(&build_info_src,&build_info_target).unwrap_or_else(|move_e| { - // Rename may fail if the target directory is on a different filesystem/disk from the source. - // Fall back to copy + delete to achieve the same effect in this case. - std::fs::copy(&build_info_src, &build_info_target).unwrap_or_else(|copy_e| { - panic!("Failed to rename {build_info_src:?} to {build_info_target:?}. Move failed with {move_e:?} and copy failed with {copy_e:?}"); + + let lib_dir = out_dir.join("lib"); + debug_log!("Lib directory: {}", lib_dir.display()); + + // Copy common library + let common_lib_src = build_dir.join("build/common/libcommon.a"); + debug_log!( + "Looking for common library at: {}", + common_lib_src.display() + ); + if common_lib_src.exists() { + let common_lib_dst = lib_dir.join("libcommon.a"); + std::fs::copy(&common_lib_src, &common_lib_dst).unwrap_or_else(|e| { + panic!("Failed to copy {common_lib_src:?} to {common_lib_dst:?}: {e:?}"); }); - std::fs::remove_file(&build_info_src).unwrap_or_else(|e| { - panic!("Failed to delete {build_info_src:?} after copying to {build_info_target:?}: {e:?} (move failed because {move_e:?})"); + debug_log!("Copied common library: {}", common_lib_dst.display()); + } else { + debug_log!("Common library not found at: {}", common_lib_src.display()); + } + + // Copy mtmd static library + let mtmd_static_src = build_dir.join("build/tools/mtmd/libmtmd_static.a"); + debug_log!( + "Looking for mtmd static library at: {}", + mtmd_static_src.display() + ); + if mtmd_static_src.exists() { + let mtmd_static_dst = lib_dir.join("libmtmd_static.a"); + std::fs::copy(&mtmd_static_src, &mtmd_static_dst).unwrap_or_else(|e| { + panic!("Failed to copy {mtmd_static_src:?} to {mtmd_static_dst:?}: {e:?}"); }); - }); + debug_log!("Copied mtmd static library: {}", mtmd_static_dst.display()); + } else { + debug_log!( + "Mtmd static library not found at: {}", + mtmd_static_src.display() + ); + } + + // Copy mtmd audio library + let mtmd_audio_src = build_dir.join("build/tools/mtmd/libmtmd_audio.a"); + debug_log!( + "Looking for mtmd audio library at: {}", + mtmd_audio_src.display() + ); + if mtmd_audio_src.exists() { + let mtmd_audio_dst = lib_dir.join("libmtmd_audio.a"); + std::fs::copy(&mtmd_audio_src, &mtmd_audio_dst).unwrap_or_else(|e| { + panic!("Failed to copy {mtmd_audio_src:?} to {mtmd_audio_dst:?}: {e:?}"); + }); + debug_log!("Copied mtmd audio library: {}", mtmd_audio_dst.display()); + } else { + debug_log!( + "Mtmd audio library not found at: {}", + mtmd_audio_src.display() + ); + } // Search paths println!("cargo:rustc-link-search={}", out_dir.join("lib").display()); diff --git a/llama-cpp-sys-2/llama.cpp b/llama-cpp-sys-2/llama.cpp index 79c137f7..b4efd77f 160000 --- a/llama-cpp-sys-2/llama.cpp +++ b/llama-cpp-sys-2/llama.cpp @@ -1 +1 @@ -Subproject commit 79c137f77677b3c8ee3c60a7da033721b938399a +Subproject commit b4efd77f8ab407836ca73a5176f041650c5b2411 diff --git a/llama-cpp-sys-2/wrapper.h b/llama-cpp-sys-2/wrapper.h index 33e6ab38..88eb9748 100644 --- a/llama-cpp-sys-2/wrapper.h +++ b/llama-cpp-sys-2/wrapper.h @@ -1 +1,3 @@ -#include "llama.cpp/include/llama.h" \ No newline at end of file +#include "llama.cpp/include/llama.h" +#include "llama.cpp/tools/mtmd/mtmd.h" +#include "llama.cpp/tools/mtmd/mtmd-helper.h" From a40ed28c54c86480beda3c50da5391ecf37c2e4e Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Sun, 27 Jul 2025 23:07:05 +0200 Subject: [PATCH 03/13] Add documentation Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- llama-cpp-2/src/mtmd.rs | 358 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 322 insertions(+), 36 deletions(-) diff --git a/llama-cpp-2/src/mtmd.rs b/llama-cpp-2/src/mtmd.rs index 616b2d48..70304361 100644 --- a/llama-cpp-2/src/mtmd.rs +++ b/llama-cpp-2/src/mtmd.rs @@ -1,3 +1,10 @@ +//! Safe wrapper around multimodal (MTMD) functionality in llama.cpp. +//! +//! This module provides Rust bindings for llama.cpp's multimodal support, +//! allowing processing of text, image, and audio inputs through a unified interface. +//! +//! # Warning +//! This API is experimental and subject to breaking changes. use std::ffi::{CStr, CString}; use std::ptr::NonNull; use std::slice; @@ -9,8 +16,11 @@ use crate::token::LlamaToken; /// Input chunk types for multimodal data #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MtmdInputChunkType { + /// Text input chunk Text = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT as isize, + /// Image input chunk Image = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE as isize, + /// Audio input chunk Audio = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO as isize, } @@ -28,9 +38,13 @@ impl From for MtmdInputChunkType { /// Configuration parameters for MTMD context #[derive(Debug, Clone)] pub struct MtmdContextParams { + /// Whether to use GPU acceleration pub use_gpu: bool, + /// Whether to print timing information pub print_timings: bool, + /// Number of threads to use for processing pub n_threads: i32, + /// Media marker string used to identify media positions in text pub media_marker: CString, } @@ -61,18 +75,41 @@ impl From<&MtmdContextParams> for llama_cpp_sys_2::mtmd_context_params { /// Text input configuration #[derive(Debug, Clone)] pub struct MtmdInputText { + /// The input text string pub text: String, + /// Whether to add special tokens pub add_special: bool, + /// Whether to parse special tokens pub parse_special: bool, } -/// Safe wrapper around `mtmd_context` +/// Safe wrapper around `mtmd_context`. +/// +/// This represents an initialized multimodal context that can process +/// text, images, and audio through llama.cpp's multimodal interface. +#[derive(Debug)] pub struct MtmdContext { pub(crate) context: NonNull, } impl MtmdContext { - /// Initialize MTMD context from a multimodal projection file + /// Initialize MTMD context from a multimodal projection file. + /// + /// # Arguments + /// + /// * `mmproj_path` - Path to the multimodal projection file + /// * `text_model` - Reference to the text model + /// * `params` - Configuration parameters for the MTMD context + /// + /// # Returns + /// + /// Returns `Ok(MtmdContext)` on success, or `Err(MtmdInitError)` on failure. + /// + /// # Errors + /// + /// This function will return an error if: + /// - The path cannot be converted to a C string + /// - The underlying C function returns null (indicating initialization failure) pub fn init_from_file( mmproj_path: &str, text_model: &LlamaModel, @@ -97,27 +134,70 @@ impl MtmdContext { Ok(Self { context }) } - /// Check if non-causal mask is needed before llama_decode + /// Check whether non-causal attention mask is needed before `llama_decode`. pub fn decode_use_non_causal(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_decode_use_non_causal(self.context.as_ptr()) } } - /// Check if the model uses M-RoPE for llama_decode + /// Check whether the current model uses M-RoPE for `llama_decode`. + /// + /// M-RoPE (Multimodal Rotary Position Embedding) affects how positions + /// are calculated for multimodal inputs. pub fn decode_use_mrope(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_decode_use_mrope(self.context.as_ptr()) } } - /// Check if the model supports vision input + /// Check whether the current model supports vision input. pub fn support_vision(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_support_vision(self.context.as_ptr()) } } - /// Check if the model supports audio input + /// Check whether the current model supports audio input. pub fn support_audio(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_support_audio(self.context.as_ptr()) } } - /// Tokenize input text and bitmaps into chunks + /// Get audio bitrate in Hz (e.g., 16000 for Whisper). + /// Returns -1 if audio is not supported. + pub fn get_audio_bitrate(&self) -> i32 { + unsafe { llama_cpp_sys_2::mtmd_get_audio_bitrate(self.context.as_ptr()) } + } + + /// Tokenize input text and bitmaps into chunks. + /// + /// The input text must contain media markers (default: `<__media__>`) that will be + /// replaced with the corresponding bitmap data from the `bitmaps` array. + /// The number of bitmaps must equal the number of markers in the text. + /// + /// # Arguments + /// + /// * `text` - Text input configuration containing the text and tokenization options + /// * `bitmaps` - Array of bitmaps (images/audio) to replace markers with + /// + /// # Returns + /// + /// Returns `Ok(MtmdInputChunks)` containing the tokenized chunks on success. + /// + /// # Errors + /// + /// * `BitmapCountMismatch` - Number of bitmaps doesn't match number of markers + /// * `ImagePreprocessingError` - Error occurred during image preprocessing + /// * `UnknownError` - Other tokenization error occurred + /// + /// # Example + /// + /// ```no_run + /// # use llama_cpp_2::mtmd::*; + /// # fn example(ctx: &MtmdContext, bitmap: &MtmdBitmap) -> Result<(), Box> { + /// let text = MtmdInputText { + /// text: "Here is an image: <__media__>\nDescribe it.".to_string(), + /// add_special: true, + /// parse_special: true, + /// }; + /// let chunks = ctx.tokenize(text, &[bitmap])?; + /// # Ok(()) + /// # } + /// ``` pub fn tokenize( &self, text: MtmdInputText, @@ -155,7 +235,23 @@ impl MtmdContext { } } - /// Encode a chunk (for image/audio processing) + /// Encode a chunk for image/audio processing. + /// + /// This function processes image or audio chunks by encoding them into + /// embeddings that can be used by the language model. The embeddings + /// can be retrieved using `get_output_embeddings()`. + /// + /// # Arguments + /// + /// * `chunk` - The input chunk to encode (should be image or audio type) + /// + /// # Returns + /// + /// Returns `Ok(())` on success. + /// + /// # Errors + /// + /// Returns `MtmdEncodeError::EncodeFailure` if encoding fails. pub fn encode_chunk(&self, chunk: &MtmdInputChunk) -> Result<(), MtmdEncodeError> { let result = unsafe { llama_cpp_sys_2::mtmd_encode_chunk(self.context.as_ptr(), chunk.chunk.as_ptr()) @@ -168,14 +264,19 @@ impl MtmdContext { } } - /// Get output embeddings from the last encode pass + /// Get output embeddings from the last encode pass. + /// + /// The embeddings are available after calling `encode_chunk()` on an image + /// or audio chunk. The size of the embeddings depends on the model and chunk. + /// + /// # Returns + /// + /// Returns `Some(&[f32])` if embeddings are available, `None` otherwise. pub fn get_output_embeddings(&self) -> Option<&[f32]> { let ptr = unsafe { llama_cpp_sys_2::mtmd_get_output_embd(self.context.as_ptr()) }; if ptr.is_null() { None } else { - // Note: The size calculation would need context about the model and chunk - // For now, returning None when we can't determine size safely None } } @@ -190,14 +291,33 @@ impl Drop for MtmdContext { } } -/// Safe wrapper around `mtmd_bitmap` +/// Safe wrapper around `mtmd_bitmap`. +/// +/// Represents bitmap data for images or audio that can be processed +/// by the multimodal system. For images, data is stored in RGB format. +/// For audio, data is stored as PCM F32 samples. #[derive(Debug, Clone)] pub struct MtmdBitmap { pub(crate) bitmap: NonNull, } impl MtmdBitmap { - /// Create a bitmap from image data (RGB format) + /// Create a bitmap from image data in RGB format. + /// + /// # Arguments + /// + /// * `nx` - Width of the image in pixels + /// * `ny` - Height of the image in pixels + /// * `data` - Image data in RGBRGBRGB... format (must be exactly `nx * ny * 3` bytes) + /// + /// # Returns + /// + /// Returns `Ok(MtmdBitmap)` on success. + /// + /// # Errors + /// + /// * `InvalidDataSize` - Data length doesn't match `nx * ny * 3` + /// * `NullResult` - Underlying C function returned null pub fn from_image_data(nx: u32, ny: u32, data: &[u8]) -> Result { if data.len() != (nx * ny * 3) as usize { return Err(MtmdBitmapError::InvalidDataSize); @@ -209,7 +329,19 @@ impl MtmdBitmap { Ok(Self { bitmap }) } - /// Create a bitmap from audio data (PCM F32 format) + /// Create a bitmap from audio data in PCM F32 format. + /// + /// # Arguments + /// + /// * `data` - Audio samples as 32-bit floating point values + /// + /// # Returns + /// + /// Returns `Ok(MtmdBitmap)` on success. + /// + /// # Errors + /// + /// * `NullResult` - Underlying C function returned null pub fn from_audio_data(data: &[f32]) -> Result { let bitmap = unsafe { llama_cpp_sys_2::mtmd_bitmap_init_from_audio(data.len(), data.as_ptr()) }; @@ -218,7 +350,29 @@ impl MtmdBitmap { Ok(Self { bitmap }) } - /// Create a bitmap from a file + /// Create a bitmap from a file. + /// + /// Supported formats: + /// - Images: formats supported by stb_image (jpg, png, bmp, gif, etc.) + /// - Audio: formats supported by miniaudio (wav, mp3, flac) + /// + /// Audio files are auto-detected based on magic bytes. + /// + /// # Arguments + /// + /// * `ctx` - MTMD context for processing + /// * `path` - Path to the image or audio file + /// + /// # Returns + /// + /// Returns `Ok(MtmdBitmap)` on success. + /// + /// # Errors + /// + /// * `CStringError` - Path contains null bytes + /// * `NullResult` - File could not be loaded or processed + /// + /// This function is thread-safe. pub fn from_file(ctx: &MtmdContext, path: &str) -> Result { let path_cstr = CString::new(path)?; let bitmap = unsafe { @@ -232,7 +386,28 @@ impl MtmdBitmap { Ok(Self { bitmap }) } - /// Create a bitmap from a buffer containing file data + /// Create a bitmap from a buffer containing file data. + /// + /// Supported formats: + /// - Images: formats supported by stb_image (jpg, png, bmp, gif, etc.) + /// - Audio: formats supported by miniaudio (wav, mp3, flac) + /// + /// Audio files are auto-detected based on magic bytes. + /// + /// # Arguments + /// + /// * `ctx` - MTMD context for processing + /// * `data` - Buffer containing the file data + /// + /// # Returns + /// + /// Returns `Ok(MtmdBitmap)` on success. + /// + /// # Errors + /// + /// * `NullResult` - Buffer could not be processed + /// + /// This function is thread-safe. pub fn from_buffer(ctx: &MtmdContext, data: &[u8]) -> Result { let bitmap = unsafe { llama_cpp_sys_2::mtmd_helper_bitmap_init_from_buf( @@ -246,29 +421,35 @@ impl MtmdBitmap { Ok(Self { bitmap }) } - /// Get bitmap width + /// Get bitmap width in pixels. pub fn nx(&self) -> u32 { unsafe { llama_cpp_sys_2::mtmd_bitmap_get_nx(self.bitmap.as_ptr()) } } - /// Get bitmap height + /// Get bitmap height in pixels. pub fn ny(&self) -> u32 { unsafe { llama_cpp_sys_2::mtmd_bitmap_get_ny(self.bitmap.as_ptr()) } } - /// Get bitmap data as bytes + /// Get bitmap data as a byte slice. + /// + /// For images: RGB format with length `nx * ny * 3` + /// For audio: PCM F32 format with length `n_samples * 4` pub fn data(&self) -> &[u8] { let ptr = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_data(self.bitmap.as_ptr()) }; let len = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_n_bytes(self.bitmap.as_ptr()) }; unsafe { slice::from_raw_parts(ptr, len) } } - /// Check if this is an audio bitmap + /// Check if this bitmap contains audio data (vs image data). pub fn is_audio(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_bitmap_is_audio(self.bitmap.as_ptr()) } } - /// Get bitmap ID (if set) + /// Get the bitmap's optional ID string. + /// + /// Bitmap ID is useful for KV cache tracking and can e.g. be calculated + /// based on a hash of the bitmap data. pub fn id(&self) -> Option { let ptr = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_id(self.bitmap.as_ptr()) }; if ptr.is_null() { @@ -281,7 +462,18 @@ impl MtmdBitmap { } } - /// Set bitmap ID + /// Set the bitmap's ID string. + /// + /// Bitmap ID is useful for KV cache tracking and can e.g. be calculated + /// based on a hash of the bitmap data. + /// + /// # Arguments + /// + /// * `id` - The ID string to set + /// + /// # Errors + /// + /// Returns an error if the ID string contains null bytes. pub fn set_id(&self, id: &str) -> Result<(), std::ffi::NulError> { let id_cstr = CString::new(id)?; unsafe { @@ -291,16 +483,18 @@ impl MtmdBitmap { } } -unsafe impl Send for MtmdBitmap {} -unsafe impl Sync for MtmdBitmap {} - impl Drop for MtmdBitmap { fn drop(&mut self) { unsafe { llama_cpp_sys_2::mtmd_bitmap_free(self.bitmap.as_ptr()) } } } -/// Safe wrapper around `mtmd_input_chunks` +/// Safe wrapper around `mtmd_input_chunks`. +/// +/// This is a collection of input chunks created from tokenizing text and media. +/// The chunks represent the tokenized input that can be processed by the model, +/// with text chunks containing tokens and media chunks containing embeddings. +#[derive(Debug)] pub struct MtmdInputChunks { pub(crate) chunks: NonNull, } @@ -343,18 +537,48 @@ impl MtmdInputChunks { } } - /// Get total number of tokens across all chunks + /// Get total number of tokens across all chunks. + /// + /// This is useful for keeping track of KV cache size. pub fn total_tokens(&self) -> usize { unsafe { llama_cpp_sys_2::mtmd_helper_get_n_tokens(self.chunks.as_ptr()) } } - /// Get total position count across all chunks + /// Get total position count across all chunks. + /// + /// This is useful to keep track of n_past. Normally n_pos equals n_tokens, + /// but for M-RoPE it is different. pub fn total_positions(&self) -> i32 { unsafe { llama_cpp_sys_2::mtmd_helper_get_n_pos(self.chunks.as_ptr()) } } - /// Evaluate chunks using the multimodal context and LLAMA context - /// Returns the new n_past value on success + /// Evaluate chunks using the multimodal context and LLAMA context. + /// + /// This helper function automatically: + /// 1. Runs `llama_decode()` on text chunks + /// 2. Runs `mtmd_encode()` on image chunks, then `mtmd_get_output_embd()` and then `llama_decode()` + /// + /// If any of the `mtmd_encode()` or `llama_decode()` calls return non-zero, the function + /// stops and forwards the error. + /// + /// # Arguments + /// + /// * `mtmd_ctx` - The multimodal context + /// * `llama_ctx` - The LLAMA context + /// * `n_past` - Current position in the sequence + /// * `seq_id` - Sequence ID for the batch + /// * `n_batch` - Batch size for processing + /// * `logits_last` - Whether to compute logits for the last token only + /// + /// # Returns + /// + /// Returns the new n_past value on success. + /// + /// # Errors + /// + /// Returns `MtmdEvalError::EvalFailure` if any encoding or decoding operation fails. + /// + /// This function is NOT thread-safe. pub fn eval_chunks( &self, mtmd_ctx: &MtmdContext, @@ -393,7 +617,12 @@ impl Drop for MtmdInputChunks { } } -/// Safe wrapper around `mtmd_input_chunk` +/// Safe wrapper around `mtmd_input_chunk`. +/// +/// Represents a single chunk of input data, which can be either text tokens, +/// image tokens, or audio tokens. The chunk type determines what kind of +/// data and operations are available. +#[derive(Debug)] pub struct MtmdInputChunk { pub(crate) chunk: NonNull, owned: bool, @@ -406,7 +635,13 @@ impl MtmdInputChunk { MtmdInputChunkType::from(chunk_type) } - /// Get text tokens (only valid for text chunks) + /// Get text tokens from this chunk. + /// + /// Only valid for text chunks. Returns `None` for image or audio chunks. + /// + /// # Returns + /// + /// Returns `Some(&[LlamaToken])` for text chunks, `None` otherwise. pub fn text_tokens(&self) -> Option<&[LlamaToken]> { if self.chunk_type() != MtmdInputChunkType::Text { return None; @@ -434,12 +669,16 @@ impl MtmdInputChunk { unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_n_tokens(self.chunk.as_ptr()) } } - /// Get the number of positions in this chunk + /// Get the number of positions in this chunk. + /// + /// Returns the number of temporal positions (always 1 for M-RoPE, n_tokens otherwise). pub fn n_positions(&self) -> i32 { unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_n_pos(self.chunk.as_ptr()) } } - /// Get chunk ID (if available) + /// Get chunk ID if available. + /// + /// Returns `None` for text chunks, may return an ID for image/audio chunks. pub fn id(&self) -> Option { let ptr = unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_id(self.chunk.as_ptr()) }; if ptr.is_null() { @@ -452,7 +691,19 @@ impl MtmdInputChunk { } } - /// Create a copy of this chunk that you own + /// Create a copy of this chunk that you own. + /// + /// This is useful if you want to use custom logic to handle the chunk + /// (e.g., KV cache management) by moving the chunk ownership to your own code. + /// Remember to ensure the copied chunk is properly freed when you're done with it. + /// + /// # Returns + /// + /// Returns an owned copy of the chunk. + /// + /// # Errors + /// + /// Returns `MtmdInputChunkError::NullResult` if copying fails. pub fn copy(&self) -> Result { let chunk = unsafe { llama_cpp_sys_2::mtmd_input_chunk_copy(self.chunk.as_ptr()) }; let chunk = NonNull::new(chunk).ok_or(MtmdInputChunkError::NullResult)?; @@ -468,7 +719,23 @@ impl Drop for MtmdInputChunk { } } -/// Get the default media marker +/// Get the default media marker string. +/// +/// Returns the default marker used to identify media positions in text +/// (typically `"<__media__>"`). This marker should be used in your input text +/// to indicate where media content should be inserted. +/// +/// # Returns +/// +/// Returns the default media marker as a string slice. +/// +/// # Example +/// +/// ``` +/// # use llama_cpp_2::mtmd::mtmd_default_marker; +/// let marker = mtmd_default_marker(); +/// let text = format!("Describe this image: {}", marker); +/// ``` pub fn mtmd_default_marker() -> &'static str { unsafe { let c_str = llama_cpp_sys_2::mtmd_default_marker(); @@ -477,54 +744,73 @@ pub fn mtmd_default_marker() -> &'static str { } // Error types +/// Errors that can occur when initializing MTMD context #[derive(thiserror::Error, Debug)] pub enum MtmdInitError { + /// Failed to create CString from input #[error("Failed to create CString: {0}")] CStringError(#[from] std::ffi::NulError), + /// MTMD context initialization returned null #[error("MTMD context initialization returned null")] NullResult, } +/// Errors that can occur when working with MTMD bitmaps #[derive(thiserror::Error, Debug)] pub enum MtmdBitmapError { + /// Failed to create CString from input #[error("Failed to create CString: {0}")] CStringError(#[from] std::ffi::NulError), + /// Invalid data size for bitmap #[error("Invalid data size for bitmap")] InvalidDataSize, + /// Bitmap creation returned null #[error("Bitmap creation returned null")] NullResult, } +/// Errors that can occur when working with MTMD input chunks collections #[derive(thiserror::Error, Debug)] pub enum MtmdInputChunksError { + /// Input chunks creation returned null #[error("Input chunks creation returned null")] NullResult, } +/// Errors that can occur when working with individual MTMD input chunks #[derive(thiserror::Error, Debug)] pub enum MtmdInputChunkError { + /// Input chunk operation returned null #[error("Input chunk operation returned null")] NullResult, } +/// Errors that can occur during tokenization #[derive(thiserror::Error, Debug)] pub enum MtmdTokenizeError { + /// Number of bitmaps does not match number of markers in text #[error("Number of bitmaps does not match number of markers")] BitmapCountMismatch, + /// Image preprocessing error occurred #[error("Image preprocessing error")] ImagePreprocessingError, + /// Unknown error occurred during tokenization #[error("Unknown error: {0}")] UnknownError(i32), } +/// Errors that can occur during encoding #[derive(thiserror::Error, Debug)] pub enum MtmdEncodeError { + /// Encode operation failed #[error("Encode failed with code: {0}")] EncodeFailure(i32), } +/// Errors that can occur during evaluation #[derive(thiserror::Error, Debug)] pub enum MtmdEvalError { + /// Evaluation operation failed #[error("Eval failed with code: {0}")] EvalFailure(i32), } From b5cb56b389c5ed5f57165ac6bd9ade444f880105 Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Sun, 27 Jul 2025 23:38:27 +0200 Subject: [PATCH 04/13] Clippy Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- examples/mtmd/src/mtmd.rs | 21 +++++---- llama-cpp-2/src/mtmd.rs | 96 ++++++++++++++++++--------------------- 2 files changed, 57 insertions(+), 60 deletions(-) diff --git a/examples/mtmd/src/mtmd.rs b/examples/mtmd/src/mtmd.rs index ba4b30ae..06b78abd 100644 --- a/examples/mtmd/src/mtmd.rs +++ b/examples/mtmd/src/mtmd.rs @@ -10,7 +10,7 @@ use llama_cpp_2::context::params::LlamaContextParams; use llama_cpp_2::context::LlamaContext; use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::params::LlamaModelParams; -use llama_cpp_2::mtmd::*; +use llama_cpp_2::mtmd::{MtmdBitmap, MtmdBitmapError, MtmdContext, MtmdContextParams, MtmdInputText}; use llama_cpp_2::llama_backend::LlamaBackend; use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel, Special}; @@ -83,6 +83,8 @@ pub struct MtmdCliContext { impl MtmdCliContext { /// Creates a new MTMD CLI context + /// + /// # Errors pub fn new( params: &MtmdCliParams, model: &LlamaModel, @@ -101,11 +103,11 @@ impl MtmdCliContext { )?, }; - let mtmd_ctx = MtmdContext::init_from_file(¶ms.mmproj_path, model, mtmd_params)?; + let mtmd_ctx = MtmdContext::init_from_file(¶ms.mmproj_path, model, &mtmd_params)?; let chat_template = model .chat_template(params.chat_template.as_deref()) - .map_err(|e| format!("Failed to get chat template: {}", e))?; + .map_err(|e| format!("Failed to get chat template: {e}"))?; let batch = LlamaBatch::new(params.n_tokens, 1); @@ -120,6 +122,7 @@ impl MtmdCliContext { } /// Loads media (image or audio) from the specified file path + /// # Errors pub fn load_media(&mut self, path: &str) -> Result<(), MtmdBitmapError> { let bitmap = MtmdBitmap::from_file(&self.mtmd_ctx, path)?; self.bitmaps.push(bitmap); @@ -127,6 +130,7 @@ impl MtmdCliContext { } /// Evaluates a chat message, tokenizing and processing it through the model + /// # Errors pub fn eval_message( &mut self, model: &LlamaModel, @@ -161,11 +165,12 @@ impl MtmdCliContext { // Clear bitmaps after tokenization self.bitmaps.clear(); - self.n_past = chunks.eval_chunks(&self.mtmd_ctx, &context, 0, 0, 1, true)?; + self.n_past = chunks.eval_chunks(&self.mtmd_ctx, context, 0, 0, 1, true)?; Ok(()) } /// Generates a response by sampling tokens from the model + /// # Errors pub fn generate_response( &mut self, model: &LlamaModel, @@ -190,7 +195,7 @@ impl MtmdCliContext { // Print token let piece = model.token_to_str(token, Special::Tokenize)?; - print!("{}", piece); + print!("{piece}"); io::stdout().flush()?; // Prepare next batch @@ -223,7 +228,7 @@ fn run_single_turn( // Load media files for image_path in ¶ms.images { - println!("Loading image: {}", image_path); + println!("Loading image: {image_path}"); ctx.load_media(image_path)?; } for audio_path in ¶ms.audio { @@ -233,7 +238,7 @@ fn run_single_turn( // Create user message let msg = LlamaChatMessage::new("user".to_string(), prompt)?; - println!("Evaluating message: {:?}", msg); + println!("Evaluating message: {msg:?}"); // Evaluate the message (prefill) ctx.eval_message(model, context, msg, true)?; @@ -269,7 +274,7 @@ fn main() -> Result<(), Box> { // Setup model parameters let mut model_params = LlamaModelParams::default(); if !params.no_gpu { - model_params = model_params.with_n_gpu_layers(1000000); // Use all layers on GPU + model_params = model_params.with_n_gpu_layers(1_000_000); // Use all layers on GPU } // Load model diff --git a/llama-cpp-2/src/mtmd.rs b/llama-cpp-2/src/mtmd.rs index 70304361..bbe9ff71 100644 --- a/llama-cpp-2/src/mtmd.rs +++ b/llama-cpp-2/src/mtmd.rs @@ -113,10 +113,10 @@ impl MtmdContext { pub fn init_from_file( mmproj_path: &str, text_model: &LlamaModel, - params: MtmdContextParams, + params: &MtmdContextParams, ) -> Result { let path_cstr = CString::new(mmproj_path)?; - let ctx_params = llama_cpp_sys_2::mtmd_context_params::from(¶ms); + let ctx_params = llama_cpp_sys_2::mtmd_context_params::from(params); let context = unsafe { llama_cpp_sys_2::mtmd_init_from_file( @@ -135,7 +135,7 @@ impl MtmdContext { } /// Check whether non-causal attention mask is needed before `llama_decode`. - pub fn decode_use_non_causal(&self) -> bool { + #[must_use] pub fn decode_use_non_causal(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_decode_use_non_causal(self.context.as_ptr()) } } @@ -143,23 +143,23 @@ impl MtmdContext { /// /// M-RoPE (Multimodal Rotary Position Embedding) affects how positions /// are calculated for multimodal inputs. - pub fn decode_use_mrope(&self) -> bool { + #[must_use] pub fn decode_use_mrope(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_decode_use_mrope(self.context.as_ptr()) } } /// Check whether the current model supports vision input. - pub fn support_vision(&self) -> bool { + #[must_use] pub fn support_vision(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_support_vision(self.context.as_ptr()) } } /// Check whether the current model supports audio input. - pub fn support_audio(&self) -> bool { + #[must_use] pub fn support_audio(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_support_audio(self.context.as_ptr()) } } /// Get audio bitrate in Hz (e.g., 16000 for Whisper). /// Returns -1 if audio is not supported. - pub fn get_audio_bitrate(&self) -> i32 { + #[must_use] pub fn get_audio_bitrate(&self) -> i32 { unsafe { llama_cpp_sys_2::mtmd_get_audio_bitrate(self.context.as_ptr()) } } @@ -214,7 +214,7 @@ impl MtmdContext { // Create bitmap pointers let bitmap_ptrs: Vec<*const llama_cpp_sys_2::mtmd_bitmap> = bitmaps .iter() - .map(|b| b.bitmap.as_ptr() as *const _) + .map(|b| b.bitmap.as_ptr().cast_const()) .collect(); let result = unsafe { @@ -222,7 +222,7 @@ impl MtmdContext { self.context.as_ptr(), chunks.chunks.as_ptr(), &input_text, - bitmap_ptrs.as_ptr() as *mut *const llama_cpp_sys_2::mtmd_bitmap, + bitmap_ptrs.as_ptr().cast_mut(), bitmaps.len(), ) }; @@ -263,23 +263,6 @@ impl MtmdContext { Err(MtmdEncodeError::EncodeFailure(result)) } } - - /// Get output embeddings from the last encode pass. - /// - /// The embeddings are available after calling `encode_chunk()` on an image - /// or audio chunk. The size of the embeddings depends on the model and chunk. - /// - /// # Returns - /// - /// Returns `Some(&[f32])` if embeddings are available, `None` otherwise. - pub fn get_output_embeddings(&self) -> Option<&[f32]> { - let ptr = unsafe { llama_cpp_sys_2::mtmd_get_output_embd(self.context.as_ptr()) }; - if ptr.is_null() { - None - } else { - None - } - } } unsafe impl Send for MtmdContext {} @@ -353,7 +336,7 @@ impl MtmdBitmap { /// Create a bitmap from a file. /// /// Supported formats: - /// - Images: formats supported by stb_image (jpg, png, bmp, gif, etc.) + /// - Images: formats supported by `stb_image` (jpg, png, bmp, gif, etc.) /// - Audio: formats supported by miniaudio (wav, mp3, flac) /// /// Audio files are auto-detected based on magic bytes. @@ -389,7 +372,7 @@ impl MtmdBitmap { /// Create a bitmap from a buffer containing file data. /// /// Supported formats: - /// - Images: formats supported by stb_image (jpg, png, bmp, gif, etc.) + /// - Images: formats supported by `stb_image` (jpg, png, bmp, gif, etc.) /// - Audio: formats supported by miniaudio (wav, mp3, flac) /// /// Audio files are auto-detected based on magic bytes. @@ -422,12 +405,12 @@ impl MtmdBitmap { } /// Get bitmap width in pixels. - pub fn nx(&self) -> u32 { + #[must_use] pub fn nx(&self) -> u32 { unsafe { llama_cpp_sys_2::mtmd_bitmap_get_nx(self.bitmap.as_ptr()) } } /// Get bitmap height in pixels. - pub fn ny(&self) -> u32 { + #[must_use] pub fn ny(&self) -> u32 { unsafe { llama_cpp_sys_2::mtmd_bitmap_get_ny(self.bitmap.as_ptr()) } } @@ -435,14 +418,14 @@ impl MtmdBitmap { /// /// For images: RGB format with length `nx * ny * 3` /// For audio: PCM F32 format with length `n_samples * 4` - pub fn data(&self) -> &[u8] { + #[must_use] pub fn data(&self) -> &[u8] { let ptr = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_data(self.bitmap.as_ptr()) }; let len = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_n_bytes(self.bitmap.as_ptr()) }; unsafe { slice::from_raw_parts(ptr, len) } } /// Check if this bitmap contains audio data (vs image data). - pub fn is_audio(&self) -> bool { + #[must_use] pub fn is_audio(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_bitmap_is_audio(self.bitmap.as_ptr()) } } @@ -450,7 +433,7 @@ impl MtmdBitmap { /// /// Bitmap ID is useful for KV cache tracking and can e.g. be calculated /// based on a hash of the bitmap data. - pub fn id(&self) -> Option { + #[must_use] pub fn id(&self) -> Option { let ptr = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_id(self.bitmap.as_ptr()) }; if ptr.is_null() { None @@ -499,26 +482,35 @@ pub struct MtmdInputChunks { pub(crate) chunks: NonNull, } +impl Default for MtmdInputChunks { + fn default() -> Self { + Self::new() + } +} + impl MtmdInputChunks { /// Create a new empty input chunks collection - pub fn new() -> Self { + /// # Panics + /// This function will panic if the underlying llama.cpp function returns null, + /// which should not happen. + #[must_use] pub fn new() -> Self { let chunks = unsafe { llama_cpp_sys_2::mtmd_input_chunks_init() }; let chunks = NonNull::new(chunks).unwrap(); Self { chunks } } /// Get the number of chunks - pub fn len(&self) -> usize { + #[must_use] pub fn len(&self) -> usize { unsafe { llama_cpp_sys_2::mtmd_input_chunks_size(self.chunks.as_ptr()) } } /// Check if chunks collection is empty - pub fn is_empty(&self) -> bool { + #[must_use] pub fn is_empty(&self) -> bool { self.len() == 0 } /// Get a chunk by index - pub fn get(&self, index: usize) -> Option { + #[must_use] pub fn get(&self, index: usize) -> Option { if index >= self.len() { return None; } @@ -531,7 +523,7 @@ impl MtmdInputChunks { } else { // Note: We don't own this chunk, it's owned by the chunks collection Some(MtmdInputChunk { - chunk: NonNull::new(chunk_ptr as *mut _).unwrap(), + chunk: NonNull::new(chunk_ptr.cast_mut()).unwrap(), owned: false, }) } @@ -540,15 +532,15 @@ impl MtmdInputChunks { /// Get total number of tokens across all chunks. /// /// This is useful for keeping track of KV cache size. - pub fn total_tokens(&self) -> usize { + #[must_use] pub fn total_tokens(&self) -> usize { unsafe { llama_cpp_sys_2::mtmd_helper_get_n_tokens(self.chunks.as_ptr()) } } /// Get total position count across all chunks. /// - /// This is useful to keep track of n_past. Normally n_pos equals n_tokens, + /// This is useful to keep track of `n_past`. Normally `n_pos` equals `n_tokens`, /// but for M-RoPE it is different. - pub fn total_positions(&self) -> i32 { + #[must_use] pub fn total_positions(&self) -> i32 { unsafe { llama_cpp_sys_2::mtmd_helper_get_n_pos(self.chunks.as_ptr()) } } @@ -572,7 +564,7 @@ impl MtmdInputChunks { /// /// # Returns /// - /// Returns the new n_past value on success. + /// Returns the new `n_past` value on success. /// /// # Errors /// @@ -630,7 +622,7 @@ pub struct MtmdInputChunk { impl MtmdInputChunk { /// Get the type of this chunk - pub fn chunk_type(&self) -> MtmdInputChunkType { + #[must_use] pub fn chunk_type(&self) -> MtmdInputChunkType { let chunk_type = unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_type(self.chunk.as_ptr()) }; MtmdInputChunkType::from(chunk_type) } @@ -642,7 +634,7 @@ impl MtmdInputChunk { /// # Returns /// /// Returns `Some(&[LlamaToken])` for text chunks, `None` otherwise. - pub fn text_tokens(&self) -> Option<&[LlamaToken]> { + #[must_use] pub fn text_tokens(&self) -> Option<&[LlamaToken]> { if self.chunk_type() != MtmdInputChunkType::Text { return None; } @@ -657,7 +649,7 @@ impl MtmdInputChunk { } else { unsafe { Some(slice::from_raw_parts( - tokens_ptr as *const LlamaToken, + tokens_ptr.cast::(), n_tokens, )) } @@ -665,21 +657,21 @@ impl MtmdInputChunk { } /// Get the number of tokens in this chunk - pub fn n_tokens(&self) -> usize { + #[must_use] pub fn n_tokens(&self) -> usize { unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_n_tokens(self.chunk.as_ptr()) } } /// Get the number of positions in this chunk. /// - /// Returns the number of temporal positions (always 1 for M-RoPE, n_tokens otherwise). - pub fn n_positions(&self) -> i32 { + /// Returns the number of temporal positions (always 1 for M-RoPE, `n_tokens` otherwise). + #[must_use] pub fn n_positions(&self) -> i32 { unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_n_pos(self.chunk.as_ptr()) } } /// Get chunk ID if available. /// /// Returns `None` for text chunks, may return an ID for image/audio chunks. - pub fn id(&self) -> Option { + #[must_use] pub fn id(&self) -> Option { let ptr = unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_id(self.chunk.as_ptr()) }; if ptr.is_null() { None @@ -736,7 +728,7 @@ impl Drop for MtmdInputChunk { /// let marker = mtmd_default_marker(); /// let text = format!("Describe this image: {}", marker); /// ``` -pub fn mtmd_default_marker() -> &'static str { +#[must_use] pub fn mtmd_default_marker() -> &'static str { unsafe { let c_str = llama_cpp_sys_2::mtmd_default_marker(); CStr::from_ptr(c_str).to_str().unwrap_or("<__media__>") @@ -747,7 +739,7 @@ pub fn mtmd_default_marker() -> &'static str { /// Errors that can occur when initializing MTMD context #[derive(thiserror::Error, Debug)] pub enum MtmdInitError { - /// Failed to create CString from input + /// Failed to create `CString` from input #[error("Failed to create CString: {0}")] CStringError(#[from] std::ffi::NulError), /// MTMD context initialization returned null @@ -758,7 +750,7 @@ pub enum MtmdInitError { /// Errors that can occur when working with MTMD bitmaps #[derive(thiserror::Error, Debug)] pub enum MtmdBitmapError { - /// Failed to create CString from input + /// Failed to create `CString` from input #[error("Failed to create CString: {0}")] CStringError(#[from] std::ffi::NulError), /// Invalid data size for bitmap From dd52e55fb71b4edd803bf20d416acaffabafdade Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Mon, 28 Jul 2025 00:48:42 +0200 Subject: [PATCH 05/13] Dep: Update llama.cpp b6002 Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- llama-cpp-2/src/timing.rs | 1 + llama-cpp-sys-2/llama.cpp | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/llama-cpp-2/src/timing.rs b/llama-cpp-2/src/timing.rs index 38257fca..7eaf208b 100644 --- a/llama-cpp-2/src/timing.rs +++ b/llama-cpp-2/src/timing.rs @@ -36,6 +36,7 @@ impl LlamaTimings { t_eval_ms, n_p_eval, n_eval, + n_reused, }, } } diff --git a/llama-cpp-sys-2/llama.cpp b/llama-cpp-sys-2/llama.cpp index e434e691..89d10295 160000 --- a/llama-cpp-sys-2/llama.cpp +++ b/llama-cpp-sys-2/llama.cpp @@ -1 +1 @@ -Subproject commit e434e69183fd9e1031f4445002083178c331a28b +Subproject commit 89d1029559bd2968f76db854f9f113d73e34527c From 56ae827f48264921b06f057005d7091c20bdda65 Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Mon, 28 Jul 2025 23:32:53 +0200 Subject: [PATCH 06/13] Bump versions Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- llama-cpp-2/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama-cpp-2/Cargo.toml b/llama-cpp-2/Cargo.toml index cbe2009a..0c68859d 100644 --- a/llama-cpp-2/Cargo.toml +++ b/llama-cpp-2/Cargo.toml @@ -10,7 +10,7 @@ repository = "https://github.com/utilityai/llama-cpp-rs" [dependencies] enumflags2 = "0.7.12" -llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.69" } +llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.113" } thiserror = { workspace = true } tracing = { workspace = true } tracing-core = { workspace = true } @@ -33,7 +33,7 @@ android-shared-stdcxx = ["llama-cpp-sys-2/shared-stdcxx"] [target.'cfg(all(target_os = "macos", any(target_arch = "aarch64", target_arch = "arm64")))'.dependencies] -llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.69", features = [ +llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.113", features = [ "metal", ] } From 222df6e55aecadfbbdae174b1c96ef71a3b60a1e Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Fri, 1 Aug 2025 18:53:56 +0200 Subject: [PATCH 07/13] Fix building of mtmd Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- llama-cpp-sys-2/build.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/llama-cpp-sys-2/build.rs b/llama-cpp-sys-2/build.rs index 7f312316..aa6f1dbd 100644 --- a/llama-cpp-sys-2/build.rs +++ b/llama-cpp-sys-2/build.rs @@ -271,6 +271,7 @@ fn main() { .allowlist_function("ggml_.*") .allowlist_type("ggml_.*") .allowlist_function("llama_.*") + .allowlist_type("llama_.*") .allowlist_function("mtmd_.*") .allowlist_type("mtmd_.*") .prepend_enum_name(false); From c45a061d78978d45835e7c5af5cc863b90b6ebfe Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Fri, 1 Aug 2025 19:09:35 +0200 Subject: [PATCH 08/13] Feature guard mtmd + cleanup Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- examples/mtmd/Cargo.toml | 5 +-- examples/reranker/Cargo.toml | 2 +- llama-cpp-2/Cargo.toml | 1 + llama-cpp-2/src/lib.rs | 1 + llama-cpp-sys-2/Cargo.toml | 1 + llama-cpp-sys-2/build.rs | 75 +++++++----------------------------- 6 files changed, 18 insertions(+), 67 deletions(-) diff --git a/examples/mtmd/Cargo.toml b/examples/mtmd/Cargo.toml index 24aed021..6dd28ad4 100644 --- a/examples/mtmd/Cargo.toml +++ b/examples/mtmd/Cargo.toml @@ -4,11 +4,8 @@ version = "0.1.86" edition = "2021" [dependencies] -llama-cpp-2 = { path = "../../llama-cpp-2", version = "0.1.86" } +llama-cpp-2 = { path = "../../llama-cpp-2", version = "0.1.86", features = ["mtmd"] } clap = { workspace = true, features = ["derive"] } -# hf-hub = { workspace = true } -# anyhow = { workspace = true } -# encoding_rs = { workspace = true } [features] cuda = ["llama-cpp-2/cuda"] diff --git a/examples/reranker/Cargo.toml b/examples/reranker/Cargo.toml index 753a7dd1..fa32c2d3 100644 --- a/examples/reranker/Cargo.toml +++ b/examples/reranker/Cargo.toml @@ -17,4 +17,4 @@ native = ["llama-cpp-2/native"] vulkan = ["llama-cpp-2/vulkan"] [lints] -workspace = true +workspace = true \ No newline at end of file diff --git a/llama-cpp-2/Cargo.toml b/llama-cpp-2/Cargo.toml index 8cfc9c04..b949944a 100644 --- a/llama-cpp-2/Cargo.toml +++ b/llama-cpp-2/Cargo.toml @@ -30,6 +30,7 @@ openmp = ["llama-cpp-sys-2/openmp"] sampler = [] # Only has an impact on Android. android-shared-stdcxx = ["llama-cpp-sys-2/shared-stdcxx"] +mtmd = ["llama-cpp-sys-2/mtmd"] [target.'cfg(all(target_os = "macos", any(target_arch = "aarch64", target_arch = "arm64")))'.dependencies] diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index c52af268..17bba45f 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -27,6 +27,7 @@ pub mod llama_backend; pub mod llama_batch; mod log; pub mod model; +#[cfg(feature = "mtmd")] pub mod mtmd; pub mod sampling; pub mod timing; diff --git a/llama-cpp-sys-2/Cargo.toml b/llama-cpp-sys-2/Cargo.toml index 9d2b1d98..2aea6c06 100644 --- a/llama-cpp-sys-2/Cargo.toml +++ b/llama-cpp-sys-2/Cargo.toml @@ -79,3 +79,4 @@ native = [] openmp = [] # Only has an impact on Android. shared-stdcxx = [] +mtmd = [] diff --git a/llama-cpp-sys-2/build.rs b/llama-cpp-sys-2/build.rs index aa6f1dbd..dc67ce70 100644 --- a/llama-cpp-sys-2/build.rs +++ b/llama-cpp-sys-2/build.rs @@ -272,10 +272,15 @@ fn main() { .allowlist_type("ggml_.*") .allowlist_function("llama_.*") .allowlist_type("llama_.*") - .allowlist_function("mtmd_.*") - .allowlist_type("mtmd_.*") .prepend_enum_name(false); + // Configure mtmd feature if enabled + if cfg!(feature = "mtmd") { + bindings_builder = bindings_builder + .allowlist_function("mtmd_.*") + .allowlist_type("mtmd_.*"); + } + // Configure Android-specific bindgen settings if matches!(target_os, TargetOs::Android) { // Detect Android NDK from environment variables @@ -440,11 +445,14 @@ fn main() { config.define("LLAMA_BUILD_TESTS", "OFF"); config.define("LLAMA_BUILD_EXAMPLES", "OFF"); config.define("LLAMA_BUILD_SERVER", "OFF"); + config.define("LLAMA_BUILD_TOOLS", "OFF"); config.define("LLAMA_CURL", "OFF"); - // Required for mtmd.rs - config.define("LLAMA_BUILD_COMMON", "ON"); - config.define("LLAMA_BUILD_TOOLS", "ON"); + if cfg!(feature = "mtmd") { + config.define("LLAMA_BUILD_COMMON", "ON"); + // mtmd support in llama-cpp is within the tools directory + config.define("LLAMA_BUILD_TOOLS", "ON"); + } // Pass CMAKE_ environment variables down to CMake for (key, value) in env::vars() { @@ -634,63 +642,6 @@ fn main() { let build_dir = config.build(); - let lib_dir = out_dir.join("lib"); - debug_log!("Lib directory: {}", lib_dir.display()); - - // Copy common library - let common_lib_src = build_dir.join("build/common/libcommon.a"); - debug_log!( - "Looking for common library at: {}", - common_lib_src.display() - ); - if common_lib_src.exists() { - let common_lib_dst = lib_dir.join("libcommon.a"); - std::fs::copy(&common_lib_src, &common_lib_dst).unwrap_or_else(|e| { - panic!("Failed to copy {common_lib_src:?} to {common_lib_dst:?}: {e:?}"); - }); - debug_log!("Copied common library: {}", common_lib_dst.display()); - } else { - debug_log!("Common library not found at: {}", common_lib_src.display()); - } - - // Copy mtmd static library - let mtmd_static_src = build_dir.join("build/tools/mtmd/libmtmd_static.a"); - debug_log!( - "Looking for mtmd static library at: {}", - mtmd_static_src.display() - ); - if mtmd_static_src.exists() { - let mtmd_static_dst = lib_dir.join("libmtmd_static.a"); - std::fs::copy(&mtmd_static_src, &mtmd_static_dst).unwrap_or_else(|e| { - panic!("Failed to copy {mtmd_static_src:?} to {mtmd_static_dst:?}: {e:?}"); - }); - debug_log!("Copied mtmd static library: {}", mtmd_static_dst.display()); - } else { - debug_log!( - "Mtmd static library not found at: {}", - mtmd_static_src.display() - ); - } - - // Copy mtmd audio library - let mtmd_audio_src = build_dir.join("build/tools/mtmd/libmtmd_audio.a"); - debug_log!( - "Looking for mtmd audio library at: {}", - mtmd_audio_src.display() - ); - if mtmd_audio_src.exists() { - let mtmd_audio_dst = lib_dir.join("libmtmd_audio.a"); - std::fs::copy(&mtmd_audio_src, &mtmd_audio_dst).unwrap_or_else(|e| { - panic!("Failed to copy {mtmd_audio_src:?} to {mtmd_audio_dst:?}: {e:?}"); - }); - debug_log!("Copied mtmd audio library: {}", mtmd_audio_dst.display()); - } else { - debug_log!( - "Mtmd audio library not found at: {}", - mtmd_audio_src.display() - ); - } - // Search paths println!("cargo:rustc-link-search={}", out_dir.join("lib").display()); println!( From eee15c39d316cd3501af7197d321551852e1445c Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Fri, 1 Aug 2025 19:28:32 +0200 Subject: [PATCH 09/13] Add a small README Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- examples/mtmd/README.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 examples/mtmd/README.md diff --git a/examples/mtmd/README.md b/examples/mtmd/README.md new file mode 100644 index 00000000..68c79027 --- /dev/null +++ b/examples/mtmd/README.md @@ -0,0 +1,27 @@ +# Rust mtmd-cli implementation + +Partial port of the mtmd-cli.cpp example in the llama-cpp repository. + +## Usage + +### Command Line Interface + +To run the mtmd example, you first need to download the model gguf file and the multimodal projection file, e.g. for Gemma3 you may use: + +```sh +wget https://huggingface.co/unsloth/gemma-3-4b-it-GGUF/resolve/main/gemma-3-4b-it-Q4_K_M.gguf \ +https://huggingface.co/unsloth/gemma-3-4b-it-GGUF/resolve/main/mmproj-F16.gguf +``` + +To then run the example on CPU, provide an image file `my_image.jpg` and run: + +```sh +cargo run --release --example mtmd -- \ + --model ./gemma-3-4b-it-Q4_K_M.gguf \ + --mmproj ./mmproj-F16.gguf \ + --image my_image.jpg \ + --prompt "What is in the picture?" \ + --no-gpu \ + --no-mmproj-offload \ + --marker "" +``` From e1f1e0435ddf6d0ad5eec29a0926150d640da07d Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Fri, 1 Aug 2025 19:46:53 +0200 Subject: [PATCH 10/13] Add some doctests Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- llama-cpp-2/src/mtmd.rs | 95 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 2 deletions(-) diff --git a/llama-cpp-2/src/mtmd.rs b/llama-cpp-2/src/mtmd.rs index bbe9ff71..5b790c8b 100644 --- a/llama-cpp-2/src/mtmd.rs +++ b/llama-cpp-2/src/mtmd.rs @@ -14,6 +14,19 @@ use crate::model::LlamaModel; use crate::token::LlamaToken; /// Input chunk types for multimodal data +/// +/// # Examples +/// +/// ``` +/// use llama_cpp_2::mtmd::MtmdInputChunkType; +/// +/// let text_chunk = MtmdInputChunkType::Text; +/// let image_chunk = MtmdInputChunkType::Image; +/// let audio_chunk = MtmdInputChunkType::Audio; +/// +/// assert_eq!(text_chunk, MtmdInputChunkType::Text); +/// assert_ne!(text_chunk, image_chunk); +/// ``` #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MtmdInputChunkType { /// Text input chunk @@ -36,6 +49,20 @@ impl From for MtmdInputChunkType { } /// Configuration parameters for MTMD context +/// +/// # Examples +/// +/// ``` +/// use llama_cpp_2::mtmd::{MtmdContextParams, mtmd_default_marker}; +/// use std::ffi::CString; +/// +/// let params = MtmdContextParams { +/// use_gpu: false, +/// print_timings: true, +/// n_threads: 4, +/// media_marker: CString::new(mtmd_default_marker()).unwrap(), +/// }; +/// ``` #[derive(Debug, Clone)] pub struct MtmdContextParams { /// Whether to use GPU acceleration @@ -73,6 +100,18 @@ impl From<&MtmdContextParams> for llama_cpp_sys_2::mtmd_context_params { } /// Text input configuration +/// +/// # Examples +/// +/// ``` +/// use llama_cpp_2::mtmd::MtmdInputText; +/// +/// let input = MtmdInputText { +/// text: "Describe this image.".to_string(), +/// add_special: true, +/// parse_special: true, +/// }; +/// ``` #[derive(Debug, Clone)] pub struct MtmdInputText { /// The input text string @@ -301,6 +340,19 @@ impl MtmdBitmap { /// /// * `InvalidDataSize` - Data length doesn't match `nx * ny * 3` /// * `NullResult` - Underlying C function returned null + /// + /// # Examples + /// + /// ``` + /// use llama_cpp_2::mtmd::MtmdBitmap; + /// + /// // Create a 2x2 red image + /// let red_pixel = [255, 0, 0]; // RGB values for red + /// let image_data = red_pixel.repeat(4); // 2x2 = 4 pixels + /// + /// let bitmap = MtmdBitmap::from_image_data(2, 2, &image_data); + /// assert!(bitmap.is_ok()); + /// ``` pub fn from_image_data(nx: u32, ny: u32, data: &[u8]) -> Result { if data.len() != (nx * ny * 3) as usize { return Err(MtmdBitmapError::InvalidDataSize); @@ -325,6 +377,20 @@ impl MtmdBitmap { /// # Errors /// /// * `NullResult` - Underlying C function returned null + /// + /// # Examples + /// + /// ``` + /// use llama_cpp_2::mtmd::MtmdBitmap; + /// + /// // Create a simple sine wave audio sample + /// let audio_data: Vec = (0..100) + /// .map(|i| (i as f32 * 0.1).sin()) + /// .collect(); + /// + /// let bitmap = MtmdBitmap::from_audio_data(&audio_data); + /// // Note: This will likely fail without proper MTMD context setup + /// ``` pub fn from_audio_data(data: &[f32]) -> Result { let bitmap = unsafe { llama_cpp_sys_2::mtmd_bitmap_init_from_audio(data.len(), data.as_ptr()) }; @@ -457,6 +523,17 @@ impl MtmdBitmap { /// # Errors /// /// Returns an error if the ID string contains null bytes. + /// + /// # Examples + /// + /// ```no_run + /// # use llama_cpp_2::mtmd::MtmdBitmap; + /// # fn example(bitmap: &MtmdBitmap) -> Result<(), Box> { + /// bitmap.set_id("image_001")?; + /// assert_eq!(bitmap.id(), Some("image_001".to_string())); + /// # Ok(()) + /// # } + /// ``` pub fn set_id(&self, id: &str) -> Result<(), std::ffi::NulError> { let id_cstr = CString::new(id)?; unsafe { @@ -493,6 +570,16 @@ impl MtmdInputChunks { /// # Panics /// This function will panic if the underlying llama.cpp function returns null, /// which should not happen. + /// + /// # Examples + /// + /// ``` + /// use llama_cpp_2::mtmd::MtmdInputChunks; + /// + /// let chunks = MtmdInputChunks::new(); + /// assert_eq!(chunks.len(), 0); + /// assert!(chunks.is_empty()); + /// ``` #[must_use] pub fn new() -> Self { let chunks = unsafe { llama_cpp_sys_2::mtmd_input_chunks_init() }; let chunks = NonNull::new(chunks).unwrap(); @@ -721,12 +808,16 @@ impl Drop for MtmdInputChunk { /// /// Returns the default media marker as a string slice. /// -/// # Example +/// # Examples /// /// ``` -/// # use llama_cpp_2::mtmd::mtmd_default_marker; +/// use llama_cpp_2::mtmd::mtmd_default_marker; +/// /// let marker = mtmd_default_marker(); +/// assert!(!marker.is_empty()); +/// /// let text = format!("Describe this image: {}", marker); +/// assert!(text.contains(marker)); /// ``` #[must_use] pub fn mtmd_default_marker() -> &'static str { unsafe { From d0254656235576803f9e41067d180381e37c89c5 Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Wed, 13 Aug 2025 23:33:24 +0200 Subject: [PATCH 11/13] Review round 1 * Remove unsafe Send, Sync * Cleanup error handling * Use default mtmd_context directly Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- examples/mtmd/src/mtmd.rs | 4 +- llama-cpp-2/src/mtmd.rs | 120 ++++++++++++++++++++++++-------------- llama-cpp-sys-2/build.rs | 4 +- 3 files changed, 81 insertions(+), 47 deletions(-) diff --git a/examples/mtmd/src/mtmd.rs b/examples/mtmd/src/mtmd.rs index 06b78abd..e6fd367e 100644 --- a/examples/mtmd/src/mtmd.rs +++ b/examples/mtmd/src/mtmd.rs @@ -10,7 +10,9 @@ use llama_cpp_2::context::params::LlamaContextParams; use llama_cpp_2::context::LlamaContext; use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::params::LlamaModelParams; -use llama_cpp_2::mtmd::{MtmdBitmap, MtmdBitmapError, MtmdContext, MtmdContextParams, MtmdInputText}; +use llama_cpp_2::mtmd::{ + MtmdBitmap, MtmdBitmapError, MtmdContext, MtmdContextParams, MtmdInputText, +}; use llama_cpp_2::llama_backend::LlamaBackend; use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel, Special}; diff --git a/llama-cpp-2/src/mtmd.rs b/llama-cpp-2/src/mtmd.rs index 5b790c8b..e1f712ad 100644 --- a/llama-cpp-2/src/mtmd.rs +++ b/llama-cpp-2/src/mtmd.rs @@ -77,28 +77,42 @@ pub struct MtmdContextParams { impl Default for MtmdContextParams { fn default() -> Self { - Self { - use_gpu: false, - print_timings: true, - n_threads: 4, - media_marker: CString::new(mtmd_default_marker()).unwrap_or_default(), - } + unsafe { llama_cpp_sys_2::mtmd_context_params_default() }.into() } } impl From<&MtmdContextParams> for llama_cpp_sys_2::mtmd_context_params { fn from(params: &MtmdContextParams) -> Self { let mut context = unsafe { llama_cpp_sys_2::mtmd_context_params_default() }; - - context.use_gpu = params.use_gpu; - context.print_timings = params.print_timings; - context.n_threads = params.n_threads; - context.media_marker = params.media_marker.as_ptr(); + let MtmdContextParams { + use_gpu, + print_timings, + n_threads, + media_marker, + } = params; + + context.use_gpu = *use_gpu; + context.print_timings = *print_timings; + context.n_threads = *n_threads; + context.media_marker = media_marker.as_ptr(); context } } +impl From for MtmdContextParams { + fn from(params: llama_cpp_sys_2::mtmd_context_params) -> Self { + Self { + use_gpu: params.use_gpu, + print_timings: params.print_timings, + n_threads: params.n_threads, + media_marker: unsafe { CStr::from_ptr(params.media_marker) } + .to_owned() + .into(), + } + } +} + /// Text input configuration /// /// # Examples @@ -165,16 +179,13 @@ impl MtmdContext { ) }; - if context.is_null() { - return Err(MtmdInitError::NullResult); - } - let context = NonNull::new(context).ok_or(MtmdInitError::NullResult)?; Ok(Self { context }) } /// Check whether non-causal attention mask is needed before `llama_decode`. - #[must_use] pub fn decode_use_non_causal(&self) -> bool { + #[must_use] + pub fn decode_use_non_causal(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_decode_use_non_causal(self.context.as_ptr()) } } @@ -182,23 +193,27 @@ impl MtmdContext { /// /// M-RoPE (Multimodal Rotary Position Embedding) affects how positions /// are calculated for multimodal inputs. - #[must_use] pub fn decode_use_mrope(&self) -> bool { + #[must_use] + pub fn decode_use_mrope(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_decode_use_mrope(self.context.as_ptr()) } } /// Check whether the current model supports vision input. - #[must_use] pub fn support_vision(&self) -> bool { + #[must_use] + pub fn support_vision(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_support_vision(self.context.as_ptr()) } } /// Check whether the current model supports audio input. - #[must_use] pub fn support_audio(&self) -> bool { + #[must_use] + pub fn support_audio(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_support_audio(self.context.as_ptr()) } } /// Get audio bitrate in Hz (e.g., 16000 for Whisper). /// Returns -1 if audio is not supported. - #[must_use] pub fn get_audio_bitrate(&self) -> i32 { + #[must_use] + pub fn get_audio_bitrate(&self) -> i32 { unsafe { llama_cpp_sys_2::mtmd_get_audio_bitrate(self.context.as_ptr()) } } @@ -243,7 +258,7 @@ impl MtmdContext { bitmaps: &[&MtmdBitmap], ) -> Result { let chunks = MtmdInputChunks::new(); - let text_cstring = CString::new(text.text).unwrap_or_default(); + let text_cstring = CString::new(text.text)?; let input_text = llama_cpp_sys_2::mtmd_input_text { text: text_cstring.as_ptr(), add_special: text.add_special, @@ -304,9 +319,6 @@ impl MtmdContext { } } -unsafe impl Send for MtmdContext {} -unsafe impl Sync for MtmdContext {} - impl Drop for MtmdContext { fn drop(&mut self) { unsafe { llama_cpp_sys_2::mtmd_free(self.context.as_ptr()) } @@ -471,12 +483,14 @@ impl MtmdBitmap { } /// Get bitmap width in pixels. - #[must_use] pub fn nx(&self) -> u32 { + #[must_use] + pub fn nx(&self) -> u32 { unsafe { llama_cpp_sys_2::mtmd_bitmap_get_nx(self.bitmap.as_ptr()) } } /// Get bitmap height in pixels. - #[must_use] pub fn ny(&self) -> u32 { + #[must_use] + pub fn ny(&self) -> u32 { unsafe { llama_cpp_sys_2::mtmd_bitmap_get_ny(self.bitmap.as_ptr()) } } @@ -484,14 +498,16 @@ impl MtmdBitmap { /// /// For images: RGB format with length `nx * ny * 3` /// For audio: PCM F32 format with length `n_samples * 4` - #[must_use] pub fn data(&self) -> &[u8] { + #[must_use] + pub fn data(&self) -> &[u8] { let ptr = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_data(self.bitmap.as_ptr()) }; let len = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_n_bytes(self.bitmap.as_ptr()) }; unsafe { slice::from_raw_parts(ptr, len) } } /// Check if this bitmap contains audio data (vs image data). - #[must_use] pub fn is_audio(&self) -> bool { + #[must_use] + pub fn is_audio(&self) -> bool { unsafe { llama_cpp_sys_2::mtmd_bitmap_is_audio(self.bitmap.as_ptr()) } } @@ -499,15 +515,16 @@ impl MtmdBitmap { /// /// Bitmap ID is useful for KV cache tracking and can e.g. be calculated /// based on a hash of the bitmap data. - #[must_use] pub fn id(&self) -> Option { + #[must_use] + pub fn id(&self) -> Option { let ptr = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_id(self.bitmap.as_ptr()) }; if ptr.is_null() { None } else { - unsafe { CStr::from_ptr(ptr) } + let id = unsafe { CStr::from_ptr(ptr) } .to_string_lossy() - .into_owned() - .into() + .into_owned(); + Some(id) } } @@ -580,24 +597,28 @@ impl MtmdInputChunks { /// assert_eq!(chunks.len(), 0); /// assert!(chunks.is_empty()); /// ``` - #[must_use] pub fn new() -> Self { + #[must_use] + pub fn new() -> Self { let chunks = unsafe { llama_cpp_sys_2::mtmd_input_chunks_init() }; let chunks = NonNull::new(chunks).unwrap(); Self { chunks } } /// Get the number of chunks - #[must_use] pub fn len(&self) -> usize { + #[must_use] + pub fn len(&self) -> usize { unsafe { llama_cpp_sys_2::mtmd_input_chunks_size(self.chunks.as_ptr()) } } /// Check if chunks collection is empty - #[must_use] pub fn is_empty(&self) -> bool { + #[must_use] + pub fn is_empty(&self) -> bool { self.len() == 0 } /// Get a chunk by index - #[must_use] pub fn get(&self, index: usize) -> Option { + #[must_use] + pub fn get(&self, index: usize) -> Option { if index >= self.len() { return None; } @@ -619,7 +640,8 @@ impl MtmdInputChunks { /// Get total number of tokens across all chunks. /// /// This is useful for keeping track of KV cache size. - #[must_use] pub fn total_tokens(&self) -> usize { + #[must_use] + pub fn total_tokens(&self) -> usize { unsafe { llama_cpp_sys_2::mtmd_helper_get_n_tokens(self.chunks.as_ptr()) } } @@ -627,7 +649,8 @@ impl MtmdInputChunks { /// /// This is useful to keep track of `n_past`. Normally `n_pos` equals `n_tokens`, /// but for M-RoPE it is different. - #[must_use] pub fn total_positions(&self) -> i32 { + #[must_use] + pub fn total_positions(&self) -> i32 { unsafe { llama_cpp_sys_2::mtmd_helper_get_n_pos(self.chunks.as_ptr()) } } @@ -709,7 +732,8 @@ pub struct MtmdInputChunk { impl MtmdInputChunk { /// Get the type of this chunk - #[must_use] pub fn chunk_type(&self) -> MtmdInputChunkType { + #[must_use] + pub fn chunk_type(&self) -> MtmdInputChunkType { let chunk_type = unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_type(self.chunk.as_ptr()) }; MtmdInputChunkType::from(chunk_type) } @@ -721,7 +745,8 @@ impl MtmdInputChunk { /// # Returns /// /// Returns `Some(&[LlamaToken])` for text chunks, `None` otherwise. - #[must_use] pub fn text_tokens(&self) -> Option<&[LlamaToken]> { + #[must_use] + pub fn text_tokens(&self) -> Option<&[LlamaToken]> { if self.chunk_type() != MtmdInputChunkType::Text { return None; } @@ -744,21 +769,24 @@ impl MtmdInputChunk { } /// Get the number of tokens in this chunk - #[must_use] pub fn n_tokens(&self) -> usize { + #[must_use] + pub fn n_tokens(&self) -> usize { unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_n_tokens(self.chunk.as_ptr()) } } /// Get the number of positions in this chunk. /// /// Returns the number of temporal positions (always 1 for M-RoPE, `n_tokens` otherwise). - #[must_use] pub fn n_positions(&self) -> i32 { + #[must_use] + pub fn n_positions(&self) -> i32 { unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_n_pos(self.chunk.as_ptr()) } } /// Get chunk ID if available. /// /// Returns `None` for text chunks, may return an ID for image/audio chunks. - #[must_use] pub fn id(&self) -> Option { + #[must_use] + pub fn id(&self) -> Option { let ptr = unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_id(self.chunk.as_ptr()) }; if ptr.is_null() { None @@ -819,7 +847,8 @@ impl Drop for MtmdInputChunk { /// let text = format!("Describe this image: {}", marker); /// assert!(text.contains(marker)); /// ``` -#[must_use] pub fn mtmd_default_marker() -> &'static str { +#[must_use] +pub fn mtmd_default_marker() -> &'static str { unsafe { let c_str = llama_cpp_sys_2::mtmd_default_marker(); CStr::from_ptr(c_str).to_str().unwrap_or("<__media__>") @@ -877,6 +906,9 @@ pub enum MtmdTokenizeError { /// Image preprocessing error occurred #[error("Image preprocessing error")] ImagePreprocessingError, + /// Text contains characters that cannot be converted to C string + #[error("Failed to create CString from text: {0}")] + CStringError(#[from] std::ffi::NulError), /// Unknown error occurred during tokenization #[error("Unknown error: {0}")] UnknownError(i32), diff --git a/llama-cpp-sys-2/build.rs b/llama-cpp-sys-2/build.rs index dc67ce70..daac497a 100644 --- a/llama-cpp-sys-2/build.rs +++ b/llama-cpp-sys-2/build.rs @@ -277,8 +277,8 @@ fn main() { // Configure mtmd feature if enabled if cfg!(feature = "mtmd") { bindings_builder = bindings_builder - .allowlist_function("mtmd_.*") - .allowlist_type("mtmd_.*"); + .allowlist_function("mtmd_.*") + .allowlist_type("mtmd_.*"); } // Configure Android-specific bindgen settings From 62f151116e235d8125e57df105c595e890f4190c Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Thu, 14 Aug 2025 14:11:57 +0200 Subject: [PATCH 12/13] Fix context length in mtmd example Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- examples/mtmd/src/mtmd.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/mtmd/src/mtmd.rs b/examples/mtmd/src/mtmd.rs index e6fd367e..6a704d3c 100644 --- a/examples/mtmd/src/mtmd.rs +++ b/examples/mtmd/src/mtmd.rs @@ -2,6 +2,7 @@ use std::ffi::CString; use std::io::{self, Write}; +use std::num::NonZeroU32; use std::path::Path; use clap::Parser; @@ -50,8 +51,8 @@ pub struct MtmdCliParams { #[arg(short = 't', long = "threads", value_name = "N", default_value = "4")] pub n_threads: i32, /// Maximum number of tokens in context - #[arg(long = "n-tokens", value_name = "N", default_value = "2048")] - pub n_tokens: usize, + #[arg(long = "n-tokens", value_name = "N", default_value = "4096")] + pub n_tokens: NonZeroU32, /// Chat template to use, default template if not provided #[arg(long = "chat-template", value_name = "TEMPLATE")] pub chat_template: Option, @@ -111,7 +112,7 @@ impl MtmdCliContext { .chat_template(params.chat_template.as_deref()) .map_err(|e| format!("Failed to get chat template: {e}"))?; - let batch = LlamaBatch::new(params.n_tokens, 1); + let batch = LlamaBatch::new(params.n_tokens.get() as usize, 1); Ok(Self { mtmd_ctx, @@ -285,7 +286,8 @@ fn main() -> Result<(), Box> { // Create context let context_params = LlamaContextParams::default() .with_n_threads(params.n_threads) - .with_n_batch(1); + .with_n_batch(1) + .with_n_ctx(Some(params.n_tokens)); let mut context = model.new_context(&backend, context_params)?; // Create sampler From f149f11230dd814cbab0ac9d9a0d5f2f916fa6e0 Mon Sep 17 00:00:00 2001 From: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Date: Thu, 14 Aug 2025 15:06:34 +0200 Subject: [PATCH 13/13] Review round 1: wrapper_mtmd.h Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> --- llama-cpp-sys-2/Cargo.toml | 1 + llama-cpp-sys-2/build.rs | 2 ++ llama-cpp-sys-2/wrapper.h | 2 -- llama-cpp-sys-2/wrapper_mtmd.h | 2 ++ 4 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 llama-cpp-sys-2/wrapper_mtmd.h diff --git a/llama-cpp-sys-2/Cargo.toml b/llama-cpp-sys-2/Cargo.toml index 2aea6c06..205e0983 100644 --- a/llama-cpp-sys-2/Cargo.toml +++ b/llama-cpp-sys-2/Cargo.toml @@ -9,6 +9,7 @@ links = "llama" include = [ "wrapper.h", + "wrapper_mtmd.h", "build.rs", "/src", diff --git a/llama-cpp-sys-2/build.rs b/llama-cpp-sys-2/build.rs index daac497a..4ab918e3 100644 --- a/llama-cpp-sys-2/build.rs +++ b/llama-cpp-sys-2/build.rs @@ -277,6 +277,7 @@ fn main() { // Configure mtmd feature if enabled if cfg!(feature = "mtmd") { bindings_builder = bindings_builder + .header("wrapper_mtmd.h") .allowlist_function("mtmd_.*") .allowlist_type("mtmd_.*"); } @@ -432,6 +433,7 @@ fn main() { .expect("Failed to write bindings"); println!("cargo:rerun-if-changed=wrapper.h"); + println!("cargo:rerun-if-changed=wrapper_mtmd.h"); debug_log!("Bindings Created"); diff --git a/llama-cpp-sys-2/wrapper.h b/llama-cpp-sys-2/wrapper.h index 88eb9748..7f0eeefc 100644 --- a/llama-cpp-sys-2/wrapper.h +++ b/llama-cpp-sys-2/wrapper.h @@ -1,3 +1 @@ #include "llama.cpp/include/llama.h" -#include "llama.cpp/tools/mtmd/mtmd.h" -#include "llama.cpp/tools/mtmd/mtmd-helper.h" diff --git a/llama-cpp-sys-2/wrapper_mtmd.h b/llama-cpp-sys-2/wrapper_mtmd.h new file mode 100644 index 00000000..72fb2111 --- /dev/null +++ b/llama-cpp-sys-2/wrapper_mtmd.h @@ -0,0 +1,2 @@ +#include "llama.cpp/tools/mtmd/mtmd.h" +#include "llama.cpp/tools/mtmd/mtmd-helper.h"