Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ members = [
"examples/embeddings",
"examples/simple",
"examples/reranker",
"examples/mtmd",
]

[workspace.dependencies]
Expand Down
21 changes: 21 additions & 0 deletions examples/mtmd/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[package]
name = "mtmd"
version = "0.1.86"
edition = "2021"

[dependencies]
llama-cpp-2 = { path = "../../llama-cpp-2", version = "0.1.86", features = ["mtmd"] }
clap = { workspace = true, features = ["derive"] }

[features]
cuda = ["llama-cpp-2/cuda"]
metal = ["llama-cpp-2/metal"]
native = ["llama-cpp-2/native"]
vulkan = ["llama-cpp-2/vulkan"]

[lints]
workspace = true

[[example]]
name = "mtmd"
path = "src/mtmd.rs"
27 changes: 27 additions & 0 deletions examples/mtmd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Rust mtmd-cli implementation

Partial port of the mtmd-cli.cpp example in the llama-cpp repository.

## Usage

### Command Line Interface

To run the mtmd example, you first need to download the model gguf file and the multimodal projection file, e.g. for Gemma3 you may use:

```sh
wget https://huggingface.co/unsloth/gemma-3-4b-it-GGUF/resolve/main/gemma-3-4b-it-Q4_K_M.gguf \
https://huggingface.co/unsloth/gemma-3-4b-it-GGUF/resolve/main/mmproj-F16.gguf
```

To then run the example on CPU, provide an image file `my_image.jpg` and run:

```sh
cargo run --release --example mtmd -- \
--model ./gemma-3-4b-it-Q4_K_M.gguf \
--mmproj ./mmproj-F16.gguf \
--image my_image.jpg \
--prompt "What is in the picture?" \
--no-gpu \
--no-mmproj-offload \
--marker "<start_of_image>"
```
307 changes: 307 additions & 0 deletions examples/mtmd/src/mtmd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
//! Based on the mtmd cli example from llama.cpp.

use std::ffi::CString;
use std::io::{self, Write};
use std::num::NonZeroU32;
use std::path::Path;

use clap::Parser;

use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::context::LlamaContext;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::mtmd::{
MtmdBitmap, MtmdBitmapError, MtmdContext, MtmdContextParams, MtmdInputText,
};

use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel, Special};
use llama_cpp_2::sampling::LlamaSampler;

/// Command line parameters for the MTMD CLI application
#[derive(clap::Parser, Debug)]
#[command(name = "mtmd-cli")]
#[command(about = "Experimental CLI for multimodal llama.cpp")]
pub struct MtmdCliParams {
/// Path to the model file
#[arg(short = 'm', long = "model", value_name = "PATH")]
pub model_path: String,
/// Path to the multimodal projection file
#[arg(long = "mmproj", value_name = "PATH")]
pub mmproj_path: String,
/// Path to image file(s)
#[arg(long = "image", value_name = "PATH")]
pub images: Vec<String>,
/// Path to audio file(s)
#[arg(long = "audio", value_name = "PATH")]
pub audio: Vec<String>,
/// Text prompt to use as input to the model. May include media markers - else they will be added automatically.
#[arg(short = 'p', long = "prompt", value_name = "TEXT")]
pub prompt: String,
/// Number of tokens to predict (-1 for unlimited)
#[arg(
short = 'n',
long = "n-predict",
value_name = "N",
default_value = "-1"
)]
pub n_predict: i32,
/// Number of threads
#[arg(short = 't', long = "threads", value_name = "N", default_value = "4")]
pub n_threads: i32,
/// Maximum number of tokens in context
#[arg(long = "n-tokens", value_name = "N", default_value = "4096")]
pub n_tokens: NonZeroU32,
/// Chat template to use, default template if not provided
#[arg(long = "chat-template", value_name = "TEMPLATE")]
pub chat_template: Option<String>,
/// Disable GPU acceleration
#[arg(long = "no-gpu")]
pub no_gpu: bool,
/// Disable GPU offload for multimodal projection
#[arg(long = "no-mmproj-offload")]
pub no_mmproj_offload: bool,
/// Media marker. If not provided, the default marker will be used.
#[arg(long = "marker", value_name = "TEXT")]
pub media_marker: Option<String>,
}

/// State of the MTMD CLI application.
#[allow(missing_debug_implementations)]
pub struct MtmdCliContext {
/// The MTMD context for multimodal processing.
pub mtmd_ctx: MtmdContext,
/// The batch used for processing tokens.
pub batch: LlamaBatch,
/// The list of loaded bitmaps (images/audio).
pub bitmaps: Vec<MtmdBitmap>,
/// The number of past tokens processed.
pub n_past: i32,
/// The chat template used for formatting messages.
pub chat_template: LlamaChatTemplate,
/// The current chat messages history.
pub chat: Vec<LlamaChatMessage>,
}

