Skip to content

Commit b4efc9d

Browse files
authored
swarm - switch to handoff node only after current node stops (#1147)
1 parent 77cb23f commit b4efc9d

File tree

2 files changed

+51
-27
lines changed

2 files changed

+51
-27
lines changed

src/strands/multiagent/swarm.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ class SwarmState:
156156
# Total metrics across all agents
157157
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
158158
execution_time: int = 0 # Total execution time in milliseconds
159+
handoff_node: SwarmNode | None = None # The agent to execute next
159160
handoff_message: str | None = None # Message passed during agent handoff
160161

161162
def should_continue(
@@ -537,7 +538,7 @@ def handoff_to_agent(agent_name: str, message: str, context: dict[str, Any] | No
537538
# Execute handoff
538539
swarm_ref._handle_handoff(target_node, message, context)
539540

540-
return {"status": "success", "content": [{"text": f"Handed off to {agent_name}: {message}"}]}
541+
return {"status": "success", "content": [{"text": f"Handing off to {agent_name}: {message}"}]}
541542
except Exception as e:
542543
return {"status": "error", "content": [{"text": f"Error in handoff: {str(e)}"}]}
543544

@@ -553,21 +554,19 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st
553554
)
554555
return
555556

556-
# Update swarm state
557-
previous_agent = cast(SwarmNode, self.state.current_node)
558-
self.state.current_node = target_node
557+
current_node = cast(SwarmNode, self.state.current_node)
559558

560-
# Store handoff message for the target agent
559+
self.state.handoff_node = target_node
561560
self.state.handoff_message = message
562561

563562
# Store handoff context as shared context
564563
if context:
565564
for key, value in context.items():
566-
self.shared_context.add_context(previous_agent, key, value)
565+
self.shared_context.add_context(current_node, key, value)
567566

568567
logger.debug(
569-
"from_node=<%s>, to_node=<%s> | handed off from agent to agent",
570-
previous_agent.node_id,
568+
"from_node=<%s>, to_node=<%s> | handing off from agent to agent",
569+
current_node.node_id,
571570
target_node.node_id,
572571
)
573572

@@ -667,7 +666,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
667666
logger.debug("reason=<%s> | stopping execution", reason)
668667
break
669668

670-
# Get current node
671669
current_node = self.state.current_node
672670
if not current_node or current_node.node_id not in self.nodes:
673671
logger.error("node=<%s> | node not found", current_node.node_id if current_node else "None")
@@ -680,13 +678,8 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
680678
len(self.state.node_history) + 1,
681679
)
682680

683-
# Store the current node before execution to detect handoffs
684-
previous_node = current_node
685-
686-
# Execute node with timeout protection
687681
# TODO: Implement cancellation token to stop _execute_node from continuing
688682
try:
689-
# Execute with timeout wrapper for async generator streaming
690683
await self.hooks.invoke_callbacks_async(
691684
BeforeNodeCallEvent(self, current_node.node_id, invocation_state)
692685
)
@@ -699,30 +692,33 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
699692
yield event
700693

701694
self.state.node_history.append(current_node)
702-
703-
# After self.state add current node, swarm state finish updating, we persist here
704695
await self.hooks.invoke_callbacks_async(
705696
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
706697
)
707698

708699
logger.debug("node=<%s> | node execution completed", current_node.node_id)
709700

710-
# Check if handoff occurred during execution
711-
if self.state.current_node is not None and self.state.current_node != previous_node:
712-
# Emit handoff event (single node transition in Swarm)
701+
# Check if handoff requested during execution
702+
if self.state.handoff_node:
703+
previous_node = current_node
704+
current_node = self.state.handoff_node
705+
706+
self.state.handoff_node = None
707+
self.state.current_node = current_node
708+
713709
handoff_event = MultiAgentHandoffEvent(
714710
from_node_ids=[previous_node.node_id],
715-
to_node_ids=[self.state.current_node.node_id],
711+
to_node_ids=[current_node.node_id],
716712
message=self.state.handoff_message or "Agent handoff occurred",
717713
)
718714
yield handoff_event
719715
logger.debug(
720716
"from_node=<%s>, to_node=<%s> | handoff detected",
721717
previous_node.node_id,
722-
self.state.current_node.node_id,
718+
current_node.node_id,
723719
)
720+
724721
else:
725-
# No handoff occurred, mark swarm as complete
726722
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
727723
self.state.completion_status = Status.COMPLETED
728724
break
@@ -866,11 +862,12 @@ def _build_result(self) -> SwarmResult:
866862
def serialize_state(self) -> dict[str, Any]:
867863
"""Serialize the current swarm state to a dictionary."""
868864
status_str = self.state.completion_status.value
869-
next_nodes = (
870-
[self.state.current_node.node_id]
871-
if self.state.completion_status == Status.EXECUTING and self.state.current_node
872-
else []
873-
)
865+
if self.state.handoff_node:
866+
next_nodes = [self.state.handoff_node.node_id]
867+
elif self.state.completion_status == Status.EXECUTING and self.state.current_node:
868+
next_nodes = [self.state.current_node.node_id]
869+
else:
870+
next_nodes = []
874871

875872
return {
876873
"type": "swarm",

tests/strands/multiagent/test_swarm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,3 +1149,30 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span):
11491149
assert final_state["status"] == "completed"
11501150
assert len(final_state["node_history"]) == 1
11511151
assert "test_agent" in final_state["node_results"]
1152+
1153+
1154+
@pytest.mark.asyncio
1155+
async def test_swarm_handle_handoff():
1156+
first_agent = create_mock_agent("first")
1157+
second_agent = create_mock_agent("second")
1158+
1159+
swarm = Swarm([first_agent, second_agent])
1160+
1161+
async def handoff_stream(*args, **kwargs):
1162+
yield {"agent_start": True}
1163+
1164+
swarm._handle_handoff(swarm.nodes["second"], "test message", {})
1165+
1166+
assert swarm.state.current_node.node_id == "first"
1167+
assert swarm.state.handoff_node.node_id == "second"
1168+
1169+
yield {"result": first_agent.return_value}
1170+
1171+
first_agent.stream_async = Mock(side_effect=handoff_stream)
1172+
1173+
result = await swarm.invoke_async("test")
1174+
assert result.status == Status.COMPLETED
1175+
1176+
tru_node_order = [node.node_id for node in result.node_history]
1177+
exp_node_order = ["first", "second"]
1178+
assert tru_node_order == exp_node_order

0 commit comments

Comments
 (0)