Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 21 additions & 33 deletions src/strands/session/file_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import shutil
import tempfile
from dataclasses import asdict
from typing import Any, Optional, cast

from ..types.exceptions import SessionException
Expand Down Expand Up @@ -57,22 +56,18 @@ def _get_agent_path(self, session_id: str, agent_id: str) -> str:
session_path = self._get_session_path(session_id)
return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}")

def _get_message_path(self, session_id: str, agent_id: str, message_id: str, timestamp: str) -> str:
def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str:
"""Get message file path.

Args:
session_id: ID of the session
agent_id: ID of the agent
message_id: ID of the message
timestamp: ISO format timestamp to include in filename for sorting
message_id: Index of the message
Returns:
The filename for the message
"""
agent_path = self._get_agent_path(session_id, agent_id)
# Use timestamp for sortable filenames
# Replace colons and periods in ISO format with underscores for filesystem compatibility
filename_timestamp = timestamp.replace(":", "_").replace(".", "_")
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{filename_timestamp}_{message_id}.json")
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json")

def _read_file(self, path: str) -> dict[str, Any]:
"""Read JSON file."""
Expand Down Expand Up @@ -100,7 +95,7 @@ def create_session(self, session: Session) -> Session:

# Write session file
session_file = os.path.join(session_dir, "session.json")
session_dict = asdict(session)
session_dict = session.to_dict()
self._write_file(session_file, session_dict)

return session
Expand All @@ -123,7 +118,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent) -> None:
os.makedirs(os.path.join(agent_dir, "messages"), exist_ok=True)

agent_file = os.path.join(agent_dir, "agent.json")
session_data = asdict(session_agent)
session_data = session_agent.to_dict()
self._write_file(agent_file, session_data)

def delete_session(self, session_id: str) -> None:
Expand Down Expand Up @@ -152,34 +147,25 @@ def update_agent(self, session_id: str, session_agent: SessionAgent) -> None:

session_agent.created_at = previous_agent.created_at
agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json")
self._write_file(agent_file, asdict(session_agent))
self._write_file(agent_file, session_agent.to_dict())

def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
"""Create a new message for the agent."""
message_file = self._get_message_path(
session_id,
agent_id,
session_message.message_id,
session_message.created_at,
)
session_dict = asdict(session_message)
session_dict = session_message.to_dict()
self._write_file(message_file, session_dict)

def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]:
def read_message(self, session_id: str, agent_id: str, message_id: int) -> Optional[SessionMessage]:
"""Read message data."""
# Get the messages directory
messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages")
if not os.path.exists(messages_dir):
message_path = self._get_message_path(session_id, agent_id, message_id)
if not os.path.exists(message_path):
return None

# List files in messages directory, and check if the filename ends with the message id
for filename in os.listdir(messages_dir):
if filename.endswith(f"{message_id}.json"):
file_path = os.path.join(messages_dir, filename)
message_data = self._read_file(file_path)
return SessionMessage.from_dict(message_data)

return None
message_data = self._read_file(message_path)
return SessionMessage.from_dict(message_data)

def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
"""Update message data."""
Expand All @@ -190,8 +176,8 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio

# Preserve the original created_at timestamp
session_message.created_at = previous_message.created_at
message_file = self._get_message_path(session_id, agent_id, message_id, session_message.created_at)
self._write_file(message_file, asdict(session_message))
message_file = self._get_message_path(session_id, agent_id, message_id)
self._write_file(message_file, session_message.to_dict())

