Skip to content

Conversation

yujonglee
Copy link
Contributor

@yujonglee yujonglee commented Aug 28, 2025

Related

Summary by cubic

Add custom vocabulary biasing to Whisper transcription to boost accuracy on domain terms and jargons. Exposes a vocabulary field end-to-end (client → API → model) and biases logits using a trie-based prefix matcher.

  • New Features

    • Add WhisperBuilder.vocabulary(Vec) and integrate a trie-based logit bias (predictive prefix search) alongside token_beg suppression.
    • Forward vocabulary from API: ListenParams includes vocabulary: Vec (defaults empty).
    • Listener sends jargons from user config (general.jargons) as vocabulary; falls back to ["Hyprnote"] when missing.
    • Add Criterion benchmark comparing runs with and without vocabulary.
  • Dependencies

    • Add trie-rs for FAST LOUDS trie; pin whisper-rs to a specific git rev; add criterion for benches.

Copy link

coderabbitai bot commented Aug 28, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds vocabulary-driven runtime biasing (BiasTrie) into Whisper runtime, threads vocabulary from listener config through the interface into WhisperBuilder, installs a logits filter using the BiasTrie, adds a Criterion benchmark, updates Cargo dependencies, and performs a small streaming builder call-site refactor.

Changes

Cohort / File(s) Summary
Biasing core & integration
crates/whisper-local/src/bias.rs, crates/whisper-local/src/model.rs, crates/whisper-local/src/lib.rs
New BiasTrie type that builds prefix biases from vocabulary and applies them to logits; WhisperBuilder gains vocabulary(Vec<String>); Whisper stores a BiasTrie and installs a logits-filter callback (set_logit_filter) using a boxed Context and LogitFilterGuard.
Benchmarks & deps
crates/whisper-local/Cargo.toml, crates/whisper-local/benches/whisper_transcription.rs
Adds criterion dev-dependency and bench target; pins whisper-rs to a git revision and adds trie-rs; adds a Criterion benchmark comparing transcription with and without vocabulary.
ListenParams interface
owhisper/owhisper-interface/src/lib.rs
Adds public vocabulary: Vec<String> to ListenParams with #[serde(default)] and updates Default to initialize it to an empty vec.
Listener wiring
plugins/listener/src/fsm.rs
Extracts vocabulary from config (fallback ["Hyprnote"]), threads it through resource setup, expands setup_listen_client signature to accept vocabulary, and forwards it into ListenParams when building the Listen client.
Streaming service minor refactor
crates/transcribe-whisper-local/src/service/streaming.rs
Precomputes languages Vec and clones vocabulary into locals before calling the Whisper builder; purely a call-site refactor with unchanged behavior.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant App as App/Plugin
  participant FSM as Listener FSM
  participant IF as owhisper-interface
  participant Svc as Listen Service
  participant W as Whisper Model
  participant Bias as BiasTrie

  App->>FSM: Start listening (config)
  FSM->>FSM: derive languages, vocabulary (fallback "Hyprnote")
  FSM->>IF: ListenParams{ languages, vocabulary, ... }
  IF->>Svc: setup_listen_client(params)
  Svc->>W: Whisper::builder().languages(...).vocabulary(vocab).build()
  App-->>W: audio frames
  W->>W: decode step -> logits ready
  W->>Bias: apply_bias_to_logits(tokens, n, logits)
  Bias-->>W: adjusted logits
  W-->>App: transcription segments
Loading
sequenceDiagram
  autonumber
  participant Decoder
  participant Callback as logits_filter_callback
  participant BiasTrie

  Decoder->>Callback: on_logits(tokens*, n, logits*)
  Callback->>BiasTrie: apply_bias_to_logits(tokens*, n, logits*)
  BiasTrie-->>Callback: modify logits (adds ln(bias)*2)
  Callback-->>Decoder: return (may negate initial token)
  Decoder->>Decoder: continue decoding
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • Add language detect constraint #1212 — Overlaps the change that converts params.languages into a filtered Vec<hypr_whisper::Language> and calls .languages(...) on the Whisper builder; strongly related to the builder-call refactor.

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch whisper-custom-vocab

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

