Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ dependencies = [
"python-dotenv>=0.23.0",
"httpx>=0.28.1",
"mcp[cli]>=1.9.3",
"requests"
"pydantic>=2",
"requests>=2"
]

[tool.pytest.ini_options]
pythonpath = "src tests"
pythonpath = ["src"]
27 changes: 27 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,33 @@
for f in registration_functions:
f(mcp)

# Temporary shim to normalize tool input schemas for Codex/Anthropic
async def _list_tools_with_shim():
tools = await original_list_tools()
for tool in tools:
schema = tool.inputSchema or {"type": "object", "properties": {}}
if schema.get("properties") and list(schema["properties"].keys()) == ["args"]:
schema = schema["properties"]["args"]
schema.setdefault("type", "object")
schema.setdefault("additionalProperties", False)

def _fix(node):
if isinstance(node, dict):
if node.get("type") == "integer":
node["type"] = "number"
for v in node.values():
_fix(v)
elif isinstance(node, list):
for v in node:
_fix(v)

_fix(schema)
tool.inputSchema = schema
return tools

original_list_tools = mcp.list_tools
mcp.list_tools = _list_tools_with_shim

if __name__ == "__main__":
# Load the organization workspace.
OrganizationWorkspace.load()
Expand Down
118 changes: 118 additions & 0 deletions src/tool_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

from typing import Any, Dict, List, Type, Union, get_args, get_origin

from pydantic import BaseModel, Field, create_model, field_validator


class ArgsBaseModel(BaseModel):
"""Base class for tool argument models with strict object schema."""

model_config = {
"extra": "forbid",
"json_schema_extra": {"additionalProperties": False},
}


_MODEL_CACHE: Dict[type[BaseModel], type[BaseModel]] = {}


def _convert_annotation(
annotation: Any,
field_name: str,
validators: Dict[str, classmethod],
) -> Any:
origin = get_origin(annotation)
if origin is None:
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
return get_args_model(annotation)
if annotation is int:
# Replace ints with floats and create validator to coerce to int
@field_validator(field_name, mode="before")
def _coerce(cls, v):
return None if v is None else int(v)

validators[f"coerce_{field_name}"] = _coerce
return float
return annotation
if origin in (list, List):
inner = _convert_annotation(get_args(annotation)[0], field_name, validators)
return List[inner]
if origin is Union:
converted = [
_convert_annotation(arg, field_name, validators) for arg in get_args(annotation)
]
return Union[tuple(converted)]
return annotation


def get_args_model(model_cls: Type[BaseModel]) -> Type[BaseModel]:
"""Create an Args model for the given request model."""
if model_cls in _MODEL_CACHE:
return _MODEL_CACHE[model_cls]

fields: Dict[str, tuple[Any, Any]] = {}
validators: Dict[str, classmethod] = {}

for name, field in model_cls.model_fields.items():
ann = _convert_annotation(field.annotation, name, validators)
default = field.default if not field.is_required() else ...
ge = le = None
for meta in getattr(field, "metadata", []):
if hasattr(meta, "ge"):
ge = meta.ge
if hasattr(meta, "le"):
le = meta.le
fields[name] = (
ann,
Field(
default,
description=getattr(field, "description", None),
ge=ge,
le=le,
),
)

ArgsModel = create_model(
f"{model_cls.__name__}Args",
__base__=ArgsBaseModel,
__validators__=validators,
**fields,
)

_MODEL_CACHE[model_cls] = ArgsModel
return ArgsModel


from typing import Callable


def tool_with_args(
mcp,
request_model: Type[BaseModel] | None = None,
**decorator_kwargs: Any,
) -> Callable:
"""Decorator to wrap MCP tools with generated Args models."""

def decorator(func: Callable) -> Callable:
if request_model is None:
ArgsModel = ArgsBaseModel

async def inner(args: ArgsModel):
return await func()
else:
ArgsModel = get_args_model(request_model)

async def inner(args: ArgsModel):
model = request_model(**args.model_dump())
return await func(model)

inner.__name__ = func.__name__
inner.__doc__ = func.__doc__
inner.__annotations__ = {
"args": ArgsModel,
"return": func.__annotations__.get("return", Any),
}
return mcp.tool(**decorator_kwargs)(inner)

return decorator
4 changes: 3 additions & 1 deletion src/tools/account.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from api_connection import post
from models import AccountResponse
from tool_args import tool_with_args

