diff --git a/src/strands_tools/memory.py b/src/strands_tools/memory.py index f57e6d95..8c5deb26 100644 --- a/src/strands_tools/memory.py +++ b/src/strands_tools/memory.py @@ -1,5 +1,4 @@ -""" -Tool for managing data in Bedrock Knowledge Base (store, delete, list, get, and retrieve) +"""Tool for managing data in Bedrock Knowledge Base (store, delete, list, get, and retrieve) This module provides comprehensive Knowledge Base management capabilities for Amazon Bedrock Knowledge Bases. It handles all aspects of document management with @@ -14,22 +13,30 @@ • get: Retrieve specific documents by document ID • retrieve: Perform semantic search across all documents -2. Safety Features: +2. Data Source Support: + • Detects CUSTOM data source types + • Falls back to first available data source if no CUSTOM found + • Provides clear error messages for unsupported data source types + • Currently supports CUSTOM data sources for direct ingestion + • S3 and other data source types show clear error messages + +3. Safety Features: • User confirmation for mutative operations • Content previews before storage • Warning messages before deletion • BYPASS_TOOL_CONSENT mode for bypassing confirmations in tests -3. Advanced Capabilities: +4. Advanced Capabilities: • Automatic document ID generation • Structured content storage with metadata • Semantic search with relevance filtering • Rich output formatting • Pagination support -4. Error Handling: +5. Error Handling: • Knowledge Base ID validation • Parameter validation + • Data source type detection and validation • Graceful API error handling • Clear error messages @@ -148,6 +155,57 @@ def runtime_client(self): self._runtime_client = self.session.client("bedrock-agent-runtime", region_name=self.region) return self._runtime_client + def _detect_data_source_type(self, kb_id: str): + """ + Helper method to detect data source type for a knowledge base. + + This method implements the same logic as store_in_kb tool: + 1. Look for CUSTOM data source first (preferred) + 2. Fall back to first available data source if no CUSTOM found + 3. Log appropriate messages for debugging + + Args: + kb_id: Knowledge Base ID + + Returns: + Tuple of (data_source_id, source_type) + + Raises: + ValueError: If no data sources are found + """ + # Get data source details to determine the type + data_sources = self.agent_client.list_data_sources(knowledgeBaseId=kb_id) + + if data_sources and not data_sources.get("dataSourceSummaries"): + raise ValueError(f"No data sources found for knowledge base {kb_id}") + + # Look for CUSTOM data source first, then fallback + data_source_id = None + source_type = None + + for ds in data_sources["dataSourceSummaries"]: + # Get the data source details to check its type + ds_detail = self.agent_client.get_data_source(knowledgeBaseId=kb_id, dataSourceId=ds["dataSourceId"]) + + # Check if this is a CUSTOM type data source + if ds_detail["dataSource"]["dataSourceConfiguration"]["type"] == "CUSTOM": + data_source_id = ds["dataSourceId"] + source_type = "CUSTOM" + logger.debug(f"Found CUSTOM data source: {data_source_id}") + break + + # If no CUSTOM data source found, use the first available one but log a warning + if not data_source_id and data_sources["dataSourceSummaries"]: + data_source_id = data_sources["dataSourceSummaries"][0]["dataSourceId"] + ds_detail = self.agent_client.get_data_source(knowledgeBaseId=kb_id, dataSourceId=data_source_id) + source_type = ds_detail["dataSource"]["dataSourceConfiguration"]["type"] + logger.debug(f"No CUSTOM data source found. Using {source_type} data source: {data_source_id}") + + if not data_source_id: + raise ValueError(f"No suitable data source found for knowledge base {kb_id}") + + return data_source_id, source_type + def get_data_source_id(self, kb_id: str) -> str: """ Get the data source ID for a knowledge base. @@ -212,16 +270,22 @@ def get_document(self, kb_id: str, data_source_id: str = None, document_id: str Returns: Response from the get_knowledge_base_documents API call """ - # Get the data source ID if not provided - if not data_source_id: - data_source_id = self.get_data_source_id(kb_id) - - # Use the get_knowledge_base_documents method - get_request = { - "knowledgeBaseId": kb_id, - "dataSourceId": data_source_id, - "documentIdentifiers": [{"dataSourceType": "CUSTOM", "custom": {"id": document_id}}], - } + # Use helper method to detect data source type + data_source_id, source_type = self._detect_data_source_type(kb_id) + + # Prepare get request based on the data source type + if source_type == "CUSTOM": + get_request = { + "knowledgeBaseId": kb_id, + "dataSourceId": data_source_id, + "documentIdentifiers": [{"dataSourceType": "CUSTOM", "custom": {"id": document_id}}], + } + elif source_type == "S3": + # For S3, we would need to construct the S3 URI identifier + # This is more complex and may require additional logic + raise ValueError("S3 data source type is not fully supported for document retrieval.") + else: + raise ValueError(f"Unsupported data source type: {source_type}") return self.agent_client.get_knowledge_base_documents(**get_request) @@ -238,10 +302,6 @@ def store_document(self, kb_id: str, data_source_id: str = None, content: str = Returns: Tuple of (response, document_id, document_title) """ - # Get the data source ID if not provided - if not data_source_id: - data_source_id = self.get_data_source_id(kb_id) - # Generate document ID with timestamp for traceability timestamp = time.strftime("%Y%m%d_%H%M%S") doc_id = f"memory_{timestamp}_{str(uuid.uuid4())[:8]}" @@ -256,34 +316,48 @@ def store_document(self, kb_id: str, data_source_id: str = None, content: str = "content": content, } - # Prepare document for ingestion - ingest_request = { - "knowledgeBaseId": kb_id, - "dataSourceId": data_source_id, - "documents": [ - { - "content": { - "dataSourceType": "CUSTOM", - "custom": { - "customDocumentIdentifier": {"id": doc_id}, - "inlineContent": { - "textContent": {"data": json.dumps(content_with_metadata)}, - "type": "TEXT", + # Use helper method to detect data source type + data_source_id, source_type = self._detect_data_source_type(kb_id) + + # Prepare document for ingestion based on the data source type + if source_type == "CUSTOM": + ingest_request = { + "knowledgeBaseId": kb_id, + "dataSourceId": data_source_id, + "documents": [ + { + "content": { + "dataSourceType": "CUSTOM", + "custom": { + "customDocumentIdentifier": {"id": doc_id}, + "inlineContent": { + "textContent": {"data": json.dumps(content_with_metadata)}, + "type": "TEXT", + }, + "sourceType": "IN_LINE", }, - "sourceType": "IN_LINE", - }, + } } - } - ], - } + ], + } + elif source_type == "S3": + # S3 source types need a different ingestion approach + raise ValueError("S3 data source type is not supported for direct ingestion with this tool.") + else: + raise ValueError(f"Unsupported data source type: {source_type}") # Ingest document into knowledge base response = self.agent_client.ingest_knowledge_base_documents(**ingest_request) + + # Log success + logger.debug(f"Successfully ingested document into knowledge base {kb_id}: {doc_id}") + return response, doc_id, doc_title def delete_document(self, kb_id: str, data_source_id: str = None, document_id: str = None): """ Delete a document from the knowledge base. + FIXED: Now handles multiple data source types like store_in_kb tool. Args: kb_id: Knowledge Base ID @@ -293,16 +367,22 @@ def delete_document(self, kb_id: str, data_source_id: str = None, document_id: s Returns: Response from the delete_knowledge_base_documents API call """ - # Get the data source ID if not provided - if not data_source_id: - data_source_id = self.get_data_source_id(kb_id) - - # Prepare delete request - delete_request = { - "knowledgeBaseId": kb_id, - "dataSourceId": data_source_id, - "documentIdentifiers": [{"dataSourceType": "CUSTOM", "custom": {"id": document_id}}], - } + # Use helper method to detect data source type + data_source_id, source_type = self._detect_data_source_type(kb_id) + + # Prepare delete request based on the data source type + if source_type == "CUSTOM": + delete_request = { + "knowledgeBaseId": kb_id, + "dataSourceId": data_source_id, + "documentIdentifiers": [{"dataSourceType": "CUSTOM", "custom": {"id": document_id}}], + } + elif source_type == "S3": + # For S3, we would need to construct the S3 URI identifier + # This is more complex and may require additional logic + raise ValueError("S3 data source type is not fully supported for document deletion.") + else: + raise ValueError(f"Unsupported data source type: {source_type}") # Delete document from knowledge base return self.agent_client.delete_knowledge_base_documents(**delete_request) diff --git a/tests/test_memory/test_memory_client.py b/tests/test_memory/test_memory_client.py index d0c3c27b..c8131477 100644 --- a/tests/test_memory/test_memory_client.py +++ b/tests/test_memory/test_memory_client.py @@ -1,6 +1,4 @@ -""" -Tests for the MemoryServiceClient class in memory.py. -""" +"""Tests for the MemoryServiceClient class in memory.py.""" import json import os @@ -185,6 +183,9 @@ def test_list_documents_with_defaults(mock_session): # Mock get_data_source_id agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + # Mock get_data_source for _detect_data_source_type method + agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}} + # Initialize client client = MemoryServiceClient() @@ -228,6 +229,9 @@ def test_get_document(mock_session): # Mock get_data_source_id agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + # Mock get_data_source for _detect_data_source_type method + agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}} + # Initialize client client = MemoryServiceClient() @@ -254,6 +258,12 @@ def test_store_document(mock_session): # Mock get_data_source_id agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + # Mock get_data_source for _detect_data_source_type method + agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}} + + # Mock get_data_source for _detect_data_source_type method + agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}} + # Mock ingest response agent_client.ingest_knowledge_base_documents.return_value = {"status": "success"} @@ -299,6 +309,12 @@ def test_store_document_no_title(mock_session): # Mock get_data_source_id agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + # Mock get_data_source for _detect_data_source_type method + agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}} + + # Mock get_data_source for _detect_data_source_type method + agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}} + # Mock ingest response agent_client.ingest_knowledge_base_documents.return_value = {"status": "success"} @@ -335,6 +351,12 @@ def test_delete_document(mock_session): # Mock get_data_source_id agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + # Mock get_data_source for _detect_data_source_type method + agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}} + + # Mock get_data_source for _detect_data_source_type method + agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}} + # Mock delete response agent_client.delete_knowledge_base_documents.return_value = {"status": "success"}