impl MtmdCliContext {
/// Creates a new MTMD CLI context
///
/// # Errors
pub fn new(
params: &MtmdCliParams,
model: &LlamaModel,
) -> Result<Self, Box<dyn std::error::Error>> {
// Initialize MTMD context
let mtmd_params = MtmdContextParams {
use_gpu: !params.no_gpu && !params.no_mmproj_offload,
print_timings: true,
n_threads: params.n_threads,
media_marker: CString::new(
params
.media_marker
.as_ref()
.unwrap_or(&llama_cpp_2::mtmd::mtmd_default_marker().to_string())
.clone(),
)?,
};

let mtmd_ctx = MtmdContext::init_from_file(&params.mmproj_path, model, &mtmd_params)?;

let chat_template = model
.chat_template(params.chat_template.as_deref())
.map_err(|e| format!("Failed to get chat template: {e}"))?;

let batch = LlamaBatch::new(params.n_tokens.get() as usize, 1);

Ok(Self {
mtmd_ctx,
batch,
chat: Vec::new(),
bitmaps: Vec::new(),
n_past: 0,
chat_template,
})
}

/// Loads media (image or audio) from the specified file path
/// # Errors
pub fn load_media(&mut self, path: &str) -> Result<(), MtmdBitmapError> {
let bitmap = MtmdBitmap::from_file(&self.mtmd_ctx, path)?;
self.bitmaps.push(bitmap);
Ok(())
}

/// Evaluates a chat message, tokenizing and processing it through the model
/// # Errors
pub fn eval_message(
&mut self,
model: &LlamaModel,
context: &mut LlamaContext,
msg: LlamaChatMessage,
add_bos: bool,
) -> Result<(), Box<dyn std::error::Error>> {
self.chat.push(msg);

// Format the message using chat template (simplified)
let formatted_prompt = model.apply_chat_template(&self.chat_template, &self.chat, true)?;

let input_text = MtmdInputText {
text: formatted_prompt,
add_special: add_bos,
parse_special: true,
};

let bitmap_refs: Vec<&MtmdBitmap> = self.bitmaps.iter().collect();

if bitmap_refs.is_empty() {
println!("No bitmaps provided, only tokenizing text");
} else {
println!("Tokenizing with {} bitmaps", bitmap_refs.len());
}

// Tokenize the input
let chunks = self.mtmd_ctx.tokenize(input_text, &bitmap_refs)?;

println!("Tokenization complete, {} chunks created", chunks.len());

// Clear bitmaps after tokenization
self.bitmaps.clear();

self.n_past = chunks.eval_chunks(&self.mtmd_ctx, context, 0, 0, 1, true)?;
Ok(())
}

/// Generates a response by sampling tokens from the model
/// # Errors
pub fn generate_response(
&mut self,
model: &LlamaModel,
context: &mut LlamaContext,
sampler: &mut LlamaSampler,
n_predict: i32,
) -> Result<(), Box<dyn std::error::Error>> {
let mut generated_tokens = Vec::new();
let max_predict = if n_predict < 0 { i32::MAX } else { n_predict };

for _i in 0..max_predict {
// Sample next token
let token = sampler.sample(context, 0);
generated_tokens.push(token);
sampler.accept(token);

// Check for end of generation
if model.is_eog_token(token) {
println!();
break;
}

// Print token
let piece = model.token_to_str(token, Special::Tokenize)?;
print!("{piece}");
io::stdout().flush()?;

// Prepare next batch
self.batch.clear();
self.batch.add(token, self.n_past, &[0], true)?;
self.n_past += 1;

// Decode
context.decode(&mut self.batch)?;
}

Ok(())
}
}

fn run_single_turn(
ctx: &mut MtmdCliContext,
model: &LlamaModel,
context: &mut LlamaContext,
sampler: &mut LlamaSampler,
params: &MtmdCliParams,
) -> Result<(), Box<dyn std::error::Error>> {
// Add media marker if not present
let mut prompt = params.prompt.clone();
let default_marker = llama_cpp_2::mtmd::mtmd_default_marker().to_string();
let media_marker = params.media_marker.as_ref().unwrap_or(&default_marker);
if !prompt.contains(media_marker) {
prompt.push_str(media_marker);
}

// Load media files
for image_path in &params.images {
println!("Loading image: {image_path}");
ctx.load_media(image_path)?;
}
for audio_path in &params.audio {
ctx.load_media(audio_path)?;
}

// Create user message
let msg = LlamaChatMessage::new("user".to_string(), prompt)?;

println!("Evaluating message: {msg:?}");

// Evaluate the message (prefill)
ctx.eval_message(model, context, msg, true)?;

// Generate response (decode)
ctx.generate_response(model, context, sampler, params.n_predict)?;

Ok(())
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
let params = MtmdCliParams::parse();

// Validate required parameters
if !Path::new(&params.model_path).exists() {
eprintln!("Error: Model file not found: {}", params.model_path);
return Err("Model file not found".into());
}

if !Path::new(&params.mmproj_path).exists() {
eprintln!(
"Error: Multimodal projection file not found: {}",
params.mmproj_path
);
return Err("Multimodal projection file not found".into());
}

println!("Loading model: {}", params.model_path);

// Initialize backend
let backend = LlamaBackend::init()?;

// Setup model parameters
let mut model_params = LlamaModelParams::default();
if !params.no_gpu {
model_params = model_params.with_n_gpu_layers(1_000_000); // Use all layers on GPU
}

// Load model
let model = LlamaModel::load_from_file(&backend, &params.model_path, &model_params)?;

// Create context
let context_params = LlamaContextParams::default()
.with_n_threads(params.n_threads)
.with_n_batch(1)
.with_n_ctx(Some(params.n_tokens));
let mut context = model.new_context(&backend, context_params)?;

// Create sampler
let mut sampler = LlamaSampler::chain_simple([LlamaSampler::greedy()]);

println!("Model loaded successfully");
println!("Loading mtmd projection: {}", params.mmproj_path);

// Create the MTMD context
let mut ctx = MtmdCliContext::new(&params, &model)?;

run_single_turn(&mut ctx, &model, &mut context, &mut sampler, &params)?;

println!("\n");

Ok(())
}
Loading
Loading