def list_messages(
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0
Expand All @@ -201,14 +187,16 @@ def list_messages(
if not os.path.exists(messages_dir):
raise SessionException(f"Messages directory missing from agent: {agent_id} in session {session_id}")

# Read all message files
message_files: list[str] = []
# Read all message files, and record the index
message_index_files: list[tuple[int, str]] = []
for filename in os.listdir(messages_dir):
if filename.startswith(MESSAGE_PREFIX) and filename.endswith(".json"):
message_files.append(filename)
# Extract index from message_<index>.json format
index = int(filename[len(MESSAGE_PREFIX) : -5]) # Remove prefix and .json suffix
message_index_files.append((index, filename))

# Sort filenames - the timestamp in the file's name will sort chronologically
message_files.sort()
# Sort by index and extract just the filenames
message_files = [f for _, f in sorted(message_index_files)]

# Apply pagination to filenames
if limit is not None:
Expand Down
35 changes: 21 additions & 14 deletions src/strands/session/repository_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,8 @@ def __init__(

self.session = session

# Keep track of the initialized agent id's so that two agents in a session cannot share an id
self._initialized_agent_ids: set[str] = set()

# Keep track of the latest message stored in the session in case we need to redact its content.
self._latest_message: Optional[SessionMessage] = None
# Keep track of the latest message of each agent in case we need to redact it.
self._latest_agent_message: dict[str, Optional[SessionMessage]] = {}

def append_message(self, message: Message, agent: Agent) -> None:
"""Append a message to the agent's session.
Expand All @@ -61,8 +58,16 @@ def append_message(self, message: Message, agent: Agent) -> None:
message: Message to add to the agent in the session
agent: Agent to append the message to
"""
self._latest_message = SessionMessage.from_message(message)
self.session_repository.create_message(self.session_id, agent.agent_id, self._latest_message)
# Calculate the next index (0 if this is the first message, otherwise increment the previous index)
latest_agent_message = self._latest_agent_message[agent.agent_id]
if latest_agent_message:
next_index = latest_agent_message.message_id + 1
else:
next_index = 0

session_message = SessionMessage.from_message(message, next_index)
self._latest_agent_message[agent.agent_id] = session_message
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)

def redact_latest_message(self, redact_message: Message, agent: Agent) -> None:
"""Redact the latest message appended to the session.
Expand All @@ -71,10 +76,11 @@ def redact_latest_message(self, redact_message: Message, agent: Agent) -> None:
redact_message: New message to use that contains the redact content
agent: Agent to apply the message redaction to
"""
if self._latest_message is None:
latest_agent_message = self._latest_agent_message[agent.agent_id]
if latest_agent_message is None:
raise SessionException("No message to redact.")
self._latest_message.redact_message = redact_message
return self.session_repository.update_message(self.session_id, agent.agent_id, self._latest_message)
latest_agent_message.redact_message = redact_message
return self.session_repository.update_message(self.session_id, agent.agent_id, latest_agent_message)

def sync_agent(self, agent: Agent) -> None:
"""Serialize and update the agent into the session repository.
Expand All @@ -93,9 +99,9 @@ def initialize(self, agent: Agent) -> None:
Args:
agent: Agent to initialize from the session
"""
if agent.agent_id in self._initialized_agent_ids:
if agent.agent_id in self._latest_agent_message:
raise SessionException("The `agent_id` of an agent must be unique in a session.")
self._initialized_agent_ids.add(agent.agent_id)
self._latest_agent_message[agent.agent_id] = None

session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id)

Expand All @@ -108,8 +114,9 @@ def initialize(self, agent: Agent) -> None:

session_agent = SessionAgent.from_agent(agent)
self.session_repository.create_agent(self.session_id, session_agent)
for message in agent.messages:
session_message = SessionMessage.from_message(message)
# Initialize messages with sequential indices
for i, message in enumerate(agent.messages):
session_message = SessionMessage.from_message(message, i)
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
else:
logger.debug(
Expand Down
67 changes: 27 additions & 40 deletions src/strands/session/s3_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import json
import logging
from dataclasses import asdict
from typing import Any, Dict, List, Optional, cast

import boto3
Expand Down Expand Up @@ -85,22 +84,18 @@ def _get_agent_path(self, session_id: str, agent_id: str) -> str:
session_path = self._get_session_path(session_id)
return f"{session_path}agents/{AGENT_PREFIX}{agent_id}/"

def _get_message_path(self, session_id: str, agent_id: str, message_id: str, timestamp: str) -> str:
def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str:
"""Get message S3 key.

Args:
session_id: ID of the session
agent_id: ID of the agent
message_id: ID of the message
timestamp: ISO format timestamp to include in key for sorting
message_id: Index of the message
Returns:
The key for the message
"""
agent_path = self._get_agent_path(session_id, agent_id)
# Use timestamp for sortable keys
# Replace colons and periods in ISO format with underscores for filesystem compatibility
filename_timestamp = timestamp.replace(":", "_").replace(".", "_")
return f"{agent_path}messages/{MESSAGE_PREFIX}{filename_timestamp}_{message_id}.json"
return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json"

def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]:
"""Read JSON object from S3."""
Expand Down Expand Up @@ -139,7 +134,7 @@ def create_session(self, session: Session) -> Session:
raise SessionException(f"S3 error checking session existence: {e}") from e

# Write session object
session_dict = asdict(session)
session_dict = session.to_dict()
self._write_s3_object(session_key, session_dict)
return session

Expand Down Expand Up @@ -177,7 +172,7 @@ def delete_session(self, session_id: str) -> None:
def create_agent(self, session_id: str, session_agent: SessionAgent) -> None:
"""Create a new agent in S3."""
agent_id = session_agent.agent_id
agent_dict = asdict(session_agent)
agent_dict = session_agent.to_dict()
agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json"
self._write_s3_object(agent_key, agent_dict)

Expand All @@ -199,35 +194,22 @@ def update_agent(self, session_id: str, session_agent: SessionAgent) -> None:
# Preserve creation timestamp
session_agent.created_at = previous_agent.created_at
agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json"
self._write_s3_object(agent_key, asdict(session_agent))
self._write_s3_object(agent_key, session_agent.to_dict())

def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
"""Create a new message in S3."""
message_id = session_message.message_id
message_dict = asdict(session_message)
message_key = self._get_message_path(session_id, agent_id, message_id, session_message.created_at)
message_dict = session_message.to_dict()
message_key = self._get_message_path(session_id, agent_id, message_id)
self._write_s3_object(message_key, message_dict)

def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]:
def read_message(self, session_id: str, agent_id: str, message_id: int) -> Optional[SessionMessage]:
"""Read message data from S3."""
# Get the messages prefix
messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/"
try:
paginator = self.client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket, Prefix=messages_prefix)

