Skip to content

Commit 90d6ea8

Browse files
committed
feat: Session persistence
1 parent da9153a commit 90d6ea8

17 files changed

+1947
-104
lines changed

src/strands/agent/agent.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import random
1616
from concurrent.futures import ThreadPoolExecutor
1717
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
18-
from uuid import uuid4
1918

2019
from opentelemetry import trace
2120
from pydantic import BaseModel
@@ -32,6 +31,7 @@
3231
)
3332
from ..models.bedrock import BedrockModel
3433
from ..models.model import Model
34+
from ..session.session_manager import SessionManager
3535
from ..telemetry.metrics import EventLoopMetrics
3636
from ..telemetry.tracer import get_tracer
3737
from ..tools.registry import ToolRegistry
@@ -62,6 +62,7 @@ class _DefaultCallbackHandlerSentinel:
6262

6363
_DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel()
6464
_DEFAULT_AGENT_NAME = "Strands Agents"
65+
_DEFAULT_AGENT_ID = "default"
6566

6667

6768
class Agent:
@@ -207,6 +208,7 @@ def __init__(
207208
description: Optional[str] = None,
208209
state: Optional[Union[AgentState, dict]] = None,
209210
hooks: Optional[list[HookProvider]] = None,
211+
session_manager: Optional[SessionManager] = None,
210212
):
211213
"""Initialize the Agent with the specified configuration.
212214
@@ -237,22 +239,24 @@ def __init__(
237239
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
238240
Defaults to False.
239241
trace_attributes: Custom trace attributes to apply to the agent's trace span.
240-
agent_id: Optional ID for the agent, useful for multi-agent scenarios.
241-
If None, a UUID is generated.
242+
agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios.
243+
Defaults to "default".
242244
name: name of the Agent
243-
Defaults to None.
245+
Defaults to "Strands Agents".
244246
description: description of what the Agent does
245247
Defaults to None.
246248
state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict.
247249
Defaults to an empty AgentState object.
248250
hooks: hooks to be added to the agent hook registry
249251
Defaults to None.
252+
session_manager: Manager for handling agent sessions including conversation history and state.
253+
If provided, enables session-based persistence and state management.
250254
"""
251255
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
252256
self.messages = messages if messages is not None else []
253257

254258
self.system_prompt = system_prompt
255-
self.agent_id = agent_id or str(uuid4())
259+
self.agent_id = agent_id or _DEFAULT_AGENT_ID
256260
self.name = name or _DEFAULT_AGENT_NAME
257261
self.description = description
258262

