diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 9df86e17a..14e71d07c 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -86,7 +86,13 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> message_id: Index of the message Returns: The filename for the message + + Raises: + ValueError: If message_id is not an integer. """ + if not isinstance(message_id, int): + raise ValueError(f"message_id=<{message_id}> | message id must be an integer") + agent_path = self._get_agent_path(session_id, agent_id) return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json") diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index d15e6e3bd..da1735e35 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -113,11 +113,16 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> session_id: ID of the session agent_id: ID of the agent message_id: Index of the message - **kwargs: Additional keyword arguments for future extensibility. Returns: The key for the message + + Raises: + ValueError: If message_id is not an integer. """ + if not isinstance(message_id, int): + raise ValueError(f"message_id=<{message_id}> | message id must be an integer") + agent_path = self._get_agent_path(session_id, agent_id) return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json" diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index a89222b7e..036591924 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -224,14 +224,14 @@ def test_read_messages_with_new_agent(file_manager, sample_session, sample_agent file_manager.create_session(sample_session) file_manager.create_agent(sample_session.session_id, sample_agent) - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999) assert result is None def test_read_nonexistent_message(file_manager, sample_session, sample_agent): """Test reading a message that doesnt exist.""" - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999) assert result is None @@ -390,3 +390,21 @@ def test__get_session_path_invalid_session_id(session_id, file_manager): def test__get_agent_path_invalid_agent_id(agent_id, file_manager): with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): file_manager._get_agent_path("session1", agent_id) + + +@pytest.mark.parametrize( + "message_id", + [ + "../../../secret", + "../../attack", + "../escape", + "path/traversal", + "not_an_int", + None, + [], + ], +) +def test__get_message_path_invalid_message_id(message_id, file_manager): + """Test that message_id that is not an integer raises ValueError.""" + with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): + file_manager._get_message_path("session1", "agent1", message_id) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 71bff3050..50fb303f7 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -251,7 +251,7 @@ def test_read_nonexistent_message(s3_manager, sample_session, sample_agent, samp s3_manager.create_agent(sample_session.session_id, sample_agent) # Read message - result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999) assert result is None @@ -356,3 +356,21 @@ def test__get_session_path_invalid_session_id(session_id, s3_manager): def test__get_agent_path_invalid_agent_id(agent_id, s3_manager): with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): s3_manager._get_agent_path("session1", agent_id) + + +@pytest.mark.parametrize( + "message_id", + [ + "../../../secret", + "../../attack", + "../escape", + "path/traversal", + "not_an_int", + None, + [], + ], +) +def test__get_message_path_invalid_message_id(message_id, s3_manager): + """Test that message_id that is not an integer raises ValueError.""" + with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): + s3_manager._get_message_path("session1", "agent1", message_id)