Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any, Callable, Tuple, cast
from typing import Any, Callable, Tuple

from opentelemetry import trace as trace_api

from ..agent import Agent, AgentResult
from ..agent import Agent
from ..telemetry import get_tracer
from ..types.content import ContentBlock
from ..types.event_loop import Metrics, Usage
Expand Down Expand Up @@ -379,15 +379,7 @@ async def _execute_node(self, node: GraphNode) -> None:
)

elif isinstance(node.executor, Agent):
agent_response: AgentResult | None = (
None # Initialize with None to handle case where no result is yielded
)
async for event in node.executor.stream_async(node_input):
if "result" in event:
agent_response = cast(AgentResult, event["result"])

if not agent_response:
raise ValueError(f"Node '{node.node_id}' did not return a result")
agent_response = await node.executor.invoke_async(node_input)

# Extract metrics from agent response
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
Expand Down
9 changes: 2 additions & 7 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any, Callable, Tuple, cast
from typing import Any, Callable, Tuple

from ..agent import Agent, AgentResult
from ..agent.state import AgentState
Expand Down Expand Up @@ -601,12 +601,7 @@ async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -
# Execute node
result = None
node.reset_executor_state()
async for event in node.executor.stream_async(node_input):
if "result" in event:
result = cast(AgentResult, event["result"])

if not result:
raise ValueError(f"Node '{node_name}' did not return a result")
result = await node.executor.invoke_async(node_input)

execution_time = round((time.time() - start_time) * 1000)

Expand Down
32 changes: 15 additions & 17 deletions tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen
agent.return_value = mock_result
agent.__call__ = Mock(return_value=mock_result)

async def mock_stream_async(*args, **kwargs):
yield {"result": mock_result}
async def mock_invoke_async(*args, **kwargs):
return mock_result

agent.stream_async = MagicMock(side_effect=mock_stream_async)
agent.invoke_async = MagicMock(side_effect=mock_invoke_async)

return agent

Expand Down Expand Up @@ -194,14 +194,14 @@ async def test_graph_execution(mock_strands_tracer, mock_use_span, mock_graph, m
assert result.execution_order[0].node_id == "start_agent"

# Verify agent calls
mock_agents["start_agent"].stream_async.assert_called_once()
mock_agents["start_agent"].invoke_async.assert_called_once()
mock_agents["multi_agent"].invoke_async.assert_called_once()
mock_agents["conditional_agent"].stream_async.assert_called_once()
mock_agents["final_agent"].stream_async.assert_called_once()
mock_agents["no_metrics_agent"].stream_async.assert_called_once()
mock_agents["partial_metrics_agent"].stream_async.assert_called_once()
string_content_agent.stream_async.assert_called_once()
mock_agents["blocked_agent"].stream_async.assert_not_called()
mock_agents["conditional_agent"].invoke_async.assert_called_once()
mock_agents["final_agent"].invoke_async.assert_called_once()
mock_agents["no_metrics_agent"].invoke_async.assert_called_once()
mock_agents["partial_metrics_agent"].invoke_async.assert_called_once()
string_content_agent.invoke_async.assert_called_once()
mock_agents["blocked_agent"].invoke_async.assert_not_called()

# Verify metrics aggregation
assert result.accumulated_usage["totalTokens"] > 0
Expand Down Expand Up @@ -261,12 +261,10 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span)
failing_agent.id = "fail_node"
failing_agent.__call__ = Mock(side_effect=Exception("Simulated failure"))

# Create a proper failing async generator for stream_async
async def mock_stream_failure(*args, **kwargs):
async def mock_invoke_failure(*args, **kwargs):
raise Exception("Simulated failure")
yield # This will never be reached

failing_agent.stream_async = mock_stream_failure
failing_agent.invoke_async = mock_invoke_failure

success_agent = create_mock_agent("success_agent", "Success")

Expand Down Expand Up @@ -301,7 +299,7 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span):
result = await graph.invoke_async([{"text": "Original task"}])

# Verify entry node was called with original task
entry_agent.stream_async.assert_called_once_with([{"text": "Original task"}])
entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}])
assert result.status == Status.COMPLETED
mock_strands_tracer.start_multiagent_span.assert_called()
mock_use_span.assert_called_once()
Expand Down Expand Up @@ -482,8 +480,8 @@ def test_graph_synchronous_execution(mock_strands_tracer, mock_use_span, mock_ag
assert result.execution_order[1].node_id == "final_agent"

# Verify agent calls
mock_agents["start_agent"].stream_async.assert_called_once()
mock_agents["final_agent"].stream_async.assert_called_once()
mock_agents["start_agent"].invoke_async.assert_called_once()
mock_agents["final_agent"].invoke_async.assert_called_once()

# Verify return type is GraphResult
assert isinstance(result, GraphResult)
Expand Down
22 changes: 10 additions & 12 deletions tests/strands/multiagent/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,10 @@ def create_mock_result():
agent.return_value = create_mock_result()
agent.__call__ = Mock(side_effect=create_mock_result)

async def mock_stream_async(*args, **kwargs):
result = create_mock_result()
yield {"result": result}
async def mock_invoke_async(*args, **kwargs):
return create_mock_result()

agent.stream_async = MagicMock(side_effect=mock_stream_async)
agent.invoke_async = MagicMock(side_effect=mock_invoke_async)

return agent

Expand Down Expand Up @@ -227,7 +226,7 @@ async def test_swarm_execution_async(mock_swarm, mock_agents):
assert len(result.results) == 1

# Verify agent was called
mock_agents["coordinator"].stream_async.assert_called()
mock_agents["coordinator"].invoke_async.assert_called()

# Verify metrics aggregation
assert result.accumulated_usage["totalTokens"] >= 0
Expand Down Expand Up @@ -264,7 +263,7 @@ def test_swarm_synchronous_execution(mock_agents):
assert result.execution_time >= 0

# Verify agent was called
mock_agents["coordinator"].stream_async.assert_called()
mock_agents["coordinator"].invoke_async.assert_called()

# Verify return type is SwarmResult
assert isinstance(result, SwarmResult)
Expand Down Expand Up @@ -350,11 +349,10 @@ def create_handoff_result():
agent.return_value = create_handoff_result()
agent.__call__ = Mock(side_effect=create_handoff_result)

async def mock_stream_async(*args, **kwargs):
result = create_handoff_result()
yield {"result": result}
async def mock_invoke_async(*args, **kwargs):
return create_handoff_result()

agent.stream_async = MagicMock(side_effect=mock_stream_async)
agent.invoke_async = MagicMock(side_effect=mock_invoke_async)
return agent

# Create agents - first one hands off, second one completes
Expand All @@ -381,8 +379,8 @@ async def mock_stream_async(*args, **kwargs):
assert result.node_history[1].node_id == "completion_agent"

# Verify both agents were called
handoff_agent.stream_async.assert_called()
completion_agent.stream_async.assert_called()
handoff_agent.invoke_async.assert_called()
completion_agent.invoke_async.assert_called()

# Test handoff when task is already completed
completed_swarm = Swarm(nodes=[handoff_agent, completion_agent])
Expand Down
Loading