-
Notifications
You must be signed in to change notification settings - Fork 341
Whisper custom vocab #1417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Whisper custom vocab #1417
Conversation
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds vocabulary-driven runtime biasing (BiasTrie) into Whisper runtime, threads vocabulary from listener config through the interface into WhisperBuilder, installs a logits filter using the BiasTrie, adds a Criterion benchmark, updates Cargo dependencies, and performs a small streaming builder call-site refactor. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant App as App/Plugin
participant FSM as Listener FSM
participant IF as owhisper-interface
participant Svc as Listen Service
participant W as Whisper Model
participant Bias as BiasTrie
App->>FSM: Start listening (config)
FSM->>FSM: derive languages, vocabulary (fallback "Hyprnote")
FSM->>IF: ListenParams{ languages, vocabulary, ... }
IF->>Svc: setup_listen_client(params)
Svc->>W: Whisper::builder().languages(...).vocabulary(vocab).build()
App-->>W: audio frames
W->>W: decode step -> logits ready
W->>Bias: apply_bias_to_logits(tokens, n, logits)
Bias-->>W: adjusted logits
W-->>App: transcription segments
sequenceDiagram
autonumber
participant Decoder
participant Callback as logits_filter_callback
participant BiasTrie
Decoder->>Callback: on_logits(tokens*, n, logits*)
Callback->>BiasTrie: apply_bias_to_logits(tokens*, n, logits*)
BiasTrie-->>Callback: modify logits (adds ln(bias)*2)
Callback-->>Decoder: return (may negate initial token)
Decoder->>Decoder: continue decoding
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
🧹 Nitpick comments (5)
crates/whisper-local/Cargo.toml (1)
45-45
: New dependency: trie-rs.Looks appropriate for prefix biasing. Keep an eye on crate maintenance and MSRV; pin minor versions if API churn appears.
plugins/listener/src/fsm.rs (1)
616-636
: Avoidunwrap()
panic risks inplugins/listener/src
- Detected numerous
unwrap()
calls (e.g. in fsm.rs, ext.rs, events.rs, manager.rs, lib.rs); replace with?
or explicit error handling to prevent runtime panics.- Rather than
conn.api_key.unwrap_or_default()
, propagate a missing‐key error to fail fast on misconfiguration.- Optionally log large
vocabulary.len()
and any truncation when building parameters to aid debugging.crates/whisper-local/src/bias.rs (1)
18-27
: Potential bias override on duplicate prefixes across phrases.Multiple phrases sharing a prefix will last-write wins in TrieBuilder::push. Consider pre-aggregating prefixes (e.g., max or sum) before build to keep bias deterministic.
Example approach:
use std::collections::HashMap; let mut agg: HashMap<Vec<WhisperTokenId>, f32> = HashMap::new(); for sequence in sequences { for i in 1..=sequence.len() { let progress = i as f32 / sequence.len() as f32; let bias = 1.0 + 10.0 * progress.powi(2); let key = sequence[..i].to_vec(); agg.entry(key).and_modify(|v| *v = v.max(bias)).or_insert(bias); } } let mut builder = TrieBuilder::new(); for (k, v) in agg { builder.push(&k, v); }crates/whisper-local/src/model.rs (2)
89-95
: Avoid cloning BiasTrie on every transcriptionBiasTrie is likely large. Store it behind Arc in Whisper and clone the Arc cheaply when building the callback context.
Proposed changes (requires adding
use std::sync::Arc;
at top):- bias_trie: BiasTrie, + bias_trie: Arc<BiasTrie>,And in build():
- bias_trie, + bias_trie: Arc::new(bias_trie),And in set_logit_filter context construction:
- bias_trie: bias_trie.clone(), + bias_trie: Arc::clone(bias_trie),This also requires bias_trie: Arc in LogitsFilterContext and BiasTrie::apply_bias_to_logits to accept &BiasTrie.
339-374
: Avoid println! in testsPrinting in tests adds noise. Prefer tracing at debug level or remove entirely.
- println!( - "{}", - segments - .iter() - .map(|s| s.text.clone()) - .collect::<Vec<String>>() - .join(" ") - ); + // Intentionally no output; assertions above validate behavior.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
⛔ Files ignored due to path filters (1)
Cargo.lock
is excluded by!**/*.lock
📒 Files selected for processing (8)
crates/transcribe-whisper-local/src/service/streaming.rs
(1 hunks)crates/whisper-local/Cargo.toml
(2 hunks)crates/whisper-local/benches/whisper_transcription.rs
(1 hunks)crates/whisper-local/src/bias.rs
(1 hunks)crates/whisper-local/src/lib.rs
(1 hunks)crates/whisper-local/src/model.rs
(9 hunks)owhisper/owhisper-interface/src/lib.rs
(2 hunks)plugins/listener/src/fsm.rs
(5 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.{js,ts,tsx,rs}
⚙️ CodeRabbit configuration file
**/*.{js,ts,tsx,rs}
: 1. No error handling.
2. No unused imports, variables, or functions.
3. For comments, keep it minimal. It should be about "Why", not "What".
Files:
crates/whisper-local/src/lib.rs
crates/transcribe-whisper-local/src/service/streaming.rs
crates/whisper-local/benches/whisper_transcription.rs
plugins/listener/src/fsm.rs
owhisper/owhisper-interface/src/lib.rs
crates/whisper-local/src/bias.rs
crates/whisper-local/src/model.rs
🧬 Code graph analysis (5)
crates/transcribe-whisper-local/src/service/streaming.rs (2)
crates/whisper-local/src/model.rs (4)
languages
(31-34)vocabulary
(36-39)builder
(98-100)model_path
(26-29)plugins/local-stt/js/bindings.gen.ts (1)
Language
(60-60)
crates/whisper-local/benches/whisper_transcription.rs (1)
crates/whisper-local/src/model.rs (3)
model_path
(26-29)builder
(98-100)segments
(183-186)
plugins/listener/src/fsm.rs (3)
crates/whisper-local/src/model.rs (2)
languages
(31-34)vocabulary
(36-39)crates/db-user/src/sessions_ops.rs (1)
onboarding_session_id
(8-10)plugins/db/src/commands/sessions.rs (1)
onboarding_session_id
(22-35)
owhisper/owhisper-interface/src/lib.rs (1)
crates/whisper-local/src/model.rs (1)
vocabulary
(36-39)
crates/whisper-local/src/model.rs (1)
crates/whisper-local/src/bias.rs (2)
custom_vocab
(11-14)new
(10-29)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: cubic · AI code reviewer
- GitHub Check: ci (macos, macos-latest)
- GitHub Check: ci (windows, windows-latest)
🔇 Additional comments (5)
crates/whisper-local/Cargo.toml (3)
20-20
: Criterion bench setup looks fine.Workspace-scoped dev-dep is appropriate for benches.
25-27
: Bench target correctly configured.harness = false is expected for Criterion.
36-36
: Add safeguards for git-pinned whisper-rs dependency
- Update and commit
Cargo.lock
so the exact git rev is recorded.- Document in
README.md
orCHANGELOG.md
why this Codeberg fork and rev (3e6d3da
) are required.- Verify that the
raw-api
andtracing_backend
features are enabled and that the correct git source is used:cargo tree -e features -i whisper-rs cargo metadata --format-version=1 | jq -r '.packages[] | select(.name=="whisper-rs") | .source'- Consider moving this override into
[patch.crates-io]
or pinning to a semantically versioned tag once upstream support is available.owhisper/owhisper-interface/src/lib.rs (1)
135-141
: Default initialization is correct.Empty vocabulary by default avoids accidental biasing.
plugins/listener/src/fsm.rs (1)
251-258
: Threading vocabulary into setup is correct.Good propagation into the client builder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 issues found across 9 files
React with 👍 or 👎 to teach cubic. You can also tag @cubic-dev-ai
to give feedback, ask questions, or re-run the review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
crates/whisper-local/src/model.rs (2)
62-68
: Remove hardcoded “Hyprnote” fallback; default to no bias.
Biasing by default can degrade general accuracy and surprises consumers. Use empty default instead.Apply this diff:
- let custom_vocab = self.vocabulary.unwrap_or(vec!["Hyprnote".to_string()]); + let custom_vocab = self.vocabulary.unwrap_or_default();Note: keep installing the logits filter regardless, since it also enforces token_beg suppression.
253-261
: Minor API/Safety tweaks in set_logit_filter.
- Pass token_beg by value (Copy), avoid reference + deref.
- If not adopting Arc, at least avoid heavy clones in hot path.
Apply this minimal diff if not switching to Arc:
- unsafe fn set_logit_filter( - params: &mut FullParams, - token_beg: &WhisperTokenId, - bias_trie: &BiasTrie, - ) -> LogitFilterGuard { + unsafe fn set_logit_filter( + params: &mut FullParams, + token_beg: WhisperTokenId, + bias_trie: &BiasTrie, + ) -> LogitFilterGuard { let context = Box::new(Context { - token_beg: *token_beg, + token_beg, bias_trie: bias_trie.clone(), });
🧹 Nitpick comments (4)
crates/whisper-local/src/model.rs (4)
74-74
: Consider Arc to avoid cloning the trie per decode.
Cloning a potentially large trie on every transcription is wasteful; share it.Apply these diffs:
@@ - bias_trie: BiasTrie, + bias_trie: std::sync::Arc<BiasTrie>,@@ - let bias_trie = { + let bias_trie = { let custom_vocab = self.vocabulary.unwrap_or_default(); let custom_vocab_refs: Vec<&str> = custom_vocab.iter().map(|s| s.as_str()).collect(); - BiasTrie::new(&ctx, &custom_vocab_refs)? + std::sync::Arc::new(BiasTrie::new(&ctx, &custom_vocab_refs)?) };@@ - let _guard = unsafe { Self::set_logit_filter(&mut params, &token_beg, &self.bias_trie) }; + let _guard = unsafe { Self::set_logit_filter(&mut params, token_beg, &self.bias_trie) };@@ - unsafe fn set_logit_filter( - params: &mut FullParams, - token_beg: &WhisperTokenId, - bias_trie: &BiasTrie, - ) -> LogitFilterGuard { + unsafe fn set_logit_filter( + params: &mut FullParams, + token_beg: WhisperTokenId, + bias_trie: &std::sync::Arc<BiasTrie>, + ) -> LogitFilterGuard { let context = Box::new(Context { - token_beg: *token_beg, - bias_trie: bias_trie.clone(), + token_beg, + bias_trie: bias_trie.clone(), });@@ -struct Context { +struct Context { token_beg: WhisperTokenId, - bias_trie: BiasTrie, + bias_trie: std::sync::Arc<BiasTrie>, }
293-296
: Rename Context to LogitsFilterContext for clarity.
Avoid overly generic names in a module with multiple contexts.Apply this diff (and update references):
-struct Context { +struct LogitsFilterContext { token_beg: WhisperTokenId, - bias_trie: BiasTrie, + bias_trie: BiasTrie, }And update uses in set_logit_filter and Drop:
- let context = Box::new(Context { + let context = Box::new(LogitsFilterContext { @@ - let context = &*(user_data as *const Context); + let context = &*(user_data as *const LogitsFilterContext);- let _ = Box::from_raw(self.context_ptr as *mut Context); + let _ = Box::from_raw(self.context_ptr as *mut LogitsFilterContext);
361-369
: Test vocabulary is fine, but keep tests hermetic.
This integration test depends on a local model file; consider marking it #[ignore] or gating behind a feature to avoid CI failures.I can provide a patch to add #[ignore] or a feature flag if desired.
381-388
: Avoid println! in tests.
Use tracing at debug level or remove to keep test output clean.Apply this diff to drop the print:
- println!( - "{}", - segments - .iter() - .map(|s| s.text.clone()) - .collect::<Vec<String>>() - .join(" ") - ); + // Debug output intentionally omitted to keep tests quiet.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
crates/whisper-local/src/model.rs
(10 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.{js,ts,tsx,rs}
⚙️ CodeRabbit configuration file
**/*.{js,ts,tsx,rs}
: 1. No error handling.
2. No unused imports, variables, or functions.
3. For comments, keep it minimal. It should be about "Why", not "What".
Files:
crates/whisper-local/src/model.rs
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: ci (windows, windows-latest)
- GitHub Check: ci (macos, macos-latest)
🔇 Additional comments (7)
crates/whisper-local/src/model.rs (7)
11-11
: Import is correct and used.
No action needed.
22-23
: Builder gains optional vocabulary — OK.
Keeps default opt-in semantics at the type level.
36-39
: Builder API for vocabulary — OK.
Fluent API remains consistent with existing builders.
112-143
: Params block — OK.
Reasonable defaults and logging; no issues spotted.
145-146
: Good: guard keeps callback context alive and fixes leak.
Lifetime cleanly spans the decode call.
302-309
: Drop frees the boxed context — OK. Verify callback lifetime.
We rely on whisper-rs not invoking the callback after full() returns.Please double-check the pinned whisper-rs guarantees the callback/user_data are only used during that full() invocation. If not, we should null the callback post-decode.
263-282
: Harden callback: guard against null tokens and zero-length input) { - if logits.is_null() || user_data.is_null() { + if logits.is_null() || user_data.is_null() || tokens.is_null() || n_tokens <= 0 { return; }Verify that in the pinned whisper-rs revision the C callback contract guarantees non-null
tokens
andn_tokens > 0
; otherwise these guards prevent UB.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
crates/whisper-local/src/bias.rs (1)
58-92
: Unsafe OOB risk on logits; prefer longest-suffix first and stop after first hit.
- Potential UB: pointer arithmetic on logits without knowing n_vocab.
- Current loop biases multiple suffix lengths, double-counting. Prefer longest-context first, break once any continuation is found.
Apply:
- pub unsafe fn apply_bias_to_logits( + /// Safety: `tokens` must point to `n_tokens` valid whisper_token_data. + /// `logits` must point to `n_vocab` contiguous f32s (the vocab size). + pub unsafe fn apply_bias_to_logits( &self, tokens: *const whisper_rs::whisper_rs_sys::whisper_token_data, n_tokens: std::os::raw::c_int, - logits: *mut f32, + logits: *mut f32, + n_vocab: std::os::raw::c_int, ) { - if tokens.is_null() || n_tokens <= 0 { + if tokens.is_null() || logits.is_null() || n_tokens <= 0 || n_vocab <= 0 { return; } @@ - for suffix_len in 1..=std::cmp::min(10, current_tokens.len()) { + let n_vocab = n_vocab as usize; + let logits_slice = std::slice::from_raw_parts_mut(logits, n_vocab); + for suffix_len in (1..=std::cmp::min(10, current_tokens.len())).rev() { let suffix = ¤t_tokens[current_tokens.len() - suffix_len..]; - - for (full_sequence, bias_value_ref) in self.trie.predictive_search(suffix) { - let bias_value = *bias_value_ref; - let full_sequence: Vec<WhisperTokenId> = full_sequence; - + let mut found = false; + for (full_sequence, bias_value_ref) in self.trie.predictive_search(suffix) { + let log_bias = *bias_value_ref; // if new() stores ln(bias); else: (*bias_value_ref).max(1e-6).ln() if full_sequence.len() > suffix.len() { let next_token = full_sequence[suffix.len()]; - let current_logit = *logits.offset(next_token as isize); - - let boost = bias_value.ln() * 2.0; - let new_logit = current_logit + boost; - - *logits.offset(next_token as isize) = new_logit; + if let Some(slot) = logits_slice.get_mut(next_token as usize) { + // cap boost magnitude for numerical stability + let boost = log_bias.clamp(-6.0, 6.0); + *slot += boost; + found = true; + } } } - } + if found { break; } + } }
🧹 Nitpick comments (5)
crates/whisper-local/src/bias.rs (5)
1-7
: Derive Debug for BiasTrie (useful for tracing) and prep for dedup imports.use trie_rs::map::{Trie, TrieBuilder}; use whisper_rs::{WhisperContext, WhisperTokenId}; +use std::collections::{HashMap, HashSet}; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct BiasTrie { trie: Trie<WhisperTokenId, f32>, }
20-24
: Store biases in log-domain to avoid per-step ln and clamp magnitude.- let prefix_bias = 10.0 + 90.0 * progress.powi(2); + // store log-bias; clamp to a sane range for stability + let prefix_bias = (10.0 + 90.0 * progress.powi(2)).ln().clamp(-6.0, 6.0);Then in apply_bias_to_logits, use the stored value directly (see separate comment).
32-56
: Variant set can contain duplicates; dedup and consider trailing-space variant.- let mut variants = Vec::new(); + let mut variants = Vec::new(); variants.push(ctx.tokenize(word, 99)?); variants.push(ctx.tokenize(&format!(" {}", word), 99)?); + variants.push(ctx.tokenize(&format!("{word} "), 99)?); @@ variants.push(ctx.tokenize(&format!("'{}", word), 99)?); variants.push(ctx.tokenize(&format!("\"{}", word), 99)?); - Ok(variants) + // remove duplicate tokenizations + variants.sort(); + variants.dedup(); + Ok(variants)
79-80
: Remove redundant type rebind.- let bias_value = *bias_value_ref; - let full_sequence: Vec<WhisperTokenId> = full_sequence; + let bias_value = *bias_value_ref;
74-91
: Consider per-token max instead of additive stacking across multiple matches.If you intentionally keep multi-suffix accumulation, guard against over-bias by aggregating the maximum boost per next_token and applying once. I can provide a small HashMap-based accumulator if you want it.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
crates/whisper-local/src/bias.rs
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.{js,ts,tsx,rs}
⚙️ CodeRabbit configuration file
**/*.{js,ts,tsx,rs}
: 1. No error handling.
2. No unused imports, variables, or functions.
3. For comments, keep it minimal. It should be about "Why", not "What".
Files:
crates/whisper-local/src/bias.rs
🔇 Additional comments (1)
crates/whisper-local/src/bias.rs (1)
38-40
: tokenize() max_tokens semantics – 0 disables limit
tokenize(&self, text: &str, max_tokens: usize)
returns up tomax_tokens
tokens; passing0
disables the cap. Since you’re feeding single words (≤2 subword tokens),99
can’t truncate here. To make it explicit, you may replace99
with0
in allctx.tokenize(..., 99)
calls.
for word in custom_vocab { | ||
let variants = Self::generate_tokenization_variants(ctx, word)?; | ||
|
||
for tokens in variants { | ||
for i in 1..=tokens.len() { | ||
let progress = i as f32 / tokens.len() as f32; | ||
|
||
let prefix_bias = 10.0 + 90.0 * progress.powi(2); | ||
|
||
let prefix = &tokens[..i]; | ||
builder.push(prefix, prefix_bias); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Deduplicate overlapping prefixes across variants; skip empty tokenizations.
Multiple variants can yield identical prefixes; pushing each inflates/overwrites bias unpredictably. Aggregate by prefix and keep the max bias.
- for word in custom_vocab {
- let variants = Self::generate_tokenization_variants(ctx, word)?;
- for tokens in variants {
- for i in 1..=tokens.len() {
- let progress = i as f32 / tokens.len() as f32;
-
- let prefix_bias = 10.0 + 90.0 * progress.powi(2);
-
- let prefix = &tokens[..i];
- builder.push(prefix, prefix_bias);
- }
- }
- }
+ let mut acc: HashMap<Vec<WhisperTokenId>, f32> = HashMap::new();
+ for word in custom_vocab {
+ let variants = Self::generate_tokenization_variants(ctx, word)?;
+ for tokens in variants {
+ if tokens.is_empty() { continue; }
+ for i in 1..=tokens.len() {
+ let progress = i as f32 / tokens.len() as f32;
+ let bias = 10.0 + 90.0 * progress.powi(2);
+ let key = tokens[..i].to_vec();
+ acc.entry(key).and_modify(|v| *v = v.max(bias)).or_insert(bias);
+ }
+ }
+ }
+ for (k, v) in acc {
+ builder.push(&k, v);
+ }
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
for word in custom_vocab { | |
let variants = Self::generate_tokenization_variants(ctx, word)?; | |
for tokens in variants { | |
for i in 1..=tokens.len() { | |
let progress = i as f32 / tokens.len() as f32; | |
let prefix_bias = 10.0 + 90.0 * progress.powi(2); | |
let prefix = &tokens[..i]; | |
builder.push(prefix, prefix_bias); | |
} | |
} | |
} | |
// Deduplicate overlapping prefixes across all variants, keeping the max bias per prefix. | |
let mut acc: HashMap<Vec<WhisperTokenId>, f32> = HashMap::new(); | |
for word in custom_vocab { | |
let variants = Self::generate_tokenization_variants(ctx, word)?; | |
for tokens in variants { | |
if tokens.is_empty() { continue; } | |
for i in 1..=tokens.len() { | |
let progress = i as f32 / tokens.len() as f32; | |
let bias = 10.0 + 90.0 * progress.powi(2); | |
let key = tokens[..i].to_vec(); | |
acc.entry(key) | |
.and_modify(|v| *v = v.max(bias)) | |
.or_insert(bias); | |
} | |
} | |
} | |
for (k, v) in acc { | |
builder.push(&k, v); | |
} |
🤖 Prompt for AI Agents
In crates/whisper-local/src/bias.rs around lines 13 to 26, the loop currently
pushes every prefix from every tokenization variant, which duplicates
overlapping prefixes and allows later variants to overwrite earlier bias values
unpredictably and also doesn't skip empty tokenizations; change the logic so you
skip any variant with zero tokens, then collect prefixes from all variants into
a temporary map (or hashmap) keyed by the prefix sequence and store the maximum
computed bias for that prefix (compare existing entry and keep the larger bias),
and after processing all variants push each unique prefix once to builder with
the aggregated max bias.
pub unsafe fn apply_bias_to_logits( | ||
&self, | ||
tokens: *const whisper_rs::whisper_rs_sys::whisper_token_data, | ||
n_tokens: std::os::raw::c_int, | ||
logits: *mut f32, | ||
) { | ||
if tokens.is_null() || n_tokens <= 0 { | ||
return; | ||
} | ||
|
||
let current_tokens: Vec<WhisperTokenId> = | ||
std::slice::from_raw_parts(tokens, n_tokens as usize) | ||
.iter() | ||
.map(|t| t.id) | ||
.collect(); | ||
|
||
for suffix_len in 1..=std::cmp::min(10, current_tokens.len()) { | ||
let suffix = ¤t_tokens[current_tokens.len() - suffix_len..]; | ||
|
||
for (full_sequence, bias_value_ref) in self.trie.predictive_search(suffix) { | ||
let bias_value = *bias_value_ref; | ||
let full_sequence: Vec<WhisperTokenId> = full_sequence; | ||
|
||
if full_sequence.len() > suffix.len() { | ||
let next_token = full_sequence[suffix.len()]; | ||
let current_logit = *logits.offset(next_token as isize); | ||
|
||
let boost = bias_value.ln() * 2.0; | ||
let new_logit = current_logit + boost; | ||
|
||
*logits.offset(next_token as isize) = new_logit; | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Update call sites to pass n_vocab and validate usage.
Use this script to find and fix all call sites:
🏁 Script executed:
#!/bin/bash
# Find apply_bias_to_logits call sites and show context
rg -n -C3 --type=rust '\.apply_bias_to_logits\s*\('
Length of output: 518
Update all call sites to pass n_vocab
In crates/whisper-local/src/model.rs:281
, change
.apply_bias_to_logits(tokens, n_tokens, logits)
to
.apply_bias_to_logits(tokens, n_tokens, n_vocab, logits)
and ensure the method signature in bias.rs
is updated accordingly.
🤖 Prompt for AI Agents
In crates/whisper-local/src/bias.rs around lines 58 to 92, update the unsafe fn
signature to accept n_vocab (std::os::raw::c_int) as a new parameter (e.g.,
tokens, n_tokens, n_vocab, logits), then use n_vocab to validate token indices
before indexing into the logits buffer: ensure next_token is within 0..n_vocab
(cast and compare as usize/int) and only read/write logits if in-bounds to avoid
OOB access; adjust pointer arithmetic to cast next_token safely to isize after
bounds checking; also update call sites (as noted in the review) to pass n_vocab
when invoking this method.
Related
Summary by cubic
Add custom vocabulary biasing to Whisper transcription to boost accuracy on domain terms and jargons. Exposes a vocabulary field end-to-end (client → API → model) and biases logits using a trie-based prefix matcher.
New Features
Dependencies