Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions examples/mtmd/src/mtmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ pub struct MtmdCliParams {
/// Number of threads
#[arg(short = 't', long = "threads", value_name = "N", default_value = "4")]
pub n_threads: i32,
/// Number of tokens to process in a batch during eval chunks
#[arg(long = "batch-size", value_name = "b", default_value = "1")]
pub batch_size: i32,
/// Maximum number of tokens in context
#[arg(long = "n-tokens", value_name = "N", default_value = "4096")]
pub n_tokens: NonZeroU32,
Expand Down Expand Up @@ -140,6 +143,7 @@ impl MtmdCliContext {
context: &mut LlamaContext,
msg: LlamaChatMessage,
add_bos: bool,
batch_size: i32,
) -> Result<(), Box<dyn std::error::Error>> {
self.chat.push(msg);

Expand Down Expand Up @@ -168,7 +172,7 @@ impl MtmdCliContext {
// Clear bitmaps after tokenization
self.bitmaps.clear();

self.n_past = chunks.eval_chunks(&self.mtmd_ctx, context, 0, 0, 1, true)?;
self.n_past = chunks.eval_chunks(&self.mtmd_ctx, context, 0, 0, batch_size, true)?;
Ok(())
}

Expand All @@ -186,7 +190,7 @@ impl MtmdCliContext {

for _i in 0..max_predict {
// Sample next token
let token = sampler.sample(context, 0);
let token = sampler.sample(context, -1);
generated_tokens.push(token);
sampler.accept(token);

Expand Down Expand Up @@ -244,7 +248,7 @@ fn run_single_turn(
println!("Evaluating message: {msg:?}");

// Evaluate the message (prefill)
ctx.eval_message(model, context, msg, true)?;
ctx.eval_message(model, context, msg, true, params.batch_size)?;

// Generate response (decode)
ctx.generate_response(model, context, sampler, params.n_predict)?;
Expand Down Expand Up @@ -286,7 +290,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create context
let context_params = LlamaContextParams::default()
.with_n_threads(params.n_threads)
.with_n_batch(1)
.with_n_batch(params.batch_size.try_into()?)
.with_n_ctx(Some(params.n_tokens));
let mut context = model.new_context(&backend, context_params)?;

Expand Down
44 changes: 22 additions & 22 deletions llama-cpp-2/src/mtmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,18 @@ use crate::token::LlamaToken;
/// let audio_chunk = MtmdInputChunkType::Audio;
///
/// assert_eq!(text_chunk, MtmdInputChunkType::Text);
/// assert_eq!(text_chunk, llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT.into());
/// assert_ne!(text_chunk, image_chunk);
/// ```
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum MtmdInputChunkType {
/// Text input chunk
Text = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT as isize,
Text = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT as _,
/// Image input chunk
Image = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE as isize,
Image = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE as _,
/// Audio input chunk
Audio = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO as isize,
Audio = llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO as _,
}

impl From<llama_cpp_sys_2::mtmd_input_chunk_type> for MtmdInputChunkType {
Expand All @@ -43,7 +45,7 @@ impl From<llama_cpp_sys_2::mtmd_input_chunk_type> for MtmdInputChunkType {
llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT => MtmdInputChunkType::Text,
llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE => MtmdInputChunkType::Image,
llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO => MtmdInputChunkType::Audio,
_ => panic!("Unknown MTMD input chunk type"),
_ => panic!("Unknown MTMD input chunk type: {chunk_type}"),
}
}
}
Expand Down Expand Up @@ -106,9 +108,7 @@ impl From<llama_cpp_sys_2::mtmd_context_params> for MtmdContextParams {
use_gpu: params.use_gpu,
print_timings: params.print_timings,
n_threads: params.n_threads,
media_marker: unsafe { CStr::from_ptr(params.media_marker) }
.to_owned()
.into(),
media_marker: unsafe { CStr::from_ptr(params.media_marker) }.to_owned(),
}
}
}
Expand Down Expand Up @@ -211,10 +211,11 @@ impl MtmdContext {
}

/// Get audio bitrate in Hz (e.g., 16000 for Whisper).
/// Returns -1 if audio is not supported.
/// Returns None if audio is not supported.
#[must_use]
pub fn get_audio_bitrate(&self) -> i32 {
unsafe { llama_cpp_sys_2::mtmd_get_audio_bitrate(self.context.as_ptr()) }
pub fn get_audio_bitrate(&self) -> Option<u32> {
let rate = unsafe { llama_cpp_sys_2::mtmd_get_audio_bitrate(self.context.as_ptr()) };
(rate > 0).then_some(rate.unsigned_abs())
}

/// Tokenize input text and bitmaps into chunks.
Expand Down Expand Up @@ -275,7 +276,7 @@ impl MtmdContext {
llama_cpp_sys_2::mtmd_tokenize(
self.context.as_ptr(),
chunks.chunks.as_ptr(),
&input_text,
&raw const input_text,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

til

bitmap_ptrs.as_ptr().cast_mut(),
bitmaps.len(),
)
Expand Down Expand Up @@ -626,15 +627,11 @@ impl MtmdInputChunks {
let chunk_ptr =
unsafe { llama_cpp_sys_2::mtmd_input_chunks_get(self.chunks.as_ptr(), index) };

if chunk_ptr.is_null() {
None
} else {
// Note: We don't own this chunk, it's owned by the chunks collection
Some(MtmdInputChunk {
chunk: NonNull::new(chunk_ptr.cast_mut()).unwrap(),
owned: false,
})
}
// Note: We don't own this chunk, it's owned by the chunks collection
NonNull::new(chunk_ptr.cast_mut()).map(|ptr| MtmdInputChunk {
chunk: ptr,
owned: false,
})
}

/// Get total number of tokens across all chunks.
Expand Down Expand Up @@ -701,7 +698,7 @@ impl MtmdInputChunks {
seq_id,
n_batch,
logits_last,
&mut new_n_past,
&raw mut new_n_past,
)
};

Expand Down Expand Up @@ -753,7 +750,10 @@ impl MtmdInputChunk {

let mut n_tokens = 0usize;
let tokens_ptr = unsafe {
llama_cpp_sys_2::mtmd_input_chunk_get_tokens_text(self.chunk.as_ptr(), &mut n_tokens)
llama_cpp_sys_2::mtmd_input_chunk_get_tokens_text(
self.chunk.as_ptr(),
&raw mut n_tokens,
)
};

if tokens_ptr.is_null() || n_tokens == 0 {
Expand Down
Loading