Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 127 additions & 47 deletions src/strands_tools/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Tool for managing data in Bedrock Knowledge Base (store, delete, list, get, and retrieve)
"""Tool for managing data in Bedrock Knowledge Base (store, delete, list, get, and retrieve)

This module provides comprehensive Knowledge Base management capabilities for
Amazon Bedrock Knowledge Bases. It handles all aspects of document management with
Expand All @@ -14,22 +13,30 @@
• get: Retrieve specific documents by document ID
• retrieve: Perform semantic search across all documents

2. Safety Features:
2. Data Source Support:
• Detects CUSTOM data source types
• Falls back to first available data source if no CUSTOM found
• Provides clear error messages for unsupported data source types
• Currently supports CUSTOM data sources for direct ingestion
• S3 and other data source types show clear error messages

3. Safety Features:
• User confirmation for mutative operations
• Content previews before storage
• Warning messages before deletion
• BYPASS_TOOL_CONSENT mode for bypassing confirmations in tests

3. Advanced Capabilities:
4. Advanced Capabilities:
• Automatic document ID generation
• Structured content storage with metadata
• Semantic search with relevance filtering
• Rich output formatting
• Pagination support

4. Error Handling:
5. Error Handling:
• Knowledge Base ID validation
• Parameter validation
• Data source type detection and validation
• Graceful API error handling
• Clear error messages

Expand Down Expand Up @@ -148,6 +155,57 @@ def runtime_client(self):
self._runtime_client = self.session.client("bedrock-agent-runtime", region_name=self.region)
return self._runtime_client

def _detect_data_source_type(self, kb_id: str):
"""
Helper method to detect data source type for a knowledge base.

This method implements the same logic as store_in_kb tool:
1. Look for CUSTOM data source first (preferred)
2. Fall back to first available data source if no CUSTOM found
3. Log appropriate messages for debugging

Args:
kb_id: Knowledge Base ID

Returns:
Tuple of (data_source_id, source_type)

