diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 3bde5c832..b48664b6c 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -19,11 +19,11 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Any, Callable, Tuple, cast +from typing import Any, Callable, Tuple from opentelemetry import trace as trace_api -from ..agent import Agent, AgentResult +from ..agent import Agent from ..telemetry import get_tracer from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage @@ -379,15 +379,7 @@ async def _execute_node(self, node: GraphNode) -> None: ) elif isinstance(node.executor, Agent): - agent_response: AgentResult | None = ( - None # Initialize with None to handle case where no result is yielded - ) - async for event in node.executor.stream_async(node_input): - if "result" in event: - agent_response = cast(AgentResult, event["result"]) - - if not agent_response: - raise ValueError(f"Node '{node.node_id}' did not return a result") + agent_response = await node.executor.invoke_async(node_input) # Extract metrics from agent response usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 824e08819..c4f8fcdb5 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -19,7 +19,7 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Any, Callable, Tuple, cast +from typing import Any, Callable, Tuple from ..agent import Agent, AgentResult from ..agent.state import AgentState @@ -601,12 +601,7 @@ async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) - # Execute node result = None node.reset_executor_state() - async for event in node.executor.stream_async(node_input): - if "result" in event: - result = cast(AgentResult, event["result"]) - - if not result: - raise ValueError(f"Node '{node_name}' did not return a result") + result = await node.executor.invoke_async(node_input) execution_time = round((time.time() - start_time) * 1000) diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index b3f2e7020..76aeb6c70 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -29,10 +29,10 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent.return_value = mock_result agent.__call__ = Mock(return_value=mock_result) - async def mock_stream_async(*args, **kwargs): - yield {"result": mock_result} + async def mock_invoke_async(*args, **kwargs): + return mock_result - agent.stream_async = MagicMock(side_effect=mock_stream_async) + agent.invoke_async = MagicMock(side_effect=mock_invoke_async) return agent @@ -194,14 +194,14 @@ async def test_graph_execution(mock_strands_tracer, mock_use_span, mock_graph, m assert result.execution_order[0].node_id == "start_agent" # Verify agent calls - mock_agents["start_agent"].stream_async.assert_called_once() + mock_agents["start_agent"].invoke_async.assert_called_once() mock_agents["multi_agent"].invoke_async.assert_called_once() - mock_agents["conditional_agent"].stream_async.assert_called_once() - mock_agents["final_agent"].stream_async.assert_called_once() - mock_agents["no_metrics_agent"].stream_async.assert_called_once() - mock_agents["partial_metrics_agent"].stream_async.assert_called_once() - string_content_agent.stream_async.assert_called_once() - mock_agents["blocked_agent"].stream_async.assert_not_called() + mock_agents["conditional_agent"].invoke_async.assert_called_once() + mock_agents["final_agent"].invoke_async.assert_called_once() + mock_agents["no_metrics_agent"].invoke_async.assert_called_once() + mock_agents["partial_metrics_agent"].invoke_async.assert_called_once() + string_content_agent.invoke_async.assert_called_once() + mock_agents["blocked_agent"].invoke_async.assert_not_called() # Verify metrics aggregation assert result.accumulated_usage["totalTokens"] > 0 @@ -261,12 +261,10 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span) failing_agent.id = "fail_node" failing_agent.__call__ = Mock(side_effect=Exception("Simulated failure")) - # Create a proper failing async generator for stream_async - async def mock_stream_failure(*args, **kwargs): + async def mock_invoke_failure(*args, **kwargs): raise Exception("Simulated failure") - yield # This will never be reached - failing_agent.stream_async = mock_stream_failure + failing_agent.invoke_async = mock_invoke_failure success_agent = create_mock_agent("success_agent", "Success") @@ -301,7 +299,7 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span): result = await graph.invoke_async([{"text": "Original task"}]) # Verify entry node was called with original task - entry_agent.stream_async.assert_called_once_with([{"text": "Original task"}]) + entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}]) assert result.status == Status.COMPLETED mock_strands_tracer.start_multiagent_span.assert_called() mock_use_span.assert_called_once() @@ -482,8 +480,8 @@ def test_graph_synchronous_execution(mock_strands_tracer, mock_use_span, mock_ag assert result.execution_order[1].node_id == "final_agent" # Verify agent calls - mock_agents["start_agent"].stream_async.assert_called_once() - mock_agents["final_agent"].stream_async.assert_called_once() + mock_agents["start_agent"].invoke_async.assert_called_once() + mock_agents["final_agent"].invoke_async.assert_called_once() # Verify return type is GraphResult assert isinstance(result, GraphResult) diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index ffb0343b2..69dd5273b 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -56,11 +56,10 @@ def create_mock_result(): agent.return_value = create_mock_result() agent.__call__ = Mock(side_effect=create_mock_result) - async def mock_stream_async(*args, **kwargs): - result = create_mock_result() - yield {"result": result} + async def mock_invoke_async(*args, **kwargs): + return create_mock_result() - agent.stream_async = MagicMock(side_effect=mock_stream_async) + agent.invoke_async = MagicMock(side_effect=mock_invoke_async) return agent @@ -227,7 +226,7 @@ async def test_swarm_execution_async(mock_swarm, mock_agents): assert len(result.results) == 1 # Verify agent was called - mock_agents["coordinator"].stream_async.assert_called() + mock_agents["coordinator"].invoke_async.assert_called() # Verify metrics aggregation assert result.accumulated_usage["totalTokens"] >= 0 @@ -264,7 +263,7 @@ def test_swarm_synchronous_execution(mock_agents): assert result.execution_time >= 0 # Verify agent was called - mock_agents["coordinator"].stream_async.assert_called() + mock_agents["coordinator"].invoke_async.assert_called() # Verify return type is SwarmResult assert isinstance(result, SwarmResult) @@ -350,11 +349,10 @@ def create_handoff_result(): agent.return_value = create_handoff_result() agent.__call__ = Mock(side_effect=create_handoff_result) - async def mock_stream_async(*args, **kwargs): - result = create_handoff_result() - yield {"result": result} + async def mock_invoke_async(*args, **kwargs): + return create_handoff_result() - agent.stream_async = MagicMock(side_effect=mock_stream_async) + agent.invoke_async = MagicMock(side_effect=mock_invoke_async) return agent # Create agents - first one hands off, second one completes @@ -381,8 +379,8 @@ async def mock_stream_async(*args, **kwargs): assert result.node_history[1].node_id == "completion_agent" # Verify both agents were called - handoff_agent.stream_async.assert_called() - completion_agent.stream_async.assert_called() + handoff_agent.invoke_async.assert_called() + completion_agent.invoke_async.assert_called() # Test handoff when task is already completed completed_swarm = Swarm(nodes=[handoff_agent, completion_agent])