‼️ IMPORTANT
Auto-reply has been disabled for this repository in the CodeRabbit settings. The CodeRabbit bot will not respond to your replies unless it is explicitly tagged.

  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🧹 Nitpick comments (5)
crates/whisper-local/Cargo.toml (1)

45-45: New dependency: trie-rs.

Looks appropriate for prefix biasing. Keep an eye on crate maintenance and MSRV; pin minor versions if API churn appears.

plugins/listener/src/fsm.rs (1)

616-636: Avoid unwrap() panic risks in plugins/listener/src

  • Detected numerous unwrap() calls (e.g. in fsm.rs, ext.rs, events.rs, manager.rs, lib.rs); replace with ? or explicit error handling to prevent runtime panics.
  • Rather than conn.api_key.unwrap_or_default(), propagate a missing‐key error to fail fast on misconfiguration.
  • Optionally log large vocabulary.len() and any truncation when building parameters to aid debugging.
crates/whisper-local/src/bias.rs (1)

18-27: Potential bias override on duplicate prefixes across phrases.

Multiple phrases sharing a prefix will last-write wins in TrieBuilder::push. Consider pre-aggregating prefixes (e.g., max or sum) before build to keep bias deterministic.

Example approach:

use std::collections::HashMap;
let mut agg: HashMap<Vec<WhisperTokenId>, f32> = HashMap::new();
for sequence in sequences {
    for i in 1..=sequence.len() {
        let progress = i as f32 / sequence.len() as f32;
        let bias = 1.0 + 10.0 * progress.powi(2);
        let key = sequence[..i].to_vec();
        agg.entry(key).and_modify(|v| *v = v.max(bias)).or_insert(bias);
    }
}
let mut builder = TrieBuilder::new();
for (k, v) in agg { builder.push(&k, v); }
crates/whisper-local/src/model.rs (2)

89-95: Avoid cloning BiasTrie on every transcription

BiasTrie is likely large. Store it behind Arc in Whisper and clone the Arc cheaply when building the callback context.

Proposed changes (requires adding use std::sync::Arc; at top):

-    bias_trie: BiasTrie,
+    bias_trie: Arc<BiasTrie>,

And in build():

-            bias_trie,
+            bias_trie: Arc::new(bias_trie),

And in set_logit_filter context construction:

-            bias_trie: bias_trie.clone(),
+            bias_trie: Arc::clone(bias_trie),

This also requires bias_trie: Arc in LogitsFilterContext and BiasTrie::apply_bias_to_logits to accept &BiasTrie.


339-374: Avoid println! in tests

Printing in tests adds noise. Prefer tracing at debug level or remove entirely.

-        println!(
-            "{}",
-            segments
-                .iter()
-                .map(|s| s.text.clone())
-                .collect::<Vec<String>>()
-                .join(" ")
-        );
+        // Intentionally no output; assertions above validate behavior.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between bd4bccb and 607a66d.

⛔ Files ignored due to path filters (1)
  • Cargo.lock is excluded by !**/*.lock
📒 Files selected for processing (8)
  • crates/transcribe-whisper-local/src/service/streaming.rs (1 hunks)
  • crates/whisper-local/Cargo.toml (2 hunks)
  • crates/whisper-local/benches/whisper_transcription.rs (1 hunks)
  • crates/whisper-local/src/bias.rs (1 hunks)
  • crates/whisper-local/src/lib.rs (1 hunks)
  • crates/whisper-local/src/model.rs (9 hunks)
  • owhisper/owhisper-interface/src/lib.rs (2 hunks)
  • plugins/listener/src/fsm.rs (5 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.{js,ts,tsx,rs}

⚙️ CodeRabbit configuration file

**/*.{js,ts,tsx,rs}: 1. No error handling.
2. No unused imports, variables, or functions.
3. For comments, keep it minimal. It should be about "Why", not "What".

