Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
"""
self.client_args = client_args or {}
self.config = dict(model_config)
self._apply_proxy_prefix()

logger.debug("config=<%s> | initializing", self.config)

Expand All @@ -61,6 +62,7 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type:
**model_config: Configuration overrides.
"""
self.config.update(model_config)
self._apply_proxy_prefix()

@override
def get_config(self) -> LiteLLMConfig:
Expand Down Expand Up @@ -223,3 +225,14 @@ async def structured_output(

# If no tool_calls found, raise an error
raise ValueError("No tool_calls found in response")

def _apply_proxy_prefix(self) -> None:
"""Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True.

This is a workaround for https://github.com/BerriAI/litellm/issues/13454
where use_litellm_proxy parameter is not honored.
"""
if self.client_args.get("use_litellm_proxy") and "model_id" in self.config:
model_id = self.get_config()["model_id"]
if not model_id.startswith("litellm_proxy/"):
self.config["model_id"] = f"litellm_proxy/{model_id}"
33 changes: 33 additions & 0 deletions tests/strands/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,39 @@ def test_update_config(model, model_id):
assert tru_model_id == exp_model_id


@pytest.mark.parametrize(
"client_args, model_id, expected_model_id",
[
({"use_litellm_proxy": True}, "openai/gpt-4", "litellm_proxy/openai/gpt-4"),
({"use_litellm_proxy": False}, "openai/gpt-4", "openai/gpt-4"),
({"use_litellm_proxy": None}, "openai/gpt-4", "openai/gpt-4"),
({}, "openai/gpt-4", "openai/gpt-4"),
(None, "openai/gpt-4", "openai/gpt-4"),
({"use_litellm_proxy": True}, "litellm_proxy/openai/gpt-4", "litellm_proxy/openai/gpt-4"),
({"use_litellm_proxy": False}, "litellm_proxy/openai/gpt-4", "litellm_proxy/openai/gpt-4"),
],
)
def test__init__use_litellm_proxy_prefix(client_args, model_id, expected_model_id):
"""Test litellm_proxy prefix behavior for various configurations."""
model = LiteLLMModel(client_args=client_args, model_id=model_id)
assert model.get_config()["model_id"] == expected_model_id


@pytest.mark.parametrize(
"client_args, initial_model_id, new_model_id, expected_model_id",
[
({"use_litellm_proxy": True}, "openai/gpt-4", "anthropic/claude-3", "litellm_proxy/anthropic/claude-3"),
({"use_litellm_proxy": False}, "openai/gpt-4", "anthropic/claude-3", "anthropic/claude-3"),
(None, "openai/gpt-4", "anthropic/claude-3", "anthropic/claude-3"),
],
)
def test_update_config_proxy_prefix(client_args, initial_model_id, new_model_id, expected_model_id):
"""Test that update_config applies proxy prefix correctly."""
model = LiteLLMModel(client_args=client_args, model_id=initial_model_id)
model.update_config(model_id=new_model_id)
assert model.get_config()["model_id"] == expected_model_id


@pytest.mark.parametrize(
"content, exp_result",
[
Expand Down
Loading