Raises:
ValueError: If no data sources are found
"""
# Get data source details to determine the type
data_sources = self.agent_client.list_data_sources(knowledgeBaseId=kb_id)

if data_sources and not data_sources.get("dataSourceSummaries"):
raise ValueError(f"No data sources found for knowledge base {kb_id}")

# Look for CUSTOM data source first, then fallback
data_source_id = None
source_type = None

for ds in data_sources["dataSourceSummaries"]:
# Get the data source details to check its type
ds_detail = self.agent_client.get_data_source(knowledgeBaseId=kb_id, dataSourceId=ds["dataSourceId"])

# Check if this is a CUSTOM type data source
if ds_detail["dataSource"]["dataSourceConfiguration"]["type"] == "CUSTOM":
data_source_id = ds["dataSourceId"]
source_type = "CUSTOM"
logger.debug(f"Found CUSTOM data source: {data_source_id}")
break

# If no CUSTOM data source found, use the first available one but log a warning
if not data_source_id and data_sources["dataSourceSummaries"]:
data_source_id = data_sources["dataSourceSummaries"][0]["dataSourceId"]
ds_detail = self.agent_client.get_data_source(knowledgeBaseId=kb_id, dataSourceId=data_source_id)
source_type = ds_detail["dataSource"]["dataSourceConfiguration"]["type"]
logger.debug(f"No CUSTOM data source found. Using {source_type} data source: {data_source_id}")

if not data_source_id:
raise ValueError(f"No suitable data source found for knowledge base {kb_id}")

return data_source_id, source_type

def get_data_source_id(self, kb_id: str) -> str:
"""
Get the data source ID for a knowledge base.
Expand Down Expand Up @@ -212,16 +270,22 @@ def get_document(self, kb_id: str, data_source_id: str = None, document_id: str
Returns:
Response from the get_knowledge_base_documents API call
"""
# Get the data source ID if not provided
if not data_source_id:
data_source_id = self.get_data_source_id(kb_id)

# Use the get_knowledge_base_documents method
get_request = {
"knowledgeBaseId": kb_id,
"dataSourceId": data_source_id,
"documentIdentifiers": [{"dataSourceType": "CUSTOM", "custom": {"id": document_id}}],
}
# Use helper method to detect data source type
data_source_id, source_type = self._detect_data_source_type(kb_id)

# Prepare get request based on the data source type
if source_type == "CUSTOM":
get_request = {
"knowledgeBaseId": kb_id,
"dataSourceId": data_source_id,
"documentIdentifiers": [{"dataSourceType": "CUSTOM", "custom": {"id": document_id}}],
}
elif source_type == "S3":
# For S3, we would need to construct the S3 URI identifier
# This is more complex and may require additional logic
raise ValueError("S3 data source type is not fully supported for document retrieval.")
else:
raise ValueError(f"Unsupported data source type: {source_type}")

return self.agent_client.get_knowledge_base_documents(**get_request)

Expand All @@ -238,10 +302,6 @@ def store_document(self, kb_id: str, data_source_id: str = None, content: str =
Returns:
Tuple of (response, document_id, document_title)
"""
# Get the data source ID if not provided
if not data_source_id:
data_source_id = self.get_data_source_id(kb_id)

# Generate document ID with timestamp for traceability
timestamp = time.strftime("%Y%m%d_%H%M%S")
doc_id = f"memory_{timestamp}_{str(uuid.uuid4())[:8]}"
Expand All @@ -256,34 +316,48 @@ def store_document(self, kb_id: str, data_source_id: str = None, content: str =
"content": content,
}

# Prepare document for ingestion
ingest_request = {
"knowledgeBaseId": kb_id,
"dataSourceId": data_source_id,
"documents": [
{
"content": {
"dataSourceType": "CUSTOM",
"custom": {
"customDocumentIdentifier": {"id": doc_id},
"inlineContent": {
"textContent": {"data": json.dumps(content_with_metadata)},
"type": "TEXT",
# Use helper method to detect data source type
data_source_id, source_type = self._detect_data_source_type(kb_id)

# Prepare document for ingestion based on the data source type
if source_type == "CUSTOM":
ingest_request = {
"knowledgeBaseId": kb_id,
"dataSourceId": data_source_id,
"documents": [
{
"content": {
"dataSourceType": "CUSTOM",
"custom": {
"customDocumentIdentifier": {"id": doc_id},
"inlineContent": {
"textContent": {"data": json.dumps(content_with_metadata)},
"type": "TEXT",
},
"sourceType": "IN_LINE",
},
"sourceType": "IN_LINE",
},
}
}
}
],
}
],
}
elif source_type == "S3":
# S3 source types need a different ingestion approach
raise ValueError("S3 data source type is not supported for direct ingestion with this tool.")
else:
raise ValueError(f"Unsupported data source type: {source_type}")

# Ingest document into knowledge base
response = self.agent_client.ingest_knowledge_base_documents(**ingest_request)

# Log success
logger.debug(f"Successfully ingested document into knowledge base {kb_id}: {doc_id}")

return response, doc_id, doc_title

def delete_document(self, kb_id: str, data_source_id: str = None, document_id: str = None):
"""
Delete a document from the knowledge base.
FIXED: Now handles multiple data source types like store_in_kb tool.

Args:
kb_id: Knowledge Base ID
Expand All @@ -293,16 +367,22 @@ def delete_document(self, kb_id: str, data_source_id: str = None, document_id: s
Returns:
Response from the delete_knowledge_base_documents API call
"""
# Get the data source ID if not provided
if not data_source_id:
data_source_id = self.get_data_source_id(kb_id)

# Prepare delete request
delete_request = {
"knowledgeBaseId": kb_id,
"dataSourceId": data_source_id,
"documentIdentifiers": [{"dataSourceType": "CUSTOM", "custom": {"id": document_id}}],
}
# Use helper method to detect data source type
data_source_id, source_type = self._detect_data_source_type(kb_id)

# Prepare delete request based on the data source type
if source_type == "CUSTOM":
delete_request = {
"knowledgeBaseId": kb_id,
"dataSourceId": data_source_id,
"documentIdentifiers": [{"dataSourceType": "CUSTOM", "custom": {"id": document_id}}],
}
elif source_type == "S3":
# For S3, we would need to construct the S3 URI identifier
# This is more complex and may require additional logic
raise ValueError("S3 data source type is not fully supported for document deletion.")
else:
raise ValueError(f"Unsupported data source type: {source_type}")

# Delete document from knowledge base
return self.agent_client.delete_knowledge_base_documents(**delete_request)
Expand Down
28 changes: 25 additions & 3 deletions tests/test_memory/test_memory_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Tests for the MemoryServiceClient class in memory.py.
"""
"""Tests for the MemoryServiceClient class in memory.py."""

import json
import os
Expand Down Expand Up @@ -185,6 +183,9 @@ def test_list_documents_with_defaults(mock_session):
# Mock get_data_source_id
agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]}

# Mock get_data_source for _detect_data_source_type method
agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}}

# Initialize client
client = MemoryServiceClient()

Expand Down Expand Up @@ -228,6 +229,9 @@ def test_get_document(mock_session):
# Mock get_data_source_id
agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]}

# Mock get_data_source for _detect_data_source_type method
agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}}

# Initialize client
client = MemoryServiceClient()

Expand All @@ -254,6 +258,12 @@ def test_store_document(mock_session):
# Mock get_data_source_id
agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]}

# Mock get_data_source for _detect_data_source_type method
agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}}

# Mock get_data_source for _detect_data_source_type method
agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}}

# Mock ingest response
agent_client.ingest_knowledge_base_documents.return_value = {"status": "success"}

Expand Down Expand Up @@ -299,6 +309,12 @@ def test_store_document_no_title(mock_session):
# Mock get_data_source_id
agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]}

# Mock get_data_source for _detect_data_source_type method
agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}}

# Mock get_data_source for _detect_data_source_type method
agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}}

# Mock ingest response
agent_client.ingest_knowledge_base_documents.return_value = {"status": "success"}

Expand Down Expand Up @@ -335,6 +351,12 @@ def test_delete_document(mock_session):
# Mock get_data_source_id
agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]}

# Mock get_data_source for _detect_data_source_type method
agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}}

# Mock get_data_source for _detect_data_source_type method
agent_client.get_data_source.return_value = {"dataSource": {"dataSourceConfiguration": {"type": "CUSTOM"}}}

# Mock delete response
agent_client.delete_knowledge_base_documents.return_value = {"status": "success"}

Expand Down
Loading