|
| 1 | +#[cfg(feature = "mkl")] |
| 2 | +extern crate intel_mkl_src; |
| 3 | + |
| 4 | +#[cfg(feature = "accelerate")] |
| 5 | +extern crate accelerate_src; |
| 6 | + |
| 7 | +use clap::{Parser, ValueEnum}; |
| 8 | +use std::io::Write; |
| 9 | +use tokenizers::Tokenizer; |
| 10 | + |
| 11 | +use candle::quantized::gguf_file; |
| 12 | +use candle::Tensor; |
| 13 | +use candle_transformers::generation::{LogitsProcessor, Sampling}; |
| 14 | + |
| 15 | +use candle_examples::token_output_stream::TokenOutputStream; |
| 16 | +use candle_transformers::models::quantized_gemma3::ModelWeights; |
| 17 | + |
| 18 | +const DEFAULT_PROMPT: &str = "Write a function to calculate fibonacci num"; |
| 19 | + |
| 20 | +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] |
| 21 | +enum Which { |
| 22 | + #[value(name = "gemma3-4b-it")] |
| 23 | + Gemma3_4bIt, |
| 24 | +} |
| 25 | + |
| 26 | +#[derive(Parser, Debug)] |
| 27 | +#[command(author, version, about, long_about = None)] |
| 28 | +struct Args { |
| 29 | + /// GGUF file to load, typically a .gguf file generated by quantization |
| 30 | + #[arg(long)] |
| 31 | + model: Option<String>, |
| 32 | + |
| 33 | + /// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way |
| 34 | + /// and 'chat' for an interactive model where history of previous prompts and generated tokens |
| 35 | + /// is preserved. |
| 36 | + #[arg(long)] |
| 37 | + prompt: Option<String>, |
| 38 | + |
| 39 | + /// The length of the sample to generate (in tokens). |
| 40 | + #[arg(short = 'n', long, default_value_t = 1000)] |
| 41 | + sample_len: usize, |
| 42 | + |
| 43 | + /// The tokenizer config in json format. |
| 44 | + #[arg(long)] |
| 45 | + tokenizer: Option<String>, |
| 46 | + |
| 47 | + /// The temperature used to generate samples, use 0 for greedy sampling. |
| 48 | + #[arg(long, default_value_t = 0.8)] |
| 49 | + temperature: f64, |
| 50 | + |
| 51 | + /// Nucleus sampling probability cutoff. |
| 52 | + #[arg(long)] |
| 53 | + top_p: Option<f64>, |
| 54 | + |
| 55 | + /// Only sample among the top K samples. |
| 56 | + #[arg(long)] |
| 57 | + top_k: Option<usize>, |
| 58 | + |
| 59 | + /// The seed to use when generating random samples. |
| 60 | + #[arg(long, default_value_t = 299792458)] |
| 61 | + seed: u64, |
| 62 | + |
| 63 | + /// Enable tracing (generates a trace-timestamp.json file). |
| 64 | + #[arg(long)] |
| 65 | + tracing: bool, |
| 66 | + |
| 67 | + /// Process prompt elements separately. |
| 68 | + #[arg(long)] |
| 69 | + split_prompt: bool, |
| 70 | + |
| 71 | + /// Run on CPU rather than GPU even if a GPU is available. |
| 72 | + #[arg(long)] |
| 73 | + cpu: bool, |
| 74 | + |
| 75 | + /// Penalty to be applied for repeating tokens, 1. means no penalty. |
| 76 | + #[arg(long, default_value_t = 1.1)] |
| 77 | + repeat_penalty: f32, |
| 78 | + |
| 79 | + /// The context size to consider for the repeat penalty. |
| 80 | + #[arg(long, default_value_t = 64)] |
| 81 | + repeat_last_n: usize, |
| 82 | + |
| 83 | + /// The model size to use. |
| 84 | + #[arg(long, default_value = "gemma3-4b-it")] |
| 85 | + which: Which, |
| 86 | +} |
| 87 | + |
| 88 | +impl Args { |
| 89 | + fn tokenizer(&self) -> anyhow::Result<Tokenizer> { |
| 90 | + let tokenizer_path = match &self.tokenizer { |
| 91 | + Some(config) => std::path::PathBuf::from(config), |
| 92 | + None => { |
| 93 | + let api = hf_hub::api::sync::Api::new()?; |
| 94 | + let repo = "google/gemma-3-4b-it"; |
| 95 | + println!("DEBUG: Downloading tokenizer from {}", repo); |
| 96 | + let api = api.model(repo.to_string()); |
| 97 | + api.get("tokenizer.json")? |
| 98 | + } |
| 99 | + }; |
| 100 | + println!("DEBUG: Loading tokenizer from {:?}", tokenizer_path); |
| 101 | + let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?; |
| 102 | + |
| 103 | + Ok(tokenizer) |
| 104 | + } |
| 105 | + |
| 106 | + fn model(&self) -> anyhow::Result<std::path::PathBuf> { |
| 107 | + let model_path = match &self.model { |
| 108 | + Some(config) => std::path::PathBuf::from(config), |
| 109 | + None => { |
| 110 | + let (repo, filename) = match self.which { |
| 111 | + Which::Gemma3_4bIt => ( |
| 112 | + "google/gemma-3-4b-it-qat-q4_0-gguf", |
| 113 | + "gemma-3-4b-it-q4_0.gguf", |
| 114 | + ), |
| 115 | + }; |
| 116 | + let api = hf_hub::api::sync::Api::new()?; |
| 117 | + api.repo(hf_hub::Repo::with_revision( |
| 118 | + repo.to_string(), |
| 119 | + hf_hub::RepoType::Model, |
| 120 | + "main".to_string(), |
| 121 | + )) |
| 122 | + .get(filename)? |
| 123 | + } |
| 124 | + }; |
| 125 | + Ok(model_path) |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | +fn format_size(size_in_bytes: usize) -> String { |
| 130 | + if size_in_bytes < 1_000 { |
| 131 | + format!("{}B", size_in_bytes) |
| 132 | + } else if size_in_bytes < 1_000_000 { |
| 133 | + format!("{:.2}KB", size_in_bytes as f64 / 1e3) |
| 134 | + } else if size_in_bytes < 1_000_000_000 { |
| 135 | + format!("{:.2}MB", size_in_bytes as f64 / 1e6) |
| 136 | + } else { |
| 137 | + format!("{:.2}GB", size_in_bytes as f64 / 1e9) |
| 138 | + } |
| 139 | +} |
| 140 | + |
| 141 | +#[derive(Debug)] |
| 142 | +enum Prompt { |
| 143 | + Interactive, |
| 144 | + Chat, |
| 145 | + One(String), |
| 146 | +} |
| 147 | + |
| 148 | +fn main() -> anyhow::Result<()> { |
| 149 | + use tracing_chrome::ChromeLayerBuilder; |
| 150 | + use tracing_subscriber::prelude::*; |
| 151 | + |
| 152 | + let args = Args::parse(); |
| 153 | + let _guard = if args.tracing { |
| 154 | + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); |
| 155 | + tracing_subscriber::registry().with(chrome_layer).init(); |
| 156 | + Some(guard) |
| 157 | + } else { |
| 158 | + None |
| 159 | + }; |
| 160 | + |
| 161 | + println!( |
| 162 | + "avx: {}, neon: {}, simd128: {}, f16c: {}", |
| 163 | + candle::utils::with_avx(), |
| 164 | + candle::utils::with_neon(), |
| 165 | + candle::utils::with_simd128(), |
| 166 | + candle::utils::with_f16c() |
| 167 | + ); |
| 168 | + println!( |
| 169 | + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", |
| 170 | + args.temperature, args.repeat_penalty, args.repeat_last_n |
| 171 | + ); |
| 172 | + |
| 173 | + let model_path = args.model()?; |
| 174 | + let mut file = std::fs::File::open(&model_path)?; |
| 175 | + let start = std::time::Instant::now(); |
| 176 | + let device = candle_examples::device(args.cpu)?; |
| 177 | + |
| 178 | + let mut model = { |
| 179 | + let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?; |
| 180 | + let mut total_size_in_bytes = 0; |
| 181 | + for (_, tensor) in model.tensor_infos.iter() { |
| 182 | + let elem_count = tensor.shape.elem_count(); |
| 183 | + total_size_in_bytes += |
| 184 | + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); |
| 185 | + } |
| 186 | + println!( |
| 187 | + "loaded {:?} tensors ({}) in {:.2}s", |
| 188 | + model.tensor_infos.len(), |
| 189 | + &format_size(total_size_in_bytes), |
| 190 | + start.elapsed().as_secs_f32(), |
| 191 | + ); |
| 192 | + ModelWeights::from_gguf(model, &mut file, &device)? |
| 193 | + }; |
| 194 | + println!("model built"); |
| 195 | + |
| 196 | + let tokenizer = args.tokenizer()?; |
| 197 | + |
| 198 | + let mut tos = TokenOutputStream::new(tokenizer); |
| 199 | + println!( |
| 200 | + "DEBUG: Tokenizer vocabulary size: {}", |
| 201 | + tos.tokenizer().get_vocab(true).len() |
| 202 | + ); |
| 203 | + |
| 204 | + let prompt = match args.prompt.as_deref() { |
| 205 | + Some("chat") => Prompt::Chat, |
| 206 | + Some("interactive") => Prompt::Interactive, |
| 207 | + Some(s) => Prompt::One(s.to_string()), |
| 208 | + None => Prompt::One(DEFAULT_PROMPT.to_string()), |
| 209 | + }; |
| 210 | + |
| 211 | + let mut pre_prompt_tokens = vec![]; |
| 212 | + for _ in 0.. { |
| 213 | + let prompt_str = match &prompt { |
| 214 | + Prompt::One(prompt) => prompt.clone(), |
| 215 | + Prompt::Interactive | Prompt::Chat => { |
| 216 | + print!("> "); |
| 217 | + std::io::stdout().flush()?; |
| 218 | + let mut prompt = String::new(); |
| 219 | + std::io::stdin().read_line(&mut prompt)?; |
| 220 | + if prompt.ends_with('\n') { |
| 221 | + prompt.pop(); |
| 222 | + if prompt.ends_with('\r') { |
| 223 | + prompt.pop(); |
| 224 | + } |
| 225 | + } |
| 226 | + // Format for Gemma 3 chat/instruction format |
| 227 | + format!("<start_of_turn>user\n{prompt}\n<end_of_turn>\n<start_of_turn>model\n") |
| 228 | + } |
| 229 | + }; |
| 230 | + print!("{}", &prompt_str); |
| 231 | + |
| 232 | + let tokens = tos |
| 233 | + .tokenizer() |
| 234 | + .encode(prompt_str, true) |
| 235 | + .map_err(anyhow::Error::msg)?; |
| 236 | + let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat(); |
| 237 | + |
| 238 | + let to_sample = args.sample_len.saturating_sub(1); |
| 239 | + let max_seq_len = 8192; // Gemma 3 context length |
| 240 | + let prompt_tokens = if prompt_tokens.len() + to_sample > max_seq_len - 10 { |
| 241 | + let to_remove = prompt_tokens.len() + to_sample + 10 - max_seq_len; |
| 242 | + prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec() |
| 243 | + } else { |
| 244 | + prompt_tokens |
| 245 | + }; |
| 246 | + let mut all_tokens = vec![]; |
| 247 | + let mut logits_processor = { |
| 248 | + let temperature = args.temperature; |
| 249 | + let sampling = if temperature <= 0. { |
| 250 | + Sampling::ArgMax |
| 251 | + } else { |
| 252 | + match (args.top_k, args.top_p) { |
| 253 | + (None, None) => Sampling::All { temperature }, |
| 254 | + (Some(k), None) => Sampling::TopK { k, temperature }, |
| 255 | + (None, Some(p)) => Sampling::TopP { p, temperature }, |
| 256 | + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, |
| 257 | + } |
| 258 | + }; |
| 259 | + LogitsProcessor::from_sampling(args.seed, sampling) |
| 260 | + }; |
| 261 | + |
| 262 | + let start_prompt_processing = std::time::Instant::now(); |
| 263 | + let mut next_token = if !args.split_prompt { |
| 264 | + let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; |
| 265 | + let logits = model.forward(&input, 0)?; |
| 266 | + let logits = logits.squeeze(0)?; |
| 267 | + logits_processor.sample(&logits)? |
| 268 | + } else { |
| 269 | + let mut next_token = 0; |
| 270 | + for (pos, token) in prompt_tokens.iter().enumerate() { |
| 271 | + let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?; |
| 272 | + let logits = model.forward(&input, pos)?; |
| 273 | + let logits = logits.squeeze(0)?; |
| 274 | + next_token = logits_processor.sample(&logits)? |
| 275 | + } |
| 276 | + next_token |
| 277 | + }; |
| 278 | + let prompt_dt = start_prompt_processing.elapsed(); |
| 279 | + all_tokens.push(next_token); |
| 280 | + if let Some(t) = tos.next_token(next_token)? { |
| 281 | + print!("{t}"); |
| 282 | + std::io::stdout().flush()?; |
| 283 | + } |
| 284 | + |
| 285 | + // For Gemma 3, use the correct end of sequence token |
| 286 | + let eos_token = *tos |
| 287 | + .tokenizer() |
| 288 | + .get_vocab(true) |
| 289 | + .get("<end_of_turn>") |
| 290 | + .unwrap(); |
| 291 | + |
| 292 | + let start_post_prompt = std::time::Instant::now(); |
| 293 | + let mut sampled = 0; |
| 294 | + for index in 0..to_sample { |
| 295 | + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; |
| 296 | + let logits = model.forward(&input, prompt_tokens.len() + index)?; |
| 297 | + let logits = logits.squeeze(0)?; |
| 298 | + let logits = if args.repeat_penalty == 1. { |
| 299 | + logits |
| 300 | + } else { |
| 301 | + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); |
| 302 | + candle_transformers::utils::apply_repeat_penalty( |
| 303 | + &logits, |
| 304 | + args.repeat_penalty, |
| 305 | + &all_tokens[start_at..], |
| 306 | + )? |
| 307 | + }; |
| 308 | + next_token = logits_processor.sample(&logits)?; |
| 309 | + all_tokens.push(next_token); |
| 310 | + if let Some(t) = tos.next_token(next_token)? { |
| 311 | + print!("{t}"); |
| 312 | + std::io::stdout().flush()?; |
| 313 | + } |
| 314 | + sampled += 1; |
| 315 | + if next_token == eos_token { |
| 316 | + break; |
| 317 | + }; |
| 318 | + } |
| 319 | + if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? { |
| 320 | + print!("{rest}"); |
| 321 | + } |
| 322 | + std::io::stdout().flush()?; |
| 323 | + let dt = start_post_prompt.elapsed(); |
| 324 | + println!( |
| 325 | + "\n\n{:4} prompt tokens processed: {:.2} token/s", |
| 326 | + prompt_tokens.len(), |
| 327 | + prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(), |
| 328 | + ); |
| 329 | + println!( |
| 330 | + "{sampled:4} tokens generated: {:.2} token/s", |
| 331 | + sampled as f64 / dt.as_secs_f64(), |
| 332 | + ); |
| 333 | + |
| 334 | + match prompt { |
| 335 | + Prompt::One(_) => break, |
| 336 | + Prompt::Interactive => {} |
| 337 | + Prompt::Chat => { |
| 338 | + pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat() |
| 339 | + } |
| 340 | + } |
| 341 | + } |
| 342 | + |
| 343 | + Ok(()) |
| 344 | +} |
0 commit comments