Files:

  • crates/whisper-local/src/lib.rs
  • crates/transcribe-whisper-local/src/service/streaming.rs
  • crates/whisper-local/benches/whisper_transcription.rs
  • plugins/listener/src/fsm.rs
  • owhisper/owhisper-interface/src/lib.rs
  • crates/whisper-local/src/bias.rs
  • crates/whisper-local/src/model.rs
🧬 Code graph analysis (5)
crates/transcribe-whisper-local/src/service/streaming.rs (2)
crates/whisper-local/src/model.rs (4)
  • languages (31-34)
  • vocabulary (36-39)
  • builder (98-100)
  • model_path (26-29)
plugins/local-stt/js/bindings.gen.ts (1)
  • Language (60-60)
crates/whisper-local/benches/whisper_transcription.rs (1)
crates/whisper-local/src/model.rs (3)
  • model_path (26-29)
  • builder (98-100)
  • segments (183-186)
plugins/listener/src/fsm.rs (3)
crates/whisper-local/src/model.rs (2)
  • languages (31-34)
  • vocabulary (36-39)
crates/db-user/src/sessions_ops.rs (1)
  • onboarding_session_id (8-10)
plugins/db/src/commands/sessions.rs (1)
  • onboarding_session_id (22-35)
owhisper/owhisper-interface/src/lib.rs (1)
crates/whisper-local/src/model.rs (1)
  • vocabulary (36-39)
crates/whisper-local/src/model.rs (1)
crates/whisper-local/src/bias.rs (2)
  • custom_vocab (11-14)
  • new (10-29)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: cubic · AI code reviewer
  • GitHub Check: ci (macos, macos-latest)
  • GitHub Check: ci (windows, windows-latest)
🔇 Additional comments (5)
crates/whisper-local/Cargo.toml (3)

20-20: Criterion bench setup looks fine.

Workspace-scoped dev-dep is appropriate for benches.


25-27: Bench target correctly configured.

harness = false is expected for Criterion.


36-36: Add safeguards for git-pinned whisper-rs dependency

  • Update and commit Cargo.lock so the exact git rev is recorded.
  • Document in README.md or CHANGELOG.md why this Codeberg fork and rev (3e6d3da) are required.
  • Verify that the raw-api and tracing_backend features are enabled and that the correct git source is used:
    cargo tree -e features -i whisper-rs
    cargo metadata --format-version=1 | jq -r '.packages[] | select(.name=="whisper-rs") | .source'
  • Consider moving this override into [patch.crates-io] or pinning to a semantically versioned tag once upstream support is available.
owhisper/owhisper-interface/src/lib.rs (1)

135-141: Default initialization is correct.

Empty vocabulary by default avoids accidental biasing.

plugins/listener/src/fsm.rs (1)

251-258: Threading vocabulary into setup is correct.

Good propagation into the client builder.

Copy link

@cubic-dev-ai cubic-dev-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 issues found across 9 files

React with 👍 or 👎 to teach cubic. You can also tag @cubic-dev-ai to give feedback, ask questions, or re-run the review.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
crates/whisper-local/src/model.rs (2)

62-68: Remove hardcoded “Hyprnote” fallback; default to no bias.
Biasing by default can degrade general accuracy and surprises consumers. Use empty default instead.

Apply this diff:

-            let custom_vocab = self.vocabulary.unwrap_or(vec!["Hyprnote".to_string()]);
+            let custom_vocab = self.vocabulary.unwrap_or_default();

Note: keep installing the logits filter regardless, since it also enforces token_beg suppression.


253-261: Minor API/Safety tweaks in set_logit_filter.

  • Pass token_beg by value (Copy), avoid reference + deref.
  • If not adopting Arc, at least avoid heavy clones in hot path.

Apply this minimal diff if not switching to Arc:

