From fabf6792c5a34f945d31ae36d3446a5fabc6054a Mon Sep 17 00:00:00 2001 From: Jerome Van Der Linden Date: Wed, 15 Oct 2025 16:46:14 +0200 Subject: [PATCH] create Valkey Session Manager --- pyproject.toml | 3 +- src/strands/session/__init__.py | 2 + src/strands/session/valkey_session_manager.py | 251 ++++++++++++++ .../session/test_valkey_session_manager.py | 312 ++++++++++++++++++ 4 files changed, 567 insertions(+), 1 deletion(-) create mode 100644 src/strands/session/valkey_session_manager.py create mode 100644 tests/strands/session/test_valkey_session_manager.py diff --git a/pyproject.toml b/pyproject.toml index af8e45ffc..7005316a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ mistral = ["mistralai>=1.8.2"] ollama = ["ollama>=0.4.8,<1.0.0"] openai = ["openai>=1.68.0,<2.0.0"] writer = ["writer-sdk>=2.2.0,<3.0.0"] +valkey = ["valkey>=6.0.0,<7.0.0"] sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface @@ -68,7 +69,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,valkey,sagemaker,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", diff --git a/src/strands/session/__init__.py b/src/strands/session/__init__.py index 7b5310190..8791c61fc 100644 --- a/src/strands/session/__init__.py +++ b/src/strands/session/__init__.py @@ -8,6 +8,7 @@ from .s3_session_manager import S3SessionManager from .session_manager import SessionManager from .session_repository import SessionRepository +from .valkey_session_manager import ValkeySessionManager __all__ = [ "FileSessionManager", @@ -15,4 +16,5 @@ "S3SessionManager", "SessionManager", "SessionRepository", + "ValkeySessionManager", ] diff --git a/src/strands/session/valkey_session_manager.py b/src/strands/session/valkey_session_manager.py new file mode 100644 index 000000000..26daf2267 --- /dev/null +++ b/src/strands/session/valkey_session_manager.py @@ -0,0 +1,251 @@ +"""Valkey-based session manager for Redis-compatible storage.""" + +import json +import logging +from typing import Any, Dict, List, Optional, Union, cast + +import valkey + +from ..types.exceptions import SessionException +from ..types.session import Session, SessionAgent, SessionMessage +from .repository_session_manager import RepositorySessionManager +from .session_repository import SessionRepository + +logger = logging.getLogger(__name__) + +SESSION_PREFIX = "session" +AGENT_PREFIX = "agent" +MESSAGE_PREFIX = "message" + + +class ValkeySessionManager(RepositorySessionManager, SessionRepository): + """Valkey-based session manager for Redis-compatible storage. + + Creates the following key structure for the session storage: + ``` + session: # Session metadata (JSON) + session::agent: # Agent metadata (JSON) + session::agent::message: # Message data (JSON) + ``` + """ + + def __init__(self, session_id: str, client: Union[valkey.Valkey, valkey.ValkeyCluster], **kwargs: Any): + """Initialize ValkeySessionManager with Valkey storage. + + Args: + session_id: ID for the session + client: Pre-configured Valkey client (Valkey or ValkeyCluster) + **kwargs: Additional keyword arguments for future extensibility. + """ + self.client = client + super().__init__(session_id=session_id, session_repository=self) + + def _get_session_key(self, session_id: str) -> str: + """Get session key. + + Args: + session_id: ID for the session. + + Raises: + ValueError: If session_id contains colon characters. + """ + if ":" in session_id: + raise ValueError(f"session_id cannot contain ':' characters: {session_id}") + return f"{SESSION_PREFIX}:{session_id}" + + def _get_agent_key(self, session_id: str, agent_id: str) -> str: + """Get agent key. + + Args: + session_id: ID for the session. + agent_id: ID for the agent. + + Raises: + ValueError: If agent_id contains colon characters. + """ + if ":" in agent_id: + raise ValueError(f"agent_id cannot contain ':' characters: {agent_id}") + session_key = self._get_session_key(session_id) + return f"{session_key}:{AGENT_PREFIX}:{agent_id}" + + def _get_message_key(self, session_id: str, agent_id: str, message_id: int) -> str: + """Get message key. + + Args: + session_id: ID of the session + agent_id: ID of the agent + message_id: Index of the message + + Returns: + The key for the message + + Raises: + ValueError: If message_id is not an integer. + """ + if not isinstance(message_id, int): + raise ValueError(f"message_id=<{message_id}> | message id must be an integer") + + agent_key = self._get_agent_key(session_id, agent_id) + return f"{agent_key}:{MESSAGE_PREFIX}:{message_id}" + + def _read_json_object(self, key: str) -> Optional[Dict[str, Any]]: + """Read JSON object from Valkey.""" + try: + data = self.client.execute_command("JSON.GET", key) + if data is None: + return None + return cast(dict[str, Any], json.loads(data)) + except Exception as e: + raise SessionException(f"Valkey error reading {key}: {e}") from e + + def _write_json_object(self, key: str, data: Dict[str, Any]) -> None: + """Write JSON object to Valkey.""" + try: + json_data = json.dumps(data, ensure_ascii=False) + self.client.execute_command("JSON.SET", key, ".", json_data) + except Exception as e: + raise SessionException(f"Failed to write Valkey object {key}: {e}") from e + + def create_session(self, session: Session, **kwargs: Any) -> Session: + """Create a new session in Valkey.""" + session_key = self._get_session_key(session.session_id) + + # Check if session already exists + if self.client.exists(session_key): + raise SessionException(f"Session {session.session_id} already exists") + + # Write session object + session_dict = session.to_dict() + self._write_json_object(session_key, session_dict) + return session + + def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + """Read session data from Valkey.""" + session_key = self._get_session_key(session_id) + session_data = self._read_json_object(session_key) + if session_data is None: + return None + return Session.from_dict(session_data) + + def delete_session(self, session_id: str, **kwargs: Any) -> None: + """Delete session and all associated data from Valkey.""" + session_key = self._get_session_key(session_id) + + # Find all keys related to this session using SCAN + pattern = f"{session_key}*" + keys = [] + cursor = 0 + while True: + cursor, batch = self.client.scan(cursor=cursor, match=pattern, count=100) # type: ignore[misc] + keys.extend(batch) + if cursor == 0: + break + + if not keys: + raise SessionException(f"Session {session_id} does not exist") + + # Delete keys individually to avoid CROSSSLOT errors in clustered mode + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else key + self.client.delete(key_str) + + def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None: + """Create a new agent in Valkey.""" + agent_id = session_agent.agent_id + agent_dict = session_agent.to_dict() + agent_key = self._get_agent_key(session_id, agent_id) + self._write_json_object(agent_key, agent_dict) + + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + """Read agent data from Valkey.""" + agent_key = self._get_agent_key(session_id, agent_id) + agent_data = self._read_json_object(agent_key) + if agent_data is None: + return None + return SessionAgent.from_dict(agent_data) + + def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None: + """Update agent data in Valkey.""" + agent_id = session_agent.agent_id + previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id) + if previous_agent is None: + raise SessionException(f"Agent {agent_id} in session {session_id} does not exist") + + # Preserve creation timestamp + session_agent.created_at = previous_agent.created_at + agent_key = self._get_agent_key(session_id, agent_id) + self._write_json_object(agent_key, session_agent.to_dict()) + + def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None: + """Create a new message in Valkey.""" + message_id = session_message.message_id + message_dict = session_message.to_dict() + message_key = self._get_message_key(session_id, agent_id, message_id) + self._write_json_object(message_key, message_dict) + + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + """Read message data from Valkey.""" + message_key = self._get_message_key(session_id, agent_id, message_id) + message_data = self._read_json_object(message_key) + if message_data is None: + return None + return SessionMessage.from_dict(message_data) + + def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None: + """Update message data in Valkey.""" + message_id = session_message.message_id + previous_message = self.read_message(session_id=session_id, agent_id=agent_id, message_id=message_id) + if previous_message is None: + raise SessionException(f"Message {message_id} does not exist") + + # Preserve creation timestamp + session_message.created_at = previous_message.created_at + message_key = self._get_message_key(session_id, agent_id, message_id) + self._write_json_object(message_key, session_message.to_dict()) + + def list_messages( + self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + ) -> List[SessionMessage]: + """List messages for an agent with pagination from Valkey.""" + agent_key = self._get_agent_key(session_id, agent_id) + messages_pattern = f"{agent_key}:{MESSAGE_PREFIX}:*" + + try: + # Use SCAN instead of KEYS (KEYS not supported in ElastiCache Serverless) + message_keys = [] + cursor = 0 + while True: + cursor, keys = self.client.scan(cursor=cursor, match=messages_pattern, count=100) # type: ignore[misc] + message_keys.extend(keys) + if cursor == 0: + break + + # Extract message indices and sort + message_index_keys: list[tuple[int, str]] = [] + for key in message_keys: + # Decode bytes to string if needed + key_str = key.decode() if isinstance(key, bytes) else key + # Extract index from key format: session:id:agent:id:message:index + index = int(key_str.split(":")[-1]) + message_index_keys.append((index, key_str)) + + # Sort by index and extract just the keys + sorted_keys = [k for _, k in sorted(message_index_keys)] + + # Apply pagination to keys before loading content + if limit is not None: + sorted_keys = sorted_keys[offset : offset + limit] + else: + sorted_keys = sorted_keys[offset:] + + # Load only the required message objects + messages: List[SessionMessage] = [] + for key in sorted_keys: + message_data = self._read_json_object(key) + if message_data: + messages.append(SessionMessage.from_dict(message_data)) + + return messages + + except Exception as e: + raise SessionException(f"Valkey error reading messages: {e}") from e diff --git a/tests/strands/session/test_valkey_session_manager.py b/tests/strands/session/test_valkey_session_manager.py new file mode 100644 index 000000000..bfd51b464 --- /dev/null +++ b/tests/strands/session/test_valkey_session_manager.py @@ -0,0 +1,312 @@ +"""Tests for ValkeySessionManager.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest +import valkey + +from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager +from strands.session.valkey_session_manager import ValkeySessionManager +from strands.types.content import ContentBlock +from strands.types.exceptions import SessionException +from strands.types.session import Session, SessionAgent, SessionMessage, SessionType + + +@pytest.fixture +def mock_valkey_client(): + """Mock Valkey client for testing.""" + client = MagicMock(spec=valkey.Valkey) + # Default behavior: return None for JSON.GET (non-existent keys) + client.execute_command.return_value = None + client.exists.return_value = False + client.scan.return_value = (0, []) + return client + + +@pytest.fixture +def valkey_manager(mock_valkey_client): + """Create ValkeySessionManager with mocked client.""" + # Mock the session repository methods during initialization + with ( + patch.object(ValkeySessionManager, "read_session", return_value=None), + patch.object(ValkeySessionManager, "create_session"), + ): + manager = ValkeySessionManager(session_id="test", client=mock_valkey_client) + return manager + + +@pytest.fixture +def sample_session(): + """Create sample session for testing.""" + return Session( + session_id="test-session-123", + session_type=SessionType.AGENT, + ) + + +@pytest.fixture +def sample_agent(): + """Create sample agent for testing.""" + return SessionAgent( + agent_id="test-agent-456", + state={"key": "value"}, + conversation_manager_state=NullConversationManager().get_state(), + ) + + +@pytest.fixture +def sample_message(): + """Create sample message for testing.""" + return SessionMessage.from_message( + message={ + "role": "user", + "content": [ContentBlock(text="test_message")], + }, + index=0, + ) + + +def test_create_session(valkey_manager, sample_session, mock_valkey_client): + """Test creating a session in Valkey.""" + mock_valkey_client.exists.return_value = False + + result = valkey_manager.create_session(sample_session) + + assert result == sample_session + mock_valkey_client.execute_command.assert_called_once() + args = mock_valkey_client.execute_command.call_args[0] + assert args[0] == "JSON.SET" + assert args[1] == "session:test-session-123" + assert args[2] == "." + + +def test_create_session_already_exists(valkey_manager, sample_session, mock_valkey_client): + """Test creating a session that already exists.""" + mock_valkey_client.exists.return_value = True + + with pytest.raises(SessionException, match="already exists"): + valkey_manager.create_session(sample_session) + + +def test_read_session(valkey_manager, sample_session, mock_valkey_client): + """Test reading a session from Valkey.""" + session_data = json.dumps(sample_session.to_dict()) + mock_valkey_client.execute_command.return_value = session_data + + result = valkey_manager.read_session(sample_session.session_id) + + assert result.session_id == sample_session.session_id + assert result.session_type == sample_session.session_type + + +def test_read_nonexistent_session(valkey_manager, mock_valkey_client): + """Test reading a session that doesn't exist.""" + mock_valkey_client.execute_command.return_value = None + + result = valkey_manager.read_session("nonexistent-session") + + assert result is None + + +def test_delete_session(valkey_manager, sample_session, mock_valkey_client): + """Test deleting a session from Valkey.""" + mock_valkey_client.scan.return_value = (0, [b"session:test-session-123", b"session:test-session-123:agent:test"]) + + valkey_manager.delete_session(sample_session.session_id) + + assert mock_valkey_client.delete.call_count == 2 + + +def test_delete_nonexistent_session(valkey_manager, mock_valkey_client): + """Test deleting a session that doesn't exist.""" + mock_valkey_client.scan.return_value = (0, []) + + with pytest.raises(SessionException, match="does not exist"): + valkey_manager.delete_session("nonexistent") + + +def test_create_agent(valkey_manager, sample_session, sample_agent, mock_valkey_client): + """Test creating an agent in Valkey.""" + valkey_manager.create_agent(sample_session.session_id, sample_agent) + + mock_valkey_client.execute_command.assert_called_once() + args = mock_valkey_client.execute_command.call_args[0] + assert args[0] == "JSON.SET" + assert args[1] == "session:test-session-123:agent:test-agent-456" + + +def test_read_agent(valkey_manager, sample_session, sample_agent, mock_valkey_client): + """Test reading an agent from Valkey.""" + agent_data = json.dumps(sample_agent.to_dict()) + mock_valkey_client.execute_command.return_value = agent_data + + result = valkey_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + + assert result.agent_id == sample_agent.agent_id + assert result.state == sample_agent.state + + +def test_read_nonexistent_agent(valkey_manager, sample_session, mock_valkey_client): + """Test reading an agent that doesn't exist.""" + mock_valkey_client.execute_command.return_value = None + + result = valkey_manager.read_agent(sample_session.session_id, "nonexistent_agent") + + assert result is None + + +def test_update_agent(valkey_manager, sample_session, sample_agent, mock_valkey_client): + """Test updating an agent in Valkey.""" + # Mock reading existing agent + agent_data = json.dumps(sample_agent.to_dict()) + mock_valkey_client.execute_command.return_value = agent_data + + sample_agent.state = {"updated": "value"} + valkey_manager.update_agent(sample_session.session_id, sample_agent) + + # Should call JSON.GET then JSON.SET + assert mock_valkey_client.execute_command.call_count == 2 + + +def test_update_nonexistent_agent(valkey_manager, sample_session, sample_agent, mock_valkey_client): + """Test updating an agent that doesn't exist.""" + mock_valkey_client.execute_command.return_value = None + + with pytest.raises(SessionException, match="does not exist"): + valkey_manager.update_agent(sample_session.session_id, sample_agent) + + +def test_create_message(valkey_manager, sample_session, sample_agent, sample_message, mock_valkey_client): + """Test creating a message in Valkey.""" + valkey_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + mock_valkey_client.execute_command.assert_called_once() + args = mock_valkey_client.execute_command.call_args[0] + assert args[0] == "JSON.SET" + assert args[1] == "session:test-session-123:agent:test-agent-456:message:0" + + +def test_read_message(valkey_manager, sample_session, sample_agent, sample_message, mock_valkey_client): + """Test reading a message from Valkey.""" + message_data = json.dumps(sample_message.to_dict()) + mock_valkey_client.execute_command.return_value = message_data + + result = valkey_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + + assert result.message_id == sample_message.message_id + assert result.message["role"] == sample_message.message["role"] + + +def test_read_nonexistent_message(valkey_manager, sample_session, sample_agent, mock_valkey_client): + """Test reading a message that doesn't exist.""" + mock_valkey_client.execute_command.return_value = None + + result = valkey_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999) + + assert result is None + + +def test_update_message(valkey_manager, sample_session, sample_agent, sample_message, mock_valkey_client): + """Test updating a message in Valkey.""" + # Mock reading existing message + message_data = json.dumps(sample_message.to_dict()) + mock_valkey_client.execute_command.return_value = message_data + + sample_message.message["content"] = [ContentBlock(text="Updated content")] + valkey_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Should call JSON.GET then JSON.SET + assert mock_valkey_client.execute_command.call_count == 2 + + +def test_update_nonexistent_message(valkey_manager, sample_session, sample_agent, sample_message, mock_valkey_client): + """Test updating a message that doesn't exist.""" + mock_valkey_client.execute_command.return_value = None + + with pytest.raises(SessionException, match="does not exist"): + valkey_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + +def test_list_messages_all(valkey_manager, sample_session, sample_agent, mock_valkey_client): + """Test listing all messages from Valkey.""" + # Mock scan response + mock_valkey_client.scan.return_value = ( + 0, + [ + b"session:test-session-123:agent:test-agent-456:message:0", + b"session:test-session-123:agent:test-agent-456:message:1", + b"session:test-session-123:agent:test-agent-456:message:2", + ], + ) + + # Mock JSON.GET responses + def mock_execute_command(cmd, key, *args): + if cmd == "JSON.GET": + message_id = int(key.split(":")[-1]) + return json.dumps( + { + "message_id": message_id, + "message": {"role": "user", "content": [{"text": f"Message {message_id}"}]}, + "created_at": "2023-01-01T00:00:00Z", + "updated_at": "2023-01-01T00:00:00Z", + } + ) + + mock_valkey_client.execute_command.side_effect = mock_execute_command + + result = valkey_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 3 + + +def test_list_messages_with_pagination(valkey_manager, sample_session, sample_agent, mock_valkey_client): + """Test listing messages with pagination.""" + # Mock 10 message keys + mock_valkey_client.scan.return_value = ( + 0, + [f"session:test-session-123:agent:test-agent-456:message:{i}".encode() for i in range(10)], + ) + + def mock_execute_command(cmd, key, *args): + if cmd == "JSON.GET": + message_id = int(key.split(":")[-1]) + return json.dumps( + { + "message_id": message_id, + "message": {"role": "user", "content": [{"text": f"Message {message_id}"}]}, + "created_at": "2023-01-01T00:00:00Z", + "updated_at": "2023-01-01T00:00:00Z", + } + ) + + mock_valkey_client.execute_command.side_effect = mock_execute_command + + # Test with limit + result = valkey_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3) + assert len(result) == 3 + + # Test with offset + result = valkey_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5) + assert len(result) == 5 + + +@pytest.mark.parametrize("session_id", ["session:with:colons", "another:colon"]) +def test_get_session_key_invalid_session_id(session_id, valkey_manager): + """Test that session_id with colons raises ValueError.""" + with pytest.raises(ValueError, match="session_id cannot contain ':' characters"): + valkey_manager._get_session_key(session_id) + + +@pytest.mark.parametrize("agent_id", ["agent:with:colons", "another:colon"]) +def test_get_agent_key_invalid_agent_id(agent_id, valkey_manager): + """Test that agent_id with colons raises ValueError.""" + with pytest.raises(ValueError, match="agent_id cannot contain ':' characters"): + valkey_manager._get_agent_key("session1", agent_id) + + +@pytest.mark.parametrize("message_id", ["not_an_int", None, [], 1.5]) +def test_get_message_key_invalid_message_id(message_id, valkey_manager): + """Test that message_id that is not an integer raises ValueError.""" + with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): + valkey_manager._get_message_key("session1", "agent1", message_id)