Skip to content

Commit a8d56ff

Browse files
authored
Merge pull request #81 from codelion/fix-litellm-wrapper-for-claude
Fix litellm wrapper for claude
2 parents c74902d + 0ebae20 commit a8d56ff

File tree

3 files changed

+25
-9
lines changed

3 files changed

+25
-9
lines changed

optillm/entropy_decoding.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,30 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup
2626
varentropy = torch.sum(probs * (log_probs / LN_2 + entropy.unsqueeze(-1))**2, dim=axis)
2727
return entropy, varentropy
2828

29-
def calculate_attention_metrics(attention_scores: torch.Tensor) -> Dict[str, torch.Tensor]:
30-
attention_probs = F.softmax(attention_scores, dim=-1)
29+
def calculate_attention_metrics(attention_weights: torch.Tensor) -> Dict[str, torch.Tensor]:
30+
attention_probs = attention_weights
31+
32+
# Calculate entropy
3133
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1)
32-
attn_varentropy = torch.var(attn_entropy, dim=-1)
3334

34-
attn_varentropy = torch.where(torch.isnan(attn_varentropy), torch.zeros_like(attn_varentropy), attn_varentropy)
35+
# Calculate variance of entropy with unbiased=False to avoid df issues
36+
# Also add a check for singleton dimensions
37+
if attn_entropy.size(-1) > 1:
38+
attn_varentropy = torch.var(attn_entropy, dim=-1, unbiased=False)
39+
else:
40+
attn_varentropy = torch.zeros_like(attn_entropy)
41+
42+
attn_varentropy = torch.where(torch.isnan(attn_varentropy),
43+
torch.zeros_like(attn_varentropy),
44+
attn_varentropy)
45+
46+
# Rest remains the same
3547
mean_attention = torch.mean(attention_probs, dim=1)
3648
agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2))
37-
38-
interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3))
39-
49+
50+
attention_scores_proxy = torch.log(torch.clamp(attention_probs, 1e-10, 1.0))
51+
interaction_strength = torch.mean(torch.abs(attention_scores_proxy), dim=(1, 2, 3))
52+
4053
return {
4154
"attn_entropy": torch.mean(attn_entropy),
4255
"attn_varentropy": torch.mean(attn_varentropy),

optillm/litellm_wrapper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ class Chat:
2424
class Completions:
2525
@staticmethod
2626
def create(model: str, messages: List[Dict[str, str]], **kwargs):
27-
response = completion(model=model, messages=messages, **kwargs, safety_settings=SAFETY_SETTINGS)
27+
if model.startswith("gemini"):
28+
response = completion(model=model, messages=messages, **kwargs, safety_settings=SAFETY_SETTINGS)
29+
else:
30+
response = completion(model=model, messages=messages, **kwargs)
2831
# Convert LiteLLM response to match OpenAI response structure
2932
return response
3033

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="optillm",
5-
version="0.0.6",
5+
version="0.0.7",
66
packages=find_packages(),
77
py_modules=['optillm'],
88
package_data={

0 commit comments

Comments
 (0)