Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 1 addition & 112 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import warnings

from google.adk.models.lite_llm import _content_to_message_param
from google.adk.models.lite_llm import _FINISH_REASON_MAPPING
from google.adk.models.lite_llm import _function_declaration_to_tool_param
from google.adk.models.lite_llm import _get_content
from google.adk.models.lite_llm import _message_to_generate_content_response
Expand Down Expand Up @@ -2019,114 +2018,4 @@ def test_non_gemini_litellm_no_warning():
warnings.simplefilter("always")
# Test with non-Gemini model
LiteLlm(model="openai/gpt-4o")
assert len(w) == 0


@pytest.mark.parametrize(
"finish_reason,response_content,expected_content,has_tool_calls",
[
("length", "Test response", "Test response", False),
("stop", "Complete response", "Complete response", False),
(
"tool_calls",
"",
"",
True,
),
("content_filter", "", "", False),
],
ids=["length", "stop", "tool_calls", "content_filter"],
)
@pytest.mark.asyncio
async def test_finish_reason_propagation(
mock_acompletion,
lite_llm_instance,
finish_reason,
response_content,
expected_content,
has_tool_calls,
):
"""Test that finish_reason is properly propagated from LiteLLM response."""
tool_calls = None
if has_tool_calls:
tool_calls = [
ChatCompletionMessageToolCall(
type="function",
id="test_id",
function=Function(
name="test_function",
arguments='{"arg": "value"}',
),
)
]

mock_response = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content=response_content,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
]
)
mock_acompletion.return_value = mock_response

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
# Verify finish_reason is mapped to FinishReason enum
assert isinstance(response.finish_reason, types.FinishReason)
# Verify correct enum mapping using the actual mapping from lite_llm
assert response.finish_reason == _FINISH_REASON_MAPPING[finish_reason]
if expected_content:
assert response.content.parts[0].text == expected_content
if has_tool_calls:
assert len(response.content.parts) > 0
assert response.content.parts[-1].function_call.name == "test_function"

mock_acompletion.assert_called_once()


@pytest.mark.asyncio
async def test_finish_reason_unknown_maps_to_other(
mock_acompletion, lite_llm_instance
):
"""Test that unknown finish_reason values map to FinishReason.OTHER."""
mock_response = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="Test response",
),
finish_reason="unknown_reason_type",
)
]
)
mock_acompletion.return_value = mock_response

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
# Unknown finish_reason should map to OTHER
assert isinstance(response.finish_reason, types.FinishReason)
assert response.finish_reason == types.FinishReason.OTHER

mock_acompletion.assert_called_once()
assert len(w) == 0