def register_account_tools(mcp):
# Read
@mcp.tool(
@tool_with_args(
mcp,
annotations={
'title': 'Read account',
'readOnlyHint': True,
Expand Down
31 changes: 25 additions & 6 deletions src/tools/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
SyntaxCheckResponse,
SearchResponse
)
from tool_args import tool_with_args

def register_ai_tools(mcp):
# Get backtest initialization errors
@mcp.tool(
@tool_with_args(
mcp,
BasicFilesRequest,
annotations={
'title': 'Check initialization errors', 'readOnlyHint': True
}
Expand All @@ -29,14 +32,20 @@ async def check_initialization_errors(
return await post('/ai/tools/backtest-init', model)

# Complete code
@mcp.tool(annotations={'title': 'Complete code', 'readOnlyHint': True})
@tool_with_args(
mcp,
CodeCompletionRequest,
annotations={'title': 'Complete code', 'readOnlyHint': True}
)
async def complete_code(
model: CodeCompletionRequest) -> CodeCompletionResponse:
"""Show the code completion for a specific text input."""
return await post('/ai/tools/complete', model)

# Enchance error message
@mcp.tool(
@tool_with_args(
mcp,
ErrorEnhanceRequest,
annotations={'title': 'Enhance error message', 'readOnlyHint': True}
)
async def enhance_error_message(
Expand All @@ -45,7 +54,9 @@ async def enhance_error_message(
return await post('/ai/tools/error-enhance', model)

# Update code to PEP8
@mcp.tool(
@tool_with_args(
mcp,
PEP8ConvertRequest,
annotations={'title': 'Update code to PEP8', 'readOnlyHint': True}
)
async def update_code_to_pep8(
Expand All @@ -54,13 +65,21 @@ async def update_code_to_pep8(
return await post('/ai/tools/pep8-convert', model)

# Check syntax
@mcp.tool(annotations={'title': 'Check syntax', 'readOnlyHint': True})
@tool_with_args(
mcp,
BasicFilesRequest,
annotations={'title': 'Check syntax', 'readOnlyHint': True}
)
async def check_syntax(model: BasicFilesRequest) -> SyntaxCheckResponse:
"""Check the syntax of a code."""
return await post('/ai/tools/syntax-check', model)

# Search
@mcp.tool(annotations={'title': 'Search QuantConnect', 'readOnlyHint': True})
@tool_with_args(
mcp,
SearchRequest,
annotations={'title': 'Search QuantConnect', 'readOnlyHint': True}
)
async def search_quantconnect(model: SearchRequest) -> SearchResponse:
"""Search for content in QuantConnect."""
return await post('/ai/tools/search', model)
37 changes: 29 additions & 8 deletions src/tools/backtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
BacktestReportGeneratingResponse,
RestResponse
)
from tool_args import tool_with_args

def register_backtest_tools(mcp):
# Create
@mcp.tool(
@tool_with_args(
mcp,
CreateBacktestRequest,
annotations={
'title': 'Create backtest',
'destructiveHint': False
Expand All @@ -34,20 +37,30 @@ async def create_backtest(
return await post('/backtests/create', model)

# Read statistics for a single backtest.
@mcp.tool(annotations={'title': 'Read backtest', 'readOnlyHint': True})
@tool_with_args(
mcp,
ReadBacktestRequest,
annotations={'title': 'Read backtest', 'readOnlyHint': True}
)
async def read_backtest(model: ReadBacktestRequest) -> BacktestResponse:
"""Read the results of a backtest."""
return await post('/backtests/read', model)

# Read a summary of all the backtests.
@mcp.tool(annotations={'title': 'List backtests', 'readOnlyHint': True})
@tool_with_args(
mcp,
ListBacktestRequest,
annotations={'title': 'List backtests', 'readOnlyHint': True}
)
async def list_backtests(
model: ListBacktestRequest) -> BacktestSummaryResponse:
"""List all the backtests for the project."""
return await post('/backtests/list', model)

# Read the chart of a single backtest.
@mcp.tool(
@tool_with_args(
mcp,
ReadBacktestChartRequest,
annotations={'title': 'Read backtest chart', 'readOnlyHint': True}
)
async def read_backtest_chart(
Expand All @@ -56,7 +69,9 @@ async def read_backtest_chart(
return await post('/backtests/chart/read', model)

# Read the orders of a single backtest.
@mcp.tool(
@tool_with_args(
mcp,
ReadBacktestOrdersRequest,
annotations={'title': 'Read backtest orders', 'readOnlyHint': True}
)
async def read_backtest_orders(
Expand All @@ -65,7 +80,9 @@ async def read_backtest_orders(
return await post('/backtests/orders/read', model)

# Read the insights of a single backtest.
@mcp.tool(
@tool_with_args(
mcp,
ReadBacktestInsightsRequest,
annotations={'title': 'Read backtest insights', 'readOnlyHint': True}
)
async def read_backtest_insights(
Expand All @@ -84,15 +101,19 @@ async def read_backtest_insights(
# return await post('/backtests/read/report', model)

# Update
@mcp.tool(
@tool_with_args(
mcp,
UpdateBacktestRequest,
annotations={'title': 'Update backtest', 'idempotentHint': True}
)
async def update_backtest(model: UpdateBacktestRequest) -> RestResponse:
"""Update the name or note of a backtest."""
return await post('/backtests/update', model)

# Delete
@mcp.tool(
@tool_with_args(
mcp,
DeleteBacktestRequest,
annotations={'title': 'Delete backtest', 'idempotentHint': True}
)
async def delete_backtest(model: DeleteBacktestRequest) -> RestResponse:
Expand Down
11 changes: 9 additions & 2 deletions src/tools/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
CreateCompileResponse,
ReadCompileResponse
)
from tool_args import tool_with_args

def register_compile_tools(mcp):
# Create
@mcp.tool(
@tool_with_args(
mcp,
CreateCompileRequest,
annotations={'title': 'Create compile', 'destructiveHint': False}
)
async def create_compile(
Expand All @@ -17,7 +20,11 @@ async def create_compile(
return await post('/compile/create', model)

# Read
@mcp.tool(annotations={'title': 'Read compile', 'readOnlyHint': True})
@tool_with_args(
mcp,
ReadCompileRequest,
annotations={'title': 'Read compile', 'readOnlyHint': True}
)
async def read_compile(model: ReadCompileRequest) -> ReadCompileResponse:
"""Read a compile packet job result."""
return await post('/compile/read', model)
Loading