Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 122 additions & 4 deletions src/postgres_mcp/server.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# ruff: noqa: B008
import argparse
import asyncio
import csv
import io
import json
import logging
import os
import signal
import sys
from datetime import date, datetime, timedelta
from decimal import Decimal
from enum import Enum
from typing import Any
from typing import List
Expand Down Expand Up @@ -72,7 +77,108 @@ async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]:

def format_text_response(text: Any) -> ResponseType:
"""Format a text response."""
return [types.TextContent(type="text", text=str(text))]
def json_serializer(obj):
"""Custom JSON serializer for PostgreSQL/Python types."""
# Handle datetime types
if isinstance(obj, (datetime, date)):
return obj.isoformat()
# Handle time (without date)
elif hasattr(obj, 'isoformat') and callable(obj.isoformat):
return obj.isoformat()
# Handle Decimal
elif isinstance(obj, Decimal):
return float(obj)
# Handle timedelta (PostgreSQL INTERVAL type)
elif isinstance(obj, timedelta):
return obj.total_seconds() # Return as seconds (number)
# Handle bytes (PostgreSQL BYTEA type)
elif isinstance(obj, (bytes, bytearray)):
import base64
return base64.b64encode(obj).decode('ascii') # Base64 encode binary data
# Handle memoryview
elif isinstance(obj, memoryview):
import base64
return base64.b64encode(obj.tobytes()).decode('ascii')
# Handle UUID
elif hasattr(obj, 'hex'): # UUID objects have a hex attribute
return str(obj)
# Default: convert to string
return str(obj)

# Convert lists and dicts to JSON, everything else to string
if isinstance(text, (list, dict)):
text = json.dumps(text, default=json_serializer, ensure_ascii=False)
else:
text = str(text)

return [types.TextContent(type="text", text=text)]


def format_csv_response(data: Any) -> ResponseType:
"""Format a response as CSV."""
def csv_value_converter(obj):
"""Convert PostgreSQL/Python types to CSV-friendly strings."""
# Handle datetime types
if isinstance(obj, (datetime, date)):
return obj.isoformat()
# Handle time (without date)
elif hasattr(obj, 'isoformat') and callable(obj.isoformat):
return obj.isoformat()
# Handle Decimal
elif isinstance(obj, Decimal):
return str(obj) # Keep full precision for CSV
# Handle timedelta (PostgreSQL INTERVAL type)
elif isinstance(obj, timedelta):
return str(obj.total_seconds()) # Return as seconds string
# Handle bytes (PostgreSQL BYTEA type)
elif isinstance(obj, (bytes, bytearray)):
import base64
return base64.b64encode(obj).decode('ascii')
# Handle memoryview
elif isinstance(obj, memoryview):
import base64
return base64.b64encode(obj.tobytes()).decode('ascii')
# Handle UUID
elif hasattr(obj, 'hex'): # UUID objects have a hex attribute
return str(obj)
# Handle None
elif obj is None:
return ""
# Handle lists/arrays - convert to pipe-separated string
elif isinstance(obj, list):
return "|".join(str(csv_value_converter(item)) for item in obj)
# Handle dicts - convert to JSON string
elif isinstance(obj, dict):
return json.dumps(obj)
# Default: convert to string
return str(obj)

if not isinstance(data, list) or not data:
return [types.TextContent(type="text", text="")]

# Create CSV output
output = io.StringIO()
writer = csv.writer(output)

# Write header row (column names)
if isinstance(data[0], dict):
headers = list(data[0].keys())
writer.writerow(headers)

# Write data rows
for row in data:
converted_row = [csv_value_converter(row.get(header)) for header in headers]
writer.writerow(converted_row)
else:
# Handle non-dict data
writer.writerow(["value"])
for item in data:
writer.writerow([csv_value_converter(item)])

csv_text = output.getvalue()
output.close()

return [types.TextContent(type="text", text=csv_text)]


def format_error_response(error: str) -> ResponseType:
Expand Down Expand Up @@ -389,14 +495,26 @@ async def explain_query(
# Query function declaration without the decorator - we'll add it dynamically based on access mode
async def execute_sql(
sql: str = Field(description="SQL to run", default="all"),
output_format: Literal["json", "csv"] = Field(description="Output format: 'json' (default) or 'csv'", default="json"),
) -> ResponseType:
"""Executes a SQL query against the database."""
"""Executes a SQL query against the database and returns results in JSON or CSV format."""
try:
sql_driver = await get_sql_driver()
rows = await sql_driver.execute_query(sql) # type: ignore
if rows is None:
return format_text_response("No results")
return format_text_response(list([r.cells for r in rows]))
if output_format == "csv":
return format_csv_response([])
else:
return format_text_response("No results")

# Convert rows to list of dictionaries
result_data = list([r.cells for r in rows])

# Format based on requested output format
if output_format == "csv":
return format_csv_response(result_data)
else:
return format_text_response(result_data)
except Exception as e:
logger.error(f"Error executing query: {e}")
return format_error_response(str(e))
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/test_execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Tests for execute_sql function with JSON and CSV output formats."""

import json
from datetime import datetime
from decimal import Decimal
from unittest.mock import AsyncMock, patch

import pytest

from postgres_mcp.server import execute_sql


class MockRow:
def __init__(self, cells):
self.cells = cells


@pytest.mark.asyncio
async def test_execute_sql_json_output():
"""Test execute_sql outputs valid JSON (not Python repr format)."""
mock_driver = AsyncMock()
mock_driver.execute_query.return_value = [
MockRow({"id": 1, "salary": Decimal('50000.00'), "created_at": datetime(2023, 1, 1)})
]

with patch('postgres_mcp.server.get_sql_driver', return_value=mock_driver):
result = await execute_sql("SELECT * FROM users")

# Should return valid JSON
parsed = json.loads(result[0].text)
assert parsed[0]["salary"] == 50000.0 # Decimal -> float, not repr
assert parsed[0]["created_at"] == "2023-01-01T00:00:00" # ISO format


@pytest.mark.asyncio
async def test_execute_sql_csv_output():
"""Test execute_sql outputs CSV format."""
mock_driver = AsyncMock()
mock_driver.execute_query.return_value = [
MockRow({"id": 1, "name": "John", "salary": Decimal('50000.00')})
]

with patch('postgres_mcp.server.get_sql_driver', return_value=mock_driver):
result = await execute_sql("SELECT * FROM users", output_format="csv")

lines = result[0].text.strip().split('\n')
assert len(lines) == 2 # Header + data
assert "id" in lines[0]
assert "50000.00" in lines[1] # Decimal precision preserved