diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index 939313a..c4d8dc3 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -712,9 +712,9 @@ def messages_to_oci_params( ) else: oci_message = self.oci_chat_message[role](content=tool_content) - elif isinstance(message, AIMessage) and message.additional_kwargs.get( - "tool_calls" - ): + elif isinstance(message, AIMessage) and ( + message.tool_calls or + message.additional_kwargs.get("tool_calls")): # Process content and tool calls for assistant messages content = self._process_message_content(message.content) tool_calls = [] diff --git a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py index d33adde..1a6649a 100644 --- a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py +++ b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py @@ -6,9 +6,9 @@ from unittest.mock import MagicMock import pytest -from langchain_core.messages import HumanMessage from pytest import MonkeyPatch +from langchain_core.messages import HumanMessage, AIMessage from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI @@ -575,6 +575,165 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] assert response["parsed"].conditions == "Sunny" +@pytest.mark.requires("oci") +def test_ai_message_tool_calls_direct_field(monkeypatch: MonkeyPatch) -> None: + """Test AIMessage with tool_calls in the direct tool_calls field.""" + + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client) + + # Track if the tool_calls processing branch is executed + tool_calls_processed = False + + def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] + nonlocal tool_calls_processed + # Check if the request contains tool_calls in the message + request = args[0] + if hasattr(request, 'chat_request') and hasattr(request.chat_request, 'messages'): + for msg in request.chat_request.messages: + if hasattr(msg, 'tool_calls') and msg.tool_calls: + tool_calls_processed = True + break + return MockResponseDict( + { + "status": 200, + "data": MockResponseDict( + { + "chat_response": MockResponseDict( + { + "api_format": "GENERIC", + "choices": [ + MockResponseDict( + { + "message": MockResponseDict( + { + "role": "ASSISTANT", + "name": None, + "content": [ + MockResponseDict( + { + "text": ( + "I'll help you." + ), + "type": "TEXT", + } + ) + ], + "tool_calls": [], + } + ), + "finish_reason": "completed", + } + ) + ], + "time_created": "2025-08-14T10:00:01.100000+00:00", + } + ), + "model_id": "meta.llama-3.3-70b-instruct", + "model_version": "1.0.0", + } + ), + "request_id": "1234567890", + "headers": MockResponseDict({"content-length": "123"}), + } + ) + + monkeypatch.setattr(llm.client, "chat", mocked_response) + + # Create AIMessage with tool_calls in the direct tool_calls field + ai_message = AIMessage( + content="I need to call a function", + tool_calls=[ + { + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + } + ] + ) + + messages = [ai_message] + + # This should not raise an error and should process the tool_calls correctly + response = llm.invoke(messages) + assert response.content == "I'll help you." + + +@pytest.mark.requires("oci") +def test_ai_message_tool_calls_additional_kwargs(monkeypatch: MonkeyPatch) -> None: + """Test AIMessage with tool_calls in additional_kwargs field.""" + + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client) + + def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] + return MockResponseDict( + { + "status": 200, + "data": MockResponseDict( + { + "chat_response": MockResponseDict( + { + "api_format": "GENERIC", + "choices": [ + MockResponseDict( + { + "message": MockResponseDict( + { + "role": "ASSISTANT", + "name": None, + "content": [ + MockResponseDict( + { + "text": ( + "I'll help you." + ), + "type": "TEXT", + } + ) + ], + "tool_calls": [], + } + ), + "finish_reason": "completed", + } + ) + ], + "time_created": "2025-08-14T10:00:01.100000+00:00", + } + ), + "model_id": "meta.llama-3.3-70b-instruct", + "model_version": "1.0.0", + } + ), + "request_id": "1234567890", + "headers": MockResponseDict({"content-length": "123"}), + } + ) + + monkeypatch.setattr(llm.client, "chat", mocked_response) + + # Create AIMessage with tool_calls in additional_kwargs + ai_message = AIMessage( + content="I need to call a function", + additional_kwargs={ + "tool_calls": [ + { + "id": "call_456", + "name": "get_weather", + "args": {"location": "New York"}, + } + ] + } + ) + + messages = [ai_message] + + # This should not raise an error and should process the tool_calls correctly + response = llm.invoke(messages) + assert response.content == "I'll help you." + + def test_get_provider(): """Test determining the provider based on the model_id.""" model_provider_map = {