diff --git a/examples/mtmd/src/mtmd.rs b/examples/mtmd/src/mtmd.rs index 6a704d3c..e0f52de9 100644 --- a/examples/mtmd/src/mtmd.rs +++ b/examples/mtmd/src/mtmd.rs @@ -50,6 +50,9 @@ pub struct MtmdCliParams { /// Number of threads #[arg(short = 't', long = "threads", value_name = "N", default_value = "4")] pub n_threads: i32, + /// Number of tokens to process in a batch during eval chunks + #[arg(long = "batch-size", value_name = "b", default_value = "1")] + pub batch_size: i32, /// Maximum number of tokens in context #[arg(long = "n-tokens", value_name = "N", default_value = "4096")] pub n_tokens: NonZeroU32, @@ -140,6 +143,7 @@ impl MtmdCliContext { context: &mut LlamaContext, msg: LlamaChatMessage, add_bos: bool, + batch_size: i32, ) -> Result<(), Box> { self.chat.push(msg); @@ -168,7 +172,7 @@ impl MtmdCliContext { // Clear bitmaps after tokenization self.bitmaps.clear(); - self.n_past = chunks.eval_chunks(&self.mtmd_ctx, context, 0, 0, 1, true)?; + self.n_past = chunks.eval_chunks(&self.mtmd_ctx, context, 0, 0, batch_size, true)?; Ok(()) } @@ -186,7 +190,7 @@ impl MtmdCliContext { for _i in 0..max_predict { // Sample next token - let token = sampler.sample(context, 0); + let token = sampler.sample(context, -1); generated_tokens.push(token); sampler.accept(token); @@ -244,7 +248,7 @@ fn run_single_turn( println!("Evaluating message: {msg:?}"); // Evaluate the message (prefill) - ctx.eval_message(model, context, msg, true)?; + ctx.eval_message(model, context, msg, true, params.batch_size)?; // Generate response (decode) ctx.generate_response(model, context, sampler, params.n_predict)?; @@ -286,7 +290,7 @@ fn main() -> Result<(), Box> { // Create context let context_params = LlamaContextParams::default() .with_n_threads(params.n_threads) - .with_n_batch(1) + .with_n_batch(params.batch_size.try_into()?) .with_n_ctx(Some(params.n_tokens)); let mut context = model.new_context(&backend, context_params)?; diff --git a/llama-cpp-2/src/mtmd.rs b/llama-cpp-2/src/mtmd.rs index e1f712ad..71b21ef8 100644 --- a/llama-cpp-2/src/mtmd.rs +++ b/llama-cpp-2/src/mtmd.rs @@ -25,16 +25,18 @@ use crate::token::LlamaToken; /// let audio_chunk = MtmdInputChunkType::Audio; /// /// assert_eq!(text_chunk, MtmdInputChunkType::Text); +/// assert_eq!(text_chunk, llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT.into()); /// assert_ne!(text_chunk, image_chunk); /// ``` #[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u32)] pub enum MtmdInputChunkType { /// Text input chunk - Text = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT as isize, + Text = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT as _, /// Image input chunk - Image = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE as isize, + Image = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE as _, /// Audio input chunk - Audio = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO as isize, + Audio = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO as _, } impl From for MtmdInputChunkType { @@ -43,7 +45,7 @@ impl From for MtmdInputChunkType { llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT => MtmdInputChunkType::Text, llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE => MtmdInputChunkType::Image, llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO => MtmdInputChunkType::Audio, - _ => panic!("Unknown MTMD input chunk type"), + _ => panic!("Unknown MTMD input chunk type: {chunk_type}"), } } } @@ -106,9 +108,7 @@ impl From for MtmdContextParams { use_gpu: params.use_gpu, print_timings: params.print_timings, n_threads: params.n_threads, - media_marker: unsafe { CStr::from_ptr(params.media_marker) } - .to_owned() - .into(), + media_marker: unsafe { CStr::from_ptr(params.media_marker) }.to_owned(), } } } @@ -211,10 +211,11 @@ impl MtmdContext { } /// Get audio bitrate in Hz (e.g., 16000 for Whisper). - /// Returns -1 if audio is not supported. + /// Returns None if audio is not supported. #[must_use] - pub fn get_audio_bitrate(&self) -> i32 { - unsafe { llama_cpp_sys_2::mtmd_get_audio_bitrate(self.context.as_ptr()) } + pub fn get_audio_bitrate(&self) -> Option { + let rate = unsafe { llama_cpp_sys_2::mtmd_get_audio_bitrate(self.context.as_ptr()) }; + (rate > 0).then_some(rate.unsigned_abs()) } /// Tokenize input text and bitmaps into chunks. @@ -275,7 +276,7 @@ impl MtmdContext { llama_cpp_sys_2::mtmd_tokenize( self.context.as_ptr(), chunks.chunks.as_ptr(), - &input_text, + &raw const input_text, bitmap_ptrs.as_ptr().cast_mut(), bitmaps.len(), ) @@ -626,15 +627,11 @@ impl MtmdInputChunks { let chunk_ptr = unsafe { llama_cpp_sys_2::mtmd_input_chunks_get(self.chunks.as_ptr(), index) }; - if chunk_ptr.is_null() { - None - } else { - // Note: We don't own this chunk, it's owned by the chunks collection - Some(MtmdInputChunk { - chunk: NonNull::new(chunk_ptr.cast_mut()).unwrap(), - owned: false, - }) - } + // Note: We don't own this chunk, it's owned by the chunks collection + NonNull::new(chunk_ptr.cast_mut()).map(|ptr| MtmdInputChunk { + chunk: ptr, + owned: false, + }) } /// Get total number of tokens across all chunks. @@ -701,7 +698,7 @@ impl MtmdInputChunks { seq_id, n_batch, logits_last, - &mut new_n_past, + &raw mut new_n_past, ) }; @@ -753,7 +750,10 @@ impl MtmdInputChunk { let mut n_tokens = 0usize; let tokens_ptr = unsafe { - llama_cpp_sys_2::mtmd_input_chunk_get_tokens_text(self.chunk.as_ptr(), &mut n_tokens) + llama_cpp_sys_2::mtmd_input_chunk_get_tokens_text( + self.chunk.as_ptr(), + &raw mut n_tokens, + ) }; if tokens_ptr.is_null() || n_tokens == 0 {