diff --git a/Cargo.lock b/Cargo.lock index bbc08ed0..5a44f42b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -445,6 +445,16 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "embeddings" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "hf-hub", + "llama-cpp-2", +] + [[package]] name = "encode_unicode" version = "0.3.6" diff --git a/Cargo.toml b/Cargo.toml index ec70b251..4cc4588f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ resolver = "2" members = [ "llama-cpp-sys-2", "llama-cpp-2", - "simple", + "simple", "embeddings", ] [workspace.dependencies] diff --git a/embeddings/Cargo.toml b/embeddings/Cargo.toml new file mode 100644 index 00000000..86f4e5aa --- /dev/null +++ b/embeddings/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "embeddings" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +llama-cpp-2 = { path = "../llama-cpp-2", version = "0.1.34" } +hf-hub = { workspace = true } +clap = { workspace = true , features = ["derive"] } +anyhow = { workspace = true } + +[lints] +workspace = true diff --git a/embeddings/src/main.rs b/embeddings/src/main.rs new file mode 100644 index 00000000..de9d33ce --- /dev/null +++ b/embeddings/src/main.rs @@ -0,0 +1,215 @@ +//! This is a translation of embedding.cpp in llama.cpp using llama-cpp-2. +#![allow( +clippy::cast_possible_wrap, +clippy::cast_possible_truncation, +clippy::cast_precision_loss, +clippy::cast_sign_loss +)] + +use std::io::Write; +use std::path::PathBuf; +use std::str::FromStr; +use std::time::Duration; + +use anyhow::{bail, Context, Result}; +use clap::Parser; +use hf_hub::api::sync::ApiBuilder; + +use llama_cpp_2::context::LlamaContext; +use llama_cpp_2::context::params::LlamaContextParams; +use llama_cpp_2::ggml_time_us; +use llama_cpp_2::llama_backend::LlamaBackend; +use llama_cpp_2::llama_batch::LlamaBatch; +use llama_cpp_2::model::AddBos; +use llama_cpp_2::model::LlamaModel; +use llama_cpp_2::model::params::LlamaModelParams; + +#[derive(clap::Parser, Debug, Clone)] +struct Args { + /// The path to the model + #[command(subcommand)] + model: Model, + /// The prompt + #[clap(default_value = "Hello my name is")] + prompt: String, + /// Whether to normalise the produced embeddings + #[clap(short)] + normalise: bool, + /// Disable offloading layers to the gpu + #[cfg(feature = "cublas")] + #[clap(long)] + disable_gpu: bool, +} + + +#[derive(clap::Subcommand, Debug, Clone)] +enum Model { + /// Use an already downloaded model + Local { + /// The path to the model. e.g. `/home/marcus/.cache/huggingface/hub/models--TheBloke--Llama-2-7B-Chat-GGUF/blobs/08a5566d61d7cb6b420c3e4387a39e0078e1f2fe5f055f3a03887385304d4bfa` + path: PathBuf, + }, + /// Download a model from huggingface (or use a cached version) + #[clap(name = "hf-model")] + HuggingFace { + /// the repo containing the model. e.g. `BAAI/bge-small-en-v1.5` + repo: String, + /// the model name. e.g. `BAAI-bge-small-v1.5.Q4_K_M.gguf` + model: String, + }, +} + +impl Model { + /// Convert the model to a path - may download from huggingface + fn get_or_load(self) -> Result { + match self { + Model::Local { path } => Ok(path), + Model::HuggingFace { model, repo } => ApiBuilder::new() + .with_progress(true) + .build() + .with_context(|| "unable to create huggingface api")? + .model(repo) + .get(&model) + .with_context(|| "unable to download model"), + } + } +} + +fn main() -> Result<()> { + let Args { + model, + prompt, + normalise, + #[cfg(feature = "cublas")] + disable_gpu, + } = Args::parse(); + + // init LLM + let backend = LlamaBackend::init()?; + + // offload all layers to the gpu + let model_params = { + #[cfg(feature = "cublas")] + if !disable_gpu { + LlamaModelParams::default().with_n_gpu_layers(1000) + } else { + LlamaModelParams::default() + } + #[cfg(not(feature = "cublas"))] + LlamaModelParams::default() + }; + + let model_path = model + .get_or_load() + .with_context(|| "failed to get model from args")?; + + let model = LlamaModel::load_from_file(&backend, model_path, &model_params) + .with_context(|| "unable to load model")?; + + // initialize the context + let ctx_params = LlamaContextParams::default() + .with_n_threads_batch(std::thread::available_parallelism()?.get() as u32) + .with_embeddings(true); + + let mut ctx = model + .new_context(&backend, ctx_params) + .with_context(|| "unable to create the llama_context")?; + + // Split the prompt to display the batching functionality + let prompt_lines = prompt.lines(); + + // tokenize the prompt + let tokens_lines_list = prompt_lines.map(|line| model.str_to_token(&line, AddBos::Always)) + .collect::, _>>() + .with_context(|| format!("failed to tokenize {prompt}"))?; + + let n_ctx = ctx.n_ctx() as usize; + let n_ctx_train = model.n_ctx_train(); + + eprintln!("n_ctx = {n_ctx}, n_ctx_train = {n_ctx_train}"); + + if tokens_lines_list.iter().any(|tok| n_ctx < tok.len()) { + bail!("One of the provided prompts exceeds the size of the context window"); + } + + // print the prompt token-by-token + eprintln!(); + + for (i, token_line) in tokens_lines_list.iter().enumerate() { + eprintln!("Prompt {i}"); + for token in token_line { + eprintln!(" {} --> {}", token, model.token_to_str(*token)?); + } + eprintln!() + } + + std::io::stderr().flush()?; + + // create a llama_batch with the size of the context + // we use this object to submit token data for decoding + let mut batch = LlamaBatch::new(n_ctx, 1); + + let mut max_seq_id_batch = 0; + let mut output = Vec::with_capacity(tokens_lines_list.len()); + + let t_main_start = ggml_time_us(); + + for tokens in &tokens_lines_list { + // Flush the batch if the next prompt would exceed our batch size + if (batch.n_tokens() as usize + tokens.len()) > n_ctx { + batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?; + max_seq_id_batch = 0; + } + + batch.add_sequence(&tokens, max_seq_id_batch, false)?; + max_seq_id_batch += 1; + } + // Handle final batch + batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?; + + let t_main_end = ggml_time_us(); + + for (i, embeddings) in output.iter().enumerate() { + eprintln!("Embeddings {i}: {embeddings:?}"); + eprintln!(); + } + + let duration = Duration::from_micros((t_main_end - t_main_start) as u64); + let total_tokens: usize = tokens_lines_list.iter().map(|v| v.len()).sum(); + eprintln!( + "Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n", + total_tokens, + duration.as_secs_f32(), + total_tokens as f32 / duration.as_secs_f32() + ); + + println!("{}", ctx.timings()); + + Ok(()) +} + +fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, output: &mut Vec>, normalise: bool) -> Result<()> { + ctx.clear_kv_cache(); + ctx.decode(batch).with_context(|| "llama_decode() failed")?; + + for i in 0..s_batch { + let embedding = ctx.embeddings_seq_ith(i).with_context(|| "Failed to get embeddings")?; + let output_embeddings = if normalise { + normalize(embedding) + } else { + embedding.to_vec() + }; + + output.push(output_embeddings); + } + + batch.clear(); + + Ok(()) +} + +fn normalize(input: &[f32]) -> Vec { + let magnitude = input.iter().fold(0.0, |acc, &val| val.mul_add(val, acc)).sqrt(); + + input.iter().map(|&val| val / magnitude).collect() +} diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index 2cccb13a..85066718 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -2,15 +2,15 @@ use std::fmt::{Debug, Formatter}; use std::num::NonZeroI32; +use std::ptr::NonNull; +use std::slice; use crate::llama_batch::LlamaBatch; use crate::model::LlamaModel; use crate::timing::LlamaTimings; use crate::token::data::LlamaTokenData; use crate::token::LlamaToken; -use crate::DecodeError; -use std::ptr::NonNull; -use std::slice; +use crate::{DecodeError, EmbeddingsError}; pub mod kv_cache; pub mod params; @@ -24,6 +24,7 @@ pub struct LlamaContext<'a> { /// a reference to the contexts model. pub model: &'a LlamaModel, initialized_logits: Vec, + embeddings_enabled: bool, } impl Debug for LlamaContext<'_> { @@ -38,11 +39,13 @@ impl<'model> LlamaContext<'model> { pub(crate) fn new( llama_model: &'model LlamaModel, llama_context: NonNull, + embeddings_enabled: bool, ) -> Self { Self { context: llama_context, model: llama_model, initialized_logits: Vec::new(), + embeddings_enabled, } } @@ -80,6 +83,63 @@ impl<'model> LlamaContext<'model> { } } + /// Get the embeddings for the `i`th sequence in the current context. + /// + /// # Returns + /// + /// A slice containing the embeddings for the last decoded batch. + /// The size corresponds to the `n_embd` parameter of the context's model. + /// + /// # Errors + /// + /// - When the current context was constructed without enabling embeddings. + /// - If the current model had a pooling type of [`llama_cpp_sys_2::LLAMA_POOLING_TYPE_NONE`] + /// - If the given sequence index exceeds the max sequence id. + pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> { + if !self.embeddings_enabled { + return Err(EmbeddingsError::NotEnabled); + } + + unsafe { + let embedding = llama_cpp_sys_2::llama_get_embeddings_seq(self.context.as_ptr(), i); + + // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here. + if embedding.is_null() { + Err(EmbeddingsError::NonePoolType) + } else { + Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize)) + } + } + } + + /// Get the embeddings for the `i`th token in the current context. + /// + /// # Returns + /// + /// A slice containing the embeddings for the last decoded batch of the given token. + /// The size corresponds to the `n_embd` parameter of the context's model. + /// + /// # Errors + /// + /// - When the current context was constructed without enabling embeddings. + /// - When the given token didn't have logits enabled when it was passed. + /// - If the given token index exceeds the max token id. + pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> { + if !self.embeddings_enabled { + return Err(EmbeddingsError::NotEnabled); + } + + unsafe { + let embedding = llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i); + // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here. + if embedding.is_null() { + Err(EmbeddingsError::LogitsNotEnabled) + } else { + Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize)) + } + } + } + /// Get the logits for the ith token in the context. /// /// # Panics diff --git a/llama-cpp-2/src/context/params.rs b/llama-cpp-2/src/context/params.rs index fba41c2d..edf3b709 100644 --- a/llama-cpp-2/src/context/params.rs +++ b/llama-cpp-2/src/context/params.rs @@ -1,8 +1,9 @@ //! A safe wrapper around `llama_context_params`. -use llama_cpp_sys_2; use std::fmt::Debug; use std::num::NonZeroU32; +use llama_cpp_sys_2; + /// A rusty wrapper around `rope_scaling_type`. #[repr(i8)] #[derive(Copy, Clone, Debug, PartialEq, Eq)] @@ -267,6 +268,19 @@ impl LlamaContextParams { self.context_params.n_threads } + /// Get the number of threads allocated for batches. + /// + /// # Examples + /// + /// ```rust + /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); + /// assert_eq!(params.n_threads_batch(), 4); + /// ``` + #[must_use] + pub fn n_threads_batch(&self) -> u32 { + self.context_params.n_threads_batch + } + /// Set the number of threads. /// /// # Examples @@ -282,6 +296,51 @@ impl LlamaContextParams { self.context_params.n_threads = n_threads; self } + + /// Set the number of threads allocated for batches. + /// + /// # Examples + /// + /// ```rust + /// use llama_cpp_2::context::params::LlamaContextParams; + /// let params = LlamaContextParams::default() + /// .with_n_threads_batch(8); + /// assert_eq!(params.n_threads_batch(), 8); + /// ``` + #[must_use] + pub fn with_n_threads_batch(mut self, n_threads: u32) -> Self { + self.context_params.n_threads_batch = n_threads; + self + } + + /// Check whether embeddings are enabled + /// + /// # Examples + /// + /// ```rust + /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); + /// assert!(!params.embeddings()); + /// ``` + #[must_use] + pub fn embeddings(&self) -> bool { + self.context_params.embeddings + } + + /// Enable the use of embeddings + /// + /// # Examples + /// + /// ```rust + /// use llama_cpp_2::context::params::LlamaContextParams; + /// let params = LlamaContextParams::default() + /// .with_embeddings(true); + /// assert!(params.embeddings()); + /// ``` + #[must_use] + pub fn with_embeddings(mut self, embedding: bool) -> Self { + self.context_params.embeddings = embedding; + self + } } /// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`) diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index ab9efb64..b5c1b7d9 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -52,6 +52,8 @@ pub enum LLamaCppError { /// There was an error adding a token to a batch. #[error["{0}"]] BatchAddError(#[from] BatchAddError), + #[error(transparent)] + EmbeddingError(#[from] EmbeddingsError), } /// Failed to Load context @@ -76,6 +78,17 @@ pub enum DecodeError { Unknown(c_int), } +/// When embedding related functions fail +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum EmbeddingsError { + #[error("Embeddings weren't enabled in the context options")] + NotEnabled, + #[error("Logits were not enabled for the given token")] + LogitsNotEnabled, + #[error("Can't use sequence embeddings with a model supporting only LLAMA_POOLING_TYPE_NONE")] + NonePoolType, +} + /// Decode a error from llama.cpp into a [`DecodeError`]. impl From for DecodeError { fn from(value: NonZeroI32) -> Self { diff --git a/llama-cpp-2/src/llama_backend.rs b/llama-cpp-2/src/llama_backend.rs index 59b4a39a..cfd45e10 100644 --- a/llama-cpp-2/src/llama_backend.rs +++ b/llama-cpp-2/src/llama_backend.rs @@ -3,6 +3,7 @@ use crate::LLamaCppError; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering::SeqCst; +use llama_cpp_sys_2::ggml_log_level; /// Representation of an initialized llama backend /// This is required as a parameter for most llama functions as the backend must be initialized @@ -68,6 +69,19 @@ impl LlamaBackend { } Ok(LlamaBackend {}) } + + /// Change the output of llama.cpp's logging to be voided instead of pushed to `stderr`. + pub fn void_logs(&mut self) { + unsafe extern "C" fn void_log( + _level: ggml_log_level, + _text: *const ::std::os::raw::c_char, + _user_data: *mut ::std::os::raw::c_void, + ) {} + + unsafe { + llama_cpp_sys_2::llama_log_set(Some(void_log), std::ptr::null_mut()) + } + } } /// A rusty wrapper around `numa_strategy`. diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index 0748dd85..1f1ecc54 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -6,11 +6,11 @@ use llama_cpp_sys_2::{llama_batch, llama_batch_free, llama_batch_init, llama_pos /// A safe wrapper around `llama_batch`. #[derive(Debug)] pub struct LlamaBatch { - /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initilized + /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized allocated: usize, - /// The logits that are initilized. Used by [`LlamaContext`] to ensure that only initilized logits are accessed. + /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed. pub(crate) initialized_logits: Vec, - /// The llama_cpp batch. always initilize by `llama_cpp_sys_2::llama_batch_init(allocated, , )` + /// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, , )` pub(crate) llama_batch: llama_batch, } @@ -31,7 +31,7 @@ impl LlamaBatch { } /// add a token to the batch for sequences [`seq_ids`] at position [pos]. If [logits] is true, the - /// token will be initilized and can be read from after the next decode. + /// token will be initialized and can be read from after the next decode. /// /// # Panics /// @@ -90,7 +90,33 @@ impl LlamaBatch { Ok(()) } - /// Create a new `LlamaBatch` that cab contain up to `n_tokens` tokens. + + /// Add a sequence of tokens to the batch for the given sequence id. If [logits_all] is true, the + /// tokens will be initialized and can be read from after the next decode. + /// + /// Either way the last token in the sequence will have its logits set to `true`. + /// + /// # Errors + /// + /// Returns an error if there is insufficient space in the buffer + pub fn add_sequence(&mut self, tokens: &[LlamaToken], + seq_id: i32, + logits_all: bool) -> Result<(), BatchAddError> { + let n_tokens_0 = self.llama_batch.n_tokens; + let n_tokens = tokens.len(); + + if self.allocated < n_tokens_0 as usize + n_tokens { + return Err(BatchAddError::InsufficientSpace(self.allocated)); + } + + for (i, token) in tokens.iter().enumerate() { + self.add(*token, i as llama_pos, &[seq_id], logits_all || i == n_tokens - 1)?; + } + + Ok(()) + } + + /// Create a new `LlamaBatch` that can contain up to `n_tokens` tokens. /// /// # Arguments /// diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 6d2242dc..e9709e95 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -1,15 +1,16 @@ //! A safe wrapper around `llama_model`. -use crate::context::params::LlamaContextParams; +use std::ffi::CString; +use std::os::raw::c_int; +use std::path::Path; +use std::ptr::NonNull; + +use crate::{LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, TokenToStringError}; use crate::context::LlamaContext; +use crate::context::params::LlamaContextParams; use crate::llama_backend::LlamaBackend; use crate::model::params::LlamaModelParams; use crate::token::LlamaToken; use crate::token_type::LlamaTokenType; -use crate::{LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, TokenToStringError}; -use std::ffi::CString; -use std::os::raw::c_int; -use std::path::Path; -use std::ptr::NonNull; pub mod params; @@ -29,6 +30,7 @@ pub enum AddBos { /// Do not add the beginning of stream token to the start of the string. Never, } + unsafe impl Send for LlamaModel {} unsafe impl Sync for LlamaModel {} @@ -38,12 +40,12 @@ impl LlamaModel { /// /// # Panics /// - /// If the number of tokens the model was trained on does not fit into an `u16`. This should be impossible on most + /// If the number of tokens the model was trained on does not fit into an `u32`. This should be impossible on most /// platforms due to llama.cpp returning a `c_int` (i32 on most platforms) which is almost certainly positive. #[must_use] - pub fn n_ctx_train(&self) -> u16 { + pub fn n_ctx_train(&self) -> u32 { let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) }; - u16::try_from(n_ctx_train).expect("n_ctx_train fits into an u16") + u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32") } /// Get all tokens in the model. @@ -54,6 +56,7 @@ impl LlamaModel { .map(LlamaToken::new) .map(|llama_token| (llama_token, self.token_to_str(llama_token))) } + /// Get the beginning of stream token. #[must_use] pub fn token_bos(&self) -> LlamaToken { @@ -276,7 +279,7 @@ impl LlamaModel { /// # Errors /// /// See [`LlamaModelLoadError`] for more information. - #[tracing::instrument(skip_all)] + #[tracing::instrument(skip_all, fields(params))] pub fn load_from_file( _: &LlamaBackend, path: impl AsRef, @@ -290,13 +293,12 @@ impl LlamaModel { let cstr = CString::new(path)?; let llama_model = unsafe { - println!("{:?}", params.params); llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params) }; let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?; - println!("Loaded {path:?}"); + tracing::debug!(?path, "Loaded model"); Ok(LlamaModel { model }) } @@ -318,7 +320,7 @@ impl LlamaModel { }; let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?; - Ok(LlamaContext::new(self, context)) + Ok(LlamaContext::new(self, context, params.embeddings())) } } diff --git a/llama-cpp-2/src/token_type.rs b/llama-cpp-2/src/token_type.rs index 44c2dbd3..35c441a9 100644 --- a/llama-cpp-2/src/token_type.rs +++ b/llama-cpp-2/src/token_type.rs @@ -28,15 +28,15 @@ pub enum LlamaTokenType { /// /// ``` /// # use std::convert::TryFrom; -/// # use std::ffi::c_uint; +/// # use std::ffi::c_int; /// # use std::num::TryFromIntError; /// # use std::result::Result; /// # use llama_cpp_2::token_type::{LlamaTokenTypeFromIntError, LlamaTokenType}; /// # fn main() -> Result<(), LlamaTokenTypeFromIntError> { -/// let llama_token_type = LlamaTokenType::try_from(0 as c_uint)?; +/// let llama_token_type = LlamaTokenType::try_from(0 as llama_cpp_sys_2::llama_token_type)?; /// assert_eq!(llama_token_type, LlamaTokenType::Undefined); /// -/// let bad_llama_token_type = LlamaTokenType::try_from(100 as c_uint); +/// let bad_llama_token_type = LlamaTokenType::try_from(100 as llama_cpp_sys_2::llama_token_type); /// assert_eq!(Err(LlamaTokenTypeFromIntError::UnknownValue(100)), bad_llama_token_type); /// # Ok(()) /// # }