Skip to content

Commit 79a6c97

Browse files
authored
[None][fix] Use fp32 for indexer weight_proj GEMM (#9243)
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent 028fc87 commit 79a6c97

File tree

4 files changed

+87
-107
lines changed

4 files changed

+87
-107
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,11 @@ def _scale(weights: torch.Tensor, q_scale: torch.Tensor,
674674
return weights * q_scale.squeeze(-1) * s
675675

676676

677+
@maybe_compile(dynamic=True)
678+
def _to_float(hidden_states: torch.Tensor) -> torch.Tensor:
679+
return hidden_states.float()
680+
681+
677682
class Indexer(nn.Module):
678683

679684
def __init__(self,
@@ -715,7 +720,7 @@ def __init__(self,
715720
self.hidden_size,
716721
self.n_heads,
717722
bias=False,
718-
dtype=dtype,
723+
dtype=torch.float32,
719724
quant_config=None,
720725
skip_create_weights_in_init=skip_create_weights_in_init,
721726
use_custom_cublas_mm=True)
@@ -1233,82 +1238,63 @@ def sparse_attn_indexer(
12331238
dtype=torch.int32)
12341239
return topk_indices_buffer
12351240

1236-
def weight_scale(self, hidden_states: torch.Tensor,
1237-
indexer_weights: Optional[torch.Tensor],
1238-
q_scale: torch.Tensor) -> torch.Tensor:
1239-
weights = indexer_weights if indexer_weights is not None else self.weights_proj(
1240-
hidden_states)
1241+
def _weight_scale(self, weights: torch.Tensor,
1242+
q_scale: torch.Tensor) -> torch.Tensor:
12411243
weights = _scale(weights, q_scale, self.weight_scale_factor)
12421244
return weights
12431245

1246+
def _qk_projection_and_rope(self, qr: torch.Tensor, indexer_k: torch.Tensor,
1247+
position_ids: torch.Tensor):
1248+
"""Project Q/K and apply RoPE"""
1249+
q = self.wq_b(qr)
1250+
k = self.k_norm(indexer_k)
1251+
q = q.view(-1, self.n_heads, self.head_dim)
1252+
q_pe, q_nope = q.split([self.rope_dim, self.head_dim - self.rope_dim],
1253+
dim=-1)
1254+
k_pe, k_nope = k.split([self.rope_dim, self.head_dim - self.rope_dim],
1255+
dim=-1)
1256+
q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)])
1257+
k_pe = k_pe[:, 0, :]
1258+
return q_pe, q_nope, k_pe, k_nope
1259+
1260+
def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
1261+
"""Concatenate, rotate, and FP8 quantize for Q or K"""
1262+
q_or_k = torch.cat([qk_pe, qk_nope], dim=-1)
1263+
q_or_k = rotate_activation(q_or_k)
1264+
q_or_k = q_or_k.view(-1, self.head_dim)
1265+
q_or_k = fp8_utils.fp8_quantize_1x128_sf_transpose(
1266+
q_or_k, use_ue8m0=self.scale_fmt == "ue8m0")
1267+
return q_or_k
1268+
12441269
@torch.inference_mode()
12451270
def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor,
12461271
metadata: DSAtrtllmAttentionMetadata,
1247-
position_ids: torch.Tensor, indexer_k: Optional[torch.Tensor],
1248-
indexer_weights: Optional[torch.Tensor]):
1272+
position_ids: torch.Tensor, indexer_k: torch.Tensor):
12491273
quant_block_size = metadata.kv_cache_manager.quant_block_size
12501274
assert quant_block_size == 128, "Only support quant_block_size = 128 for now"
12511275

