From 8b267d21972f72bc0b4cdd59c132c282852dd9ca Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 25 Jun 2025 21:11:15 +0000 Subject: [PATCH 1/2] feat: Session persistence --- src/strands/agent/agent.py | 20 +- src/strands/session/agent_session_manager.py | 102 ++++++ src/strands/session/file_session_manager.py | 187 ++++++++++ src/strands/session/s3_session_manager.py | 237 +++++++++++++ src/strands/session/session_manager.py | 51 +++ src/strands/session/session_repository.py | 48 +++ src/strands/types/exceptions.py | 6 + src/strands/types/session.py | 118 +++++++ tests/fixtures/mock_session_repository.py | 98 ++++++ tests/strands/agent/test_agent.py | 36 +- tests/strands/session/__init__.py | 1 + .../session/test_agent_session_manager.py | 148 ++++++++ .../session/test_file_session_manager.py | 317 +++++++++++++++++ .../session/test_s3_session_manager.py | 329 ++++++++++++++++++ tests/strands/types/test_session.py | 91 +++++ tests_integ/test_session.py | 123 +++++++ 16 files changed, 1906 insertions(+), 6 deletions(-) create mode 100644 src/strands/session/agent_session_manager.py create mode 100644 src/strands/session/file_session_manager.py create mode 100644 src/strands/session/s3_session_manager.py create mode 100644 src/strands/session/session_manager.py create mode 100644 src/strands/session/session_repository.py create mode 100644 src/strands/types/session.py create mode 100644 tests/fixtures/mock_session_repository.py create mode 100644 tests/strands/session/__init__.py create mode 100644 tests/strands/session/test_agent_session_manager.py create mode 100644 tests/strands/session/test_file_session_manager.py create mode 100644 tests/strands/session/test_s3_session_manager.py create mode 100644 tests/strands/types/test_session.py create mode 100644 tests_integ/test_session.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index ab3c6d143..4925abfca 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -15,7 +15,6 @@ import random from concurrent.futures import ThreadPoolExecutor from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast -from uuid import uuid4 from opentelemetry import trace from pydantic import BaseModel @@ -32,6 +31,7 @@ ) from ..models.bedrock import BedrockModel from ..models.model import Model +from ..session.session_manager import SessionManager from ..telemetry.metrics import EventLoopMetrics from ..telemetry.tracer import get_tracer from ..tools.registry import ToolRegistry @@ -62,6 +62,7 @@ class _DefaultCallbackHandlerSentinel: _DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() _DEFAULT_AGENT_NAME = "Strands Agents" +_DEFAULT_AGENT_ID = "default" class Agent: @@ -207,6 +208,7 @@ def __init__( description: Optional[str] = None, state: Optional[Union[AgentState, dict]] = None, hooks: Optional[list[HookProvider]] = None, + session_manager: Optional[SessionManager] = None, ): """Initialize the Agent with the specified configuration. @@ -237,22 +239,24 @@ def __init__( load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. Defaults to False. trace_attributes: Custom trace attributes to apply to the agent's trace span. - agent_id: Optional ID for the agent, useful for multi-agent scenarios. - If None, a UUID is generated. + agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios. + Defaults to "default". name: name of the Agent - Defaults to None. + Defaults to "Strands Agents". description: description of what the Agent does Defaults to None. state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. Defaults to an empty AgentState object. hooks: hooks to be added to the agent hook registry Defaults to None. + session_manager: Manager for handling agent sessions including conversation history and state. + If provided, enables session-based persistence and state management. """ self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] self.system_prompt = system_prompt - self.agent_id = agent_id or str(uuid4()) + self.agent_id = agent_id or _DEFAULT_AGENT_ID self.name = name or _DEFAULT_AGENT_NAME self.description = description @@ -312,6 +316,12 @@ def __init__( self.tool_caller = Agent.ToolCaller(self) self.hooks = HookRegistry() + + # Initialize session management functionality + self.session_manager = session_manager + if self.session_manager: + self.hooks.add_hook(self.session_manager) + if hooks: for hook in hooks: self.hooks.add_hook(hook) diff --git a/src/strands/session/agent_session_manager.py b/src/strands/session/agent_session_manager.py new file mode 100644 index 000000000..4e0d86dec --- /dev/null +++ b/src/strands/session/agent_session_manager.py @@ -0,0 +1,102 @@ +"""Agent session manager implementation.""" + +import logging + +from ..agent.agent import _DEFAULT_AGENT_ID, Agent +from ..agent.state import AgentState +from ..types.content import Message +from ..types.exceptions import SessionException +from ..types.session import ( + Session, + SessionAgent, + SessionMessage, + SessionType, +) +from .session_manager import SessionManager +from .session_repository import SessionRepository + +logger = logging.getLogger(__name__) + +DEFAULT_SESSION_AGENT_ID = "default" + + +class AgentSessionManager(SessionManager): + """Session manager for persisting agent's in a Session.""" + + def __init__( + self, + session_id: str, + session_repository: SessionRepository, + ): + """Initialize the AgentSessionManager.""" + self.session_repository = session_repository + self.session_id = session_id + session = session_repository.read_session(session_id) + # Create a session if it does not exist yet + if session is None: + logger.debug("session_id=<%s> | Session not found, creating new session.", self.session_id) + session = Session(session_id=session_id, session_type=SessionType.AGENT) + session_repository.create_session(session) + + self.session = session + self._default_agent_initialized = False + + def append_message(self, message: Message, agent: Agent) -> None: + """Append a message to the agent's session. + + Args: + message: Message to add to the agent in the session + agent: Agent to append the message to + """ + session_message = SessionMessage.from_message(message) + if agent.agent_id is None: + raise ValueError("`agent.agent_id` must be set before appending message to session.") + self.session_repository.create_message(self.session_id, agent.agent_id, session_message) + + def sync_agent(self, agent: Agent) -> None: + """Sync agent to the session. + + Args: + agent: Agent to sync to the session. + """ + self.session_repository.update_agent( + self.session_id, + SessionAgent.from_agent(agent), + ) + + def initialize(self, agent: Agent) -> None: + """Initialize an agent with a session. + + Args: + agent: Agent to initialize from the session + """ + if agent.agent_id is _DEFAULT_AGENT_ID: + if self._default_agent_initialized: + raise SessionException("Set `agent_id` to support more than one agent in a session.") + self._default_agent_initialized = True + + session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) + + if session_agent is None: + logger.debug( + "agent_id=<%s> | session_id=<%s> | Creating agent.", + agent.agent_id, + self.session_id, + ) + + session_agent = SessionAgent.from_agent(agent) + self.session_repository.create_agent(self.session_id, session_agent) + for message in agent.messages: + session_message = SessionMessage.from_message(message) + self.session_repository.create_message(self.session_id, agent.agent_id, session_message) + else: + logger.debug( + "agent_id=<%s> | session_id=<%s> | Restoring agent.", + agent.agent_id, + self.session_id, + ) + agent.messages = [ + session_message.to_message() + for session_message in self.session_repository.list_messages(self.session_id, agent.agent_id) + ] + agent.state = AgentState(session_agent.state) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py new file mode 100644 index 000000000..615fad7fc --- /dev/null +++ b/src/strands/session/file_session_manager.py @@ -0,0 +1,187 @@ +"""File-based session manager for local filesystem storage.""" + +import json +import logging +import os +import shutil +import tempfile +from dataclasses import asdict +from typing import Any, Optional, cast + +from ..types.exceptions import SessionException +from ..types.session import Session, SessionAgent, SessionMessage +from .agent_session_manager import AgentSessionManager +from .session_repository import SessionRepository + +logger = logging.getLogger(__name__) + +SESSION_PREFIX = "session_" +AGENT_PREFIX = "agent_" +MESSAGE_PREFIX = "message_" + + +class FileSessionManager(AgentSessionManager, SessionRepository): + """File-based session manager for local filesystem storage.""" + + def __init__(self, session_id: str, storage_dir: Optional[str] = None): + """Initialize FileSession with filesystem storage. + + Args: + session_id: ID for the session + storage_dir: Directory for local filesystem storage (defaults to temp dir) + """ + self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions") + os.makedirs(self.storage_dir, exist_ok=True) + + super().__init__(session_id=session_id, session_repository=self) + + def _get_session_path(self, session_id: str) -> str: + """Get session directory path.""" + return os.path.join(self.storage_dir, f"{SESSION_PREFIX}{session_id}") + + def _get_agent_path(self, session_id: str, agent_id: str) -> str: + """Get agent directory path.""" + session_path = self._get_session_path(session_id) + return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}") + + def _get_message_path(self, session_id: str, agent_id: str, message_id: str) -> str: + """Get message file path.""" + agent_path = self._get_agent_path(session_id, agent_id) + return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json") + + def _read_file(self, path: str) -> dict[str, Any]: + """Read JSON file.""" + try: + with open(path, "r", encoding="utf-8") as f: + return cast(dict[str, Any], json.load(f)) + except json.JSONDecodeError as e: + raise SessionException(f"Invalid JSON in file {path}: {str(e)}") from e + + def _write_file(self, path: str, data: dict[str, Any]) -> None: + """Write JSON file.""" + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + def create_session(self, session: Session) -> Session: + """Create a new session.""" + session_dir = self._get_session_path(session.session_id) + if os.path.exists(session_dir): + raise SessionException(f"Session {session.session_id} already exists") + + # Create directory structure + os.makedirs(session_dir, exist_ok=True) + os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) + + # Write session file + session_file = os.path.join(session_dir, "session.json") + session_dict = asdict(session) + self._write_file(session_file, session_dict) + + return session + + def read_session(self, session_id: str) -> Optional[Session]: + """Read session data.""" + session_file = os.path.join(self._get_session_path(session_id), "session.json") + if not os.path.exists(session_file): + return None + + session_data = self._read_file(session_file) + return Session.from_dict(session_data) + + def create_agent(self, session_id: str, session_agent: SessionAgent) -> None: + """Create a new agent in the session.""" + agent_id = session_agent.agent_id + + agent_dir = self._get_agent_path(session_id, agent_id) + os.makedirs(agent_dir, exist_ok=True) + os.makedirs(os.path.join(agent_dir, "messages"), exist_ok=True) + + agent_file = os.path.join(agent_dir, "agent.json") + session_data = asdict(session_agent) + self._write_file(agent_file, session_data) + + def delete_session(self, session_id: str) -> None: + """Delete session and all associated data.""" + session_dir = self._get_session_path(session_id) + if not os.path.exists(session_dir): + raise SessionException(f"Session {session_id} does not exist") + + shutil.rmtree(session_dir) + + def read_agent(self, session_id: str, agent_id: str) -> Optional[SessionAgent]: + """Read agent data.""" + agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json") + if not os.path.exists(agent_file): + return None + + agent_data = self._read_file(agent_file) + return SessionAgent.from_dict(agent_data) + + def update_agent(self, session_id: str, session_agent: SessionAgent) -> None: + """Update agent data.""" + agent_id = session_agent.agent_id + previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id) + if previous_agent is None: + raise SessionException(f"Agent {agent_id} in session {session_id} does not exist") + + session_agent.created_at = previous_agent.created_at + agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json") + self._write_file(agent_file, asdict(session_agent)) + + def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: + """Create a new message for the agent.""" + message_file = self._get_message_path( + session_id, + agent_id, + session_message.message_id, + ) + session_dict = asdict(session_message) + self._write_file(message_file, session_dict) + + def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]: + """Read message data.""" + message_file = self._get_message_path(session_id, agent_id, message_id) + if not os.path.exists(message_file): + return None + message_data = self._read_file(message_file) + return SessionMessage.from_dict(message_data) + + def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: + """Update message data.""" + message_id = session_message.message_id + previous_message = self.read_message(session_id=session_id, agent_id=agent_id, message_id=message_id) + if previous_message is None: + raise SessionException(f"Message {message_id} does not exist") + + # Preserve the original created_at timestamp + session_message.created_at = previous_message.created_at + message_file = self._get_message_path(session_id, agent_id, message_id) + self._write_file(message_file, asdict(session_message)) + + def list_messages( + self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0 + ) -> list[SessionMessage]: + """List messages for an agent with pagination.""" + messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages") + if not os.path.exists(messages_dir): + raise SessionException(f"Messages directory missing from agent: {agent_id} in session {session_id}") + + # Read all message files + messages: list[SessionMessage] = [] + for filename in os.listdir(messages_dir): + if filename.startswith(MESSAGE_PREFIX) and filename.endswith(".json"): + file_path = os.path.join(messages_dir, filename) + message_data = self._read_file(file_path) + messages.append(SessionMessage.from_dict(message_data)) + + # Sort by created_at timestamp (oldest first) + messages.sort(key=lambda x: x.created_at) + + # Apply pagination + if limit is not None: + messages = messages[offset : offset + limit] + else: + messages = messages[offset:] + + return messages diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py new file mode 100644 index 000000000..013339e22 --- /dev/null +++ b/src/strands/session/s3_session_manager.py @@ -0,0 +1,237 @@ +"""S3-based session manager for cloud storage.""" + +import json +import logging +from dataclasses import asdict +from typing import Any, Dict, List, Optional, cast + +import boto3 +from botocore.config import Config as BotocoreConfig +from botocore.exceptions import ClientError + +from ..types.exceptions import SessionException +from ..types.session import Session, SessionAgent, SessionMessage +from .agent_session_manager import AgentSessionManager +from .session_repository import SessionRepository + +logger = logging.getLogger(__name__) + +SESSION_PREFIX = "session_" +AGENT_PREFIX = "agent_" +MESSAGE_PREFIX = "message_" + + +class S3SessionManager(AgentSessionManager, SessionRepository): + """S3-based session manager for cloud storage.""" + + def __init__( + self, + session_id: str, + bucket: str, + prefix: str = "", + boto_session: Optional[boto3.Session] = None, + boto_client_config: Optional[BotocoreConfig] = None, + region_name: Optional[str] = None, + ): + """Initialize S3SessionManager with S3 storage. + + Args: + session_id: ID for the session + bucket: S3 bucket name (required) + prefix: S3 key prefix for storage organization + boto_session: Optional boto3 session + boto_client_config: Optional boto3 client configuration + region_name: AWS region for S3 storage + """ + self.bucket = bucket + self.prefix = prefix + + session = boto_session or boto3.Session(region_name=region_name) + + # Add strands-agents to the request user agent + if boto_client_config: + existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) + # Append 'strands-agents' to existing user_agent_extra or set it if not present + if existing_user_agent: + new_user_agent = f"{existing_user_agent} strands-agents" + else: + new_user_agent = "strands-agents" + client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + else: + client_config = BotocoreConfig(user_agent_extra="strands-agents") + + self.client = session.client(service_name="s3", config=client_config) + super().__init__(session_id=session_id, session_repository=self) + + def _get_session_path(self, session_id: str) -> str: + """Get session S3 prefix.""" + return f"{self.prefix}{SESSION_PREFIX}{session_id}/" + + def _get_agent_path(self, session_id: str, agent_id: str) -> str: + """Get agent S3 prefix.""" + session_path = self._get_session_path(session_id) + return f"{session_path}agents/{AGENT_PREFIX}{agent_id}/" + + def _get_message_path(self, session_id: str, agent_id: str, message_id: str) -> str: + """Get message S3 key.""" + agent_path = self._get_agent_path(session_id, agent_id) + return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json" + + def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]: + """Read JSON object from S3.""" + try: + response = self.client.get_object(Bucket=self.bucket, Key=key) + content = response["Body"].read().decode("utf-8") + return cast(dict[str, Any], json.loads(content)) + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + return None + else: + raise SessionException(f"S3 error reading {key}: {e}") from e + except json.JSONDecodeError as e: + raise SessionException(f"Invalid JSON in S3 object {key}: {e}") from e + + def _write_s3_object(self, key: str, data: Dict[str, Any]) -> None: + """Write JSON object to S3.""" + try: + content = json.dumps(data, indent=2, ensure_ascii=False) + self.client.put_object( + Bucket=self.bucket, Key=key, Body=content.encode("utf-8"), ContentType="application/json" + ) + except ClientError as e: + raise SessionException(f"Failed to write S3 object {key}: {e}") from e + + def create_session(self, session: Session) -> Session: + """Create a new session in S3.""" + session_key = f"{self._get_session_path(session.session_id)}session.json" + + # Check if session already exists + try: + self.client.head_object(Bucket=self.bucket, Key=session_key) + raise SessionException(f"Session {session.session_id} already exists") + except ClientError as e: + if e.response["Error"]["Code"] != "404": + raise SessionException(f"S3 error checking session existence: {e}") from e + + # Write session object + session_dict = asdict(session) + self._write_s3_object(session_key, session_dict) + return session + + def read_session(self, session_id: str) -> Optional[Session]: + """Read session data from S3.""" + session_key = f"{self._get_session_path(session_id)}session.json" + session_data = self._read_s3_object(session_key) + if session_data is None: + return None + return Session.from_dict(session_data) + + def delete_session(self, session_id: str) -> None: + """Delete session and all associated data from S3.""" + session_prefix = self._get_session_path(session_id) + try: + paginator = self.client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=self.bucket, Prefix=session_prefix) + + objects_to_delete = [] + for page in pages: + if "Contents" in page: + objects_to_delete.extend([{"Key": obj["Key"]} for obj in page["Contents"]]) + + if not objects_to_delete: + raise SessionException(f"Session {session_id} does not exist") + + # Delete objects in batches + for i in range(0, len(objects_to_delete), 1000): + batch = objects_to_delete[i : i + 1000] + self.client.delete_objects(Bucket=self.bucket, Delete={"Objects": batch}) + + except ClientError as e: + raise SessionException(f"S3 error deleting session {session_id}: {e}") from e + + def create_agent(self, session_id: str, session_agent: SessionAgent) -> None: + """Create a new agent in S3.""" + agent_id = session_agent.agent_id + agent_dict = asdict(session_agent) + agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" + self._write_s3_object(agent_key, agent_dict) + + def read_agent(self, session_id: str, agent_id: str) -> Optional[SessionAgent]: + """Read agent data from S3.""" + agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" + agent_data = self._read_s3_object(agent_key) + if agent_data is None: + return None + return SessionAgent.from_dict(agent_data) + + def update_agent(self, session_id: str, session_agent: SessionAgent) -> None: + """Update agent data in S3.""" + agent_id = session_agent.agent_id + previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id) + if previous_agent is None: + raise SessionException(f"Agent {agent_id} in session {session_id} does not exist") + + # Preserve creation timestamp + session_agent.created_at = previous_agent.created_at + agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" + self._write_s3_object(agent_key, asdict(session_agent)) + + def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: + """Create a new message in S3.""" + message_id = session_message.message_id + message_dict = asdict(session_message) + message_key = self._get_message_path(session_id, agent_id, message_id) + self._write_s3_object(message_key, message_dict) + + def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]: + """Read message data from S3.""" + message_key = self._get_message_path(session_id, agent_id, message_id) + message_data = self._read_s3_object(message_key) + if message_data is None: + return None + return SessionMessage.from_dict(message_data) + + def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: + """Update message data in S3.""" + message_id = session_message.message_id + previous_message = self.read_message(session_id=session_id, agent_id=agent_id, message_id=message_id) + if previous_message is None: + raise SessionException(f"Message {message_id} does not exist") + + # Preserve creation timestamp + session_message.created_at = previous_message.created_at + message_key = self._get_message_path(session_id, agent_id, message_id) + self._write_s3_object(message_key, asdict(session_message)) + + def list_messages( + self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0 + ) -> List[SessionMessage]: + """List messages for an agent with pagination from S3.""" + messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" + try: + paginator = self.client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=self.bucket, Prefix=messages_prefix) + + # Read all message objects + messages: List[SessionMessage] = [] + for page in pages: + if "Contents" in page: + for obj in page["Contents"]: + if obj["Key"].endswith(".json"): + message_data = self._read_s3_object(obj["Key"]) + if message_data: + messages.append(SessionMessage.from_dict(message_data)) + + # Sort by created_at timestamp (oldest first) + messages.sort(key=lambda x: x.created_at) + + # Apply pagination + if limit is not None: + messages = messages[offset : offset + limit] + else: + messages = messages[offset:] + + return messages + + except ClientError as e: + raise SessionException(f"S3 error reading messages: {e}") from e diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py new file mode 100644 index 000000000..984ae6f81 --- /dev/null +++ b/src/strands/session/session_manager.py @@ -0,0 +1,51 @@ +"""Session manager interface for agent session management.""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from ..hooks.events import AgentInitializedEvent, MessageAddedEvent +from ..hooks.registry import HookProvider, HookRegistry +from ..types.content import Message + +if TYPE_CHECKING: + from ..agent.agent import Agent + + +class SessionManager(HookProvider, ABC): + """Abstract interface for managing sessions. + + A session represents a complete interaction context including conversation + history, user information, agent state, and metadata. This interface provides + methods to manage sessions and their associated data. + """ + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register initialize and append_message as hooks for the Agent.""" + registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) + registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent)) + registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent)) + + @abstractmethod + def append_message(self, message: Message, agent: "Agent") -> None: + """Append a message to the agent's session. + + Args: + message: Message to add to the agent in the session + agent: Agent to append the message to + """ + + @abstractmethod + def sync_agent(self, agent: "Agent") -> None: + """Sync the agent to the session. + + Args: + agent: Agent to sync to the session + """ + + @abstractmethod + def initialize(self, agent: "Agent") -> None: + """Initialize an agent with a session. + + Args: + agent: Agent to initialize + """ diff --git a/src/strands/session/session_repository.py b/src/strands/session/session_repository.py new file mode 100644 index 000000000..9b6465f28 --- /dev/null +++ b/src/strands/session/session_repository.py @@ -0,0 +1,48 @@ +"""Session repository interface for agent session management.""" + +from abc import ABC, abstractmethod +from typing import Optional + +from ..types.session import Session, SessionAgent, SessionMessage + + +class SessionRepository(ABC): + """Abstract repository for creating, reading, and updating Sessions, AgentSessions, and AgentMessages.""" + + @abstractmethod + def create_session(self, session: Session) -> Session: + """Create a new Session.""" + + @abstractmethod + def read_session(self, session_id: str) -> Optional[Session]: + """Read a Session.""" + + @abstractmethod + def create_agent(self, session_id: str, session_agent: SessionAgent) -> None: + """Create a new Agent in a Session.""" + + @abstractmethod + def read_agent(self, session_id: str, agent_id: str) -> Optional[SessionAgent]: + """Read an Agent.""" + + @abstractmethod + def update_agent(self, session_id: str, session_agent: SessionAgent) -> None: + """Update an Agent.""" + + @abstractmethod + def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: + """Create a new Message for the Agent.""" + + @abstractmethod + def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]: + """Read a Message.""" + + @abstractmethod + def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: + """Update a Message.""" + + @abstractmethod + def list_messages( + self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0 + ) -> list[SessionMessage]: + """List Messages from an Agent with pagination.""" diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 1ffeba4ec..4bd3fd88e 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -52,3 +52,9 @@ def __init__(self, message: str) -> None: super().__init__(message) pass + + +class SessionException(Exception): + """Exception raised when session operations fail.""" + + pass diff --git a/src/strands/types/session.py b/src/strands/types/session.py new file mode 100644 index 000000000..2779b2866 --- /dev/null +++ b/src/strands/types/session.py @@ -0,0 +1,118 @@ +"""Data models for session management.""" + +import base64 +import inspect +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, cast +from uuid import uuid4 + +from ..agent.agent import Agent +from .content import Message + + +class SessionType(str, Enum): + """Enumeration of session types.""" + + AGENT = "AGENT" + + +def encode_bytes_values(obj: Any) -> Any: + """Recursively encode any bytes values in an object to base64. + + Handles dictionaries, lists, and nested structures. + """ + if isinstance(obj, bytes): + return {"__bytes_encoded__": True, "data": base64.b64encode(obj).decode()} + elif isinstance(obj, dict): + return {k: encode_bytes_values(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [encode_bytes_values(item) for item in obj] + else: + return obj + + +def decode_bytes_values(obj: Any) -> Any: + """Recursively decode any base64-encoded bytes values in an object. + + Handles dictionaries, lists, and nested structures. + """ + if isinstance(obj, dict): + if obj.get("__bytes_encoded__") is True and "data" in obj: + return base64.b64decode(obj["data"]) + return {k: decode_bytes_values(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [decode_bytes_values(item) for item in obj] + else: + return obj + + +@dataclass +class SessionMessage: + """Message within a SessionAgent.""" + + message: Message + message_id: str = field(default_factory=lambda: str(uuid4())) + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + @classmethod + def from_message(cls, message: Message) -> "SessionMessage": + """Convert from a Message, base64 encoding bytes values.""" + bytes_encoded_dict = encode_bytes_values(message) + return cls( + message=bytes_encoded_dict, + message_id=str(uuid4()), + created_at=datetime.now(timezone.utc).isoformat(), + updated_at=datetime.now(timezone.utc).isoformat(), + ) + + def to_message(self) -> Message: + """Convert SessionMessage back to a Message, decoding any bytes values.""" + return cast(Message, decode_bytes_values(self.message)) + + @classmethod + def from_dict(cls, env: dict[str, Any]) -> "SessionMessage": + """Initialize a SessionMessage from a dictionary, ignoring keys that are not calss parameters.""" + return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) + + +@dataclass +class SessionAgent: + """Agent within a Session.""" + + agent_id: str + state: Dict[str, Any] + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + @classmethod + def from_agent(cls, agent: Agent) -> "SessionAgent": + """Convert an Agent to a SessionAgent.""" + if agent.agent_id is None: + raise ValueError("agent_id needs to be defined.") + return cls( + agent_id=agent.agent_id, + state=agent.state.get(), + ) + + @classmethod + def from_dict(cls, env: dict[str, Any]) -> "SessionAgent": + """Initialize a SessionAgent from a dictionary, ignoring keys that are not calss parameters.""" + return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) + + +@dataclass +class Session: + """Session data model.""" + + session_id: str + session_type: SessionType + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + @classmethod + def from_dict(cls, env: dict[str, Any]) -> "Session": + """Initialize a Session from a dictionary, ignoring keys that are not calss parameters.""" + return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) diff --git a/tests/fixtures/mock_session_repository.py b/tests/fixtures/mock_session_repository.py new file mode 100644 index 000000000..8e25691d0 --- /dev/null +++ b/tests/fixtures/mock_session_repository.py @@ -0,0 +1,98 @@ +from strands.session.session_repository import SessionRepository +from strands.types.exceptions import SessionException + + +class MockedSessionRepository(SessionRepository): + """Mock repository for testing.""" + + def __init__(self): + """Initialize with empty storage.""" + self.sessions = {} + self.agents = {} + self.messages = {} + + def create_session(self, session): + """Create a session.""" + session_id = session.session_id + if session_id in self.sessions: + raise SessionException(f"Session {session_id} already exists") + self.sessions[session_id] = session + self.agents[session_id] = {} + self.messages[session_id] = {} + return session + + def read_session(self, session_id): + """Read a session.""" + return self.sessions.get(session_id) + + def create_agent(self, session_id, session_agent): + """Create an agent.""" + agent_id = session_agent.agent_id + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + if agent_id in self.agents.get(session_id, {}): + raise SessionException(f"Agent {agent_id} already exists in session {session_id}") + self.agents.setdefault(session_id, {})[agent_id] = session_agent + self.messages.setdefault(session_id, {}).setdefault(agent_id, []) + return session_agent + + def read_agent(self, session_id, agent_id): + """Read an agent.""" + if session_id not in self.sessions: + return None + return self.agents.get(session_id, {}).get(agent_id) + + def update_agent(self, session_id, session_agent): + """Update an agent.""" + agent_id = session_agent.agent_id + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + if agent_id not in self.agents.get(session_id, {}): + raise SessionException(f"Agent {agent_id} does not exist in session {session_id}") + self.agents[session_id][agent_id] = session_agent + + def create_message(self, session_id, agent_id, session_message): + """Create a message.""" + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + if agent_id not in self.agents.get(session_id, {}): + raise SessionException(f"Agent {agent_id} does not exist in session {session_id}") + self.messages.setdefault(session_id, {}).setdefault(agent_id, []).append(session_message) + + def read_message(self, session_id, agent_id, message_id): + """Read a message.""" + if session_id not in self.sessions: + return None + if agent_id not in self.agents.get(session_id, {}): + return None + for message in self.messages.get(session_id, {}).get(agent_id, []): + if message.message_id == message_id: + return message + return None + + def update_message(self, session_id, agent_id, session_message): + """Update a message.""" + message_id = session_message.message_id + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + if agent_id not in self.agents.get(session_id, {}): + raise SessionException(f"Agent {agent_id} does not exist in session {session_id}") + + for i, message in enumerate(self.messages.get(session_id, {}).get(agent_id, [])): + if message.message_id == message_id: + self.messages[session_id][agent_id][i] = session_message + return + + raise SessionException(f"Message {message_id} does not exist") + + def list_messages(self, session_id, agent_id, limit=None, offset=0): + """List messages.""" + if session_id not in self.sessions: + return [] + if agent_id not in self.agents.get(session_id, {}): + return [] + + messages = self.messages.get(session_id, {}).get(agent_id, []) + if limit is not None: + return messages[offset : offset + limit] + return messages[offset:] diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 988e08919..b9aa15c91 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -13,10 +13,15 @@ from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager +from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel +from strands.session.agent_session_manager import DEFAULT_SESSION_AGENT_ID, AgentSessionManager from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException, EventLoopException +from strands.types.session import Session, SessionAgent, SessionType +from tests.fixtures.mock_session_repository import MockedSessionRepository +from tests.fixtures.mocked_model_provider import MockedModelProvider @pytest.fixture @@ -636,7 +641,6 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): ) agent("test") - callback_handler.assert_has_calls( [ unittest.mock.call(init_event_loop=True), @@ -1338,6 +1342,11 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) +def test_agent_init_with_state_object(): + agent = Agent(state=AgentState({"foo": "bar"})) + assert agent.state.get("foo") == "bar" + + def test_non_dict_throws_error(): with pytest.raises(ValueError, match="state must be an AgentState object or a dict"): agent = Agent(state={"object", object()}) @@ -1391,3 +1400,28 @@ def test_agent_state_get_breaks_deep_dict_reference(): # This will fail if AgentState reflects the updated reference json.dumps(agent.state.get()) + + +def test_agent_session_management(): + mock_session_repository = MockedSessionRepository() + session_manager = AgentSessionManager(session_id="123", session_repository=mock_session_repository) + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) + agent = Agent(session_manager=session_manager, model=model) + agent("Hello!") + + +def test_agent_restored_from_session_management(): + mock_session_repository = MockedSessionRepository() + mock_session_repository.create_session(Session(session_id="123", session_type=SessionType.AGENT)) + mock_session_repository.create_agent( + "123", + SessionAgent( + agent_id=DEFAULT_SESSION_AGENT_ID, + state={"foo": "bar"}, + ), + ) + session_manager = AgentSessionManager(session_id="123", session_repository=mock_session_repository) + + agent = Agent(session_manager=session_manager) + + assert agent.state.get("foo") == "bar" diff --git a/tests/strands/session/__init__.py b/tests/strands/session/__init__.py new file mode 100644 index 000000000..601ac7006 --- /dev/null +++ b/tests/strands/session/__init__.py @@ -0,0 +1 @@ +"""Tests for session management.""" diff --git a/tests/strands/session/test_agent_session_manager.py b/tests/strands/session/test_agent_session_manager.py new file mode 100644 index 000000000..11de4fb2d --- /dev/null +++ b/tests/strands/session/test_agent_session_manager.py @@ -0,0 +1,148 @@ +"""Tests for AgentSessionManager.""" + +import pytest + +from strands.agent.agent import Agent +from strands.session.agent_session_manager import AgentSessionManager +from strands.types.content import ContentBlock +from strands.types.exceptions import SessionException +from strands.types.session import Session, SessionAgent, SessionMessage, SessionType +from tests.fixtures.mock_session_repository import MockedSessionRepository + + +@pytest.fixture +def mock_repository(): + """Create a mock repository.""" + return MockedSessionRepository() + + +@pytest.fixture +def session_manager(mock_repository): + """Create a session manager with mock repository.""" + return AgentSessionManager(session_id="test-session", session_repository=mock_repository) + + +@pytest.fixture +def agent(): + """Create a mock agent.""" + return Agent(messages=[{"role": "user", "content": [{"text": "Hello!"}]}]) + + +def test_init_creates_session_if_not_exists(mock_repository): + """Test that init creates a session if it doesn't exist.""" + # Session doesn't exist yet + assert mock_repository.read_session("test-session") is None + + # Creating manager should create session + AgentSessionManager(session_id="test-session", session_repository=mock_repository) + + # Verify session created + session = mock_repository.read_session("test-session") + assert session is not None + assert session.session_id == "test-session" + assert session.session_type == SessionType.AGENT + + +def test_init_uses_existing_session(mock_repository): + """Test that init uses existing session if it exists.""" + # Create session first + session = Session(session_id="test-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Creating manager should use existing session + manager = AgentSessionManager(session_id="test-session", session_repository=mock_repository) + + # Verify session used + assert manager.session == session + + +def test_initialize_with_existing_agent_id(session_manager, agent): + """Test initializing an agent with existing agent_id.""" + # Set agent ID + agent.agent_id = "custom-agent" + + # Initialize agent + session_manager.initialize(agent) + + # Verify agent created in repository + agent_data = session_manager.session_repository.read_agent("test-session", "custom-agent") + assert agent_data is not None + assert agent_data.agent_id == "custom-agent" + + +def test_initialize_multiple_agents_without_id(session_manager, agent): + """Test initializing multiple agents without IDs.""" + # First agent initialization works + session_manager.initialize(agent) + + # Second agent with no set agent_id should fail + agent2 = Agent() + + with pytest.raises(SessionException, match="Set `agent_id` to support more than one agent in a session."): + session_manager.initialize(agent2) + + +def test_initialize_restores_existing_agent(session_manager, agent): + """Test that initializing an existing agent restores its state.""" + # Set agent ID + agent.agent_id = "existing-agent" + + # Create agent in repository first + session_agent = SessionAgent(agent_id="existing-agent", state={"key": "value"}) + session_manager.session_repository.create_agent("test-session", session_agent) + + # Create some messages + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text="Hello")], + } + ) + session_manager.session_repository.create_message("test-session", "existing-agent", message) + + # Initialize agent + session_manager.initialize(agent) + + # Verify agent state restored + assert agent.state.get("key") == "value" + assert len(agent.messages) == 1 + assert agent.messages[0]["role"] == "user" + assert agent.messages[0]["content"][0]["text"] == "Hello" + + +def test_append_message(session_manager, agent): + """Test appending a message to an agent's session.""" + # Set agent ID + agent.agent_id = "test-agent" + + # Create agent in repository + session_agent = SessionAgent( + agent_id="test-agent", + state={}, + ) + session_manager.session_repository.create_agent("test-session", session_agent) + + # Create message + message = {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + + # Append message + session_manager.append_message(message, agent) + + # Verify message created in repository + messages = session_manager.session_repository.list_messages("test-session", "test-agent") + assert len(messages) == 1 + assert messages[0].message["role"] == "user" + assert messages[0].message["content"][0]["text"] == "Hello" + + +def test_append_message_without_agent_id(session_manager, agent): + """Test appending a message to an agent without ID.""" + # Agent has no ID + agent.agent_id = None + + # Create message + message = {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + + # Append message should fail + with pytest.raises(ValueError, match="`agent.agent_id` must be set"): + session_manager.append_message(message, agent) diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py new file mode 100644 index 000000000..f6acccc7b --- /dev/null +++ b/tests/strands/session/test_file_session_manager.py @@ -0,0 +1,317 @@ +"""Tests for FileSessionManager.""" + +import json +import os +import tempfile +from unittest.mock import patch + +import pytest + +from strands.session.file_session_manager import FileSessionManager +from strands.types.content import ContentBlock +from strands.types.exceptions import SessionException +from strands.types.session import Session, SessionAgent, SessionMessage, SessionType + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + +@pytest.fixture +def file_manager(temp_dir): + """Create FileSessionManager for testing.""" + return FileSessionManager(session_id="test", storage_dir=temp_dir) + + +@pytest.fixture +def sample_session(): + """Create sample session for testing.""" + return Session(session_id="test-session", session_type=SessionType.AGENT) + + +@pytest.fixture +def sample_agent(): + """Create sample agent for testing.""" + return SessionAgent( + agent_id="test-agent", + state={"key": "value"}, + ) + + +@pytest.fixture +def sample_message(): + """Create sample message for testing.""" + return SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text="Hello world")], + } + ) + + +class TestFileSessionManagerSessionOperations: + """Tests for session operations.""" + + def test_create_session(self, file_manager, sample_session): + """Test creating a session.""" + file_manager.create_session(sample_session) + + # Verify directory structure created + session_path = file_manager._get_session_path(sample_session.session_id) + assert os.path.exists(session_path) + + # Verify session file created + session_file = os.path.join(session_path, "session.json") + assert os.path.exists(session_file) + + # Verify content + with open(session_file, "r") as f: + data = json.load(f) + assert data["session_id"] == sample_session.session_id + assert data["session_type"] == sample_session.session_type + + def test_read_session(self, file_manager, sample_session): + """Test reading an existing session.""" + # Create session first + file_manager.create_session(sample_session) + + # Read it back + result = file_manager.read_session(sample_session.session_id) + + assert result.session_id == sample_session.session_id + assert result.session_type == sample_session.session_type + + def test_read_nonexistent_session(self, file_manager): + """Test reading a session that doesn't exist.""" + result = file_manager.read_session("nonexistent-session") + assert result is None + + def test_delete_session(self, file_manager, sample_session): + """Test deleting a session.""" + # Create session first + file_manager.create_session(sample_session) + session_path = file_manager._get_session_path(sample_session.session_id) + assert os.path.exists(session_path) + + # Delete session + file_manager.delete_session(sample_session.session_id) + + # Verify deletion + assert not os.path.exists(session_path) + + def test_delete_nonexistent_session(self, file_manager): + """Test deleting a session that doesn't exist.""" + # Should raise an error according to the implementation + with pytest.raises(SessionException, match="does not exist"): + file_manager.delete_session("nonexistent-session") + + +class TestFileSessionManagerAgentOperations: + """Tests for agent operations.""" + + def test_create_agent(self, file_manager, sample_session, sample_agent): + """Test creating an agent in a session.""" + # Create session first + file_manager.create_session(sample_session) + + # Create agent + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Verify directory structure + agent_path = file_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id) + assert os.path.exists(agent_path) + + # Verify agent file + agent_file = os.path.join(agent_path, "agent.json") + assert os.path.exists(agent_file) + + # Verify content + with open(agent_file, "r") as f: + data = json.load(f) + assert data["agent_id"] == sample_agent.agent_id + assert data["state"] == sample_agent.state + + def test_read_agent(self, file_manager, sample_session, sample_agent): + """Test reading an agent from a session.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Read agent + result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + + assert result.agent_id == sample_agent.agent_id + assert result.state == sample_agent.state + + def test_read_nonexistent_agent(self, file_manager, sample_session): + """Test reading an agent that doesn't exist.""" + result = file_manager.read_agent(sample_session.session_id, "nonexistent_agent") + assert result is None + + def test_update_agent(self, file_manager, sample_session, sample_agent): + """Test updating an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Update agent + sample_agent.state = {"updated": "value"} + file_manager.update_agent(sample_session.session_id, sample_agent) + + # Verify update + result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + assert result.state == {"updated": "value"} + + +class TestFileSessionManagerMessageOperations: + """Tests for message operations.""" + + def test_create_message(self, file_manager, sample_session, sample_agent, sample_message): + """Test creating a message for an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Create message + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify message file + message_path = file_manager._get_message_path( + sample_session.session_id, sample_agent.agent_id, sample_message.message_id + ) + assert os.path.exists(message_path) + + # Verify content + with open(message_path, "r") as f: + data = json.load(f) + assert data["message_id"] == sample_message.message_id + + def test_read_message(self, file_manager, sample_session, sample_agent, sample_message): + """Test reading a message.""" + # Create session, agent, and message + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Read message + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + + assert result.message_id == sample_message.message_id + assert result.message["role"] == sample_message.message["role"] + assert result.message["content"] == sample_message.message["content"] + + def test_read_nonexistent_message(self, file_manager, sample_session, sample_agent): + """Test reading a message that doesnt exist.""" + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + assert result is None + + def test_list_messages_all(self, file_manager, sample_session, sample_agent): + """Test listing all messages for an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + messages = [] + for i in range(5): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + } + ) + messages.append(message) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List all messages + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 5 + + def test_list_messages_with_limit(self, file_manager, sample_session, sample_agent): + """Test listing messages with limit.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for i in range(10): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + } + ) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with limit + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3) + + assert len(result) == 3 + + def test_list_messages_with_offset(self, file_manager, sample_session, sample_agent): + """Test listing messages with offset.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for i in range(10): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + } + ) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with offset + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5) + + assert len(result) == 5 + + def test_update_message(self, file_manager, sample_session, sample_agent, sample_message): + """Test updating a message.""" + # Create session, agent, and message + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Update message + sample_message.message["content"] = [ContentBlock(text="Updated content")] + file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify update + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + assert result.message["content"][0]["text"] == "Updated content" + + # Note: delete_message is not implemented in FileSessionManager + + +class TestFileSessionManagerErrorHandling: + """Tests for error handling scenarios.""" + + def test_corrupted_json_file(self, file_manager, temp_dir): + """Test handling of corrupted JSON files.""" + # Create a corrupted session file + session_path = os.path.join(temp_dir, "session_test") + os.makedirs(session_path, exist_ok=True) + session_file = os.path.join(session_path, "session.json") + + with open(session_file, "w") as f: + f.write("invalid json content") + + # Should raise SessionException + with pytest.raises(SessionException, match="Invalid JSON"): + file_manager._read_file(session_file) + + def test_permission_error_handling(self, file_manager): + """Test handling of permission errors.""" + with patch("builtins.open", side_effect=PermissionError("Access denied")): + session = Session(session_id="test", session_type=SessionType.AGENT) + + with pytest.raises(SessionException): + file_manager.create_session(session) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py new file mode 100644 index 000000000..59a92a6ec --- /dev/null +++ b/tests/strands/session/test_s3_session_manager.py @@ -0,0 +1,329 @@ +"""Tests for S3SessionManager.""" + +import json + +import boto3 +import pytest +from botocore.config import Config as BotocoreConfig +from botocore.exceptions import ClientError +from moto import mock_aws + +from strands.session.s3_session_manager import S3SessionManager +from strands.types.content import ContentBlock +from strands.types.exceptions import SessionException +from strands.types.session import Session, SessionAgent, SessionMessage, SessionType + + +@pytest.fixture +def mocked_aws(): + """ + Mock all AWS interactions + Requires you to create your own boto3 clients + """ + with mock_aws(): + yield + + +@pytest.fixture(scope="function") +def s3_bucket(mocked_aws): + """S3 bucket name for testing.""" + # Create the bucket + s3_client = boto3.client("s3", region_name="us-west-2") + s3_client.create_bucket(Bucket="test-session-bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-2"}) + return "test-session-bucket" + + +@pytest.fixture +def s3_manager(mocked_aws, s3_bucket): + """Create S3SessionManager with mocked S3.""" + yield S3SessionManager(session_id="test", bucket=s3_bucket, prefix="sessions/", region_name="us-west-2") + + +@pytest.fixture +def sample_session(): + """Create sample session for testing.""" + return Session( + session_id="test-session-123", + session_type=SessionType.AGENT, + ) + + +@pytest.fixture +def sample_agent(): + """Create sample agent for testing.""" + return SessionAgent( + agent_id="test-agent-456", + state={"key": "value"}, + ) + + +@pytest.fixture +def sample_message(): + """Create sample message for testing.""" + return SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text="test_message")], + } + ) + + +def test_init_s3_session_manager(mocked_aws, s3_bucket): + session_manager = S3SessionManager(session_id="test", bucket=s3_bucket) + assert "strands-agents" in session_manager.client.meta.config.user_agent_extra + + +def test_init_s3_session_manager_with_config(mocked_aws, s3_bucket): + session_manager = S3SessionManager(session_id="test", bucket=s3_bucket, boto_client_config=BotocoreConfig()) + assert "strands-agents" in session_manager.client.meta.config.user_agent_extra + + +def test_init_s3_session_manager_with_existing_user_agent(mocked_aws, s3_bucket): + session_manager = S3SessionManager( + session_id="test", bucket=s3_bucket, boto_client_config=BotocoreConfig(user_agent_extra="test") + ) + assert "strands-agents" in session_manager.client.meta.config.user_agent_extra + + +def test_create_session(s3_manager, sample_session): + """Test creating a session in S3.""" + result = s3_manager.create_session(sample_session) + + assert result == sample_session + + # Verify S3 object created + key = f"{s3_manager._get_session_path(sample_session.session_id)}session.json" + response = s3_manager.client.get_object(Bucket=s3_manager.bucket, Key=key) + data = json.loads(response["Body"].read().decode("utf-8")) + + assert data["session_id"] == sample_session.session_id + assert data["session_type"] == sample_session.session_type + + +def test_create_session_already_exists(s3_manager, sample_session): + """Test creating a session in S3.""" + s3_manager.create_session(sample_session) + + with pytest.raises(SessionException): + s3_manager.create_session(sample_session) + + +def test_read_session(s3_manager, sample_session): + """Test reading a session from S3.""" + # Create session first + s3_manager.create_session(sample_session) + + # Read it back + result = s3_manager.read_session(sample_session.session_id) + + assert result.session_id == sample_session.session_id + assert result.session_type == sample_session.session_type + + +def test_read_nonexistent_session(s3_manager): + """Test reading a session that doesn't exist in S3.""" + with mock_aws(): + result = s3_manager.read_session("nonexistent-session") + assert result is None + + +def test_delete_session(s3_manager, sample_session): + """Test deleting a session from S3.""" + # Create session first + s3_manager.create_session(sample_session) + + # Verify session exists + key = f"{s3_manager._get_session_path(sample_session.session_id)}session.json" + s3_manager.client.head_object(Bucket=s3_manager.bucket, Key=key) + + # Delete session + s3_manager.delete_session(sample_session.session_id) + + # Verify deletion + with pytest.raises(ClientError) as excinfo: + s3_manager.client.head_object(Bucket=s3_manager.bucket, Key=key) + assert excinfo.value.response["Error"]["Code"] == "404" + + +def test_create_agent(s3_manager, sample_session, sample_agent): + """Test creating an agent in S3.""" + # Create session first + s3_manager.create_session(sample_session) + + # Create agent + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Verify S3 object created + key = f"{s3_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id)}agent.json" + response = s3_manager.client.get_object(Bucket=s3_manager.bucket, Key=key) + data = json.loads(response["Body"].read().decode("utf-8")) + + assert data["agent_id"] == sample_agent.agent_id + assert data["state"] == sample_agent.state + + +def test_read_agent(s3_manager, sample_session, sample_agent): + """Test reading an agent from S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Read agent + result = s3_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + + assert result.agent_id == sample_agent.agent_id + assert result.state == sample_agent.state + + +def test_read_nonexistent_agent(s3_manager, sample_session, sample_agent): + """Test reading an agent from S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + # Read agent + result = s3_manager.read_agent(sample_session.session_id, "nonexistent_agent") + + assert result is None + + +def test_update_agent(s3_manager, sample_session, sample_agent): + """Test updating an agent in S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Update agent + sample_agent.state = {"updated": "value"} + s3_manager.update_agent(sample_session.session_id, sample_agent) + + # Verify update + result = s3_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + assert result.state == {"updated": "value"} + + +def test_update_nonexistent_agent(s3_manager, sample_session, sample_agent): + """Test updating an agent in S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + + with pytest.raises(SessionException): + s3_manager.update_agent(sample_session.session_id, sample_agent) + + +def test_create_message(s3_manager, sample_session, sample_agent, sample_message): + """Test creating a message in S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Create message + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify S3 object created + key = s3_manager._get_message_path(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + response = s3_manager.client.get_object(Bucket=s3_manager.bucket, Key=key) + data = json.loads(response["Body"].read().decode("utf-8")) + + assert data["message_id"] == sample_message.message_id + + +def test_read_message(s3_manager, sample_session, sample_agent, sample_message): + """Test reading a message from S3.""" + # Create session, agent, and message + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Read message + result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + + assert result.message_id == sample_message.message_id + assert result.message["role"] == sample_message.message["role"] + assert result.message["content"] == sample_message.message["content"] + + +def test_read_nonexistent_message(s3_manager, sample_session, sample_agent, sample_message): + """Test reading a message from S3.""" + # Create session, agent, and message + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Read message + result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + + assert result is None + + +def test_list_messages_all(s3_manager, sample_session, sample_agent): + """Test listing all messages from S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + messages = [] + for i in range(5): + message = SessionMessage( + { + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + } + ) + messages.append(message) + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List all messages + result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 5 + + +def test_list_messages_with_pagination(s3_manager, sample_session, sample_agent): + """Test listing messages with pagination in S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for _ in range(10): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text="test_message")], + } + ) + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with limit + result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3) + assert len(result) == 3 + + # List with offset + result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5) + assert len(result) == 5 + + +def test_update_message(s3_manager, sample_session, sample_agent, sample_message): + """Test updating a message in S3.""" + # Create session, agent, and message + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Update message + sample_message.message["content"] = [ContentBlock(text="Updated content")] + s3_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify update + result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + assert result.message["content"][0]["text"] == "Updated content" + + +def test_update_nonexistent_message(s3_manager, sample_session, sample_agent, sample_message): + """Test updating a message in S3.""" + # Create session, agent, and message + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Update message + with pytest.raises(SessionException): + s3_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) diff --git a/tests/strands/types/test_session.py b/tests/strands/types/test_session.py new file mode 100644 index 000000000..596fe337c --- /dev/null +++ b/tests/strands/types/test_session.py @@ -0,0 +1,91 @@ +import json +from dataclasses import asdict +from uuid import uuid4 + +from strands.types.session import ( + Session, + SessionAgent, + SessionMessage, + SessionType, + decode_bytes_values, + encode_bytes_values, +) + + +def test_session_json_serializable(): + session = Session(session_id=str(uuid4()), session_type=SessionType.AGENT) + # json dumps will fail if its not json serializable + session_json_string = json.dumps(asdict(session)) + loaded_session = Session.from_dict(json.loads(session_json_string)) + assert loaded_session is not None + + +def test_agent_json_serializable(): + agent = SessionAgent(agent_id=str(uuid4()), state={"foo": "bar"}) + # json dumps will fail if its not json serializable + agent_json_string = json.dumps(asdict(agent)) + loaded_agent = SessionAgent.from_dict(json.loads(agent_json_string)) + assert loaded_agent is not None + + +def test_message_json_serializable(): + message = SessionMessage(message={"role": "user", "content": [{"text": "Hello!"}]}) + # json dumps will fail if its not json serializable + message_json_string = json.dumps(asdict(message)) + loaded_message = SessionMessage.from_dict(json.loads(message_json_string)) + assert loaded_message is not None + + +def test_bytes_encoding_decoding(): + # Test simple bytes + test_bytes = b"Hello, world!" + encoded = encode_bytes_values(test_bytes) + assert isinstance(encoded, dict) + assert encoded["__bytes_encoded__"] is True + decoded = decode_bytes_values(encoded) + assert decoded == test_bytes + + # Test nested structure with bytes + test_data = { + "text": "Hello", + "binary": b"Binary data", + "nested": {"more_binary": b"More binary data", "list_with_binary": [b"Item 1", "Text item", b"Item 3"]}, + } + + encoded = encode_bytes_values(test_data) + # Verify it's JSON serializable + json_str = json.dumps(encoded) + # Deserialize and decode + decoded = decode_bytes_values(json.loads(json_str)) + + # Verify the decoded data matches the original + assert decoded["text"] == test_data["text"] + assert decoded["binary"] == test_data["binary"] + assert decoded["nested"]["more_binary"] == test_data["nested"]["more_binary"] + assert decoded["nested"]["list_with_binary"][0] == test_data["nested"]["list_with_binary"][0] + assert decoded["nested"]["list_with_binary"][1] == test_data["nested"]["list_with_binary"][1] + assert decoded["nested"]["list_with_binary"][2] == test_data["nested"]["list_with_binary"][2] + + +def test_session_message_with_bytes(): + # Create a message with bytes content + message = { + "role": "user", + "content": [{"text": "Here is some binary data"}, {"binary_data": b"This is binary data"}], + } + + # Create a SessionMessage + session_message = SessionMessage.from_message(message) + + # Verify it's JSON serializable + message_json_string = json.dumps(asdict(session_message)) + + # Load it back + loaded_message = SessionMessage.from_dict(json.loads(message_json_string)) + + # Convert back to original message and verify + original_message = loaded_message.to_message() + + assert original_message["role"] == message["role"] + assert original_message["content"][0]["text"] == message["content"][0]["text"] + assert original_message["content"][1]["binary_data"] == message["content"][1]["binary_data"] diff --git a/tests_integ/test_session.py b/tests_integ/test_session.py new file mode 100644 index 000000000..fbfd54384 --- /dev/null +++ b/tests_integ/test_session.py @@ -0,0 +1,123 @@ +"""Integration tests for session management.""" + +import tempfile +from uuid import uuid4 + +import boto3 +import pytest +from botocore.client import ClientError + +from strands import Agent +from strands.session.file_session_manager import FileSessionManager +from strands.session.s3_session_manager import S3SessionManager + + +@pytest.fixture +def yellow_img(pytestconfig): + path = pytestconfig.rootdir / "tests_integ/yellow.png" + with open(path, "rb") as fp: + return fp.read() + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + +@pytest.fixture +def bucket_name(): + bucket_name = f"test-strands-session-bucket-{boto3.client('sts').get_caller_identity()['Account']}" + s3_client = boto3.resource("s3", region_name="us-west-2") + try: + s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": "us-west-2"}) + except ClientError as e: + if "BucketAlreadyOwnedByYou" not in str(e): + raise e + yield bucket_name + + +def test_agent_with_file_session(temp_dir): + # Set up the session manager and add an agent + test_session_id = str(uuid4()) + # Create a session + session_manager = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) + try: + agent = Agent(session_manager=session_manager) + agent("Hello!") + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 + + # After agent is persisted and run, restore the agent and run it again + session_manager_2 = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) + agent_2 = Agent(session_manager=session_manager_2) + assert len(agent_2.messages) == 2 + agent_2("Hello!") + assert len(agent_2.messages) == 4 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + finally: + # Delete the session + session_manager.delete_session(test_session_id) + assert session_manager.read_session(test_session_id) is None + + +def test_agent_with_file_session_with_image(temp_dir, yellow_img): + test_session_id = str(uuid4()) + # Create a session + session_manager = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) + try: + agent = Agent(session_manager=session_manager) + agent([{"image": {"format": "png", "source": {"bytes": yellow_img}}}]) + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 + + # After agent is persisted and run, restore the agent and run it again + session_manager_2 = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) + agent_2 = Agent(session_manager=session_manager_2) + assert len(agent_2.messages) == 2 + agent_2("Hello!") + assert len(agent_2.messages) == 4 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + finally: + # Delete the session + session_manager.delete_session(test_session_id) + assert session_manager.read_session(test_session_id) is None + + +def test_agent_with_s3_session(bucket_name): + test_session_id = str(uuid4()) + session_manager = S3SessionManager(session_id=test_session_id, bucket=bucket_name, region_name="us-west-2") + try: + agent = Agent(session_manager=session_manager) + agent("Hello!") + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 + + # After agent is persisted and run, restore the agent and run it again + session_manager_2 = S3SessionManager(session_id=test_session_id, bucket=bucket_name, region_name="us-west-2") + agent_2 = Agent(session_manager=session_manager_2) + assert len(agent_2.messages) == 2 + agent_2("Hello!") + assert len(agent_2.messages) == 4 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + finally: + session_manager.delete_session(test_session_id) + assert session_manager.read_session(test_session_id) is None + + +def test_agent_with_s3_session_with_image(yellow_img, bucket_name): + test_session_id = str(uuid4()) + session_manager = S3SessionManager(session_id=test_session_id, bucket=bucket_name, region_name="us-west-2") + try: + agent = Agent(session_manager=session_manager) + agent([{"image": {"format": "png", "source": {"bytes": yellow_img}}}]) + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 + + # After agent is persisted and run, restore the agent and run it again + session_manager_2 = S3SessionManager(session_id=test_session_id, bucket=bucket_name, region_name="us-west-2") + agent_2 = Agent(session_manager=session_manager_2) + assert len(agent_2.messages) == 2 + agent_2("Hello!") + assert len(agent_2.messages) == 4 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + finally: + session_manager.delete_session(test_session_id) + assert session_manager.read_session(test_session_id) is None From 0bc4cd0168fbdcfc973d83405abdd1542a6dfa3f Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Sun, 13 Jul 2025 16:20:26 +0000 Subject: [PATCH 2/2] refactor: add pr feedback --- src/strands/agent/agent.py | 6 +- src/strands/session/file_session_manager.py | 79 +++++++++++----- ...nager.py => repository_session_manager.py} | 42 +++++---- src/strands/session/s3_session_manager.py | 89 ++++++++++++++----- src/strands/session/session_manager.py | 13 +-- src/strands/session/session_repository.py | 5 +- src/strands/types/session.py | 10 ++- tests/strands/agent/test_agent.py | 8 +- .../session/test_file_session_manager.py | 45 +++++++++- ....py => test_repository_session_manager.py} | 28 ++---- .../session/test_s3_session_manager.py | 4 +- 11 files changed, 228 insertions(+), 101 deletions(-) rename src/strands/session/{agent_session_manager.py => repository_session_manager.py} (67%) rename tests/strands/session/{test_agent_session_manager.py => test_repository_session_manager.py} (80%) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4925abfca..54e7a58ec 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -318,9 +318,9 @@ def __init__( self.hooks = HookRegistry() # Initialize session management functionality - self.session_manager = session_manager - if self.session_manager: - self.hooks.add_hook(self.session_manager) + self._session_manager = session_manager + if self._session_manager: + self.hooks.add_hook(self._session_manager) if hooks: for hook in hooks: diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 615fad7fc..2748f2e20 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -10,7 +10,7 @@ from ..types.exceptions import SessionException from ..types.session import Session, SessionAgent, SessionMessage -from .agent_session_manager import AgentSessionManager +from .repository_session_manager import RepositorySessionManager from .session_repository import SessionRepository logger = logging.getLogger(__name__) @@ -20,8 +20,21 @@ MESSAGE_PREFIX = "message_" -class FileSessionManager(AgentSessionManager, SessionRepository): - """File-based session manager for local filesystem storage.""" +class FileSessionManager(RepositorySessionManager, SessionRepository): + """File-based session manager for local filesystem storage. + + Creates the following filesystem structure for the session storage: + // + └── session_/ + ├── session.json # Session metadata + └── agents/ + └── agent_/ + ├── agent.json # Agent metadata + └── messages/ + ├── message__.json + └── message__.json + + """ def __init__(self, session_id: str, storage_dir: Optional[str] = None): """Initialize FileSession with filesystem storage. @@ -44,10 +57,22 @@ def _get_agent_path(self, session_id: str, agent_id: str) -> str: session_path = self._get_session_path(session_id) return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}") - def _get_message_path(self, session_id: str, agent_id: str, message_id: str) -> str: - """Get message file path.""" + def _get_message_path(self, session_id: str, agent_id: str, message_id: str, timestamp: str) -> str: + """Get message file path. + + Args: + session_id: ID of the session + agent_id: ID of the agent + message_id: ID of the message + timestamp: ISO format timestamp to include in filename for sorting + Returns: + The filename for the message + """ agent_path = self._get_agent_path(session_id, agent_id) - return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json") + # Use timestamp for sortable filenames + # Replace colons and periods in ISO format with underscores for filesystem compatibility + filename_timestamp = timestamp.replace(":", "_").replace(".", "_") + return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{filename_timestamp}_{message_id}.json") def _read_file(self, path: str) -> dict[str, Any]: """Read JSON file.""" @@ -135,17 +160,26 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio session_id, agent_id, session_message.message_id, + session_message.created_at, ) session_dict = asdict(session_message) self._write_file(message_file, session_dict) def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]: """Read message data.""" - message_file = self._get_message_path(session_id, agent_id, message_id) - if not os.path.exists(message_file): + # Get the messages directory + messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages") + if not os.path.exists(messages_dir): return None - message_data = self._read_file(message_file) - return SessionMessage.from_dict(message_data) + + # List files in messages directory, and check if the filename ends with the message id + for filename in os.listdir(messages_dir): + if filename.endswith(f"{message_id}.json"): + file_path = os.path.join(messages_dir, filename) + message_data = self._read_file(file_path) + return SessionMessage.from_dict(message_data) + + return None def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: """Update message data.""" @@ -156,7 +190,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio # Preserve the original created_at timestamp session_message.created_at = previous_message.created_at - message_file = self._get_message_path(session_id, agent_id, message_id) + message_file = self._get_message_path(session_id, agent_id, message_id, session_message.created_at) self._write_file(message_file, asdict(session_message)) def list_messages( @@ -168,20 +202,25 @@ def list_messages( raise SessionException(f"Messages directory missing from agent: {agent_id} in session {session_id}") # Read all message files - messages: list[SessionMessage] = [] + message_files: list[str] = [] for filename in os.listdir(messages_dir): if filename.startswith(MESSAGE_PREFIX) and filename.endswith(".json"): - file_path = os.path.join(messages_dir, filename) - message_data = self._read_file(file_path) - messages.append(SessionMessage.from_dict(message_data)) + message_files.append(filename) - # Sort by created_at timestamp (oldest first) - messages.sort(key=lambda x: x.created_at) + # Sort filenames - the timestamp in the file's name will sort chronologically + message_files.sort() - # Apply pagination + # Apply pagination to filenames if limit is not None: - messages = messages[offset : offset + limit] + message_files = message_files[offset : offset + limit] else: - messages = messages[offset:] + message_files = message_files[offset:] + + # Load only the message files + messages: list[SessionMessage] = [] + for filename in message_files: + file_path = os.path.join(messages_dir, filename) + message_data = self._read_file(file_path) + messages.append(SessionMessage.from_dict(message_data)) return messages diff --git a/src/strands/session/agent_session_manager.py b/src/strands/session/repository_session_manager.py similarity index 67% rename from src/strands/session/agent_session_manager.py rename to src/strands/session/repository_session_manager.py index 4e0d86dec..fd31d9671 100644 --- a/src/strands/session/agent_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -1,8 +1,8 @@ -"""Agent session manager implementation.""" +"""Repository session manager implementation.""" import logging -from ..agent.agent import _DEFAULT_AGENT_ID, Agent +from ..agent.agent import Agent from ..agent.state import AgentState from ..types.content import Message from ..types.exceptions import SessionException @@ -17,29 +17,38 @@ logger = logging.getLogger(__name__) -DEFAULT_SESSION_AGENT_ID = "default" - -class AgentSessionManager(SessionManager): - """Session manager for persisting agent's in a Session.""" +class RepositorySessionManager(SessionManager): + """Session manager for persisting agents in a SessionRepository.""" def __init__( self, session_id: str, session_repository: SessionRepository, ): - """Initialize the AgentSessionManager.""" + """Initialize the RepositorySessionManager. + + If no session with the specified session_id exists yet, it will be created + in the session_repository. + + Args: + session_id: ID to use for the session. A new session with this id will be created if it does + not exist in the reposiory yet + session_repository: Underlying session repository to use to store the sessions state. + """ self.session_repository = session_repository self.session_id = session_id session = session_repository.read_session(session_id) # Create a session if it does not exist yet if session is None: - logger.debug("session_id=<%s> | Session not found, creating new session.", self.session_id) + logger.debug("session_id=<%s> | session not found, creating new session", self.session_id) session = Session(session_id=session_id, session_type=SessionType.AGENT) session_repository.create_session(session) self.session = session - self._default_agent_initialized = False + + # Keep track of the initialized agent id's so that two agents in a session cannot share an id + self._initialized_agent_ids: set[str] = set() def append_message(self, message: Message, agent: Agent) -> None: """Append a message to the agent's session. @@ -49,12 +58,10 @@ def append_message(self, message: Message, agent: Agent) -> None: agent: Agent to append the message to """ session_message = SessionMessage.from_message(message) - if agent.agent_id is None: - raise ValueError("`agent.agent_id` must be set before appending message to session.") self.session_repository.create_message(self.session_id, agent.agent_id, session_message) def sync_agent(self, agent: Agent) -> None: - """Sync agent to the session. + """Serialize and update the agent into the session repository. Args: agent: Agent to sync to the session. @@ -70,16 +77,15 @@ def initialize(self, agent: Agent) -> None: Args: agent: Agent to initialize from the session """ - if agent.agent_id is _DEFAULT_AGENT_ID: - if self._default_agent_initialized: - raise SessionException("Set `agent_id` to support more than one agent in a session.") - self._default_agent_initialized = True + if agent.agent_id in self._initialized_agent_ids: + raise SessionException("The `agent_id` of an agent must be unique in a session.") + self._initialized_agent_ids.add(agent.agent_id) session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) if session_agent is None: logger.debug( - "agent_id=<%s> | session_id=<%s> | Creating agent.", + "agent_id=<%s> | session_id=<%s> | creating agent", agent.agent_id, self.session_id, ) @@ -91,7 +97,7 @@ def initialize(self, agent: Agent) -> None: self.session_repository.create_message(self.session_id, agent.agent_id, session_message) else: logger.debug( - "agent_id=<%s> | session_id=<%s> | Restoring agent.", + "agent_id=<%s> | session_id=<%s> | restoring agent", agent.agent_id, self.session_id, ) diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 013339e22..af14c5386 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -11,7 +11,7 @@ from ..types.exceptions import SessionException from ..types.session import Session, SessionAgent, SessionMessage -from .agent_session_manager import AgentSessionManager +from .repository_session_manager import RepositorySessionManager from .session_repository import SessionRepository logger = logging.getLogger(__name__) @@ -21,8 +21,21 @@ MESSAGE_PREFIX = "message_" -class S3SessionManager(AgentSessionManager, SessionRepository): - """S3-based session manager for cloud storage.""" +class S3SessionManager(RepositorySessionManager, SessionRepository): + """S3-based session manager for cloud storage. + + Creates the following filesystem structure for the session storage: + // + └── session_/ + ├── session.json # Session metadata + └── agents/ + └── agent_/ + ├── agent.json # Agent metadata + └── messages/ + ├── message__.json + └── message__.json + + """ def __init__( self, @@ -72,10 +85,22 @@ def _get_agent_path(self, session_id: str, agent_id: str) -> str: session_path = self._get_session_path(session_id) return f"{session_path}agents/{AGENT_PREFIX}{agent_id}/" - def _get_message_path(self, session_id: str, agent_id: str, message_id: str) -> str: - """Get message S3 key.""" + def _get_message_path(self, session_id: str, agent_id: str, message_id: str, timestamp: str) -> str: + """Get message S3 key. + + Args: + session_id: ID of the session + agent_id: ID of the agent + message_id: ID of the message + timestamp: ISO format timestamp to include in key for sorting + Returns: + The key for the message + """ agent_path = self._get_agent_path(session_id, agent_id) - return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json" + # Use timestamp for sortable keys + # Replace colons and periods in ISO format with underscores for filesystem compatibility + filename_timestamp = timestamp.replace(":", "_").replace(".", "_") + return f"{agent_path}messages/{MESSAGE_PREFIX}{filename_timestamp}_{message_id}.json" def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]: """Read JSON object from S3.""" @@ -180,16 +205,29 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio """Create a new message in S3.""" message_id = session_message.message_id message_dict = asdict(session_message) - message_key = self._get_message_path(session_id, agent_id, message_id) + message_key = self._get_message_path(session_id, agent_id, message_id, session_message.created_at) self._write_s3_object(message_key, message_dict) def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]: """Read message data from S3.""" - message_key = self._get_message_path(session_id, agent_id, message_id) - message_data = self._read_s3_object(message_key) - if message_data is None: + # Get the messages prefix + messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" + try: + paginator = self.client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=self.bucket, Prefix=messages_prefix) + + for page in pages: + if "Contents" in page: + for obj in page["Contents"]: + if obj["Key"].endswith(f"{message_id}.json"): + message_data = self._read_s3_object(obj["Key"]) + if message_data: + return SessionMessage.from_dict(message_data) + return None - return SessionMessage.from_dict(message_data) + + except ClientError as e: + raise SessionException(f"S3 error reading message: {e}") from e def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: """Update message data in S3.""" @@ -200,7 +238,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio # Preserve creation timestamp session_message.created_at = previous_message.created_at - message_key = self._get_message_path(session_id, agent_id, message_id) + message_key = self._get_message_path(session_id, agent_id, message_id, session_message.created_at) self._write_s3_object(message_key, asdict(session_message)) def list_messages( @@ -212,24 +250,29 @@ def list_messages( paginator = self.client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=self.bucket, Prefix=messages_prefix) - # Read all message objects - messages: List[SessionMessage] = [] + # Collect all message keys first + message_keys = [] for page in pages: if "Contents" in page: for obj in page["Contents"]: - if obj["Key"].endswith(".json"): - message_data = self._read_s3_object(obj["Key"]) - if message_data: - messages.append(SessionMessage.from_dict(message_data)) + if obj["Key"].endswith(".json") and MESSAGE_PREFIX in obj["Key"]: + message_keys.append(obj["Key"]) - # Sort by created_at timestamp (oldest first) - messages.sort(key=lambda x: x.created_at) + # Sort keys - timestamp prefixed keys will sort chronologically + message_keys.sort() - # Apply pagination + # Apply pagination to keys before loading content if limit is not None: - messages = messages[offset : offset + limit] + message_keys = message_keys[offset : offset + limit] else: - messages = messages[offset:] + message_keys = message_keys[offset:] + + # Load only the required message objects + messages: List[SessionMessage] = [] + for key in message_keys: + message_data = self._read_s3_object(key) + if message_data: + messages.append(SessionMessage.from_dict(message_data)) return messages diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 984ae6f81..6f071f929 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -14,13 +14,14 @@ class SessionManager(HookProvider, ABC): """Abstract interface for managing sessions. - A session represents a complete interaction context including conversation - history, user information, agent state, and metadata. This interface provides - methods to manage sessions and their associated data. + A session manager is in charge of persisting the conversation and state of an agent across its interaction. + Changes made to the agents conversation, state, or other attributes should be persisted immediately after + they are changed. The different methods introduced in this class are called at important lifecycle events + for an agent, and should be persisted in the session. """ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: - """Register initialize and append_message as hooks for the Agent.""" + """Register hooks for persisting the agent to the session.""" registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent)) registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent)) @@ -36,10 +37,10 @@ def append_message(self, message: Message, agent: "Agent") -> None: @abstractmethod def sync_agent(self, agent: "Agent") -> None: - """Sync the agent to the session. + """Serialize and sync the agent with the session storage. Args: - agent: Agent to sync to the session + agent: Agent who should be synchronized with the session storage """ @abstractmethod diff --git a/src/strands/session/session_repository.py b/src/strands/session/session_repository.py index 9b6465f28..b9735e05f 100644 --- a/src/strands/session/session_repository.py +++ b/src/strands/session/session_repository.py @@ -39,7 +39,10 @@ def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optio @abstractmethod def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None: - """Update a Message.""" + """Update a Message. + + A message is usually only updated when some content is redacted due to a guardrail. + """ @abstractmethod def list_messages( diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 2779b2866..50d82b368 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -13,7 +13,11 @@ class SessionType(str, Enum): - """Enumeration of session types.""" + """Enumeration of session types. + + As sessions are expanded to support new usecases like multi-agent patterns, + new types will be added here. + """ AGENT = "AGENT" @@ -74,13 +78,13 @@ def to_message(self) -> Message: @classmethod def from_dict(cls, env: dict[str, Any]) -> "SessionMessage": - """Initialize a SessionMessage from a dictionary, ignoring keys that are not calss parameters.""" + """Initialize a SessionMessage from a dictionary, ignoring keys that are not class parameters.""" return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) @dataclass class SessionAgent: - """Agent within a Session.""" + """Agent that belongs to a Session.""" agent_id: str state: Dict[str, Any] diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index b9aa15c91..c5453c5fe 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -16,7 +16,7 @@ from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel -from strands.session.agent_session_manager import DEFAULT_SESSION_AGENT_ID, AgentSessionManager +from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionType @@ -1404,7 +1404,7 @@ def test_agent_state_get_breaks_deep_dict_reference(): def test_agent_session_management(): mock_session_repository = MockedSessionRepository() - session_manager = AgentSessionManager(session_id="123", session_repository=mock_session_repository) + session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) agent = Agent(session_manager=session_manager, model=model) agent("Hello!") @@ -1416,11 +1416,11 @@ def test_agent_restored_from_session_management(): mock_session_repository.create_agent( "123", SessionAgent( - agent_id=DEFAULT_SESSION_AGENT_ID, + agent_id="default", state={"foo": "bar"}, ), ) - session_manager = AgentSessionManager(session_id="123", session_repository=mock_session_repository) + session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) agent = Agent(session_manager=session_manager) diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index f6acccc7b..3153c611d 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -165,6 +165,15 @@ def test_update_agent(self, file_manager, sample_session, sample_agent): result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) assert result.state == {"updated": "value"} + def test_update_nonexistent_agent(self, file_manager, sample_session, sample_agent): + """Test updating an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + + # Update agent + with pytest.raises(SessionException): + file_manager.update_agent(sample_session.session_id, sample_agent) + class TestFileSessionManagerMessageOperations: """Tests for message operations.""" @@ -180,7 +189,7 @@ def test_create_message(self, file_manager, sample_session, sample_agent, sample # Verify message file message_path = file_manager._get_message_path( - sample_session.session_id, sample_agent.agent_id, sample_message.message_id + sample_session.session_id, sample_agent.agent_id, sample_message.message_id, sample_message.created_at ) assert os.path.exists(message_path) @@ -196,6 +205,10 @@ def test_read_message(self, file_manager, sample_session, sample_agent, sample_m file_manager.create_agent(sample_session.session_id, sample_agent) file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + # Create multiple messages when reading + sample_message.message_id = sample_message.message_id + "_2" + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + # Read message result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) @@ -203,6 +216,16 @@ def test_read_message(self, file_manager, sample_session, sample_agent, sample_m assert result.message["role"] == sample_message.message["role"] assert result.message["content"] == sample_message.message["content"] + def test_read_messages_with_new_agent(self, file_manager, sample_session, sample_agent): + """Test reading a message with with a new agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + + assert result is None + def test_read_nonexistent_message(self, file_manager, sample_session, sample_agent): """Test reading a message that doesnt exist.""" result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") @@ -273,6 +296,16 @@ def test_list_messages_with_offset(self, file_manager, sample_session, sample_ag assert len(result) == 5 + def test_list_messages_with_new_agent(self, file_manager, sample_session, sample_agent): + """Test listing messages with new agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 0 + def test_update_message(self, file_manager, sample_session, sample_agent, sample_message): """Test updating a message.""" # Create session, agent, and message @@ -288,7 +321,15 @@ def test_update_message(self, file_manager, sample_session, sample_agent, sample result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) assert result.message["content"][0]["text"] == "Updated content" - # Note: delete_message is not implemented in FileSessionManager + def test_update_nonexistent_message(self, file_manager, sample_session, sample_agent, sample_message): + """Test updating a message.""" + # Create session, agent, and message + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Update nonexistent message + with pytest.raises(SessionException): + file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) class TestFileSessionManagerErrorHandling: diff --git a/tests/strands/session/test_agent_session_manager.py b/tests/strands/session/test_repository_session_manager.py similarity index 80% rename from tests/strands/session/test_agent_session_manager.py rename to tests/strands/session/test_repository_session_manager.py index 11de4fb2d..10901d30b 100644 --- a/tests/strands/session/test_agent_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -3,7 +3,7 @@ import pytest from strands.agent.agent import Agent -from strands.session.agent_session_manager import AgentSessionManager +from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import ContentBlock from strands.types.exceptions import SessionException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType @@ -19,7 +19,7 @@ def mock_repository(): @pytest.fixture def session_manager(mock_repository): """Create a session manager with mock repository.""" - return AgentSessionManager(session_id="test-session", session_repository=mock_repository) + return RepositorySessionManager(session_id="test-session", session_repository=mock_repository) @pytest.fixture @@ -34,7 +34,7 @@ def test_init_creates_session_if_not_exists(mock_repository): assert mock_repository.read_session("test-session") is None # Creating manager should create session - AgentSessionManager(session_id="test-session", session_repository=mock_repository) + RepositorySessionManager(session_id="test-session", session_repository=mock_repository) # Verify session created session = mock_repository.read_session("test-session") @@ -50,7 +50,7 @@ def test_init_uses_existing_session(mock_repository): mock_repository.create_session(session) # Creating manager should use existing session - manager = AgentSessionManager(session_id="test-session", session_repository=mock_repository) + manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) # Verify session used assert manager.session == session @@ -71,14 +71,15 @@ def test_initialize_with_existing_agent_id(session_manager, agent): def test_initialize_multiple_agents_without_id(session_manager, agent): - """Test initializing multiple agents without IDs.""" + """Test initializing multiple agents with same ID.""" # First agent initialization works + agent.agent_id = "custom-agent" session_manager.initialize(agent) # Second agent with no set agent_id should fail - agent2 = Agent() + agent2 = Agent(agent_id="custom-agent") - with pytest.raises(SessionException, match="Set `agent_id` to support more than one agent in a session."): + with pytest.raises(SessionException, match="The `agent_id` of an agent must be unique in a session."): session_manager.initialize(agent2) @@ -133,16 +134,3 @@ def test_append_message(session_manager, agent): assert len(messages) == 1 assert messages[0].message["role"] == "user" assert messages[0].message["content"][0]["text"] == "Hello" - - -def test_append_message_without_agent_id(session_manager, agent): - """Test appending a message to an agent without ID.""" - # Agent has no ID - agent.agent_id = None - - # Create message - message = {"role": "user", "content": [{"type": "text", "text": "Hello"}]} - - # Append message should fail - with pytest.raises(ValueError, match="`agent.agent_id` must be set"): - session_manager.append_message(message, agent) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 59a92a6ec..ffc05e53e 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -219,7 +219,9 @@ def test_create_message(s3_manager, sample_session, sample_agent, sample_message s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) # Verify S3 object created - key = s3_manager._get_message_path(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + key = s3_manager._get_message_path( + sample_session.session_id, sample_agent.agent_id, sample_message.message_id, sample_message.created_at + ) response = s3_manager.client.get_object(Bucket=s3_manager.bucket, Key=key) data = json.loads(response["Body"].read().decode("utf-8"))