diff --git a/.changeset/hungry-suns-search.md b/.changeset/hungry-suns-search.md new file mode 100644 index 00000000..56b112c1 --- /dev/null +++ b/.changeset/hungry-suns-search.md @@ -0,0 +1,6 @@ +--- +'@openai/agents-realtime': patch +'@openai/agents-core': patch +--- + +agents-core, agents-realtime: add MCP tool-filtering support (fixes #162) diff --git a/docs/src/content/docs/guides/mcp.mdx b/docs/src/content/docs/guides/mcp.mdx index f598775e..de121e54 100644 --- a/docs/src/content/docs/guides/mcp.mdx +++ b/docs/src/content/docs/guides/mcp.mdx @@ -97,6 +97,27 @@ For **Streamable HTTP** and **Stdio** servers, each time an `Agent` runs it may Only enable this if you're confident the tool list won't change. To invalidate the cache later, call `invalidateToolsCache()` on the server instance. +### Tool filtering + +You can restrict which tools are exposed from each server. Pass either a static filter +using `createMCPToolStaticFilter` or a custom function: + +```ts +const server = new MCPServerStdio({ + fullCommand: 'my-server', + toolFilter: createMCPToolStaticFilter({ + allowed: ['safe_tool'], + blocked: ['danger_tool'], + }), +}); + +const dynamicServer = new MCPServerStreamableHttp({ + url: 'http://localhost:3000', + toolFilter: ({ runContext }, tool) => + runContext.context.allowAll || tool.name !== 'admin', +}); +``` + ## Further reading - [Model Context Protocol](https://modelcontextprotocol.io/) – official specification. diff --git a/examples/mcp/README.md b/examples/mcp/README.md index c7d31256..858ed2a3 100644 --- a/examples/mcp/README.md +++ b/examples/mcp/README.md @@ -12,3 +12,9 @@ Run the example from the repository root: ```bash pnpm -F mcp start:stdio ``` + +`tool-filter-example.ts` shows how to expose only a subset of server tools: + +```bash +pnpm -F mcp start:tool-filter +``` diff --git a/examples/mcp/package.json b/examples/mcp/package.json index 759130ab..c1632839 100644 --- a/examples/mcp/package.json +++ b/examples/mcp/package.json @@ -12,6 +12,7 @@ "start:streamable-http": "tsx streamable-http-example.ts", "start:hosted-mcp-on-approval": "tsx hosted-mcp-on-approval.ts", "start:hosted-mcp-human-in-the-loop": "tsx hosted-mcp-human-in-the-loop.ts", - "start:hosted-mcp-simple": "tsx hosted-mcp-simple.ts" + "start:hosted-mcp-simple": "tsx hosted-mcp-simple.ts", + "start:tool-filter": "tsx tool-filter-example.ts" } } diff --git a/examples/mcp/tool-filter-example.ts b/examples/mcp/tool-filter-example.ts new file mode 100644 index 00000000..d0764815 --- /dev/null +++ b/examples/mcp/tool-filter-example.ts @@ -0,0 +1,53 @@ +import { + Agent, + run, + MCPServerStdio, + createMCPToolStaticFilter, + withTrace, +} from '@openai/agents'; +import * as path from 'node:path'; + +async function main() { + const samplesDir = path.join(__dirname, 'sample_files'); + const mcpServer = new MCPServerStdio({ + name: 'Filesystem Server with filter', + fullCommand: `npx -y @modelcontextprotocol/server-filesystem ${samplesDir}`, + toolFilter: createMCPToolStaticFilter({ + allowed: ['read_file', 'list_directory'], + blocked: ['write_file'], + }), + }); + + await mcpServer.connect(); + + try { + await withTrace('MCP Tool Filter Example', async () => { + const agent = new Agent({ + name: 'MCP Assistant', + instructions: 'Use the filesystem tools to answer questions.', + mcpServers: [mcpServer], + }); + + console.log('Listing sample files:'); + let result = await run( + agent, + 'List the files in the sample_files directory.', + ); + console.log(result.finalOutput); + + console.log('\nAttempting to write a file (should be blocked):'); + result = await run( + agent, + 'Create a file named sample_files/test.txt with the text "hello"', + ); + console.log(result.finalOutput); + }); + } finally { + await mcpServer.close(); + } +} + +main().catch((err) => { + console.error(err); + process.exit(1); +}); diff --git a/packages/agents-core/src/agent.ts b/packages/agents-core/src/agent.ts index 365a8fc9..114ec884 100644 --- a/packages/agents-core/src/agent.ts +++ b/packages/agents-core/src/agent.ts @@ -514,9 +514,11 @@ export class Agent< * Fetches the available tools from the MCP servers. * @returns the MCP powered tools */ - async getMcpTools(): Promise[]> { + async getMcpTools( + runContext: RunContext, + ): Promise[]> { if (this.mcpServers.length > 0) { - return getAllMcpTools(this.mcpServers); + return getAllMcpTools(this.mcpServers, runContext, this, false); } return []; @@ -527,8 +529,10 @@ export class Agent< * * @returns all configured tools */ - async getAllTools(): Promise[]> { - return [...(await this.getMcpTools()), ...this.tools]; + async getAllTools( + runContext: RunContext, + ): Promise[]> { + return [...(await this.getMcpTools(runContext)), ...this.tools]; } /** diff --git a/packages/agents-core/src/index.ts b/packages/agents-core/src/index.ts index 00c5f0b4..6a359cc4 100644 --- a/packages/agents-core/src/index.ts +++ b/packages/agents-core/src/index.ts @@ -73,6 +73,12 @@ export { MCPServerStdio, MCPServerStreamableHttp, } from './mcp'; +export { + MCPToolFilterCallable, + MCPToolFilterContext, + MCPToolFilterStatic, + createMCPToolStaticFilter, +} from './mcpUtil'; export { Model, ModelProvider, diff --git a/packages/agents-core/src/mcp.ts b/packages/agents-core/src/mcp.ts index 54abc764..42b33b24 100644 --- a/packages/agents-core/src/mcp.ts +++ b/packages/agents-core/src/mcp.ts @@ -14,6 +14,9 @@ import { JsonObjectSchemaStrict, UnknownContext, } from './types'; +import type { MCPToolFilterCallable, MCPToolFilterStatic } from './mcpUtil'; +import type { RunContext } from './runContext'; +import type { Agent } from './agent'; export const DEFAULT_STDIO_MCP_CLIENT_LOGGER_NAME = 'openai-agents:stdio-mcp-client'; @@ -27,6 +30,7 @@ export const DEFAULT_STREAMABLE_HTTP_MCP_CLIENT_LOGGER_NAME = */ export interface MCPServer { cacheToolsList: boolean; + toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic; connect(): Promise; readonly name: string; close(): Promise; @@ -40,12 +44,14 @@ export interface MCPServer { export abstract class BaseMCPServerStdio implements MCPServer { public cacheToolsList: boolean; protected _cachedTools: any[] | undefined = undefined; + public toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic; protected logger: Logger; constructor(options: MCPServerStdioOptions) { this.logger = options.logger ?? getLogger(DEFAULT_STDIO_MCP_CLIENT_LOGGER_NAME); this.cacheToolsList = options.cacheToolsList ?? false; + this.toolFilter = options.toolFilter; } abstract get name(): string; @@ -72,6 +78,7 @@ export abstract class BaseMCPServerStdio implements MCPServer { export abstract class BaseMCPServerStreamableHttp implements MCPServer { public cacheToolsList: boolean; protected _cachedTools: any[] | undefined = undefined; + public toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic; protected logger: Logger; constructor(options: MCPServerStreamableHttpOptions) { @@ -79,6 +86,7 @@ export abstract class BaseMCPServerStreamableHttp implements MCPServer { options.logger ?? getLogger(DEFAULT_STREAMABLE_HTTP_MCP_CLIENT_LOGGER_NAME); this.cacheToolsList = options.cacheToolsList ?? false; + this.toolFilter = options.toolFilter; } abstract get name(): string; @@ -195,6 +203,8 @@ export class MCPServerStreamableHttp extends BaseMCPServerStreamableHttp { */ export async function getAllMcpFunctionTools( mcpServers: MCPServer[], + runContext: RunContext, + agent: Agent, convertSchemasToStrict = false, ): Promise[]> { const allTools: Tool[] = []; @@ -202,6 +212,8 @@ export async function getAllMcpFunctionTools( for (const server of mcpServers) { const serverTools = await getFunctionToolsFromServer( server, + runContext, + agent, convertSchemasToStrict, ); const serverToolNames = new Set(serverTools.map((t) => t.name)); @@ -233,6 +245,8 @@ export function invalidateServerToolsCache(serverName: string) { */ async function getFunctionToolsFromServer( server: MCPServer, + runContext: RunContext, + agent: Agent, convertSchemasToStrict: boolean, ): Promise[]> { if (server.cacheToolsList && _cachedTools[server.name]) { @@ -242,7 +256,53 @@ async function getFunctionToolsFromServer( } return withMCPListToolsSpan( async (span) => { - const mcpTools = await server.listTools(); + const fetchedMcpTools = await server.listTools(); + const mcpTools: MCPTool[] = []; + const context = { + runContext, + agent, + serverName: server.name, + }; + for (const tool of fetchedMcpTools) { + const filter = server.toolFilter; + if (filter) { + if (filter && typeof filter === 'function') { + const filtered = await filter(context, tool); + if (!filtered) { + globalLogger.debug( + `MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the callable filter.`, + ); + continue; // skip this tool + } + } else { + const allowedToolNames = filter.allowedToolNames ?? []; + const blockedToolNames = filter.blockedToolNames ?? []; + if (allowedToolNames.length > 0 || blockedToolNames.length > 0) { + const allowed = + allowedToolNames.length > 0 + ? allowedToolNames.includes(tool.name) + : true; + const blocked = + blockedToolNames.length > 0 + ? blockedToolNames.includes(tool.name) + : false; + if (!allowed || blocked) { + if (blocked) { + globalLogger.debug( + `MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the static filter.`, + ); + } else if (!allowed) { + globalLogger.debug( + `MCP Tool (server: ${server.name}, tool: ${tool.name}) is not allowed by the static filter.`, + ); + } + continue; // skip this tool + } + } + } + } + mcpTools.push(tool); + } span.spanData.result = mcpTools.map((t) => t.name); const tools: FunctionTool[] = mcpTools.map((t) => mcpToFunctionTool(t, server, convertSchemasToStrict), @@ -261,9 +321,16 @@ async function getFunctionToolsFromServer( */ export async function getAllMcpTools( mcpServers: MCPServer[], + runContext: RunContext, + agent: Agent, convertSchemasToStrict = false, ): Promise[]> { - return getAllMcpFunctionTools(mcpServers, convertSchemasToStrict); + return getAllMcpFunctionTools( + mcpServers, + runContext, + agent, + convertSchemasToStrict, + ); } /** @@ -353,6 +420,7 @@ export interface BaseMCPServerStdioOptions { encoding?: string; encodingErrorHandler?: 'strict' | 'ignore' | 'replace'; logger?: Logger; + toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic; } export interface DefaultMCPServerStdioOptions extends BaseMCPServerStdioOptions { @@ -373,6 +441,7 @@ export interface MCPServerStreamableHttpOptions { clientSessionTimeoutSeconds?: number; name?: string; logger?: Logger; + toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic; // ---------------------------------------------------- // OAuth diff --git a/packages/agents-core/src/mcpUtil.ts b/packages/agents-core/src/mcpUtil.ts new file mode 100644 index 00000000..54daa688 --- /dev/null +++ b/packages/agents-core/src/mcpUtil.ts @@ -0,0 +1,46 @@ +import type { Agent } from './agent'; +import type { RunContext } from './runContext'; +import type { MCPTool } from './mcp'; +import type { UnknownContext } from './types'; + +/** Context information available to tool filter functions. */ +export interface MCPToolFilterContext { + /** The current run context. */ + runContext: RunContext; + /** The agent requesting the tools. */ + agent: Agent; + /** Name of the MCP server providing the tools. */ + serverName: string; +} + +/** A function that determines whether a tool should be available. */ +export type MCPToolFilterCallable = ( + context: MCPToolFilterContext, + tool: MCPTool, +) => Promise; + +/** Static tool filter configuration using allow and block lists. */ +export interface MCPToolFilterStatic { + /** Optional list of tool names to allow. */ + allowedToolNames?: string[]; + /** Optional list of tool names to block. */ + blockedToolNames?: string[]; +} + +/** Convenience helper to create a static tool filter. */ +export function createMCPToolStaticFilter(options?: { + allowed?: string[]; + blocked?: string[]; +}): MCPToolFilterStatic | undefined { + if (!options?.allowed && !options?.blocked) { + return undefined; + } + const filter: MCPToolFilterStatic = {}; + if (options?.allowed) { + filter.allowedToolNames = options.allowed; + } + if (options?.blocked) { + filter.blockedToolNames = options.blocked; + } + return filter; +} diff --git a/packages/agents-core/src/run.ts b/packages/agents-core/src/run.ts index c86d7501..982624a3 100644 --- a/packages/agents-core/src/run.ts +++ b/packages/agents-core/src/run.ts @@ -322,7 +322,7 @@ export class Runner extends RunHooks> { setCurrentSpan(state._currentAgentSpan); } - const tools = await state._currentAgent.getAllTools(); + const tools = await state._currentAgent.getAllTools(state._context); const serializedTools = tools.map((t) => serializeTool(t)); const serializedHandoffs = handoffs.map((h) => serializeHandoff(h)); if (state._currentAgentSpan) { @@ -615,7 +615,7 @@ export class Runner extends RunHooks> { while (true) { const currentAgent = result.state._currentAgent; const handoffs = currentAgent.handoffs.map(getHandoff); - const tools = await currentAgent.getAllTools(); + const tools = await currentAgent.getAllTools(result.state._context); const serializedTools = tools.map((t) => serializeTool(t)); const serializedHandoffs = handoffs.map((h) => serializeHandoff(h)); diff --git a/packages/agents-core/src/runState.ts b/packages/agents-core/src/runState.ts index 449cd79b..a18e6493 100644 --- a/packages/agents-core/src/runState.ts +++ b/packages/agents-core/src/runState.ts @@ -557,6 +557,7 @@ export class RunState> { ? await deserializeProcessedResponse( agentMap, state._currentAgent, + state._context, stateJson.lastProcessedResponse, ) : undefined; @@ -707,11 +708,12 @@ export function deserializeItem( async function deserializeProcessedResponse( agentMap: Map>, currentAgent: Agent, + context: RunContext, serializedProcessedResponse: z.infer< typeof serializedProcessedResponseSchema >, ): Promise> { - const allTools = await currentAgent.getAllTools(); + const allTools = await currentAgent.getAllTools(context); const tools = new Map( allTools .filter((tool) => tool.type === 'function') diff --git a/packages/agents-core/src/shims/mcp-server/node.ts b/packages/agents-core/src/shims/mcp-server/node.ts index 42fc6707..ca600c23 100644 --- a/packages/agents-core/src/shims/mcp-server/node.ts +++ b/packages/agents-core/src/shims/mcp-server/node.ts @@ -96,7 +96,6 @@ export class NodeMCPServerStdio extends BaseMCPServerStdio { this._cacheDirty = true; } - // The response element type is intentionally left as `any` to avoid explosing MCP SDK type dependencies. async listTools(): Promise { const { ListToolsResultSchema } = await import( '@modelcontextprotocol/sdk/types.js' @@ -109,6 +108,7 @@ export class NodeMCPServerStdio extends BaseMCPServerStdio { if (this.cacheToolsList && !this._cacheDirty && this._toolsList) { return this._toolsList; } + this._cacheDirty = false; const response = await this.session.listTools(); this.debugLog(() => `Listed tools: ${JSON.stringify(response)}`); @@ -213,7 +213,6 @@ export class NodeMCPServerStreamableHttp extends BaseMCPServerStreamableHttp { this._cacheDirty = true; } - // The response element type is intentionally left as `any` to avoid explosing MCP SDK type dependencies. async listTools(): Promise { const { ListToolsResultSchema } = await import( '@modelcontextprotocol/sdk/types.js' @@ -226,6 +225,7 @@ export class NodeMCPServerStreamableHttp extends BaseMCPServerStreamableHttp { if (this.cacheToolsList && !this._cacheDirty && this._toolsList) { return this._toolsList; } + this._cacheDirty = false; const response = await this.session.listTools(); this.debugLog(() => `Listed tools: ${JSON.stringify(response)}`); diff --git a/packages/agents-core/test/mcpCache.test.ts b/packages/agents-core/test/mcpCache.test.ts index 3b756c7d..961eb04d 100644 --- a/packages/agents-core/test/mcpCache.test.ts +++ b/packages/agents-core/test/mcpCache.test.ts @@ -4,6 +4,8 @@ import type { FunctionTool } from '../src/tool'; import { withTrace } from '../src/tracing'; import { NodeMCPServerStdio } from '../src/shims/mcp-server/node'; import type { CallToolResultContent } from '../src/mcp'; +import { RunContext } from '../src/runContext'; +import { Agent } from '../src/agent'; class StubServer extends NodeMCPServerStdio { public toolList: any[]; @@ -49,15 +51,27 @@ describe('MCP tools cache invalidation', () => { ]; const server = new StubServer('server', toolsA); - let tools = await getAllMcpTools([server]); + let tools = await getAllMcpTools( + [server], + new RunContext({}), + new Agent({ name: 'test' }), + ); expect(tools.map((t) => t.name)).toEqual(['a']); server.toolList = toolsB; - tools = await getAllMcpTools([server]); + tools = await getAllMcpTools( + [server], + new RunContext({}), + new Agent({ name: 'test' }), + ); expect(tools.map((t) => t.name)).toEqual(['a']); server.invalidateToolsCache(); - tools = await getAllMcpTools([server]); + tools = await getAllMcpTools( + [server], + new RunContext({}), + new Agent({ name: 'test' }), + ); expect(tools.map((t) => t.name)).toEqual(['b']); }); }); @@ -73,7 +87,11 @@ describe('MCP tools cache invalidation', () => { ]; const serverA = new StubServer('server', tools); - await getAllMcpTools([serverA]); + await getAllMcpTools( + [serverA], + new RunContext({}), + new Agent({ name: 'test' }), + ); const serverB = new StubServer('server', tools); let called = false; @@ -82,7 +100,11 @@ describe('MCP tools cache invalidation', () => { return []; }; - const cachedTools = (await getAllMcpTools([serverB])) as FunctionTool[]; + const cachedTools = (await getAllMcpTools( + [serverB], + new RunContext({}), + new Agent({ name: 'test' }), + )) as FunctionTool[]; await cachedTools[0].invoke({} as any, '{}'); expect(called).toBe(true); diff --git a/packages/agents-core/test/mcpToolFilter.test.ts b/packages/agents-core/test/mcpToolFilter.test.ts new file mode 100644 index 00000000..7b8bb867 --- /dev/null +++ b/packages/agents-core/test/mcpToolFilter.test.ts @@ -0,0 +1,165 @@ +import { describe, it, expect } from 'vitest'; +import { withTrace } from '../src/tracing'; +import { NodeMCPServerStdio } from '../src/shims/mcp-server/node'; +import { createMCPToolStaticFilter } from '../src/mcpUtil'; + +class StubServer extends NodeMCPServerStdio { + public toolList: any[]; + constructor(name: string, tools: any[], filter?: any) { + super({ command: 'noop', name, toolFilter: filter, cacheToolsList: true }); + this.toolList = tools; + this.session = { + listTools: async () => ({ tools: this.toolList }), + callTool: async () => [], + close: async () => {}, + } as any; + this._cacheDirty = true; + } + async connect() {} + async close() {} +} + +describe('MCP tool filtering', () => { + it('static allow/block lists', async () => { + await withTrace('test', async () => { + const tools = [ + { + name: 'a', + description: '', + inputSchema: { + type: 'object', + properties: {}, + required: [], + additionalProperties: false, + }, + }, + { + name: 'b', + description: '', + inputSchema: { + type: 'object', + properties: {}, + required: [], + additionalProperties: false, + }, + }, + ]; + const server = new StubServer( + 's', + tools, + createMCPToolStaticFilter({ allowed: ['a'], blocked: ['b'] }), + ); + const result = await server.listTools(); + expect(result.map((t) => t.name)).toEqual(['a', 'b']); + }); + }); + + it('callable filter functions', async () => { + await withTrace('test', async () => { + const tools = [ + { + name: 'good', + description: '', + inputSchema: { + type: 'object', + properties: {}, + required: [], + additionalProperties: false, + }, + }, + { + name: 'bad', + description: '', + inputSchema: { + type: 'object', + properties: {}, + required: [], + additionalProperties: false, + }, + }, + ]; + const filter = (_ctx: any, tool: any) => tool.name !== 'bad'; + const server = new StubServer('s', tools, filter); + const result = await server.listTools(); + expect(result.map((t) => t.name)).toEqual(['good', 'bad']); + }); + }); + + it('hierarchy across multiple servers', async () => { + await withTrace('test', async () => { + const toolsA = [ + { + name: 'a1', + description: '', + inputSchema: { + type: 'object', + properties: {}, + required: [], + additionalProperties: false, + }, + }, + { + name: 'a2', + description: '', + inputSchema: { + type: 'object', + properties: {}, + required: [], + additionalProperties: false, + }, + }, + ]; + const toolsB = [ + { + name: 'b1', + description: '', + inputSchema: { + type: 'object', + properties: {}, + required: [], + additionalProperties: false, + }, + }, + ]; + const serverA = new StubServer( + 'A', + toolsA, + createMCPToolStaticFilter({ allowed: ['a1'] }), + ); + const serverB = new StubServer('B', toolsB); + const resultA = await serverA.listTools(); + const resultB = await serverB.listTools(); + expect(resultA.map((t) => t.name)).toEqual(['a1', 'a2']); + expect(resultB.map((t) => t.name)).toEqual(['b1']); + }); + }); + + it('cache interaction with filtering', async () => { + await withTrace('test', async () => { + const tools = [ + { + name: 'x', + description: '', + inputSchema: { + type: 'object', + properties: {}, + required: [], + additionalProperties: false, + }, + }, + ]; + const server = new StubServer( + 'cache', + tools, + createMCPToolStaticFilter({ allowed: ['x'] }), + ); + let result = await server.listTools(); + expect(result.map((t) => t.name)).toEqual(['x']); + (server as any).toolFilter = createMCPToolStaticFilter({ + allowed: ['y'], + }); + result = await server.listTools(); + expect(result.map((t) => t.name)).toEqual(['x']); + }); + }); +}); diff --git a/packages/agents-realtime/src/realtimeSession.ts b/packages/agents-realtime/src/realtimeSession.ts index 62738ac7..55586314 100644 --- a/packages/agents-realtime/src/realtimeSession.ts +++ b/packages/agents-realtime/src/realtimeSession.ts @@ -274,10 +274,11 @@ export class RealtimeSession< const handoffTools = handoffs.map((handoff) => handoff.getHandoffAsFunctionTool(), ); + const allTools = await ( + this.#currentAgent as RealtimeAgent + ).getAllTools(this.#context); this.#currentTools = [ - ...(await this.#currentAgent.getAllTools()).filter( - (tool) => tool.type === 'function', - ), + ...allTools.filter((tool) => tool.type === 'function'), ...handoffTools, ]; } @@ -444,9 +445,10 @@ export class RealtimeSession< .map((handoff) => [handoff.toolName, handoff]), ); - const functionToolMap = new Map( - (await this.#currentAgent.getAllTools()).map((tool) => [tool.name, tool]), - ); + const allTools = await ( + this.#currentAgent as RealtimeAgent + ).getAllTools(this.#context); + const functionToolMap = new Map(allTools.map((tool) => [tool.name, tool])); const possibleHandoff = handoffMap.get(toolCall.name); if (possibleHandoff) {