diff --git a/samples/AspNetCoreSseServer/Attributes/LimitCalls.cs b/samples/AspNetCoreSseServer/Attributes/LimitCalls.cs new file mode 100644 index 000000000..8ca7ecace --- /dev/null +++ b/samples/AspNetCoreSseServer/Attributes/LimitCalls.cs @@ -0,0 +1,40 @@ +using ModelContextProtocol.Core; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace AspNetCoreSseServer.Attributes; + +public class LimitCallsAttribute(int maxCalls) : ToolFilterAttribute +{ + private int _callCount; + + public override ValueTask? OnToolCalling(Tool tool, RequestContext context) + { + //Thread-safe increment + var currentCount = Interlocked.Add(ref _callCount, 1); + + //Log count + Console.Out.WriteLine($"Tool: {tool.Name} called {currentCount} time(s)"); + + //If under threshold, do nothing + if (currentCount <= maxCalls) + return null; //do nothing + + //If above threshold, return error message + return new ValueTask(new CallToolResult + { + Content = [new TextContentBlock { Text = $"This tool can only be called {maxCalls} time(s)" }] + }); + } + + public override bool OnToolListed(Tool tool, RequestContext context) + { + //With the provided request context, you can access the dependency injection + var configuration = context.Services?.GetService(); + var hide = configuration?["hide-tools-above-limit"] == "True"; + + //Prevent the tool being listed (return false) + //if the hide flag is true and the call count is above the threshold + return _callCount <= maxCalls || !hide; + } +} diff --git a/samples/AspNetCoreSseServer/Tools/EchoTool.cs b/samples/AspNetCoreSseServer/Tools/EchoTool.cs index 7913b73e4..d4fb23139 100644 --- a/samples/AspNetCoreSseServer/Tools/EchoTool.cs +++ b/samples/AspNetCoreSseServer/Tools/EchoTool.cs @@ -1,5 +1,6 @@ using ModelContextProtocol.Server; using System.ComponentModel; +using AspNetCoreSseServer.Attributes; namespace TestServerWithHosting.Tools; @@ -7,6 +8,7 @@ namespace TestServerWithHosting.Tools; public sealed class EchoTool { [McpServerTool, Description("Echoes the input back to the client.")] + [LimitCalls(maxCalls: 10)] public static string Echo(string message) { return "hello " + message; diff --git a/samples/AspNetCoreSseServer/appsettings.json b/samples/AspNetCoreSseServer/appsettings.json index 10f68b8c8..8552e10d6 100644 --- a/samples/AspNetCoreSseServer/appsettings.json +++ b/samples/AspNetCoreSseServer/appsettings.json @@ -5,5 +5,6 @@ "Microsoft.AspNetCore": "Warning" } }, - "AllowedHosts": "*" + "AllowedHosts": "*", + "hide-tools-above-limit": true } diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs index afd3912b6..92ee071fa 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs @@ -9,6 +9,7 @@ using System.Text.Json; using System.Text.Json.Nodes; using System.Text.RegularExpressions; +using ModelContextProtocol.Core; namespace ModelContextProtocol.Server; @@ -146,7 +147,7 @@ options.OpenWorld is not null || } } - return new AIFunctionMcpServerTool(function, tool, options?.Services, structuredOutputRequiresWrapping); + return new AIFunctionMcpServerTool(function, tool, options?.Services, structuredOutputRequiresWrapping, options?.Filters ?? []); } private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpServerToolCreateOptions? options) @@ -185,6 +186,9 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe { newOptions.Description ??= descAttr.Description; } + + var filters = method.GetCustomAttributes().OrderBy(f => f.Order).ToArray(); + newOptions.Filters = filters; return newOptions; } @@ -193,10 +197,11 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider, bool structuredOutputRequiresWrapping) + private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider, bool structuredOutputRequiresWrapping, IToolFilter[] filters) { AIFunction = function; ProtocolTool = tool; + Filters = filters; _logger = serviceProvider?.GetService()?.CreateLogger() ?? (ILogger)NullLogger.Instance; _structuredOutputRequiresWrapping = structuredOutputRequiresWrapping; } @@ -204,6 +209,9 @@ private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider /// public override Tool ProtocolTool { get; } + /// + public override IToolFilter[] Filters { get; } + /// public override async ValueTask InvokeAsync( RequestContext request, CancellationToken cancellationToken = default) diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index 6c5858f91..4484a77db 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -448,6 +448,11 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) { foreach (var t in tools) { + if (t.Filters.Any(f => !f.OnToolListed(t.ProtocolTool,request))) + { + continue; + } + result.Tools.Add(t.ProtocolTool); } } @@ -461,7 +466,22 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) if (request.Params is not null && tools.TryGetPrimitive(request.Params.Name, out var tool)) { - return tool.InvokeAsync(request, cancellationToken); + foreach (var filter in tool.Filters) + { + var filterResult = filter.OnToolCalling(tool.ProtocolTool, request); + if(filterResult != null) + return filterResult.Value; + } + + var result = tool.InvokeAsync(request, cancellationToken); + + foreach (var filter in tool.Filters) + { + var filterResult = filter.OnToolCalled(tool.ProtocolTool, request, result); + if(filterResult != null) + return filterResult.Value; + } + return result; } return originalCallToolHandler(request, cancellationToken); diff --git a/src/ModelContextProtocol.Core/Server/McpServerTool.cs b/src/ModelContextProtocol.Core/Server/McpServerTool.cs index e3958271b..0b6d0e10c 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerTool.cs @@ -4,6 +4,7 @@ using ModelContextProtocol.Protocol; using System.Reflection; using System.Text.Json; +using ModelContextProtocol.Core; namespace ModelContextProtocol.Server; @@ -140,6 +141,9 @@ protected McpServerTool() /// Gets the protocol type for this instance. public abstract Tool ProtocolTool { get; } + + /// Gets the filters () associated to this tool. + public abstract IToolFilter[] Filters { get; } /// Invokes the . /// The request information resulting in the invocation of this tool. diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs index bdb4ecb8d..d747edc35 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs @@ -2,6 +2,7 @@ using ModelContextProtocol.Protocol; using System.ComponentModel; using System.Text.Json; +using ModelContextProtocol.Core; namespace ModelContextProtocol.Server; @@ -155,6 +156,9 @@ public sealed class McpServerToolCreateOptions /// public AIJsonSchemaCreateOptions? SchemaCreateOptions { get; set; } + /// TODO + public IToolFilter[] Filters { get; set; } = []; + /// /// Creates a shallow clone of the current instance. /// diff --git a/src/ModelContextProtocol.Core/ToolFilter.cs b/src/ModelContextProtocol.Core/ToolFilter.cs new file mode 100644 index 000000000..579055b44 --- /dev/null +++ b/src/ModelContextProtocol.Core/ToolFilter.cs @@ -0,0 +1,39 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.Core; + +/// TODO: +public interface IToolFilter +{ + /// TODO: + bool OnToolListed(Tool tool, RequestContext context); + + /// TODO: + ValueTask? OnToolCalling(Tool tool, RequestContext context); + + /// TODO: + ValueTask? OnToolCalled(Tool tool, RequestContext context, ValueTask callResult); +} + +/// TODO: +[AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] +public abstract class ToolFilterAttribute(int order = 0) : Attribute, IToolFilter +{ + /// + /// Gets the order value for determining the order of execution of filters. Filters execute in + /// ascending numeric value of the property. + /// + public int Order { get; } = order; + + /// + public virtual bool OnToolListed(Tool tool, RequestContext context) => true; + + /// + public virtual ValueTask? OnToolCalling(Tool tool, RequestContext context) => + null; + + /// + public virtual ValueTask? OnToolCalled(Tool tool, RequestContext context, + ValueTask callResult) => null; +} \ No newline at end of file