for page in pages:
if "Contents" in page:
for obj in page["Contents"]:
if obj["Key"].endswith(f"{message_id}.json"):
message_data = self._read_s3_object(obj["Key"])
if message_data:
return SessionMessage.from_dict(message_data)

message_key = self._get_message_path(session_id, agent_id, message_id)
message_data = self._read_s3_object(message_key)
if message_data is None:
return None

except ClientError as e:
raise SessionException(f"S3 error reading message: {e}") from e
return SessionMessage.from_dict(message_data)

def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
"""Update message data in S3."""
Expand All @@ -238,8 +220,8 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio

# Preserve creation timestamp
session_message.created_at = previous_message.created_at
message_key = self._get_message_path(session_id, agent_id, message_id, session_message.created_at)
self._write_s3_object(message_key, asdict(session_message))
message_key = self._get_message_path(session_id, agent_id, message_id)
self._write_s3_object(message_key, session_message.to_dict())

def list_messages(
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0
Expand All @@ -250,16 +232,21 @@ def list_messages(
paginator = self.client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket, Prefix=messages_prefix)

# Collect all message keys first
message_keys = []
# Collect all message keys and extract their indices
message_index_keys: list[tuple[int, str]] = []
for page in pages:
if "Contents" in page:
for obj in page["Contents"]:
if obj["Key"].endswith(".json") and MESSAGE_PREFIX in obj["Key"]:
message_keys.append(obj["Key"])

# Sort keys - timestamp prefixed keys will sort chronologically
message_keys.sort()
key = obj["Key"]
if key.endswith(".json") and MESSAGE_PREFIX in key:
# Extract the filename part from the full S3 key
filename = key.split("/")[-1]
# Extract index from message_<index>.json format
index = int(filename[len(MESSAGE_PREFIX) : -5]) # Remove prefix and .json suffix
message_index_keys.append((index, key))

# Sort by index and extract just the keys
message_keys = [k for _, k in sorted(message_index_keys)]

# Apply pagination to keys before loading content
if limit is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/strands/session/session_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio
"""Create a new Message for the Agent."""

@abstractmethod
def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]:
def read_message(self, session_id: str, agent_id: str, message_id: int) -> Optional[SessionMessage]:
"""Read a Message."""

@abstractmethod
Expand Down
Loading
Loading