-    unsafe fn set_logit_filter(
-        params: &mut FullParams,
-        token_beg: &WhisperTokenId,
-        bias_trie: &BiasTrie,
-    ) -> LogitFilterGuard {
+    unsafe fn set_logit_filter(
+        params: &mut FullParams,
+        token_beg: WhisperTokenId,
+        bias_trie: &BiasTrie,
+    ) -> LogitFilterGuard {
         let context = Box::new(Context {
-            token_beg: *token_beg,
+            token_beg,
             bias_trie: bias_trie.clone(),
         });
🧹 Nitpick comments (4)
crates/whisper-local/src/model.rs (4)

74-74: Consider Arc to avoid cloning the trie per decode.
Cloning a potentially large trie on every transcription is wasteful; share it.

Apply these diffs:

@@
-    bias_trie: BiasTrie,
+    bias_trie: std::sync::Arc<BiasTrie>,
@@
-        let bias_trie = {
+        let bias_trie = {
             let custom_vocab = self.vocabulary.unwrap_or_default();
             let custom_vocab_refs: Vec<&str> = custom_vocab.iter().map(|s| s.as_str()).collect();
-            BiasTrie::new(&ctx, &custom_vocab_refs)?
+            std::sync::Arc::new(BiasTrie::new(&ctx, &custom_vocab_refs)?)
         };
@@
-        let _guard = unsafe { Self::set_logit_filter(&mut params, &token_beg, &self.bias_trie) };
+        let _guard = unsafe { Self::set_logit_filter(&mut params, token_beg, &self.bias_trie) };
@@
-    unsafe fn set_logit_filter(
-        params: &mut FullParams,
-        token_beg: &WhisperTokenId,
-        bias_trie: &BiasTrie,
-    ) -> LogitFilterGuard {
+    unsafe fn set_logit_filter(
+        params: &mut FullParams,
+        token_beg: WhisperTokenId,
+        bias_trie: &std::sync::Arc<BiasTrie>,
+    ) -> LogitFilterGuard {
         let context = Box::new(Context {
-            token_beg: *token_beg,
-            bias_trie: bias_trie.clone(),
+            token_beg,
+            bias_trie: bias_trie.clone(),
         });
@@
-struct Context {
+struct Context {
     token_beg: WhisperTokenId,
-    bias_trie: BiasTrie,
+    bias_trie: std::sync::Arc<BiasTrie>,
 }

293-296: Rename Context to LogitsFilterContext for clarity.
Avoid overly generic names in a module with multiple contexts.

Apply this diff (and update references):

-struct Context {
+struct LogitsFilterContext {
     token_beg: WhisperTokenId,
-    bias_trie: BiasTrie,
+    bias_trie: BiasTrie,
 }

And update uses in set_logit_filter and Drop:

-        let context = Box::new(Context {
+        let context = Box::new(LogitsFilterContext {
@@
-            let context = &*(user_data as *const Context);
+            let context = &*(user_data as *const LogitsFilterContext);
-                let _ = Box::from_raw(self.context_ptr as *mut Context);
+                let _ = Box::from_raw(self.context_ptr as *mut LogitsFilterContext);

361-369: Test vocabulary is fine, but keep tests hermetic.
This integration test depends on a local model file; consider marking it #[ignore] or gating behind a feature to avoid CI failures.

I can provide a patch to add #[ignore] or a feature flag if desired.


381-388: Avoid println! in tests.
Use tracing at debug level or remove to keep test output clean.

Apply this diff to drop the print:

-        println!(
-            "{}",
-            segments
-                .iter()
-                .map(|s| s.text.clone())
-                .collect::<Vec<String>>()
-                .join(" ")
-        );
+        // Debug output intentionally omitted to keep tests quiet.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 607a66d and ff2bc1a.

📒 Files selected for processing (1)
  • crates/whisper-local/src/model.rs (10 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.{js,ts,tsx,rs}

⚙️ CodeRabbit configuration file

**/*.{js,ts,tsx,rs}: 1. No error handling.
2. No unused imports, variables, or functions.
3. For comments, keep it minimal. It should be about "Why", not "What".

Files:

  • crates/whisper-local/src/model.rs
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: ci (windows, windows-latest)
  • GitHub Check: ci (macos, macos-latest)
🔇 Additional comments (7)
crates/whisper-local/src/model.rs (7)

11-11: Import is correct and used.
No action needed.


22-23: Builder gains optional vocabulary — OK.
Keeps default opt-in semantics at the type level.


36-39: Builder API for vocabulary — OK.
Fluent API remains consistent with existing builders.


112-143: Params block — OK.
Reasonable defaults and logging; no issues spotted.


145-146: Good: guard keeps callback context alive and fixes leak.
Lifetime cleanly spans the decode call.


302-309: Drop frees the boxed context — OK. Verify callback lifetime.
We rely on whisper-rs not invoking the callback after full() returns.

Please double-check the pinned whisper-rs guarantees the callback/user_data are only used during that full() invocation. If not, we should null the callback post-decode.


263-282: Harden callback: guard against null tokens and zero-length input

         ) {
-            if logits.is_null() || user_data.is_null() {
+            if logits.is_null() || user_data.is_null() || tokens.is_null() || n_tokens <= 0 {
             return;
         }

Verify that in the pinned whisper-rs revision the C callback contract guarantees non-null tokens and n_tokens > 0; otherwise these guards prevent UB.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
crates/whisper-local/src/bias.rs (1)

58-92: Unsafe OOB risk on logits; prefer longest-suffix first and stop after first hit.

  • Potential UB: pointer arithmetic on logits without knowing n_vocab.
  • Current loop biases multiple suffix lengths, double-counting. Prefer longest-context first, break once any continuation is found.

Apply:

-    pub unsafe fn apply_bias_to_logits(
+    /// Safety: `tokens` must point to `n_tokens` valid whisper_token_data.
+    /// `logits` must point to `n_vocab` contiguous f32s (the vocab size).
+    pub unsafe fn apply_bias_to_logits(
         &self,
         tokens: *const whisper_rs::whisper_rs_sys::whisper_token_data,
         n_tokens: std::os::raw::c_int,
-        logits: *mut f32,
+        logits: *mut f32,
+        n_vocab: std::os::raw::c_int,
     ) {
-        if tokens.is_null() || n_tokens <= 0 {
+        if tokens.is_null() || logits.is_null() || n_tokens <= 0 || n_vocab <= 0 {
             return;
         }
@@
-        for suffix_len in 1..=std::cmp::min(10, current_tokens.len()) {
+        let n_vocab = n_vocab as usize;
+        let logits_slice = std::slice::from_raw_parts_mut(logits, n_vocab);
+        for suffix_len in (1..=std::cmp::min(10, current_tokens.len())).rev() {
             let suffix = &current_tokens[current_tokens.len() - suffix_len..];
-
-            for (full_sequence, bias_value_ref) in self.trie.predictive_search(suffix) {
-                let bias_value = *bias_value_ref;
-                let full_sequence: Vec<WhisperTokenId> = full_sequence;
-
+            let mut found = false;
+            for (full_sequence, bias_value_ref) in self.trie.predictive_search(suffix) {
+                let log_bias = *bias_value_ref; // if new() stores ln(bias); else: (*bias_value_ref).max(1e-6).ln()
                 if full_sequence.len() > suffix.len() {
                     let next_token = full_sequence[suffix.len()];
-                    let current_logit = *logits.offset(next_token as isize);
-
-                    let boost = bias_value.ln() * 2.0;
-                    let new_logit = current_logit + boost;
-
-                    *logits.offset(next_token as isize) = new_logit;
+                    if let Some(slot) = logits_slice.get_mut(next_token as usize) {
+                        // cap boost magnitude for numerical stability
+                        let boost = log_bias.clamp(-6.0, 6.0);
+                        *slot += boost;
+                        found = true;
+                    }
                 }
             }
-        }
+            if found { break; }
+        }
     }
🧹 Nitpick comments (5)
crates/whisper-local/src/bias.rs (5)

1-7: Derive Debug for BiasTrie (useful for tracing) and prep for dedup imports.

 use trie_rs::map::{Trie, TrieBuilder};
 use whisper_rs::{WhisperContext, WhisperTokenId};
+use std::collections::{HashMap, HashSet};
 
-#[derive(Clone)]
+#[derive(Clone, Debug)]
 pub struct BiasTrie {
     trie: Trie<WhisperTokenId, f32>,
 }

20-24: Store biases in log-domain to avoid per-step ln and clamp magnitude.

-                    let prefix_bias = 10.0 + 90.0 * progress.powi(2);
+                    // store log-bias; clamp to a sane range for stability
+                    let prefix_bias = (10.0 + 90.0 * progress.powi(2)).ln().clamp(-6.0, 6.0);

Then in apply_bias_to_logits, use the stored value directly (see separate comment).


32-56: Variant set can contain duplicates; dedup and consider trailing-space variant.

-        let mut variants = Vec::new();
+        let mut variants = Vec::new();
 
         variants.push(ctx.tokenize(word, 99)?);
         variants.push(ctx.tokenize(&format!(" {}", word), 99)?);
+        variants.push(ctx.tokenize(&format!("{word} "), 99)?);
@@
         variants.push(ctx.tokenize(&format!("'{}", word), 99)?);
         variants.push(ctx.tokenize(&format!("\"{}", word), 99)?);
 
-        Ok(variants)
+        // remove duplicate tokenizations
+        variants.sort();
+        variants.dedup();
+        Ok(variants)

79-80: Remove redundant type rebind.

-                let bias_value = *bias_value_ref;
-                let full_sequence: Vec<WhisperTokenId> = full_sequence;
+                let bias_value = *bias_value_ref;

74-91: Consider per-token max instead of additive stacking across multiple matches.

If you intentionally keep multi-suffix accumulation, guard against over-bias by aggregating the maximum boost per next_token and applying once. I can provide a small HashMap-based accumulator if you want it.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between ff2bc1a and 8adcf40.

📒 Files selected for processing (1)
  • crates/whisper-local/src/bias.rs (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.{js,ts,tsx,rs}

⚙️ CodeRabbit configuration file

**/*.{js,ts,tsx,rs}: 1. No error handling.
2. No unused imports, variables, or functions.
3. For comments, keep it minimal. It should be about "Why", not "What".

Files:

  • crates/whisper-local/src/bias.rs
🔇 Additional comments (1)
crates/whisper-local/src/bias.rs (1)

38-40: tokenize() max_tokens semantics – 0 disables limit
tokenize(&self, text: &str, max_tokens: usize) returns up to max_tokens tokens; passing 0 disables the cap. Since you’re feeding single words (≤2 subword tokens), 99 can’t truncate here. To make it explicit, you may replace 99 with 0 in all ctx.tokenize(..., 99) calls.

Comment on lines +13 to +26
for word in custom_vocab {
let variants = Self::generate_tokenization_variants(ctx, word)?;

for tokens in variants {
for i in 1..=tokens.len() {
let progress = i as f32 / tokens.len() as f32;

let prefix_bias = 10.0 + 90.0 * progress.powi(2);

let prefix = &tokens[..i];
builder.push(prefix, prefix_bias);
}
}
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Deduplicate overlapping prefixes across variants; skip empty tokenizations.

Multiple variants can yield identical prefixes; pushing each inflates/overwrites bias unpredictably. Aggregate by prefix and keep the max bias.

-        for word in custom_vocab {
-            let variants = Self::generate_tokenization_variants(ctx, word)?;
-            for tokens in variants {
-                for i in 1..=tokens.len() {
-                    let progress = i as f32 / tokens.len() as f32;
-
-                    let prefix_bias = 10.0 + 90.0 * progress.powi(2);
-
-                    let prefix = &tokens[..i];
-                    builder.push(prefix, prefix_bias);
-                }
-            }
-        }
+        let mut acc: HashMap<Vec<WhisperTokenId>, f32> = HashMap::new();
+        for word in custom_vocab {
+            let variants = Self::generate_tokenization_variants(ctx, word)?;
+            for tokens in variants {
+                if tokens.is_empty() { continue; }
+                for i in 1..=tokens.len() {
+                    let progress = i as f32 / tokens.len() as f32;
+                    let bias = 10.0 + 90.0 * progress.powi(2);
+                    let key = tokens[..i].to_vec();
+                    acc.entry(key).and_modify(|v| *v = v.max(bias)).or_insert(bias);
+                }
+            }
+        }
+        for (k, v) in acc {
+            builder.push(&k, v);
+        }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for word in custom_vocab {
let variants = Self::generate_tokenization_variants(ctx, word)?;
for tokens in variants {
for i in 1..=tokens.len() {
let progress = i as f32 / tokens.len() as f32;
let prefix_bias = 10.0 + 90.0 * progress.powi(2);
let prefix = &tokens[..i];
builder.push(prefix, prefix_bias);
}
}
}
// Deduplicate overlapping prefixes across all variants, keeping the max bias per prefix.
let mut acc: HashMap<Vec<WhisperTokenId>, f32> = HashMap::new();
for word in custom_vocab {
let variants = Self::generate_tokenization_variants(ctx, word)?;
for tokens in variants {
if tokens.is_empty() { continue; }
for i in 1..=tokens.len() {
let progress = i as f32 / tokens.len() as f32;
let bias = 10.0 + 90.0 * progress.powi(2);
let key = tokens[..i].to_vec();
acc.entry(key)
.and_modify(|v| *v = v.max(bias))
.or_insert(bias);
}
}
}
for (k, v) in acc {
builder.push(&k, v);
}
🤖 Prompt for AI Agents
In crates/whisper-local/src/bias.rs around lines 13 to 26, the loop currently
pushes every prefix from every tokenization variant, which duplicates
overlapping prefixes and allows later variants to overwrite earlier bias values
unpredictably and also doesn't skip empty tokenizations; change the logic so you
skip any variant with zero tokens, then collect prefixes from all variants into
a temporary map (or hashmap) keyed by the prefix sequence and store the maximum
computed bias for that prefix (compare existing entry and keep the larger bias),
and after processing all variants push each unique prefix once to builder with
the aggregated max bias.

Comment on lines +58 to +92
pub unsafe fn apply_bias_to_logits(
&self,
tokens: *const whisper_rs::whisper_rs_sys::whisper_token_data,
n_tokens: std::os::raw::c_int,
logits: *mut f32,
) {
if tokens.is_null() || n_tokens <= 0 {
return;
}

let current_tokens: Vec<WhisperTokenId> =
std::slice::from_raw_parts(tokens, n_tokens as usize)
.iter()
.map(|t| t.id)
.collect();

for suffix_len in 1..=std::cmp::min(10, current_tokens.len()) {
let suffix = &current_tokens[current_tokens.len() - suffix_len..];

for (full_sequence, bias_value_ref) in self.trie.predictive_search(suffix) {
let bias_value = *bias_value_ref;
let full_sequence: Vec<WhisperTokenId> = full_sequence;

if full_sequence.len() > suffix.len() {
let next_token = full_sequence[suffix.len()];
let current_logit = *logits.offset(next_token as isize);

let boost = bias_value.ln() * 2.0;
let new_logit = current_logit + boost;

*logits.offset(next_token as isize) = new_logit;
}
}
}
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Update call sites to pass n_vocab and validate usage.

Use this script to find and fix all call sites:


🏁 Script executed:

#!/bin/bash
# Find apply_bias_to_logits call sites and show context
rg -n -C3 --type=rust '\.apply_bias_to_logits\s*\('

Length of output: 518


Update all call sites to pass n_vocab
In crates/whisper-local/src/model.rs:281, change

.apply_bias_to_logits(tokens, n_tokens, logits)

to

.apply_bias_to_logits(tokens, n_tokens, n_vocab, logits)

and ensure the method signature in bias.rs is updated accordingly.

🤖 Prompt for AI Agents
In crates/whisper-local/src/bias.rs around lines 58 to 92, update the unsafe fn
signature to accept n_vocab (std::os::raw::c_int) as a new parameter (e.g.,
tokens, n_tokens, n_vocab, logits), then use n_vocab to validate token indices
before indexing into the logits buffer: ensure next_token is within 0..n_vocab
(cast and compare as usize/int) and only read/write logits if in-bounds to avoid
OOB access; adjust pointer arithmetic to cast next_token safely to isize after
bounds checking; also update call sites (as noted in the review) to pass n_vocab
when invoking this method.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant