@@ -26,17 +26,30 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup
2626 varentropy = torch .sum (probs * (log_probs / LN_2 + entropy .unsqueeze (- 1 ))** 2 , dim = axis )
2727 return entropy , varentropy
2828
29- def calculate_attention_metrics (attention_scores : torch .Tensor ) -> Dict [str , torch .Tensor ]:
30- attention_probs = F .softmax (attention_scores , dim = - 1 )
29+ def calculate_attention_metrics (attention_weights : torch .Tensor ) -> Dict [str , torch .Tensor ]:
30+ attention_probs = attention_weights
31+
32+ # Calculate entropy
3133 attn_entropy = - torch .sum (attention_probs * torch .log2 (torch .clamp (attention_probs , 1e-10 , 1.0 )), dim = - 1 )
32- attn_varentropy = torch .var (attn_entropy , dim = - 1 )
3334
34- attn_varentropy = torch .where (torch .isnan (attn_varentropy ), torch .zeros_like (attn_varentropy ), attn_varentropy )
35+ # Calculate variance of entropy with unbiased=False to avoid df issues
36+ # Also add a check for singleton dimensions
37+ if attn_entropy .size (- 1 ) > 1 :
38+ attn_varentropy = torch .var (attn_entropy , dim = - 1 , unbiased = False )
39+ else :
40+ attn_varentropy = torch .zeros_like (attn_entropy )
41+
42+ attn_varentropy = torch .where (torch .isnan (attn_varentropy ),
43+ torch .zeros_like (attn_varentropy ),
44+ attn_varentropy )
45+
46+ # Rest remains the same
3547 mean_attention = torch .mean (attention_probs , dim = 1 )
3648 agreement = torch .mean (torch .abs (attention_probs - mean_attention .unsqueeze (1 )), dim = (1 , 2 ))
37-
38- interaction_strength = torch .mean (torch .abs (attention_scores ), dim = (1 , 2 , 3 ))
39-
49+
50+ attention_scores_proxy = torch .log (torch .clamp (attention_probs , 1e-10 , 1.0 ))
51+ interaction_strength = torch .mean (torch .abs (attention_scores_proxy ), dim = (1 , 2 , 3 ))
52+
4053 return {
4154 "attn_entropy" : torch .mean (attn_entropy ),
4255 "attn_varentropy" : torch .mean (attn_varentropy ),
0 commit comments