1252-
if indexer_k is not None:
1253-
q, k = maybe_execute_in_parallel(
1254-
lambda: self.wq_b(
1255-
qr), # TODO: fuse wq_b and move this outside of the indexer
1256-
lambda: self.k_norm(indexer_k),
1257-
self.ln_events[0],
1258-
self.ln_events[1],
1259-
self.aux_stream,
1260-
)
1261-
else:
1262-
q, k = maybe_execute_in_parallel(
1263-
lambda: self.wq_b(qr),
1264-
lambda: self.k_norm(self.wk(hidden_states)),
1265-
self.ln_events[0],
1266-
self.ln_events[1],
1267-
self.aux_stream,
1268-
)
1269-
1270-
# q/k rope + possible fast_hadamard_transform
1271-
q = q.view(-1, self.n_heads, self.head_dim)
1272-
1273-
q, k = maybe_execute_in_parallel(
1274-
lambda: torch.split(
1275-
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1),
1276-
lambda: torch.split(
1277-
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1),
1276+
q_and_k, weights = maybe_execute_in_parallel(
1277+
lambda: self._qk_projection_and_rope(qr, indexer_k, position_ids),
1278+
lambda: self.weights_proj(_to_float(hidden_states)),
12781279
self.ln_events[0],
12791280
self.ln_events[1],
12801281
self.aux_stream,
12811282
)
1282-
1283-
q_pe, q_nope = q
1284-
k_pe, k_nope = k
1285-
q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)])
1286-
1287-
k_pe = k_pe[:, 0, :]
1288-
1289-
def _prep_q_or_k(qk_pe, qk_nope):
1290-
q_or_k = torch.cat([qk_pe, qk_nope], dim=-1)
1291-
q_or_k = rotate_activation(q_or_k)
1292-
q_or_k = q_or_k.view(-1, self.head_dim)
1293-
q_or_k = fp8_utils.fp8_quantize_1x128_sf_transpose(
1294-
q_or_k, use_ue8m0=self.scale_fmt == "ue8m0")
1295-
return q_or_k
1296-
1283+
q_pe, q_nope, k_pe, k_nope = q_and_k
12971284
q, k = maybe_execute_in_parallel(
1298-
lambda: _prep_q_or_k(q_pe, q_nope),
1299-
lambda: _prep_q_or_k(k_pe, k_nope),
1285+
lambda: self._prep_q_or_k(q_pe, q_nope),
1286+
lambda: self._prep_q_or_k(k_pe, k_nope),
13001287
self.ln_events[0],
13011288
self.ln_events[1],
13021289
self.aux_stream,
13031290
)
1304-
13051291
q_fp8, q_scale = q
13061292
k_fp8, k_scale = k
13071293
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
13081294
q_scale = q_scale.view(-1, self.n_heads, 1)
13091295

