Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,11 @@ def _record_tool_execution(
tool_result: The result returned by the tool.
user_message_override: Optional custom message to include.
"""
# Filter tool input parameters to only include those defined in tool spec
filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"])

# Create user message describing the tool call
input_parameters = json.dumps(tool["input"], default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")
input_parameters = json.dumps(filtered_input, default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")

user_msg_content: list[ContentBlock] = [
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")}
Expand All @@ -649,14 +652,21 @@ def _record_tool_execution(
if user_message_override:
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})

# Create filtered tool use for message history
filtered_tool: ToolUse = {
"toolUseId": tool["toolUseId"],
"name": tool["name"],
"input": filtered_input,
}

# Create the message sequence
user_msg: Message = {
"role": "user",
"content": user_msg_content,
}
tool_use_msg: Message = {
"role": "assistant",
"content": [{"toolUse": tool}],
"content": [{"toolUse": filtered_tool}],
}
tool_result_msg: Message = {
"role": "user",
Expand Down Expand Up @@ -713,6 +723,25 @@ def _end_agent_trace_span(

self.tracer.end_agent_span(**trace_attributes)

def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]:
"""Filter input parameters to only include those defined in the tool specification.

Args:
tool_name: Name of the tool to get specification for
input_params: Original input parameters

Returns:
Filtered parameters containing only those defined in tool spec
"""
all_tools_config = self.tool_registry.get_all_tools_config()
tool_spec = all_tools_config.get(tool_name)

if not tool_spec or "inputSchema" not in tool_spec:
return input_params.copy()

properties = tool_spec["inputSchema"]["json"]["properties"]
return {k: v for k, v in input_params.items() if k in properties}

def _append_message(self, message: Message) -> None:
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
self.messages.append(message)
Expand Down
127 changes: 34 additions & 93 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,99 +1726,7 @@ def test_agent_tool_non_serializable_parameter_filtering(agent, mock_randint):
tool_call_text = user_message["content"][1]["text"]
assert "agent.tool.tool_decorated direct tool call." in tool_call_text
assert '"random_string": "test_value"' in tool_call_text
assert '"non_serializable_agent": "<<non-serializable: Agent>>"' in tool_call_text


def test_agent_tool_multiple_non_serializable_types(agent, mock_randint):
"""Test filtering of various non-serializable object types."""
mock_randint.return_value = 123

# Create various non-serializable objects
class CustomClass:
def __init__(self, value):
self.value = value

non_serializable_objects = {
"agent": Agent(),
"custom_object": CustomClass("test"),
"function": lambda x: x,
"set_object": {1, 2, 3},
"complex_number": 3 + 4j,
"serializable_string": "this_should_remain",
"serializable_number": 42,
"serializable_list": [1, 2, 3],
"serializable_dict": {"key": "value"},
}

# This should not crash
result = agent.tool.tool_decorated(random_string="test_filtering", **non_serializable_objects)

# Verify tool executed successfully
expected_result = {
"content": [{"text": "test_filtering"}],
"status": "success",
"toolUseId": "tooluse_tool_decorated_123",
}
assert result == expected_result

# Check the recorded message for proper parameter filtering
assert len(agent.messages) > 0
user_message = agent.messages[0]
tool_call_text = user_message["content"][0]["text"]

# Verify serializable objects remain unchanged
assert '"serializable_string": "this_should_remain"' in tool_call_text
assert '"serializable_number": 42' in tool_call_text
assert '"serializable_list": [1, 2, 3]' in tool_call_text
assert '"serializable_dict": {"key": "value"}' in tool_call_text

# Verify non-serializable objects are replaced with descriptive strings
assert '"agent": "<<non-serializable: Agent>>"' in tool_call_text
assert (
'"custom_object": "<<non-serializable: test_agent_tool_multiple_non_serializable_types.<locals>.CustomClass>>"'
in tool_call_text
)
assert '"function": "<<non-serializable: function>>"' in tool_call_text
assert '"set_object": "<<non-serializable: set>>"' in tool_call_text
assert '"complex_number": "<<non-serializable: complex>>"' in tool_call_text


def test_agent_tool_serialization_edge_cases(agent, mock_randint):
"""Test edge cases in parameter serialization filtering."""
mock_randint.return_value = 999

# Test with None values, empty containers, and nested structures
edge_case_params = {
"none_value": None,
"empty_list": [],
"empty_dict": {},
"nested_list_with_non_serializable": [1, 2, Agent()], # This should be filtered out
"nested_dict_serializable": {"nested": {"key": "value"}}, # This should remain
}

result = agent.tool.tool_decorated(random_string="edge_cases", **edge_case_params)

# Verify successful execution
expected_result = {
"content": [{"text": "edge_cases"}],
"status": "success",
"toolUseId": "tooluse_tool_decorated_999",
}
assert result == expected_result

# Check parameter filtering in recorded message
assert len(agent.messages) > 0
user_message = agent.messages[0]
tool_call_text = user_message["content"][0]["text"]

# Verify serializable values remain
assert '"none_value": null' in tool_call_text
assert '"empty_list": []' in tool_call_text
assert '"empty_dict": {}' in tool_call_text
assert '"nested_dict_serializable": {"nested": {"key": "value"}}' in tool_call_text

# Verify non-serializable nested structure is replaced
assert '"nested_list_with_non_serializable": [1, 2, "<<non-serializable: Agent>>"]' in tool_call_text
assert '"non_serializable_agent": "<<non-serializable: Agent>>"' not in tool_call_text


def test_agent_tool_no_non_serializable_parameters(agent, mock_randint):
Expand Down Expand Up @@ -1870,3 +1778,36 @@ def test_agent_tool_record_direct_tool_call_disabled_with_non_serializable(agent

# Verify no messages were recorded
assert len(agent.messages) == 0


def test_agent_tool_call_parameter_filtering_integration(mock_randint):
"""Test that tool calls properly filter parameters in message recording."""
mock_randint.return_value = 42

@strands.tool
def test_tool(action: str) -> str:
"""Test tool with single parameter."""
return action

agent = Agent(tools=[test_tool])

# Call tool with extra non-spec parameters
result = agent.tool.test_tool(
action="test_value",
agent=agent, # Should be filtered out
extra_param="filtered", # Should be filtered out
)

# Verify tool executed successfully
assert result["status"] == "success"
assert result["content"] == [{"text": "test_value"}]

# Check that only spec parameters are recorded in message history
assert len(agent.messages) > 0
user_message = agent.messages[0]
tool_call_text = user_message["content"][0]["text"]

# Should only contain the 'action' parameter
assert '"action": "test_value"' in tool_call_text
assert '"agent"' not in tool_call_text
assert '"extra_param"' not in tool_call_text
Loading