@@ -640,18 +640,18 @@ def __init__(
640640 def get_scores (logits , e_score_correction_bias ):
641641 scores = F .sigmoid (logits )
642642 scores_with_bias = scores + e_score_correction_bias
643- return scores , scores_with_bias
644-
645- def noaux_tc (self , logits , e_score_correction_bias ):
646- n_group = self .n_group
647-
648643 if enable_llm_debug ():
649644 has_nan = torch .isnan (scores_with_bias ).any ()
650645 if has_nan :
651646 warnings .warn (
652647 "Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation."
653648 )
654649
650+ return scores , scores_with_bias
651+
652+ def noaux_tc (self , logits , e_score_correction_bias ):
653+ n_group = self .n_group
654+
655655 _ , num_experts = logits .shape
656656 if self .n_group > 1 :
657657 if self .top_k > 8 or (num_experts / n_group ) > 32 or (
@@ -672,6 +672,13 @@ def noaux_tc(self, logits, e_score_correction_bias):
672672 if not self .is_fused :
673673 scores , scores_with_bias = Deepseekv3RoutingImpl .get_scores (
674674 logits , e_score_correction_bias )
675+ if enable_llm_debug ():
676+ has_nan = torch .isnan (scores_with_bias ).any ()
677+ if has_nan :
678+ warnings .warn (
679+ "Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation."
680+ )
681+
675682 scores_shape = list (scores_with_bias .shape )
676683 group_scores = torch .sum (torch .topk (
677684 scores_with_bias .view (scores_shape [:- 1 ] +
0 commit comments