From c61b93e499563a976163be0718cf7323fba26e1c Mon Sep 17 00:00:00 2001 From: Pavindu Lakshan Date: Sun, 11 May 2025 18:59:48 +0530 Subject: [PATCH] feat: introduce secureTool implementation --- examples/mcp-server/src/index.ts | 23 ++- packages/mcp-express/package.json | 4 +- .../src/middlewares/protected-route.ts | 10 +- packages/mcp-express/src/public-api.ts | 1 + .../src/utils/create-secure-tool.ts | 165 +++++++++++++----- 5 files changed, 154 insertions(+), 49 deletions(-) diff --git a/examples/mcp-server/src/index.ts b/examples/mcp-server/src/index.ts index 1513de6..ba91692 100644 --- a/examples/mcp-server/src/index.ts +++ b/examples/mcp-server/src/index.ts @@ -4,7 +4,7 @@ import {McpServer} from '@modelcontextprotocol/sdk/server/mcp.js'; import {StreamableHTTPServerTransport} from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import {isInitializeRequest} from '@modelcontextprotocol/sdk/types.js'; import {z} from 'zod'; -import {McpAuthServer, protectedRoute} from '@brionmario-experimental/mcp-express'; +import {McpAuthServer, protectedRoute, secureTool} from '@brionmario-experimental/mcp-express'; import {config} from 'dotenv'; config(); @@ -162,6 +162,27 @@ app.post( }, ); + // Example usage of secure tool + secureTool( + server, + 'securedVetAppointment', + 'secured tool that enables authenticated users to book appointments', + { + name: z.string(), + age: z.number().optional(), + }, + async ({name, age, authContext}) => { + return { + content: [ + { + type: 'text', + text: `Booked vet appointment for pet ID: ${name} with age ${age} using token ${authContext?.token}`, + }, + ], + }; + }, + ); + try { // Connect to the MCP server await server.connect(transport); diff --git a/packages/mcp-express/package.json b/packages/mcp-express/package.json index 11efd0c..07b909a 100644 --- a/packages/mcp-express/package.json +++ b/packages/mcp-express/package.json @@ -55,7 +55,9 @@ "typescript": "~5.7.2" }, "peerDependencies": { - "express": ">=4" + "express": ">=4", + "@modelcontextprotocol/sdk": ">1", + "zod": ">3" }, "publishConfig": { "access": "public" diff --git a/packages/mcp-express/src/middlewares/protected-route.ts b/packages/mcp-express/src/middlewares/protected-route.ts index 74c8c09..8d59eea 100644 --- a/packages/mcp-express/src/middlewares/protected-route.ts +++ b/packages/mcp-express/src/middlewares/protected-route.ts @@ -25,7 +25,7 @@ export default function protectedRoute(provider?: McpAuthProvider) { res: Response, next: NextFunction, ): Promise> | undefined> { - const authHeader: string | undefined = req.headers.authorization; + const authHeader: string | undefined = req.headers['authorization'] as string; if (!authHeader) { res.setHeader( @@ -68,6 +68,14 @@ export default function protectedRoute(provider?: McpAuthProvider) { try { await validateAccessToken(token, TOKEN_VALIDATION_CONFIG.jwksUri, TOKEN_VALIDATION_CONFIG.options); + + // Insert authContext into request params when the request is a tool call. + if (req?.body?.method === 'tools/call' && req.body.params.arguments) { + req.body.params.arguments['authContext'] = { + token, + }; + } + next(); return undefined; } catch (error: any) { diff --git a/packages/mcp-express/src/public-api.ts b/packages/mcp-express/src/public-api.ts index ee9a1d1..e10cb71 100644 --- a/packages/mcp-express/src/public-api.ts +++ b/packages/mcp-express/src/public-api.ts @@ -18,3 +18,4 @@ export {default as McpAuthServer} from './routes/auth'; export {default as protectedRoute} from './middlewares/protected-route'; +export {default as secureTool} from './utils/create-secure-tool'; diff --git a/packages/mcp-express/src/utils/create-secure-tool.ts b/packages/mcp-express/src/utils/create-secure-tool.ts index 8974d94..6abd2b9 100644 --- a/packages/mcp-express/src/utils/create-secure-tool.ts +++ b/packages/mcp-express/src/utils/create-secure-tool.ts @@ -1,51 +1,124 @@ -import { z, ZodRawShape } from 'zod'; -import { decodeAccessToken } from '@brionmario-experimental/mcp-node'; -import { StrictDecodedIDTokenPayload } from './types'; - -export async function createSecureTool( - mcpServer: McpServer, - toolName: string, - toolDescription: string, - paramsSchema: Args, - // @ts-ignore - secureCallback: ( - args: z.infer>, - context: StrictDecodedIDTokenPayload - ) => Promise -) { - // biome-ignore lint/suspicious/noExplicitAny: tool interface requirement - const callback = async (args: any, extra: any): Promise => { - try { - const authHeader = extra?.headers?.authorization || ''; - const token = authHeader.replace(/^Bearer\s+/i, ''); - - if (!token) { - throw new Error('Missing Authorization token.'); - } - - const context = decodeAccessToken(token) as StrictDecodedIDTokenPayload; - - return await secureCallback(args, context); - } catch (error) { - console.error('Secure tool authorization error:', error); - return { - content: [ - { - type: 'text', - text: 'Unauthorized: Invalid or missing access token.', - }, - ], - isError: true, - }; - } +/** + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import {ServerRequest, ServerNotification, CallToolResult, ToolAnnotations} from '@modelcontextprotocol/sdk//types'; +import {McpServer, ToolCallback} from '@modelcontextprotocol/sdk/server/mcp'; // Adjust import path as needed +import {RequestHandlerExtra} from '@modelcontextprotocol/sdk/shared/protocol'; +import {z, ZodRawShape, ZodString, ZodObject, ZodTypeAny} from 'zod'; + +// The type of authContextSchema +type AuthContextSchemaType = { + authContext: ZodObject<{ + token: ZodString; + }>; +}; + +/** + * Auth context shape that will be added to all secured tools + */ +const authContextSchema: AuthContextSchemaType = { + authContext: z.object({ + token: z.string(), + }), +}; + +/** + * Implementation for a tool callback function that processes arguments based on Zod schema + * @param schema Optional Zod schema for validating arguments + * @param handler Function to handle the validated arguments and extra context + * @returns A function that satisfies the ToolCallback type + */ +function createToolCallback( + schema: Args, + handler: Args extends ZodRawShape + ? ( + args: z.objectOutputType, + extra: RequestHandlerExtra, + ) => CallToolResult | Promise + : (extra: RequestHandlerExtra) => CallToolResult | Promise, +): ToolCallback { + if (schema) { + // Case when Args extends ZodRawShape + return (( + args: z.objectOutputType, + extra: RequestHandlerExtra, + ) => (handler as any)(args, extra)) as ToolCallback; + } + // Case when Args is undefined + return ((extra: RequestHandlerExtra) => + (handler as any)(extra)) as ToolCallback; +} + +/** + * Secures a tool with specified input schema and handler that expects named parameters + * @param server The server instance + * @param name Tool name + * @param description Tool description + * @param annotations Tool annotations + * @param inputSchema Zod schema for input validation + * @param handler The callback that handles the tool's execution with named parameters + */ +export default function secureTool( + server: McpServer, + name: string, + description: string, + inputSchema: S, + handler: ToolCallback, + annotations?: ToolAnnotations, +): void { + // Enhance the schema with authContext + const enhancedSchema: S & AuthContextSchemaType = { + ...inputSchema, + ...authContextSchema, }; - mcpServer.tool( - toolName, - toolDescription, - paramsSchema, - callback as ToolCallback + const toolImpl: ToolCallback = createToolCallback( + enhancedSchema, + // Use the correct type for the args parameter based on the inputSchema + (( + args: z.objectOutputType, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + _extra: RequestHandlerExtra, + ) => { + // Extract values from args in the order of inputSchema keys + // eslint-disable-next-line @typescript-eslint/typedef + const paramValues = {} as Record; + + Object.keys(enhancedSchema).forEach((key: string) => { + const typedKey: keyof S | 'authContext' = key as keyof typeof enhancedSchema; + paramValues[typedKey] = args[typedKey]; + }); + + const toolArgs: Record[] = [paramValues]; + + // Call the handler with all parameters + // eslint-disable-next-line @typescript-eslint/typedef, prefer-spread + const result = (handler as Function).apply(null, toolArgs); + + // Make sure we return a value + return result || {data: args, success: true}; + }) as any, ); - await Promise.resolve(); + // Use the original secureTool with our wrapper handler + if (annotations) { + server.tool(name, description, enhancedSchema, annotations, toolImpl); + } else { + server.tool(name, description, enhancedSchema, toolImpl); + } }