Skip to content

Commit e5c4865

Browse files
committed
fix dsv3 debug mode
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 24f5cd7 commit e5c4865

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -640,18 +640,18 @@ def __init__(
640640
def get_scores(logits, e_score_correction_bias):
641641
scores = F.sigmoid(logits)
642642
scores_with_bias = scores + e_score_correction_bias
643-
return scores, scores_with_bias
644-
645-
def noaux_tc(self, logits, e_score_correction_bias):
646-
n_group = self.n_group
647-
648643
if enable_llm_debug():
649644
has_nan = torch.isnan(scores_with_bias).any()
650645
if has_nan:
651646
warnings.warn(
652647
"Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation."
653648
)
654649

650+
return scores, scores_with_bias
651+
652+
def noaux_tc(self, logits, e_score_correction_bias):
653+
n_group = self.n_group
654+
655655
_, num_experts = logits.shape
656656
if self.n_group > 1:
657657
if self.top_k > 8 or (num_experts / n_group) > 32 or (
@@ -672,6 +672,13 @@ def noaux_tc(self, logits, e_score_correction_bias):
672672
if not self.is_fused:
673673
scores, scores_with_bias = Deepseekv3RoutingImpl.get_scores(
674674
logits, e_score_correction_bias)
675+
if enable_llm_debug():
676+
has_nan = torch.isnan(scores_with_bias).any()
677+
if has_nan:
678+
warnings.warn(
679+
"Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation."
680+
)
681+
675682
scores_shape = list(scores_with_bias.shape)
676683
group_scores = torch.sum(torch.topk(
677684
scores_with_bias.view(scores_shape[:-1] +

0 commit comments

Comments
 (0)