Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions backends/candle/src/models/flash_qwen3.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm};
use crate::models::{Model, Qwen3Config, Qwen3ClassificationHead};
use crate::models::{Model, Qwen3Config};
use crate::models::qwen3::Qwen3ClassificationHead;
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;
Expand Down Expand Up @@ -309,7 +310,7 @@ impl FlashQwen3Model {
ModelType::Classifier => {
// Load classification head before the vb is modified
let classification_head = Some(Qwen3ClassificationHead::load(vb.clone(), config)?);
(Pool::Cls, classification_head) // Use CLS pooling for classification
(Pool::LastToken, classification_head) // Use LastToken pooling for classification
}
ModelType::Embedding(pool) => (pool, None),
};
Expand Down Expand Up @@ -404,8 +405,8 @@ impl FlashQwen3Model {
// CLS and LastToken pooling
Pool::Cls | Pool::LastToken => {
if batch_size > 1 {
// Get token indices form cu_seqlens
let mut indices = match self.pool {
// Get token indices for each sequence
let all_indices = match self.pool {
Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?,
Pool::LastToken => {
let end = cu_seqlens.narrow(0, 1, batch_size)?;
Expand All @@ -414,19 +415,21 @@ impl FlashQwen3Model {
_ => unreachable!(),
};

// If raw_indices is empty, we don't need to do anything with
// the pooled_indices
if has_raw_requests {
// We need the pooled indices to select the correct cls indices
// Select the appropriate indices based on pooled_indices
let indices = if has_raw_requests {
// Select only the sequences that need pooling
let pooled_indices_vec: Vec<i64> = batch.pooled_indices.iter()
.map(|&idx| idx as i64)
.collect();
let pooled_indices = Tensor::from_vec(
batch.pooled_indices.clone(),
pooled_indices_vec,
batch.pooled_indices.len(),
&self.device,
)?;

// Only select indices that requires pooling
indices = indices.index_select(&pooled_indices, 0)?
}
all_indices.index_select(&pooled_indices, 0)?
} else {
all_indices
};

// Select tokens
Some(outputs.index_select(&indices, 0)?)
Expand Down