diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 43b5cbf8c..5d4a89866 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -638,8 +638,11 @@ def _record_tool_execution( tool_result: The result returned by the tool. user_message_override: Optional custom message to include. """ + # Filter tool input parameters to only include those defined in tool spec + filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) + # Create user message describing the tool call - input_parameters = json.dumps(tool["input"], default=lambda o: f"<>") + input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") user_msg_content: list[ContentBlock] = [ {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} @@ -649,6 +652,13 @@ def _record_tool_execution( if user_message_override: user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) + # Create filtered tool use for message history + filtered_tool: ToolUse = { + "toolUseId": tool["toolUseId"], + "name": tool["name"], + "input": filtered_input, + } + # Create the message sequence user_msg: Message = { "role": "user", @@ -656,7 +666,7 @@ def _record_tool_execution( } tool_use_msg: Message = { "role": "assistant", - "content": [{"toolUse": tool}], + "content": [{"toolUse": filtered_tool}], } tool_result_msg: Message = { "role": "user", @@ -713,6 +723,25 @@ def _end_agent_trace_span( self.tracer.end_agent_span(**trace_attributes) + def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: + """Filter input parameters to only include those defined in the tool specification. + + Args: + tool_name: Name of the tool to get specification for + input_params: Original input parameters + + Returns: + Filtered parameters containing only those defined in tool spec + """ + all_tools_config = self.tool_registry.get_all_tools_config() + tool_spec = all_tools_config.get(tool_name) + + if not tool_spec or "inputSchema" not in tool_spec: + return input_params.copy() + + properties = tool_spec["inputSchema"]["json"]["properties"] + return {k: v for k, v in input_params.items() if k in properties} + def _append_message(self, message: Message) -> None: """Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent.""" self.messages.append(message) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index fdce7c368..07529d28f 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1726,99 +1726,7 @@ def test_agent_tool_non_serializable_parameter_filtering(agent, mock_randint): tool_call_text = user_message["content"][1]["text"] assert "agent.tool.tool_decorated direct tool call." in tool_call_text assert '"random_string": "test_value"' in tool_call_text - assert '"non_serializable_agent": "<>"' in tool_call_text - - -def test_agent_tool_multiple_non_serializable_types(agent, mock_randint): - """Test filtering of various non-serializable object types.""" - mock_randint.return_value = 123 - - # Create various non-serializable objects - class CustomClass: - def __init__(self, value): - self.value = value - - non_serializable_objects = { - "agent": Agent(), - "custom_object": CustomClass("test"), - "function": lambda x: x, - "set_object": {1, 2, 3}, - "complex_number": 3 + 4j, - "serializable_string": "this_should_remain", - "serializable_number": 42, - "serializable_list": [1, 2, 3], - "serializable_dict": {"key": "value"}, - } - - # This should not crash - result = agent.tool.tool_decorated(random_string="test_filtering", **non_serializable_objects) - - # Verify tool executed successfully - expected_result = { - "content": [{"text": "test_filtering"}], - "status": "success", - "toolUseId": "tooluse_tool_decorated_123", - } - assert result == expected_result - - # Check the recorded message for proper parameter filtering - assert len(agent.messages) > 0 - user_message = agent.messages[0] - tool_call_text = user_message["content"][0]["text"] - - # Verify serializable objects remain unchanged - assert '"serializable_string": "this_should_remain"' in tool_call_text - assert '"serializable_number": 42' in tool_call_text - assert '"serializable_list": [1, 2, 3]' in tool_call_text - assert '"serializable_dict": {"key": "value"}' in tool_call_text - - # Verify non-serializable objects are replaced with descriptive strings - assert '"agent": "<>"' in tool_call_text - assert ( - '"custom_object": "<.CustomClass>>"' - in tool_call_text - ) - assert '"function": "<>"' in tool_call_text - assert '"set_object": "<>"' in tool_call_text - assert '"complex_number": "<>"' in tool_call_text - - -def test_agent_tool_serialization_edge_cases(agent, mock_randint): - """Test edge cases in parameter serialization filtering.""" - mock_randint.return_value = 999 - - # Test with None values, empty containers, and nested structures - edge_case_params = { - "none_value": None, - "empty_list": [], - "empty_dict": {}, - "nested_list_with_non_serializable": [1, 2, Agent()], # This should be filtered out - "nested_dict_serializable": {"nested": {"key": "value"}}, # This should remain - } - - result = agent.tool.tool_decorated(random_string="edge_cases", **edge_case_params) - - # Verify successful execution - expected_result = { - "content": [{"text": "edge_cases"}], - "status": "success", - "toolUseId": "tooluse_tool_decorated_999", - } - assert result == expected_result - - # Check parameter filtering in recorded message - assert len(agent.messages) > 0 - user_message = agent.messages[0] - tool_call_text = user_message["content"][0]["text"] - - # Verify serializable values remain - assert '"none_value": null' in tool_call_text - assert '"empty_list": []' in tool_call_text - assert '"empty_dict": {}' in tool_call_text - assert '"nested_dict_serializable": {"nested": {"key": "value"}}' in tool_call_text - - # Verify non-serializable nested structure is replaced - assert '"nested_list_with_non_serializable": [1, 2, "<>"]' in tool_call_text + assert '"non_serializable_agent": "<>"' not in tool_call_text def test_agent_tool_no_non_serializable_parameters(agent, mock_randint): @@ -1870,3 +1778,36 @@ def test_agent_tool_record_direct_tool_call_disabled_with_non_serializable(agent # Verify no messages were recorded assert len(agent.messages) == 0 + + +def test_agent_tool_call_parameter_filtering_integration(mock_randint): + """Test that tool calls properly filter parameters in message recording.""" + mock_randint.return_value = 42 + + @strands.tool + def test_tool(action: str) -> str: + """Test tool with single parameter.""" + return action + + agent = Agent(tools=[test_tool]) + + # Call tool with extra non-spec parameters + result = agent.tool.test_tool( + action="test_value", + agent=agent, # Should be filtered out + extra_param="filtered", # Should be filtered out + ) + + # Verify tool executed successfully + assert result["status"] == "success" + assert result["content"] == [{"text": "test_value"}] + + # Check that only spec parameters are recorded in message history + assert len(agent.messages) > 0 + user_message = agent.messages[0] + tool_call_text = user_message["content"][0]["text"] + + # Should only contain the 'action' parameter + assert '"action": "test_value"' in tool_call_text + assert '"agent"' not in tool_call_text + assert '"extra_param"' not in tool_call_text