From f1ec2a0ff0cde1982db180d42f09965d823e6050 Mon Sep 17 00:00:00 2001 From: Ryadh DAHIMENE Date: Mon, 8 Sep 2025 17:48:23 +0200 Subject: [PATCH] Add a memory mode backed by ClickHouse --- README.md | 36 +++++ mcp_clickhouse/mcp_env.py | 8 + mcp_clickhouse/mcp_server.py | 289 +++++++++++++++++++++++++++++++++++ pyproject.toml | 5 + tests/test_tool.py | 172 +++++++++++++++++++++ uv.lock | 10 ++ 6 files changed, 520 insertions(+) diff --git a/README.md b/README.md index d83294f..97c63aa 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,39 @@ An MCP server for ClickHouse. * Input: `sql` (string): The SQL query to execute. * Query data directly from various sources (files, URLs, databases) without ETL processes. +### Memory Tools (Experimental) + +> [!WARNING] +> Memory tools are an experimental feature and may change in future versions. + +When enabled via the `CLICKHOUSE_MEMORY=true` environment variable, the following memory management tools become available: + +* `save_memory` + * Store user-provided information as key-value pairs for later retrieval and reference. + * Input: `key` (string): A concise, descriptive key that summarizes the content. + * Input: `value` (string): The information to store. + +* `get_memories_titles` + * Retrieve all memory keys/titles to see what information has been stored. + * Returns a list of all stored memory keys with timestamps. + +* `get_memory` + * Retrieve all memory entries matching a specific key. + * Input: `key` (string): The key to search for. + * Returns all memories associated with that key, ordered by most recent first. + +* `get_all_memories` + * Retrieve all saved memories from the memory table. + * Input: None + * **Warning**: Should only be used when explicitly requested, as it may return large amounts of data. + +* `delete_memory` + * Delete all memory entries matching a specific key. + * Input: `key` (string): The key of memories to delete. + * **Warning**: Should only be used when explicitly requested by the user. + +These tools use ClickHouse to store memories in a `user_memory` table, allowing information to persist across sessions. + ### Health Check Endpoint When running with HTTP or SSE transport, a health check endpoint is available at `/health`. This endpoint: @@ -317,6 +350,9 @@ The following environment variables are used to configure the ClickHouse and chD * `CLICKHOUSE_ENABLED`: Enable/disable ClickHouse functionality * Default: `"true"` * Set to `"false"` to disable ClickHouse tools when using chDB only +* `CLICKHOUSE_MEMORY`: Enable/disable memory tools (experimental) + * Default: `"false"` + * Set to `"true"` to enable memory management tools for storing key-value data #### chDB Variables diff --git a/mcp_clickhouse/mcp_env.py b/mcp_clickhouse/mcp_env.py index 40c0424..3009b93 100644 --- a/mcp_clickhouse/mcp_env.py +++ b/mcp_clickhouse/mcp_env.py @@ -162,6 +162,14 @@ def mcp_bind_port(self) -> int: """ return int(os.getenv("CLICKHOUSE_MCP_BIND_PORT", "8000")) + @property + def memory_enabled(self) -> bool: + """Get whether memory tools are enabled. + + Default: False + """ + return os.getenv("CLICKHOUSE_MEMORY", "false").lower() == "true" + def get_client_config(self) -> dict: """Get the configuration dictionary for clickhouse_connect client. diff --git a/mcp_clickhouse/mcp_server.py b/mcp_clickhouse/mcp_server.py index 589ff2e..493e593 100644 --- a/mcp_clickhouse/mcp_server.py +++ b/mcp_clickhouse/mcp_server.py @@ -267,6 +267,283 @@ def get_readonly_setting(client) -> str: return "1" # Default to basic read-only mode if setting isn't present +def execute_write_query(query: str): + """This function bypasses the read-only mode and allows write queries to be executed. + + TODO: Find a sustainable way to execute write queries. + + Args: + query: The write query to execute + + Returns: + The result of the write query + """ + client = create_clickhouse_client() + try: + res = client.command(query) + logger.info("Write query executed successfully") + return res + except Exception as err: + logger.error(f"Error executing write query: {err}") + return {"error": str(err)} + + +def save_memory(key: str, value: str): + """Store user-provided information as key-value pairs for later retrieval and reference. Generate a concise, descriptive key that summarizes the value content provided.""" + logger.info(f"Saving memory with key: {key}") + + try: + # Create table if it doesn't exist + logger.info("Ensuring user_memory table exists") + create_table_query = """ + CREATE TABLE IF NOT EXISTS user_memory ( + key String, + value String, + created_at DateTime DEFAULT now(), + updated_at DateTime DEFAULT now() + ) ENGINE = MergeTree() + ORDER BY key + """ + + create_result = execute_write_query(create_table_query) + if isinstance(create_result, dict) and "error" in create_result: + return { + "status": "error", + "message": f"Failed to create user_memory table: {create_result['error']}" + } + + # Insert or replace the memory data using REPLACE INTO for upsert behavior + insert_query = f""" + INSERT INTO user_memory (key, value, updated_at) + VALUES ({format_query_value(key)}, {format_query_value(value)}, now()) + """ + + insert_result = execute_write_query(insert_query) + if isinstance(insert_result, dict) and "error" in insert_result: + return { + "status": "error", + "message": f"Failed to save memory: {insert_result['error']}" + } + + logger.info(f"Successfully saved memory with key: {key}") + return { + "status": "success", + "message": f"Memory '{key}' saved successfully", + "key": key + } + + except Exception as e: + logger.error(f"Unexpected error in save_memory: {str(e)}") + return {"status": "error", "message": f"Unexpected error: {str(e)}"} + + +def get_memories_titles(): + """Retrieve all memory keys/titles from the user memory table to see what information has been stored.""" + logger.info("Retrieving all memory titles") + + try: + # Query to get all keys with their timestamps + query = """ + SELECT key, created_at, updated_at + FROM user_memory + ORDER BY updated_at DESC + """ + + result = execute_query(query) + + # Check if we received an error structure from execute_query + if isinstance(result, dict) and "error" in result: + return { + "status": "error", + "message": f"Failed to retrieve memory titles: {result['error']}" + } + + # Extract just the keys for the response from the new result format + rows = result.get("rows", []) + titles = [row[0] for row in rows] if rows else [] + + # Convert rows to dict format for details + columns = result.get("columns", []) + details = [] + for row in rows: + row_dict = {} + for i, col_name in enumerate(columns): + row_dict[col_name] = row[i] + details.append(row_dict) + + logger.info(f"Retrieved {len(titles)} memory titles") + return { + "status": "success", + "titles": titles, + "count": len(titles), + "details": details # Include full details with timestamps + } + + except Exception as e: + logger.error(f"Unexpected error in get_memories_titles: {str(e)}") + return {"status": "error", "message": f"Unexpected error: {str(e)}"} + + +def get_memory(key: str): + """Retrieve all memory entries matching the specified key from the user memory table.""" + logger.info(f"Retrieving memory for key: {key}") + + try: + # Query to get all memories matching the key, ordered by most recent first + query = f""" + SELECT key, value, created_at, updated_at + FROM user_memory + WHERE key = {format_query_value(key)} + ORDER BY updated_at DESC + """ + + result = execute_query(query) + + # Check if we received an error structure from execute_query + if isinstance(result, dict) and "error" in result: + return { + "status": "error", + "message": f"Failed to retrieve memory: {result['error']}" + } + + # Convert to dict format + columns = result.get("columns", []) + rows = result.get("rows", []) + memories = [] + for row in rows: + row_dict = {} + for i, col_name in enumerate(columns): + row_dict[col_name] = row[i] + memories.append(row_dict) + + # Check if memory exists + if not memories: + logger.info(f"No memory found for key: {key}") + return { + "status": "not_found", + "message": f"No memory found with key '{key}'", + "key": key + } + + # Return all matching memories + logger.info(f"Successfully retrieved {len(memories)} memories for key: {key}") + + return { + "status": "success", + "key": key, + "count": len(memories), + "memories": memories + } + + except Exception as e: + logger.error(f"Unexpected error in get_memory: {str(e)}") + return {"status": "error", "message": f"Unexpected error: {str(e)}"} + + +def get_all_memories(): + """Retrieve all saved memories from the user memory table, don't list them back, just the give the number of memories retrieved. WARNING: This tool should only be used when explicitly requested by the user, as it may return large amounts of data.""" + logger.info("Retrieving all memories") + + try: + # Query to get all memories ordered by most recent first + query = """ + SELECT key, value, created_at, updated_at + FROM user_memory + ORDER BY updated_at DESC + """ + + result = execute_query(query) + + # Check if we received an error structure from execute_query + if isinstance(result, dict) and "error" in result: + return { + "status": "error", + "message": f"Failed to retrieve all memories: {result['error']}" + } + + # Convert to dict format + columns = result.get("columns", []) + rows = result.get("rows", []) + memories = [] + for row in rows: + row_dict = {} + for i, col_name in enumerate(columns): + row_dict[col_name] = row[i] + memories.append(row_dict) + + # Return all memories + logger.info(f"Successfully retrieved {len(memories)} total memories") + + return { + "status": "success", + "count": len(memories), + "memories": memories + } + + except Exception as e: + logger.error(f"Unexpected error in get_all_memories: {str(e)}") + return {"status": "error", "message": f"Unexpected error: {str(e)}"} + + +def delete_memory(key: str): + """Delete all memory entries matching the specified key from the user memory table. Warining this tool should only be used when explicitly requested by the user""" + logger.info(f"Deleting memory for key: {key}") + + try: + # First check if memories exist for this key + check_query = f""" + SELECT count() + FROM user_memory + WHERE key = {format_query_value(key)} + """ + + check_result = execute_query(check_query) + + # Check if we received an error structure from execute_query + if isinstance(check_result, dict) and "error" in check_result: + return { + "status": "error", + "message": f"Failed to check memory existence: {check_result['error']}" + } + + # Check if any memories exist - handle new result format + rows = check_result.get("rows", []) + if not rows or len(rows) == 0 or rows[0][0] == 0: + logger.info(f"No memories found for key: {key}") + return { + "status": "not_found", + "message": f"No memories found with key '{key}'", + "key": key + } + + memories_count = rows[0][0] + + # Delete the memories + delete_query = f""" + ALTER TABLE user_memory + DELETE WHERE key = {format_query_value(key)} + """ + + delete_result = execute_write_query(delete_query) + if isinstance(delete_result, dict) and "error" in delete_result: + return { + "status": "error", + "message": f"Failed to delete memories: {delete_result['error']}" + } + + logger.info(f"Successfully deleted {memories_count} memories for key: {key}") + return { + "status": "success", + "message": f"Deleted {memories_count} memory entries with key '{key}'", + "key": key, + "deleted_count": memories_count + } + + except Exception as e: + logger.error(f"Unexpected error in delete_memory: {str(e)}") + return {"status": "error", "message": f"Unexpected error: {str(e)}"} + + def create_chdb_client(): """Create a chDB client connection.""" if not get_chdb_config().enabled: @@ -370,3 +647,15 @@ def _init_chdb_client(): ) mcp.add_prompt(chdb_prompt) logger.info("chDB tools and prompts registered") + +# Conditionally register memory tools based on CLICKHOUSE_MEMORY flag +config = get_config() +if config.memory_enabled: + logger.info("Memory tools enabled - registering memory management tools") + mcp.add_tool(Tool.from_function(save_memory)) + mcp.add_tool(Tool.from_function(get_memories_titles)) + mcp.add_tool(Tool.from_function(get_memory)) + mcp.add_tool(Tool.from_function(get_all_memories)) + mcp.add_tool(Tool.from_function(delete_memory)) +else: + logger.info("Memory tools disabled - set CLICKHOUSE_MEMORY=true to enable") diff --git a/pyproject.toml b/pyproject.toml index 5c81e63..ec1c0b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,3 +36,8 @@ line-length = 100 [build-system] requires = ["hatchling"] build-backend = "hatchling.build" + +[dependency-groups] +dev = [ + "pytest>=8.4.1", +] diff --git a/tests/test_tool.py b/tests/test_tool.py index 50878c4..f0e8bb0 100644 --- a/tests/test_tool.py +++ b/tests/test_tool.py @@ -5,6 +5,7 @@ from fastmcp.exceptions import ToolError from mcp_clickhouse import create_clickhouse_client, list_databases, list_tables, run_select_query +from mcp_clickhouse.mcp_server import save_memory, get_memories_titles, get_memory, get_all_memories, delete_memory load_dotenv() @@ -36,11 +37,35 @@ def setUpClass(cls): INSERT INTO {cls.test_db}.{cls.test_table} (id, name) VALUES (1, 'Alice'), (2, 'Bob') """) + # Set up memory table for memory tests + cls.client.command(f"DROP TABLE IF EXISTS {cls.test_db}.user_memory") + cls.client.command(f""" + CREATE TABLE {cls.test_db}.user_memory ( + key String, + value String, + created_at DateTime DEFAULT now(), + updated_at DateTime DEFAULT now() + ) ENGINE = MergeTree() + ORDER BY key + """) + @classmethod def tearDownClass(cls): """Clean up the environment after tests.""" cls.client.command(f"DROP DATABASE IF EXISTS {cls.test_db}") + def setUp(self): + """Set up clean state before each memory test.""" + # Clear memory table before each test - need to delete from default database too + try: + self.client.command("TRUNCATE TABLE user_memory") + except Exception: + pass # Table may not exist in default database + try: + self.client.command(f"TRUNCATE TABLE {self.test_db}.user_memory") + except Exception: + pass # Table may not exist in test database + def test_list_databases(self): """Test listing databases.""" result = list_databases() @@ -98,6 +123,153 @@ def test_table_and_column_comments(self): self.assertEqual(columns["id"]["comment"], "Primary identifier") self.assertEqual(columns["name"]["comment"], "User name field") + def test_save_memory_success(self): + """Test saving memory successfully.""" + result = save_memory("Test Key", "Test Value") + self.assertEqual(result["status"], "success") + self.assertEqual(result["key"], "Test Key") + self.assertIn("saved successfully", result["message"]) + + def test_save_memory_multiple_same_key(self): + """Test saving multiple memories with the same key.""" + save_memory("Duplicate Key", "First Value") + result = save_memory("Duplicate Key", "Second Value") + self.assertEqual(result["status"], "success") + self.assertEqual(result["key"], "Duplicate Key") + + def test_get_memories_titles_empty(self): + """Test getting titles from empty memory table.""" + result = get_memories_titles() + self.assertEqual(result["status"], "success") + self.assertEqual(result["count"], 0) + self.assertEqual(result["titles"], []) + + def test_get_memories_titles_with_data(self): + """Test getting titles with data in memory table.""" + save_memory("Key 1", "Value 1") + save_memory("Key 2", "Value 2") + + result = get_memories_titles() + self.assertEqual(result["status"], "success") + self.assertEqual(result["count"], 2) + self.assertIn("Key 1", result["titles"]) + self.assertIn("Key 2", result["titles"]) + + def test_get_memory_success(self): + """Test retrieving an existing memory.""" + save_memory("Retrieve Key", "Retrieve Value") + + result = get_memory("Retrieve Key") + self.assertEqual(result["status"], "success") + self.assertEqual(result["key"], "Retrieve Key") + self.assertEqual(result["count"], 1) + self.assertEqual(result["memories"][0]["value"], "Retrieve Value") + + def test_get_memory_not_found(self): + """Test retrieving a non-existent memory.""" + result = get_memory("Non-existent Key") + self.assertEqual(result["status"], "not_found") + self.assertEqual(result["key"], "Non-existent Key") + + def test_get_memory_multiple_entries(self): + """Test retrieving multiple memories with the same key.""" + save_memory("Multi Key", "First Value") + save_memory("Multi Key", "Second Value") + + result = get_memory("Multi Key") + self.assertEqual(result["status"], "success") + self.assertEqual(result["count"], 2) + self.assertEqual(len(result["memories"]), 2) + + def test_get_all_memories_empty(self): + """Test getting all memories from empty table.""" + result = get_all_memories() + self.assertEqual(result["status"], "success") + self.assertEqual(result["count"], 0) + + def test_get_all_memories_with_data(self): + """Test getting all memories with data.""" + save_memory("All Key 1", "All Value 1") + save_memory("All Key 2", "All Value 2") + + result = get_all_memories() + self.assertEqual(result["status"], "success") + self.assertEqual(result["count"], 2) + self.assertEqual(len(result["memories"]), 2) + + def test_delete_memory_success(self): + """Test deleting an existing memory.""" + save_memory("Delete Key", "Delete Value") + + result = delete_memory("Delete Key") + self.assertEqual(result["status"], "success") + self.assertEqual(result["key"], "Delete Key") + self.assertEqual(result["deleted_count"], 1) + + # Verify it's actually deleted + get_result = get_memory("Delete Key") + self.assertEqual(get_result["status"], "not_found") + + def test_delete_memory_not_found(self): + """Test deleting a non-existent memory.""" + result = delete_memory("Non-existent Delete Key") + self.assertEqual(result["status"], "not_found") + self.assertEqual(result["key"], "Non-existent Delete Key") + + def test_memory_workflow_integration(self): + """Test complete memory workflow integration.""" + # Save memory + save_result = save_memory("Workflow Key", "Workflow Value") + self.assertEqual(save_result["status"], "success") + + # Check it appears in titles + titles_result = get_memories_titles() + self.assertIn("Workflow Key", titles_result["titles"]) + + # Retrieve it + get_result = get_memory("Workflow Key") + self.assertEqual(get_result["status"], "success") + self.assertEqual(get_result["memories"][0]["value"], "Workflow Value") + + # Update it (save again) + update_result = save_memory("Workflow Key", "Updated Value") + self.assertEqual(update_result["status"], "success") + + # Verify update + get_updated = get_memory("Workflow Key") + self.assertEqual(get_updated["count"], 2) # Now has 2 entries + + # Delete it + delete_result = delete_memory("Workflow Key") + self.assertEqual(delete_result["status"], "success") + self.assertEqual(delete_result["deleted_count"], 2) + + # Verify deletion + final_get = get_memory("Workflow Key") + self.assertEqual(final_get["status"], "not_found") + + +class TestMemoryFlag(unittest.TestCase): + """Test CLICKHOUSE_MEMORY flag functionality.""" + + def test_memory_flag_controls_tool_availability(self): + """Test that CLICKHOUSE_MEMORY flag controls whether memory tools are available.""" + # This test documents the expected behavior but cannot directly test + # the conditional registration since it happens at module import time + # In practice, this would be tested by running the server with different + # CLICKHOUSE_MEMORY flag values and checking available tools + + # For now, we just verify the memory functions exist and work when flag is enabled + # (which it must be for this test to run) + from mcp_clickhouse.mcp_server import save_memory, get_memories_titles, get_memory, get_all_memories, delete_memory + + # These should be callable functions when CLICKHOUSE_MEMORY=true + self.assertTrue(callable(save_memory)) + self.assertTrue(callable(get_memories_titles)) + self.assertTrue(callable(get_memory)) + self.assertTrue(callable(get_all_memories)) + self.assertTrue(callable(delete_memory)) + if __name__ == "__main__": unittest.main() diff --git a/uv.lock b/uv.lock index 91eb32a..eb50fb9 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.12'", @@ -450,6 +451,11 @@ dev = [ { name = "ruff" }, ] +[package.dev-dependencies] +dev = [ + { name = "pytest" }, +] + [package.metadata] requires-dist = [ { name = "chdb", specifier = ">=3.3.0" }, @@ -461,6 +467,10 @@ requires-dist = [ { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "ruff", marker = "extra == 'dev'" }, ] +provides-extras = ["dev"] + +[package.metadata.requires-dev] +dev = [{ name = "pytest", specifier = ">=8.4.1" }] [[package]] name = "mdurl"