diff --git a/src/bedrock_agentcore/memory/models/__init__.py b/src/bedrock_agentcore/memory/models/__init__.py index 0322483..1e08d8a 100644 --- a/src/bedrock_agentcore/memory/models/__init__.py +++ b/src/bedrock_agentcore/memory/models/__init__.py @@ -3,7 +3,15 @@ from typing import Any, Dict from .DictWrapper import DictWrapper - +from .filters import ( + StringValue, + MetadataValue, + MetadataKey, + LeftExpression, + OperatorType, + RightExpression, + EventMetadataFilter, +) class ActorSummary(DictWrapper): """A class representing an actor summary.""" @@ -75,3 +83,20 @@ def __init__(self, session_summary: Dict[str, Any]): session_summary: Dictionary containing session summary data. """ super().__init__(session_summary) + +__all__ = [ + "DictWrapper", + "ActorSummary", + "Branch", + "Event", + "EventMessage", + "MemoryRecord", + "SessionSummary", + "StringValue", + "MetadataValue", + "MetadataKey", + "LeftExpression", + "OperatorType", + "RightExpression", + "EventMetadataFilter", +] diff --git a/src/bedrock_agentcore/memory/models/filters.py b/src/bedrock_agentcore/memory/models/filters.py new file mode 100644 index 0000000..58377a0 --- /dev/null +++ b/src/bedrock_agentcore/memory/models/filters.py @@ -0,0 +1,118 @@ +from enum import Enum +from typing import Optional, TypedDict, Union, NotRequired + +class StringValue(TypedDict): + """Value associated with the `eventMetadata` key.""" + stringValue: str + + @staticmethod + def build(value: str) -> 'StringValue': + return { + "stringValue": value + } + +MetadataValue = Union[StringValue] +""" +Union type representing metadata values. + +Variants: +- StringValue: {"stringValue": str} - String metadata value +""" + +MetadataKey = Union[str] +""" +Union type representing metadata key. +""" + +class LeftExpression(TypedDict): + """ + Left operand of the event metadata filter expression. + """ + metadataKey: MetadataKey + + @staticmethod + def build(key: str) -> 'LeftExpression': + """Builds the `metadataKey` for `LeftExpression`""" + return { + "metadataKey": key + } + +class OperatorType(Enum): + """ + Operator applied to the event metadata filter expression. + + Currently supports: + - `EQUALS_TO` + - `EXISTS` + - `NOT_EXISTS` + """ + EQUALS_TO = "EQUALS_TO" + EXISTS = "EXISTS" + NOT_EXISTS = "NOT_EXISTS" + +class RightExpression(TypedDict): + """ + Right operand of the event metadata filter expression. + + Variants: + - StringValue: {"metadataValue": {"stringValue": str}} + """ + metadataValue: MetadataValue + + @staticmethod + def build(value: str) -> 'RightExpression': + """Builds the `RightExpression` for `stringValue` type""" + return {"metadataValue": StringValue.build(value)} + +class EventMetadataFilter(TypedDict): + """ + Filter expression for retrieving events based on metadata associated with an event. + + Args: + left: `LeftExpression` of the event metadata filter expression. + operator: `OperatorType` applied to the event metadata filter expression. + right: Optional `RightExpression` of the event metadata filter expression. + """ + left: LeftExpression + operator: OperatorType + right: NotRequired[RightExpression] + + def build_expression(left_operand: LeftExpression, operator: OperatorType, right_operand: Optional[RightExpression] = None) -> 'EventMetadataFilter': + """ + This method builds the required event metadata filter expression into the `EventMetadataFilterExpression` type when querying listEvents. + + Args: + left_operand: Left operand of the event metadata filter expression + operator: Operator applied to the event metadata filter expression + right_operand: Optional right_operand of the event metadata filter expression. + + Example: + ``` + left_operand = LeftExpression.build_key(key='location') + operator = OperatorType.EQUALS_TO + right_operand = RightExpression.build_string_value(value='NYC') + ``` + + #### Response Object: + ``` + { + 'left': { + 'metadataKey': 'location' + }, + 'operator': 'EQUALS_TO', + 'right': { + 'metadataValue': { + 'stringValue': 'NYC' + } + } + } + ``` + """ + filter = { + 'left': left_operand, + 'operator': operator.value + } + + if right_operand: + filter['right'] = right_operand + return filter \ No newline at end of file diff --git a/src/bedrock_agentcore/memory/session.py b/src/bedrock_agentcore/memory/session.py index d6d09fd..a7945ed 100644 --- a/src/bedrock_agentcore/memory/session.py +++ b/src/bedrock_agentcore/memory/session.py @@ -18,6 +18,8 @@ EventMessage, MemoryRecord, SessionSummary, + MetadataValue, + EventMetadataFilter ) logger = logging.getLogger(__name__) @@ -246,6 +248,7 @@ def process_turn_with_llm( user_input: str, llm_callback: Callable[[str, List[Dict[str, Any]]], str], retrieval_config: Optional[Dict[str, RetrievalConfig]], + metadata: Optional[Dict[str, MetadataValue]] = None, event_timestamp: Optional[datetime] = None, ) -> Tuple[List[Dict[str, Any]], str, Dict[str, Any]]: r"""Complete conversation turn with LLM callback integration. @@ -263,6 +266,7 @@ def process_turn_with_llm( retrieval_config: Optional dictionary mapping namespaces to RetrievalConfig objects. Each namespace can contain template variables like {actorId}, {sessionId}, {memoryStrategyId} that will be resolved at runtime. + metadata: Optional custom key-value metadata to attach to an event. event_timestamp: Optional timestamp for the event Returns: @@ -340,6 +344,7 @@ def my_llm(user_input: str, memories: List[Dict]) -> str: ConversationalMessage(user_input, MessageRole.USER), ConversationalMessage(agent_response, MessageRole.ASSISTANT), ], + metadata=metadata, event_timestamp=event_timestamp, ) @@ -352,6 +357,7 @@ def add_turns( session_id: str, messages: List[Union[ConversationalMessage, BlobMessage]], branch: Optional[Dict[str, str]] = None, + metadata: Optional[Dict[str, MetadataValue]] = None, event_timestamp: Optional[datetime] = None, ) -> Event: """Adds conversational turns or blob objects to short-term memory. @@ -365,12 +371,14 @@ def add_turns( - ConversationalMessage objects for conversational messages - BlobMessage objects for blob data branch: Optional branch info + metadata: Optional custom key-value metadata to attach to an event. event_timestamp: Optional timestamp for the event Returns: Created event Example: + ``` manager.add_turns( actor_id="user-123", session_id="session-456", @@ -378,8 +386,16 @@ def add_turns( ConversationalMessage("Hello", USER), BlobMessage({"file_data": "base64_content"}), ConversationalMessage("How can I help?", ASSISTANT) + ], + metadata=[ + { + 'location': { + 'stringValue': 'NYC' + } + } ] ) + ``` """ logger.info(" -> Storing %d messages in short-term memory...", len(messages)) @@ -412,6 +428,10 @@ def add_turns( if branch: params["branch"] = branch + + if metadata: + params["metadata"] = metadata + try: response = self._data_plane_client.create_event(**params) logger.info(" ✅ Turn stored successfully with Event ID: %s", response.get("eventId")) @@ -427,6 +447,7 @@ def fork_conversation( root_event_id: str, branch_name: str, messages: List[Union[ConversationalMessage, BlobMessage]], + metadata: Optional[Dict[str, MetadataValue]] = None, event_timestamp: Optional[datetime] = None, ) -> Dict[str, Any]: """Fork a conversation from a specific event to create a new branch.""" @@ -439,6 +460,7 @@ def fork_conversation( messages=messages, event_timestamp=event_timestamp, branch=branch, + metadata=metadata, ) logger.info("Created branch '%s' from event %s", branch_name, root_event_id) @@ -454,6 +476,7 @@ def list_events( session_id: str, branch_name: Optional[str] = None, include_parent_branches: bool = False, + eventMetadata: Optional[List[EventMetadataFilter]] = None, max_results: int = 100, include_payload: bool = True, ) -> List[Event]: @@ -482,6 +505,49 @@ def list_events( # Get events from a specific branch branch_events = client.list_events(actor_id, session_id, branch_name="test-branch") + + #### Get events with event metadata filter + ``` + filtered_events_with_metadata = client.list_events( + actor_id=actor_id, + session_id=session_id, + eventMetadata=[ + { + 'left': { + 'metadataKey': 'location' + }, + 'operator': 'EQUALS_TO', + 'right': { + 'metadataValue': { + 'stringValue': 'NYC' + } + } + } + ] + ) + ``` + + #### Get events with event metadata filter + specific branch filter + ``` + branch_with_metadata_filtered_events = client.list_events( + actor_id=actor_id, + session_id=session_id, + branch_name="test-branch", + eventMetadata=[ + { + 'left': { + 'metadataKey': 'location' + }, + 'operator': 'EQUALS_TO', + 'right': { + 'metadataValue': { + 'stringValue': 'NYC' + } + } + } + ] + ) + ``` """ try: all_events: List[Event] = [] @@ -509,6 +575,12 @@ def list_events( "branch": {"name": branch_name, "includeParentBranches": include_parent_branches} } + # Add eventMetadata filter if specified + if eventMetadata: + params["filter"] = { + "eventMetadata": eventMetadata + } + response = self._data_plane_client.list_events(**params) events = response.get("events", []) @@ -888,21 +960,23 @@ def add_turns( self, messages: List[Union[ConversationalMessage, BlobMessage]], branch: Optional[Dict[str, str]] = None, + metadata: Optional[Dict[str, MetadataValue]] = None, event_timestamp: Optional[datetime] = None, ) -> Event: """Delegates to manager.add_turns.""" - return self._manager.add_turns(self._actor_id, self._session_id, messages, branch, event_timestamp) + return self._manager.add_turns(self._actor_id, self._session_id, messages, branch, metadata, event_timestamp) def fork_conversation( self, messages: List[Union[ConversationalMessage, BlobMessage]], root_event_id: str, branch_name: str, + metadata: Optional[Dict[str, MetadataValue]] = None, event_timestamp: Optional[datetime] = None, ) -> Event: """Delegates to manager.fork_conversation.""" return self._manager.fork_conversation( - self._actor_id, self._session_id, root_event_id, branch_name, messages, event_timestamp + self._actor_id, self._session_id, root_event_id, branch_name, messages, metadata, event_timestamp ) def process_turn_with_llm( @@ -910,6 +984,7 @@ def process_turn_with_llm( user_input: str, llm_callback: Callable[[str, List[Dict[str, Any]]], str], retrieval_config: Optional[Dict[str, RetrievalConfig]], + metadata: Optional[Dict[str, MetadataValue]] = None, event_timestamp: Optional[datetime] = None, ) -> Tuple[List[Dict[str, Any]], str, Dict[str, Any]]: """Delegates to manager.process_turn_with_llm.""" @@ -919,6 +994,7 @@ def process_turn_with_llm( user_input, llm_callback, retrieval_config, + metadata, event_timestamp, ) @@ -975,6 +1051,7 @@ def list_events( self, branch_name: Optional[str] = None, include_parent_branches: bool = False, + eventMetadata: Optional[List[EventMetadataFilter]] = None, max_results: int = 100, include_payload: bool = True, ) -> List[Event]: @@ -984,6 +1061,7 @@ def list_events( session_id=self._session_id, branch_name=branch_name, include_parent_branches=include_parent_branches, + eventMetadata=eventMetadata, include_payload=include_payload, max_results=max_results, ) diff --git a/tests/bedrock_agentcore/memory/test_session.py b/tests/bedrock_agentcore/memory/test_session.py index 5de0f7f..6f29abc 100644 --- a/tests/bedrock_agentcore/memory/test_session.py +++ b/tests/bedrock_agentcore/memory/test_session.py @@ -1523,9 +1523,9 @@ def test_session_add_turns_delegation(self): result = session.add_turns(messages=[ConversationalMessage("Hello", MessageRole.USER)]) assert result == mock_event - mock_add_turns.assert_called_once_with( - "user-123", "session-456", [ConversationalMessage("Hello", MessageRole.USER)], None, None - ) + mock_add_turns.assert_called_once_with( + "user-123", "session-456", [ConversationalMessage("Hello", MessageRole.USER)], None, None, None + ) def test_session_fork_conversation_delegation(self): """Test MemorySession.fork_conversation delegates to manager.""" @@ -1552,6 +1552,7 @@ def test_session_fork_conversation_delegation(self): "test-branch", [ConversationalMessage("Fork message", MessageRole.USER)], None, + None, ) def test_session_create_blob_event_delegation(self): @@ -1569,7 +1570,7 @@ def test_session_create_blob_event_delegation(self): result = session.add_turns(messages=[BlobMessage(blob_data)]) assert result == mock_event - mock_add_turns.assert_called_once_with("user-123", "session-456", [BlobMessage(blob_data)], None, None) + mock_add_turns.assert_called_once_with("user-123", "session-456", [BlobMessage(blob_data)], None, None, None) def test_session_process_turn_with_llm_delegation(self): """Test MemorySession.process_turn_with_llm delegates to manager.""" @@ -1597,7 +1598,7 @@ def mock_llm(user_input: str, memories: List[Dict[str, Any]]) -> str: assert memories == mock_memories assert response == mock_response assert event == mock_event - mock_process.assert_called_once_with("user-123", "session-456", "Hello", mock_llm, None, None) + mock_process.assert_called_once_with("user-123", "session-456", "Hello", mock_llm, None, None, None) def test_session_get_last_k_turns_delegation(self): """Test MemorySession.get_last_k_turns delegates to manager.""" @@ -1743,6 +1744,7 @@ def test_session_list_events_delegation(self): session_id="session-456", branch_name="test-branch", include_parent_branches=False, + eventMetadata=None, include_payload=True, max_results=100, ) @@ -1948,6 +1950,7 @@ def test_session_delegation_with_optional_parameters(self): "session-456", [ConversationalMessage("Hello", MessageRole.USER)], branch, + None, custom_timestamp, ) @@ -2196,9 +2199,358 @@ def test_session_add_turns_parameter_order(self): "session-456", [ConversationalMessage("Hello", MessageRole.USER)], branch, + None, custom_timestamp, ) +class TestEventMetadataFlow: + """Test cases for metadata support for STM in MemorySessionManager.""" + + def test_fork_conversation_with_metadata_parameter(self): + """Test fork_conversation with new metadata parameter.""" + with patch("boto3.Session") as mock_session_class: + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + mock_client_instance = MagicMock() + mock_session.client.return_value = mock_client_instance + mock_session_class.return_value = mock_session + + manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") + + # Mock add_turns + mock_event = {"eventId": "fork-event-123", "memoryId": "testMemory-1234567890"} + with patch.object(manager, "add_turns", return_value=Event(mock_event)) as mock_add_turns: + metadata = {"location": {"stringValue": "NYC"}} + + result = manager.fork_conversation( + actor_id="user-123", + session_id="session-456", + root_event_id="event-root-123", + branch_name="test-branch", + messages=[ConversationalMessage("Fork message", MessageRole.USER)], + metadata=metadata, + ) + + assert result["eventId"] == "fork-event-123" + + # Verify add_turns was called with metadata + mock_add_turns.assert_called_once() + call_args = mock_add_turns.call_args[1] + assert call_args["metadata"] == metadata + assert call_args["branch"]["rootEventId"] == "event-root-123" + assert call_args["branch"]["name"] == "test-branch" + + def test_list_events_with_event_metadata_filter(self): + """Test list_events with eventMetadata filter parameter.""" + with patch("boto3.Session") as mock_session_class: + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + mock_client_instance = MagicMock() + mock_session.client.return_value = mock_client_instance + mock_session_class.return_value = mock_session + + manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") + + # Mock response + mock_events = [{"eventId": "filtered-event-1", "eventTimestamp": datetime.now()}] + mock_client_instance.list_events.return_value = {"events": mock_events, "nextToken": None} + + # Test with eventMetadata filter + event_metadata_filter = [ + { + 'left': { + 'metadataKey': 'location' + }, + 'operator': 'EQUALS_TO', + 'right': { + 'metadataValue': { + 'stringValue': 'NYC' + } + } + } + ] + + result = manager.list_events( + actor_id="user-123", + session_id="session-456", + eventMetadata=event_metadata_filter + ) + + assert len(result) == 1 + assert result[0]["eventId"] == "filtered-event-1" + + # Verify filter was applied + call_args = mock_client_instance.list_events.call_args[1] + assert "filter" in call_args + assert call_args["filter"]["eventMetadata"] == event_metadata_filter + + def test_list_events_with_both_branch_and_metadata_filters(self): + """Test list_events with both branch and eventMetadata filters.""" + with patch("boto3.Session") as mock_session_class: + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + mock_client_instance = MagicMock() + mock_session.client.return_value = mock_client_instance + mock_session_class.return_value = mock_session + + manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") + + # Mock response + mock_events = [{"eventId": "filtered-event-1", "eventTimestamp": datetime.now()}] + mock_client_instance.list_events.return_value = {"events": mock_events, "nextToken": None} + + # Test with both branch and eventMetadata filters + event_metadata_filter = [ + { + 'left': { + 'metadataKey': 'location' + }, + 'operator': 'EQUALS_TO', + 'right': { + 'metadataValue': { + 'stringValue': 'NYC' + } + } + } + ] + + result = manager.list_events( + actor_id="user-123", + session_id="session-456", + branch_name="test-branch", + include_parent_branches=True, + eventMetadata=event_metadata_filter + ) + + assert len(result) == 1 + + # Verify both filters were applied - eventMetadata should override branch filter + call_args = mock_client_instance.list_events.call_args[1] + assert "filter" in call_args + assert call_args["filter"]["eventMetadata"] == event_metadata_filter + # Branch filter should not be present when eventMetadata is specified + assert "branch" not in call_args["filter"] + + def test_memory_session_list_events_with_event_metadata(self): + """Test MemorySession.list_events with eventMetadata parameter.""" + with patch("boto3.Session"): + manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") + session = MemorySession( + memory_id="testMemory-1234567890", actor_id="user-123", session_id="session-456", manager=manager + ) + + # Mock manager method + mock_events = [Event({"eventId": "event-1"})] + event_metadata_filter = [ + { + 'left': { + 'metadataKey': 'location' + }, + 'operator': 'EQUALS_TO', + 'right': { + 'metadataValue': { + 'stringValue': 'NYC' + } + } + } + ] + + with patch.object(manager, "list_events", return_value=mock_events) as mock_list_events: + result = session.list_events( + branch_name="test-branch", + eventMetadata=event_metadata_filter + ) + + assert result == mock_events + mock_list_events.assert_called_once_with( + actor_id="user-123", + session_id="session-456", + branch_name="test-branch", + include_parent_branches=False, + eventMetadata=event_metadata_filter, + include_payload=True, + max_results=100, + ) + + def test_memory_session_fork_conversation_with_metadata(self): + """Test MemorySession.fork_conversation with metadata parameter.""" + with patch("boto3.Session"): + manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") + session = MemorySession( + memory_id="testMemory-1234567890", actor_id="user-123", session_id="session-456", manager=manager + ) + + # Mock manager method + mock_event = Event({"eventId": "fork-event-123"}) + metadata = {"location": {"stringValue": "NYC"}} + + with patch.object(manager, "fork_conversation", return_value=mock_event) as mock_fork: + result = session.fork_conversation( + messages=[ConversationalMessage("Fork message", MessageRole.USER)], + root_event_id="event-root-123", + branch_name="test-branch", + metadata=metadata, + ) + + assert result == mock_event + mock_fork.assert_called_once_with( + "user-123", + "session-456", + "event-root-123", + "test-branch", + [ConversationalMessage("Fork message", MessageRole.USER)], + metadata, + None, + ) + + def test_memory_session_process_turn_with_llm_with_metadata(self): + """Test MemorySession.process_turn_with_llm with metadata parameter.""" + with patch("boto3.Session"): + manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") + session = MemorySession( + memory_id="testMemory-1234567890", actor_id="user-123", session_id="session-456", manager=manager + ) + + # Mock manager method + mock_memories = [{"content": {"text": "Memory"}}] + mock_response = "LLM response" + mock_event = {"eventId": "event-123"} + metadata = {"location": {"stringValue": "NYC"}} + + with patch.object( + manager, "process_turn_with_llm", return_value=(mock_memories, mock_response, mock_event) + ) as mock_process: + + def mock_llm(user_input: str, memories: List[Dict[str, Any]]) -> str: + return "Response" + + memories, response, event = session.process_turn_with_llm( + user_input="Hello", + llm_callback=mock_llm, + retrieval_config=None, + metadata=metadata + ) + + assert memories == mock_memories + assert response == mock_response + assert event == mock_event + mock_process.assert_called_once_with( + "user-123", + "session-456", + "Hello", + mock_llm, + None, + metadata, + None + ) + + def test_process_turn_with_llm_with_metadata_parameter(self): + """Test process_turn_with_llm with metadata parameter.""" + with patch("boto3.Session") as mock_session_class: + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + mock_client_instance = MagicMock() + mock_session.client.return_value = mock_client_instance + mock_session_class.return_value = mock_session + + manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") + + # Mock search_long_term_memories + mock_memories = [{"content": {"text": "Previous context"}, "memoryRecordId": "rec-123"}] + with patch.object(manager, "search_long_term_memories", return_value=mock_memories): + # Mock add_turns + mock_event = {"eventId": "event-123", "memoryId": "testMemory-1234567890"} + with patch.object(manager, "add_turns", return_value=Event(mock_event)) as mock_add_turns: + # Define LLM callback + def mock_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str: + return f"Response to: {user_input} with {len(memories)} memories" + + # Test process_turn_with_llm with metadata + retrieval_config = {"test/namespace": RetrievalConfig(top_k=5)} + metadata = {"location": {"stringValue": "NYC"}} + + memories, response, event = manager.process_turn_with_llm( + actor_id="user-123", + session_id="session-456", + user_input="Hello", + llm_callback=mock_llm_callback, + retrieval_config=retrieval_config, + metadata=metadata, + ) + + assert len(memories) == 1 + assert "Response to: Hello with 1 memories" in response + assert event["eventId"] == "event-123" + + # Verify add_turns was called with metadata + mock_add_turns.assert_called_once() + call_args = mock_add_turns.call_args[1] + assert call_args["metadata"] == metadata + + def test_add_turns_with_metadata_parameter(self): + """Test add_turns with metadata parameter.""" + with patch("boto3.Session") as mock_session_class: + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + mock_client_instance = MagicMock() + mock_session.client.return_value = mock_client_instance + mock_session_class.return_value = mock_session + + manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") + + # Mock create_event response + mock_response = {"event": {"eventId": "turn-event-123", "memoryId": "testMemory-1234567890"}} + mock_client_instance.create_event.return_value = mock_response + + messages = [ + ConversationalMessage("Hello", MessageRole.USER), + ConversationalMessage("Hi there", MessageRole.ASSISTANT), + ] + metadata = {"location": {"stringValue": "NYC"}} + + result = manager.add_turns( + actor_id="user-123", + session_id="session-456", + messages=messages, + metadata=metadata + ) + + assert isinstance(result, Event) + assert result["eventId"] == "turn-event-123" + + # Verify metadata was passed to create_event + call_args = mock_client_instance.create_event.call_args[1] + assert call_args["metadata"] == metadata + assert len(call_args["payload"]) == 2 + + def test_memory_session_add_turns_with_metadata(self): + """Test MemorySession.add_turns with metadata parameter.""" + with patch("boto3.Session"): + manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") + session = MemorySession( + memory_id="testMemory-1234567890", actor_id="user-123", session_id="session-456", manager=manager + ) + + # Mock manager method + mock_event = Event({"eventId": "event-123"}) + metadata = {"location": {"stringValue": "NYC"}} + + with patch.object(manager, "add_turns", return_value=mock_event) as mock_add_turns: + result = session.add_turns( + messages=[ConversationalMessage("Hello", MessageRole.USER)], + metadata=metadata + ) + + assert result == mock_event + mock_add_turns.assert_called_once_with( + "user-123", + "session-456", + [ConversationalMessage("Hello", MessageRole.USER)], + None, + metadata, + None + ) + class TestAdditionalCoverage: """Additional tests to reach 99% coverage.""" @@ -2572,7 +2924,7 @@ def test_memory_session_add_turns_parameter_order(self): session.add_turns(messages=messages, branch=branch, event_timestamp=custom_timestamp) # Verify the exact parameter order: actor_id, session_id, messages, branch, event_timestamp - mock_add_turns.assert_called_once_with("user-123", "session-456", messages, branch, custom_timestamp) + mock_add_turns.assert_called_once_with("user-123", "session-456", messages, branch, None, custom_timestamp) def test_process_turn_with_llm_no_relevance_score_config(self): """Test process_turn_with_llm when RetrievalConfig has no relevance_score.""" @@ -2640,7 +2992,7 @@ def test_memory_session_add_turns_branch_parameter_order(self): session.add_turns(messages=messages, branch=branch) # Verify the exact parameter order: actor_id, session_id, messages, branch, event_timestamp - mock_add_turns.assert_called_once_with("user-123", "session-456", messages, branch, None) + mock_add_turns.assert_called_once_with("user-123", "session-456", messages, branch, None, None) def test_list_long_term_memory_records_memoryRecordSummaries_fallback(self): """Test list_long_term_memory_records fallback to memoryRecordSummaries."""