11from strands .session .session_repository import SessionRepository
22from strands .types .exceptions import SessionException
3+ from strands .types .session import SessionAgent , SessionMessage
34
45
56class MockedSessionRepository (SessionRepository ):
@@ -11,21 +12,20 @@ def __init__(self):
1112 self .agents = {}
1213 self .messages = {}
1314
14- def create_session (self , session ):
15+ def create_session (self , session ) -> None :
1516 """Create a session."""
1617 session_id = session .session_id
1718 if session_id in self .sessions :
1819 raise SessionException (f"Session { session_id } already exists" )
1920 self .sessions [session_id ] = session
2021 self .agents [session_id ] = {}
2122 self .messages [session_id ] = {}
22- return session
2323
24- def read_session (self , session_id ):
24+ def read_session (self , session_id ) -> SessionAgent :
2525 """Read a session."""
2626 return self .sessions .get (session_id )
2727
28- def create_agent (self , session_id , session_agent ):
28+ def create_agent (self , session_id , session_agent ) -> None :
2929 """Create an agent."""
3030 agent_id = session_agent .agent_id
3131 if session_id not in self .sessions :
@@ -36,13 +36,13 @@ def create_agent(self, session_id, session_agent):
3636 self .messages .setdefault (session_id , {}).setdefault (agent_id , {})
3737 return session_agent
3838
39- def read_agent (self , session_id , agent_id ):
39+ def read_agent (self , session_id , agent_id ) -> SessionAgent :
4040 """Read an agent."""
4141 if session_id not in self .sessions :
4242 return None
4343 return self .agents .get (session_id , {}).get (agent_id )
4444
45- def update_agent (self , session_id , session_agent ):
45+ def update_agent (self , session_id , session_agent ) -> None :
4646 """Update an agent."""
4747 agent_id = session_agent .agent_id
4848 if session_id not in self .sessions :
@@ -51,7 +51,7 @@ def update_agent(self, session_id, session_agent):
5151 raise SessionException (f"Agent { agent_id } does not exist in session { session_id } " )
5252 self .agents [session_id ][agent_id ] = session_agent
5353
54- def create_message (self , session_id , agent_id , session_message ):
54+ def create_message (self , session_id , agent_id , session_message ) -> None :
5555 """Create a message."""
5656 message_id = session_message .message_id
5757 if session_id not in self .sessions :
@@ -62,15 +62,15 @@ def create_message(self, session_id, agent_id, session_message):
6262 raise SessionException (f"Message { message_id } already exists in agent { agent_id } in session { session_id } " )
6363 self .messages .setdefault (session_id , {}).setdefault (agent_id , {})[message_id ] = session_message
6464
65- def read_message (self , session_id , agent_id , message_id ):
65+ def read_message (self , session_id , agent_id , message_id ) -> SessionMessage :
6666 """Read a message."""
6767 if session_id not in self .sessions :
6868 return None
6969 if agent_id not in self .agents .get (session_id , {}):
7070 return None
7171 return self .messages .get (session_id , {}).get (agent_id , {}).get (message_id )
7272
73- def update_message (self , session_id , agent_id , session_message ):
73+ def update_message (self , session_id , agent_id , session_message ) -> None :
7474 """Update a message."""
7575
7676 message_id = session_message .message_id
@@ -82,7 +82,7 @@ def update_message(self, session_id, agent_id, session_message):
8282 raise SessionException (f"Message { message_id } does not exist in session { session_id } " )
8383 self .messages [session_id ][agent_id ][message_id ] = session_message
8484
85- def list_messages (self , session_id , agent_id , limit = None , offset = 0 ):
85+ def list_messages (self , session_id , agent_id , limit = None , offset = 0 ) -> list [ SessionMessage ] :
8686 """List messages."""
8787 if session_id not in self .sessions :
8888 return []
0 commit comments