@@ -156,6 +156,7 @@ class SwarmState:
156156 # Total metrics across all agents
157157 accumulated_metrics : Metrics = field (default_factory = lambda : Metrics (latencyMs = 0 ))
158158 execution_time : int = 0 # Total execution time in milliseconds
159+ handoff_node : SwarmNode | None = None # The agent to execute next
159160 handoff_message : str | None = None # Message passed during agent handoff
160161
161162 def should_continue (
@@ -537,7 +538,7 @@ def handoff_to_agent(agent_name: str, message: str, context: dict[str, Any] | No
537538 # Execute handoff
538539 swarm_ref ._handle_handoff (target_node , message , context )
539540
540- return {"status" : "success" , "content" : [{"text" : f"Handed off to { agent_name } : { message } " }]}
541+ return {"status" : "success" , "content" : [{"text" : f"Handing off to { agent_name } : { message } " }]}
541542 except Exception as e :
542543 return {"status" : "error" , "content" : [{"text" : f"Error in handoff: { str (e )} " }]}
543544
@@ -553,21 +554,19 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st
553554 )
554555 return
555556
556- # Update swarm state
557- previous_agent = cast (SwarmNode , self .state .current_node )
558- self .state .current_node = target_node
557+ current_node = cast (SwarmNode , self .state .current_node )
559558
560- # Store handoff message for the target agent
559+ self . state . handoff_node = target_node
561560 self .state .handoff_message = message
562561
563562 # Store handoff context as shared context
564563 if context :
565564 for key , value in context .items ():
566- self .shared_context .add_context (previous_agent , key , value )
565+ self .shared_context .add_context (current_node , key , value )
567566
568567 logger .debug (
569- "from_node=<%s>, to_node=<%s> | handed off from agent to agent" ,
570- previous_agent .node_id ,
568+ "from_node=<%s>, to_node=<%s> | handing off from agent to agent" ,
569+ current_node .node_id ,
571570 target_node .node_id ,
572571 )
573572
@@ -667,7 +666,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
667666 logger .debug ("reason=<%s> | stopping execution" , reason )
668667 break
669668
670- # Get current node
671669 current_node = self .state .current_node
672670 if not current_node or current_node .node_id not in self .nodes :
673671 logger .error ("node=<%s> | node not found" , current_node .node_id if current_node else "None" )
@@ -680,13 +678,8 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
680678 len (self .state .node_history ) + 1 ,
681679 )
682680
683- # Store the current node before execution to detect handoffs
684- previous_node = current_node
685-
686- # Execute node with timeout protection
687681 # TODO: Implement cancellation token to stop _execute_node from continuing
688682 try :
689- # Execute with timeout wrapper for async generator streaming
690683 await self .hooks .invoke_callbacks_async (
691684 BeforeNodeCallEvent (self , current_node .node_id , invocation_state )
692685 )
@@ -699,30 +692,33 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
699692 yield event
700693
701694 self .state .node_history .append (current_node )
702-
703- # After self.state add current node, swarm state finish updating, we persist here
704695 await self .hooks .invoke_callbacks_async (
705696 AfterNodeCallEvent (self , current_node .node_id , invocation_state )
706697 )
707698
708699 logger .debug ("node=<%s> | node execution completed" , current_node .node_id )
709700
710- # Check if handoff occurred during execution
711- if self .state .current_node is not None and self .state .current_node != previous_node :
712- # Emit handoff event (single node transition in Swarm)
701+ # Check if handoff requested during execution
702+ if self .state .handoff_node :
703+ previous_node = current_node
704+ current_node = self .state .handoff_node
705+
706+ self .state .handoff_node = None
707+ self .state .current_node = current_node
708+
713709 handoff_event = MultiAgentHandoffEvent (
714710 from_node_ids = [previous_node .node_id ],
715- to_node_ids = [self . state . current_node .node_id ],
711+ to_node_ids = [current_node .node_id ],
716712 message = self .state .handoff_message or "Agent handoff occurred" ,
717713 )
718714 yield handoff_event
719715 logger .debug (
720716 "from_node=<%s>, to_node=<%s> | handoff detected" ,
721717 previous_node .node_id ,
722- self . state . current_node .node_id ,
718+ current_node .node_id ,
723719 )
720+
724721 else :
725- # No handoff occurred, mark swarm as complete
726722 logger .debug ("node=<%s> | no handoff occurred, marking swarm as complete" , current_node .node_id )
727723 self .state .completion_status = Status .COMPLETED
728724 break
@@ -866,11 +862,12 @@ def _build_result(self) -> SwarmResult:
866862 def serialize_state (self ) -> dict [str , Any ]:
867863 """Serialize the current swarm state to a dictionary."""
868864 status_str = self .state .completion_status .value
869- next_nodes = (
870- [self .state .current_node .node_id ]
871- if self .state .completion_status == Status .EXECUTING and self .state .current_node
872- else []
873- )
865+ if self .state .handoff_node :
866+ next_nodes = [self .state .handoff_node .node_id ]
867+ elif self .state .completion_status == Status .EXECUTING and self .state .current_node :
868+ next_nodes = [self .state .current_node .node_id ]
869+ else :
870+ next_nodes = []
874871
875872 return {
876873 "type" : "swarm" ,
0 commit comments