1111"""
1212
1313import threading
14- from typing import Optional
14+ from typing import TYPE_CHECKING
1515
1616from ...hooks .registry import HookProvider , HookRegistry
17- from ...multiagent .base import MultiAgentBase
1817from ...session import SessionManager
1918from .multiagent_events import (
20- AfterGraphInvocationEvent ,
19+ AfterMultiAgentInvocationEvent ,
2120 AfterNodeInvocationEvent ,
22- BeforeGraphInvocationEvent ,
21+ BeforeMultiAgentInvocationEvent ,
2322 BeforeNodeInvocationEvent ,
2423 MultiAgentInitializationEvent ,
25- MultiAgentState ,
2624)
27- from .multiagent_state_adapter import MultiAgentAdapter
2825
26+ if TYPE_CHECKING :
27+ from ...multiagent .base import MultiAgentBase
2928
30- def _get_multiagent_state (
31- multiagent_state : Optional [MultiAgentState ],
32- orchestrator : MultiAgentBase ,
33- ) -> MultiAgentState :
34- if multiagent_state is not None :
35- return multiagent_state
3629
37- return MultiAgentAdapter .create_multi_agent_state (orchestrator = orchestrator )
38-
39-
40- class MultiAgentHook (HookProvider ):
30+ class PersistentHook (HookProvider ):
4131 """Hook provider for automatic multi-agent session persistence.
4232
4333 This hook automatically persists multi-agent orchestrator state at key
4434 execution points to enable resumable execution after interruptions.
4535
46- Args:
47- session_manager: SessionManager instance for state persistence
48- session_id: Unique identifier for the session
4936 """
5037
51- def __init__ (self , session_manager : SessionManager , session_id : str ):
38+ def __init__ (self , session_manager : SessionManager ):
5239 """Initialize the multi-agent persistence hook.
5340
5441 Args:
5542 session_manager: SessionManager instance for state persistence
56- session_id: Unique identifier for the session
5743 """
5844 self ._session_manager = session_manager
59- self ._session_id = session_id
6045 self ._lock = threading .RLock ()
6146
6247 def register_hooks (self , registry : HookRegistry , ** kwargs : object ) -> None :
@@ -67,40 +52,40 @@ def register_hooks(self, registry: HookRegistry, **kwargs: object) -> None:
6752 **kwargs: Additional keyword arguments (unused)
6853 """
6954 registry .add_callback (MultiAgentInitializationEvent , self ._on_initialization )
70- registry .add_callback (BeforeGraphInvocationEvent , self ._on_before_graph )
55+ registry .add_callback (BeforeMultiAgentInvocationEvent , self ._on_before_multiagent )
7156 registry .add_callback (BeforeNodeInvocationEvent , self ._on_before_node )
7257 registry .add_callback (AfterNodeInvocationEvent , self ._on_after_node )
73- registry .add_callback (AfterGraphInvocationEvent , self ._on_after_graph )
58+ registry .add_callback (AfterMultiAgentInvocationEvent , self ._on_after_multiagent )
7459
75- def _on_initialization (self , event : MultiAgentInitializationEvent ):
60+ # TODO: We can add **kwarg or invocation_state later if we need to persist
61+ def _on_initialization (self , event : MultiAgentInitializationEvent ) -> None :
7662 """Persist state when multi-agent orchestrator initializes."""
77- self ._persist (_get_multiagent_state ( event .state , event . orchestrator ) )
63+ self ._persist (event .orchestrator )
7864
79- def _on_before_graph (self , event : BeforeGraphInvocationEvent ) :
80- """Hook called before graph execution starts ."""
65+ def _on_before_multiagent (self , event : BeforeMultiAgentInvocationEvent ) -> None :
66+ """Persist state when multi-agent orchestrator initializes ."""
8167 pass
8268
83- def _on_before_node (self , event : BeforeNodeInvocationEvent ):
69+ def _on_before_node (self , event : BeforeNodeInvocationEvent ) -> None :
8470 """Hook called before individual node execution."""
8571 pass
8672
87- def _on_after_node (self , event : AfterNodeInvocationEvent ):
73+ def _on_after_node (self , event : AfterNodeInvocationEvent ) -> None :
8874 """Persist state after each node completes execution."""
89- multi_agent_state = _get_multiagent_state (multiagent_state = event .state , orchestrator = event .orchestrator )
90- self ._persist (multi_agent_state )
75+ self ._persist (event .orchestrator )
9176
92- def _on_after_graph (self , event : AfterGraphInvocationEvent ) :
77+ def _on_after_multiagent (self , event : AfterMultiAgentInvocationEvent ) -> None :
9378 """Persist final state after graph execution completes."""
94- multiagent_state = _get_multiagent_state (multiagent_state = event .state , orchestrator = event .orchestrator )
95- self ._persist (multiagent_state )
79+ self ._persist (event .orchestrator )
9680
97- def _persist (self , multiagent_state : MultiAgentState ) -> None :
81+ def _persist (self , orchestrator : "MultiAgentBase" ) -> None :
9882 """Persist the provided MultiAgentState using the configured SessionManager.
9983
10084 This method is synchronized across threads/tasks to avoid write races.
10185
10286 Args:
103- multiagent_state : State to persist
87+ orchestrator : State to persist
10488 """
89+ current_state = orchestrator .get_state_from_orchestrator ()
10590 with self ._lock :
106- self ._session_manager .write_multi_agent_state ( multiagent_state )
91+ self ._session_manager .write_multi_agent_json ( current_state )
0 commit comments