13101296
weights, _ = maybe_execute_in_parallel(
1311-
lambda: self.weight_scale(hidden_states, indexer_weights, q_scale),
1297+
lambda: self._weight_scale(weights, q_scale),
13121298
lambda: self._update_k_cache(
13131299
k_fp8, k_scale, metadata), # store k_fp8 and k_scale in k cache
13141300
self.ln_events[0],

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,10 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
362362
fused_a_scale = torch.cat(
363363
[q_a_proj_scale, fused_a_scale], dim=0)
364364

365-
module.weight_scale.data.copy_(fused_a_scale)
366-
# For DeepseekV32 with fuse_a_indexer_k_weight=True: kv_a_proj_with_mqa is oversized
367-
# to include indexer weights, which is filled in post_load_weights.
365+
module.weight_scale.data[0:fused_a_scale.
366+
shape[0]].copy_(fused_a_scale)
367+
# For DeepseekV32: kv_a_proj_with_mqa is oversized
368+
# to include indexer k weights, which is filled in post_load_weights.
368369
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
369370
elif names[-1] in params_map:
370371
module_weights = []
@@ -556,13 +557,6 @@ def __init__(
556557
config = model_config.pretrained_config
557558
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
558559

559-
# DSV3.2 nvfp4 ckpt has kv_a_proj_with_mqa module in bfloat16
560-
# TODO: check it more directly/robustly, e.g., indexer_weight_quant == fuseA_quant == indexer_quant
561-
if model_config.get_quant_config().quant_algo == QuantAlgo.NVFP4:
562-
self.fuse_a_indexer_k_weight = True
563-
else:
564-
self.fuse_a_indexer_k_weight = False
565-
566560
super().__init__(hidden_size=config.hidden_size,
567561
num_attention_heads=config.num_attention_heads,
568562
num_key_value_heads=config.num_key_value_heads,
@@ -586,36 +580,46 @@ def __init__(
586580

587581
self.indexer = self.mqa.indexer
588582

589-
if self.fuse_a_indexer_k_weight:
590-
# For DeepseekV32, the kv_a_proj_with_mqa includes:
591-
# q_a_proj + kv_a_proj_with_mqa + indexer.wk + indexer.weights_proj
592-
self.kv_a_proj_with_mqa = DeepseekV3Linear(
593-
config.hidden_size,
594-
self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank +
595-
self.indexer.head_dim + self.indexer.n_heads,
596-
bias=False,
597-
dtype=config.torch_dtype,
598-
quant_config=model_config.get_quant_config(),
599-
skip_create_weights_in_init=model_config.
600-
skip_create_weights_in_init,
601-
use_custom_cublas_mm=True)
583+
# For DeepseekV32, the kv_a_proj_with_mqa includes:
584+
# q_a_proj + kv_a_proj_with_mqa + indexer.wk
585+
self.kv_a_proj_with_mqa = DeepseekV3Linear(
586+
config.hidden_size,
587+
self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank +
588+
self.indexer.head_dim,
589+
bias=False,
590+
dtype=config.torch_dtype,
591+
quant_config=model_config.get_quant_config(),
592+
skip_create_weights_in_init=model_config.
593+
skip_create_weights_in_init,
594+
use_custom_cublas_mm=True)
602595

603596
def post_load_weights(self):
604-
if self.fuse_a_indexer_k_weight:
605-
assert self.kv_a_proj_with_mqa.weight.data.dtype == self.indexer.wk.weight.data.dtype == self.indexer.weights_proj.weight.data.dtype, "all weights in kv_a_proj_with_mqa module must have matching dtype"
606-
# Copy indexer weights into the fused kv_a_proj_with_mqa module
607-
indexer_wk_weight = self.indexer.wk.weight.data
608-
offset = self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank
609-
self.kv_a_proj_with_mqa.weight.data[offset:offset +
610-
self.indexer.head_dim].copy_(
611-
indexer_wk_weight)
612-
offset += self.indexer.head_dim
613-
indexer_weights_proj_weight = self.indexer.weights_proj.weight.data
614-
self.kv_a_proj_with_mqa.weight.data[offset:offset +
615-
self.indexer.n_heads].copy_(
616-
indexer_weights_proj_weight)
617-
self.indexer.wk = None
618-
self.indexer.weights_proj = None
597+
"""
598+
Concatenate indexer.wk weights into kv_a_proj_with_mqa's last dimension, to fuse indexer.wk projection with kv_a_proj_with_mqa GEMM.
599+
"""
600+
assert self.kv_a_proj_with_mqa.weight.data.dtype == self.indexer.wk.weight.data.dtype, "all weights in kv_a_proj_with_mqa module must have matching dtype"
601+
# Copy indexer weights into the fused kv_a_proj_with_mqa module
602+
indexer_wk_weight = self.indexer.wk.weight.data
603+
offset = self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank
604+
self.kv_a_proj_with_mqa.weight.data[offset:offset +
605+
self.indexer.head_dim].copy_(
606+
indexer_wk_weight)
607+
608+
# Copy indexer scale data if it exists
609+
if hasattr(self.indexer.wk,
610+
'weight_scale') and self.indexer.wk.weight_scale is not None:
611+
indexer_wk_scale = self.indexer.wk.weight_scale.data
612+
assert self.kv_a_proj_with_mqa.weight_scale.dim(
613+
) == 2, "weight_scale must be a 2D tensor"
614+
group_size = self.kv_a_proj_with_mqa.weight.shape[
615+
0] // self.kv_a_proj_with_mqa.weight_scale.shape[0]
616+
scale_offset = offset // group_size
617+
scale_size = indexer_wk_scale.shape[0]
618+
# Copy indexer scale to the corresponding position in the fused module
619+
self.kv_a_proj_with_mqa.weight_scale.data[
620+
scale_offset:scale_offset + scale_size].copy_(indexer_wk_scale)
621+
622+
self.indexer.wk = None
619623

620624

621625
class Deepseekv3RoutingImpl():

tensorrt_llm/_torch/modules/attention.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,19 +1221,11 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],
12211221
if position_ids is not None:
12221222
position_ids = position_ids[..., :num_tokens]
12231223

1224-
if self.fuse_a_indexer_k_weight:
1225-
q, compressed_kv, k_pe, indexer_k, indexer_weights = self.kv_a_proj_with_mqa(
1226-
hidden_states).split([
1227-
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim,
1228-
self.indexer.head_dim, self.indexer.n_heads
1229-
], -1)
1230-
else:
1231-
q, compressed_kv, k_pe = self.kv_a_proj_with_mqa(
1232-
hidden_states).split([
1233-
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim
1234-
], -1)
1235-
indexer_k = None
1236-
indexer_weights = None
1224+
q, compressed_kv, k_pe, indexer_k = self.kv_a_proj_with_mqa(
1225+
hidden_states).split([
1226+
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim,
1227+
self.indexer.head_dim
1228+
], -1)
12371229

12381230
# TODO: possibly overlap/fuse q_a_rmsnorm + kv_a_rmsnorm + indexer.k_layernorm?
12391231
q, compressed_kv = maybe_execute_in_parallel(
@@ -1255,7 +1247,6 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],
12551247
attn_metadata,
12561248
position_ids,
12571249
indexer_k=indexer_k, # indexer K proj
1258-
indexer_weights=indexer_weights, # indexer weights proj
12591250
)
12601251

12611252
assert q.shape[

tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,8 +681,7 @@ def yarn_get_mscale(scale=1, mscale=1):
681681
hidden_states,
682682
attn_metadata,
683683
position_ids,
684-
None, # indexer_k
685-
None, # indexer_weights
684+
indexer_k=mla.mqa.indexer.wk(hidden_states), # indexer_k
686685
)
687686

688687
# Validate indexer output against expected causal indices (since seq_len < topk=2048)

0 commit comments

Comments
 (0)