@@ -312,6 +316,12 @@ def __init__(
312316
self.tool_caller = Agent.ToolCaller(self)
313317

314318
self.hooks = HookRegistry()
319+
320+
# Initialize session management functionality
321+
self.session_manager = session_manager
322+
if self.session_manager:
323+
self.hooks.add_hook(self.session_manager)
324+
315325
if hooks:
316326
for hook in hooks:
317327
self.hooks.add_hook(hook)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Agent session manager implementation."""
2+
3+
import logging
4+
5+
from ..agent.agent import _DEFAULT_AGENT_ID, Agent
6+
from ..agent.state import AgentState
7+
from ..types.content import Message
8+
from ..types.exceptions import SessionException
9+
from ..types.session import (
10+
Session,
11+
SessionAgent,
12+
SessionMessage,
13+
SessionType,
14+
)
15+
from .session_manager import SessionManager
16+
from .session_repository import SessionRepository
17+
18+
logger = logging.getLogger(__name__)
19+
20+
DEFAULT_SESSION_AGENT_ID = "default"
21+
22+
23+
class AgentSessionManager(SessionManager):
24+
"""Session manager for persisting agent's in a Session."""
25+
26+
def __init__(
27+
self,
28+
session_id: str,
29+
session_repository: SessionRepository,
30+
):
31+
"""Initialize the AgentSessionManager."""
32+
self.session_repository = session_repository
33+
self.session_id = session_id
34+
session = session_repository.read_session(session_id)
35+
# Create a session if it does not exist yet
36+
if session is None:
37+
logger.debug("session_id=<%s> | Session not found, creating new session.", self.session_id)
38+
session = Session(session_id=session_id, session_type=SessionType.AGENT)
39+
session_repository.create_session(session)
40+
41+
self.session = session
42+
self._default_agent_initialized = False
43+
44+
def append_message(self, message: Message, agent: Agent) -> None:
45+
"""Append a message to the agent's session.
46+
47+
Args:
48+
message: Message to add to the agent in the session
49+
agent: Agent to append the message to
50+
"""
51+
session_message = SessionMessage.from_message(message)
52+
if agent.agent_id is None:
53+
raise ValueError("`agent.agent_id` must be set before appending message to session.")
54+
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
55+
56+
def sync_agent(self, agent: Agent) -> None:
57+
"""Sync agent to the session.
58+
59+
Args:
60+
agent: Agent to sync to the session.
61+
"""
62+
self.session_repository.update_agent(
63+
self.session_id,
64+
SessionAgent.from_agent(agent),
65+
)
66+
67+
def initialize(self, agent: Agent) -> None:
68+
"""Initialize an agent with a session.
69+
70+
Args:
71+
agent: Agent to initialize from the session
72+
"""
73+
if agent.agent_id is _DEFAULT_AGENT_ID:
74+
if self._default_agent_initialized:
75+
raise SessionException("Set `agent_id` to support more than one agent in a session.")
76+
self._default_agent_initialized = True
77+
78+
session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id)
79+
80+
if session_agent is None:
81+
logger.debug(
82+
"agent_id=<%s> | session_id=<%s> | Creating agent.",
83+
agent.agent_id,
84+
self.session_id,
85+
)
86+
87+
session_agent = SessionAgent.from_agent(agent)
88+
self.session_repository.create_agent(self.session_id, session_agent)
89+
for message in agent.messages:
90+
session_message = SessionMessage.from_message(message)
91+
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
92+
else:
93+
logger.debug(
94+
"agent_id=<%s> | session_id=<%s> | Restoring agent.",
95+
agent.agent_id,
96+
self.session_id,
97+
)
98+
agent.messages = [
99+
session_message.to_message()
100+
for session_message in self.session_repository.list_messages(self.session_id, agent.agent_id)
101+
]
102+
agent.state = AgentState(session_agent.state)
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
"""File-based session manager for local filesystem storage."""
2+
3+
import json
4+
import logging
5+
import os
6+
import shutil
7+
import tempfile
8+
from dataclasses import asdict
9+
from typing import Any, Optional, cast
10+
11+
from ..types.exceptions import SessionException
12+
from ..types.session import Session, SessionAgent, SessionMessage
13+
from .agent_session_manager import AgentSessionManager
14+
from .session_repository import SessionRepository
15+
16+
logger = logging.getLogger(__name__)
17+
18+
SESSION_PREFIX = "session_"
19+
AGENT_PREFIX = "agent_"
20+
MESSAGE_PREFIX = "message_"
21+
22+
23+
class FileSessionManager(AgentSessionManager, SessionRepository):
24+
"""File-based session manager for local filesystem storage."""
25+
26+
def __init__(self, session_id: str, storage_dir: Optional[str] = None):
27+
"""Initialize FileSession with filesystem storage.
28+
29+
Args:
30+
session_id: ID for the session
31+
storage_dir: Directory for local filesystem storage (defaults to temp dir)
32+
"""
33+
self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions")
34+
os.makedirs(self.storage_dir, exist_ok=True)
35+
36+
super().__init__(session_id=session_id, session_repository=self)
37+
38+
def _get_session_path(self, session_id: str) -> str:
39+
"""Get session directory path."""
40+
return os.path.join(self.storage_dir, f"{SESSION_PREFIX}{session_id}")
41+
42+
def _get_agent_path(self, session_id: str, agent_id: str) -> str:
43+
"""Get agent directory path."""
44+
session_path = self._get_session_path(session_id)
45+
return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}")
46+
47+
def _get_message_path(self, session_id: str, agent_id: str, message_id: str) -> str:
48+
"""Get message file path."""
49+
agent_path = self._get_agent_path(session_id, agent_id)
50+
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json")
51+
52+
def _read_file(self, path: str) -> dict[str, Any]:
53+
"""Read JSON file."""
54+
try:
55+
with open(path, "r", encoding="utf-8") as f:
56+
return cast(dict[str, Any], json.load(f))
57+
except json.JSONDecodeError as e:
58+
raise SessionException(f"Invalid JSON in file {path}: {str(e)}") from e
59+
60+
def _write_file(self, path: str, data: dict[str, Any]) -> None:
61+
"""Write JSON file."""
62+
os.makedirs(os.path.dirname(path), exist_ok=True)
63+
with open(path, "w", encoding="utf-8") as f:
64+
json.dump(data, f, indent=2, ensure_ascii=False)
65+
66+
def create_session(self, session: Session) -> Session:
67+
"""Create a new session."""
68+
session_dir = self._get_session_path(session.session_id)
69+
if os.path.exists(session_dir):
70+
raise SessionException(f"Session {session.session_id} already exists")
71+
72+
# Create directory structure
73+
os.makedirs(session_dir, exist_ok=True)
74+
os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True)
75+
76+
# Write session file
77+
session_file = os.path.join(session_dir, "session.json")
78+
session_dict = asdict(session)
79+
self._write_file(session_file, session_dict)
80+
81+
return session
82+
83+
def read_session(self, session_id: str) -> Optional[Session]:
84+
"""Read session data."""
85+
session_file = os.path.join(self._get_session_path(session_id), "session.json")
86+
if not os.path.exists(session_file):
87+
return None
88+
89+
session_data = self._read_file(session_file)
90+
return Session.from_dict(session_data)
91+
92+
def create_agent(self, session_id: str, session_agent: SessionAgent) -> None:
93+
"""Create a new agent in the session."""
94+
agent_id = session_agent.agent_id
95+
96+
agent_dir = self._get_agent_path(session_id, agent_id)
97+
os.makedirs(agent_dir, exist_ok=True)
98+
os.makedirs(os.path.join(agent_dir, "messages"), exist_ok=True)
99+
100+
agent_file = os.path.join(agent_dir, "agent.json")
101+
session_data = asdict(session_agent)
102+
self._write_file(agent_file, session_data)
103+
104+
def delete_session(self, session_id: str) -> None:
105+
"""Delete session and all associated data."""
106+
session_dir = self._get_session_path(session_id)
107+
if not os.path.exists(session_dir):
108+
raise SessionException(f"Session {session_id} does not exist")
109+
110+
shutil.rmtree(session_dir)
111+
112+
def read_agent(self, session_id: str, agent_id: str) -> Optional[SessionAgent]:
113+
"""Read agent data."""
114+
agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json")
115+
if not os.path.exists(agent_file):
116+
return None
117+
118+
agent_data = self._read_file(agent_file)
119+
return SessionAgent.from_dict(agent_data)
120+
121+
def update_agent(self, session_id: str, session_agent: SessionAgent) -> None:
122+
"""Update agent data."""
123+
agent_id = session_agent.agent_id
124+
previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id)
125+
if previous_agent is None:
126+
raise SessionException(f"Agent {agent_id} in session {session_id} does not exist")
127+
128+
session_agent.created_at = previous_agent.created_at
129+
agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json")
130+
self._write_file(agent_file, asdict(session_agent))
131+
132+
def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
133+
"""Create a new message for the agent."""
134+
message_file = self._get_message_path(
135+
session_id,
136+
agent_id,
137+
session_message.message_id,
138+
)
139+
session_dict = asdict(session_message)
140+
self._write_file(message_file, session_dict)
141+
142+
def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]:
143+
"""Read message data."""
144+
message_file = self._get_message_path(session_id, agent_id, message_id)
145+
if not os.path.exists(message_file):
146+
return None
147+
message_data = self._read_file(message_file)
148+
return SessionMessage.from_dict(message_data)
149+
150+
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
151+
"""Update message data."""
152+
message_id = session_message.message_id
153+
previous_message = self.read_message(session_id=session_id, agent_id=agent_id, message_id=message_id)
154+
if previous_message is None:
155+
raise SessionException(f"Message {message_id} does not exist")
156+
157+
# Preserve the original created_at timestamp
158+
session_message.created_at = previous_message.created_at
159+
message_file = self._get_message_path(session_id, agent_id, message_id)
160+
self._write_file(message_file, asdict(session_message))
161+
162+
def list_messages(
163+
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0
164+
) -> list[SessionMessage]:
165+
"""List messages for an agent with pagination."""
166+
messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages")
167+
if not os.path.exists(messages_dir):
168+
raise SessionException(f"Messages directory missing from agent: {agent_id} in session {session_id}")
169+
170+
# Read all message files
171+
messages: list[SessionMessage] = []
172+
for filename in os.listdir(messages_dir):
173+
if filename.startswith(MESSAGE_PREFIX) and filename.endswith(".json"):
174+
file_path = os.path.join(messages_dir, filename)
175+
message_data = self._read_file(file_path)
176+
messages.append(SessionMessage.from_dict(message_data))
177+
178+
# Sort by created_at timestamp (oldest first)
179+
messages.sort(key=lambda x: x.created_at)
180+
181+
# Apply pagination
182+
if limit is not None:
183+
messages = messages[offset : offset + limit]
184+
else:
185+
messages = messages[offset:]
186+
187+
return messages

0 commit comments

Comments
 (0)