1111"""
1212
1313import threading
14- from typing import Optional
1514
1615from ...hooks .registry import HookProvider , HookRegistry
1716from ...multiagent .base import MultiAgentBase
1817from ...session import SessionManager
1918from .multiagent_events import (
20- AfterGraphInvocationEvent ,
19+ AfterMultiAgentInvocationEvent ,
2120 AfterNodeInvocationEvent ,
22- BeforeGraphInvocationEvent ,
21+ BeforeMultiAgentInvocationEvent ,
2322 BeforeNodeInvocationEvent ,
2423 MultiAgentInitializationEvent ,
25- MultiAgentState ,
2624)
27- from .multiagent_state_adapter import MultiAgentAdapter
2825
2926
30- def _get_multiagent_state (
31- multiagent_state : Optional [MultiAgentState ],
32- orchestrator : MultiAgentBase ,
33- ) -> MultiAgentState :
34- if multiagent_state is not None :
35- return multiagent_state
36-
37- return MultiAgentAdapter .create_multi_agent_state (orchestrator = orchestrator )
38-
39-
40- class MultiAgentHook (HookProvider ):
27+ class PersistentHook (HookProvider ):
4128 """Hook provider for automatic multi-agent session persistence.
4229
4330 This hook automatically persists multi-agent orchestrator state at key
@@ -48,15 +35,14 @@ class MultiAgentHook(HookProvider):
4835 session_id: Unique identifier for the session
4936 """
5037
51- def __init__ (self , session_manager : SessionManager , session_id : str ):
38+ def __init__ (self , session_manager : SessionManager ):
5239 """Initialize the multi-agent persistence hook.
5340
5441 Args:
5542 session_manager: SessionManager instance for state persistence
5643 session_id: Unique identifier for the session
5744 """
5845 self ._session_manager = session_manager
59- self ._session_id = session_id
6046 self ._lock = threading .RLock ()
6147
6248 def register_hooks (self , registry : HookRegistry , ** kwargs : object ) -> None :
@@ -67,17 +53,18 @@ def register_hooks(self, registry: HookRegistry, **kwargs: object) -> None:
6753 **kwargs: Additional keyword arguments (unused)
6854 """
6955 registry .add_callback (MultiAgentInitializationEvent , self ._on_initialization )
70- registry .add_callback (BeforeGraphInvocationEvent , self ._on_before_graph )
56+ registry .add_callback (BeforeMultiAgentInvocationEvent , self ._on_initialization )
7157 registry .add_callback (BeforeNodeInvocationEvent , self ._on_before_node )
7258 registry .add_callback (AfterNodeInvocationEvent , self ._on_after_node )
73- registry .add_callback (AfterGraphInvocationEvent , self ._on_after_graph )
59+ registry .add_callback (AfterMultiAgentInvocationEvent , self ._on_after_Execution )
7460
7561 def _on_initialization (self , event : MultiAgentInitializationEvent ):
7662 """Persist state when multi-agent orchestrator initializes."""
77- self ._persist (_get_multiagent_state (event .state , event .orchestrator ))
63+ self ._persist (event .orchestrator )
64+ pass
7865
79- def _on_before_graph (self , event : BeforeGraphInvocationEvent ):
80- """Hook called before graph execution starts ."""
66+ def _on_before_Invocation (self , event : MultiAgentInitializationEvent ):
67+ """Persist state when multi-agent orchestrator initializes ."""
8168 pass
8269
8370 def _on_before_node (self , event : BeforeNodeInvocationEvent ):
@@ -86,21 +73,20 @@ def _on_before_node(self, event: BeforeNodeInvocationEvent):
8673
8774 def _on_after_node (self , event : AfterNodeInvocationEvent ):
8875 """Persist state after each node completes execution."""
89- multi_agent_state = _get_multiagent_state (multiagent_state = event .state , orchestrator = event .orchestrator )
90- self ._persist (multi_agent_state )
76+ self ._persist (event .orchestrator )
9177
92- def _on_after_graph (self , event : AfterGraphInvocationEvent ):
78+ def _on_after_Execution (self , event : AfterMultiAgentInvocationEvent ):
9379 """Persist final state after graph execution completes."""
94- multiagent_state = _get_multiagent_state (multiagent_state = event .state , orchestrator = event .orchestrator )
95- self ._persist (multiagent_state )
80+ self ._persist (event .orchestrator )
9681
97- def _persist (self , multiagent_state : MultiAgentState ) -> None :
82+ def _persist (self , orchestrator : MultiAgentBase ) -> None :
9883 """Persist the provided MultiAgentState using the configured SessionManager.
9984
10085 This method is synchronized across threads/tasks to avoid write races.
10186
10287 Args:
103- multiagent_state : State to persist
88+ orchestrator : State to persist
10489 """
90+ current_state = orchestrator .get_state_from_orchestrator ()
10591 with self ._lock :
106- self ._session_manager .write_multi_agent_state (multiagent_state )
92+ self ._session_manager .write_multi_agent_state (current_state )
0 commit comments