diff --git a/README.md b/README.md index 4c87ba9bd..3099dfcd3 100644 --- a/README.md +++ b/README.md @@ -207,7 +207,7 @@ McpServerOptions options = new() { if (request.Params.Arguments?.TryGetValue("message", out var message) is not true) { - throw new McpException("Missing required argument 'message'"); + throw new McpProtocolException("Missing required argument 'message'", McpErrorCode.InvalidParams); } return ValueTask.FromResult(new CallToolResult @@ -216,7 +216,7 @@ McpServerOptions options = new() }); } - throw new McpException($"Unknown tool: '{request.Params?.Name}'"); + throw new McpProtocolException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidRequest); } } }; diff --git a/samples/EverythingServer/Program.cs b/samples/EverythingServer/Program.cs index b976bcc0a..a18a29461 100644 --- a/samples/EverythingServer/Program.cs +++ b/samples/EverythingServer/Program.cs @@ -123,7 +123,7 @@ await ctx.Server.SampleAsync([ { if (ctx.Params?.Level is null) { - throw new McpException("Missing required argument 'level'", McpErrorCode.InvalidParams); + throw new McpProtocolException("Missing required argument 'level'", McpErrorCode.InvalidParams); } _minimumLoggingLevel = ctx.Params.Level; diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs index 2cfb74d09..dce594e2e 100644 --- a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs @@ -78,7 +78,7 @@ private void ConfigureCallToolFilter(McpServerOptions options) var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); if (!authResult.Succeeded) { - throw new McpException("Access forbidden: This tool requires authorization.", McpErrorCode.InvalidRequest); + throw new McpProtocolException("Access forbidden: This tool requires authorization.", McpErrorCode.InvalidRequest); } context.Items[AuthorizationFilterInvokedKey] = true; @@ -170,7 +170,7 @@ private void ConfigureReadResourceFilter(McpServerOptions options) var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); if (!authResult.Succeeded) { - throw new McpException("Access forbidden: This resource requires authorization.", McpErrorCode.InvalidRequest); + throw new McpProtocolException("Access forbidden: This resource requires authorization.", McpErrorCode.InvalidRequest); } return await next(context, cancellationToken); @@ -230,7 +230,7 @@ private void ConfigureGetPromptFilter(McpServerOptions options) var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); if (!authResult.Succeeded) { - throw new McpException("Access forbidden: This prompt requires authorization.", McpErrorCode.InvalidRequest); + throw new McpProtocolException("Access forbidden: This prompt requires authorization.", McpErrorCode.InvalidRequest); } return await next(context, cancellationToken); diff --git a/src/ModelContextProtocol.Core/McpException.cs b/src/ModelContextProtocol.Core/McpException.cs index 3831dd688..6498c662f 100644 --- a/src/ModelContextProtocol.Core/McpException.cs +++ b/src/ModelContextProtocol.Core/McpException.cs @@ -1,14 +1,21 @@ +using ModelContextProtocol.Protocol; + namespace ModelContextProtocol; /// /// Represents an exception that is thrown when an Model Context Protocol (MCP) error occurs. /// /// -/// This exception is used to represent failures to do with protocol-level concerns, such as invalid JSON-RPC requests, -/// invalid parameters, or internal errors. It is not intended to be used for application-level errors. -/// or from a may be -/// propagated to the remote endpoint; sensitive information should not be included. If sensitive details need -/// to be included, a different exception type should be used. +/// The from a may be propagated to the remote +/// endpoint; sensitive information should not be included. If sensitive details need to be included, +/// a different exception type should be used. +/// +/// This exception type can be thrown by MCP tools or tool call filters to propogate detailed error messages +/// from when a tool execution fails via a . +/// For non-tool calls, this exception controls the message propogated via a . +/// +/// is a derived type that can be used to also specify the +/// that should be used for the resulting . /// public class McpException : Exception { @@ -28,46 +35,13 @@ public McpException(string message) : base(message) } /// - /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. + /// Initializes a new instance of the class with a specified error message and + /// a reference to the inner exception that is the cause of this exception. /// /// The message that describes the error. - /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. + /// The exception that is the cause of the current exception, or a null + /// reference if no inner exception is specified. public McpException(string message, Exception? innerException) : base(message, innerException) { } - - /// - /// Initializes a new instance of the class with a specified error message and JSON-RPC error code. - /// - /// The message that describes the error. - /// A . - public McpException(string message, McpErrorCode errorCode) : this(message, null, errorCode) - { - } - - /// - /// Initializes a new instance of the class with a specified error message, inner exception, and JSON-RPC error code. - /// - /// The message that describes the error. - /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. - /// A . - public McpException(string message, Exception? innerException, McpErrorCode errorCode) : base(message, innerException) - { - ErrorCode = errorCode; - } - - /// - /// Gets the error code associated with this exception. - /// - /// - /// This property contains a standard JSON-RPC error code as defined in the MCP specification. Common error codes include: - /// - /// -32700: Parse error - Invalid JSON received - /// -32600: Invalid request - The JSON is not a valid Request object - /// -32601: Method not found - The method does not exist or is not available - /// -32602: Invalid params - Invalid method parameters - /// -32603: Internal error - Internal JSON-RPC error - /// - /// - public McpErrorCode ErrorCode { get; } = McpErrorCode.InternalError; -} \ No newline at end of file +} diff --git a/src/ModelContextProtocol.Core/McpProtocolException.cs b/src/ModelContextProtocol.Core/McpProtocolException.cs new file mode 100644 index 000000000..d6bc7b632 --- /dev/null +++ b/src/ModelContextProtocol.Core/McpProtocolException.cs @@ -0,0 +1,73 @@ +namespace ModelContextProtocol; + +/// +/// Represents an exception that is thrown when an Model Context Protocol (MCP) error occurs. +/// +/// +/// This exception is used to represent failures to do with protocol-level concerns, such as invalid JSON-RPC requests, +/// invalid parameters, or internal errors. It is not intended to be used for application-level errors. +/// or from a may be +/// propagated to the remote endpoint; sensitive information should not be included. If sensitive details need +/// to be included, a different exception type should be used. +/// +public sealed class McpProtocolException : McpException +{ + /// + /// Initializes a new instance of the class. + /// + public McpProtocolException() + { + } + + /// + /// Initializes a new instance of the class with a specified error message. + /// + /// The message that describes the error. + public McpProtocolException(string message) : base(message) + { + } + + /// + /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. + /// + /// The message that describes the error. + /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. + public McpProtocolException(string message, Exception? innerException) : base(message, innerException) + { + } + + /// + /// Initializes a new instance of the class with a specified error message and JSON-RPC error code. + /// + /// The message that describes the error. + /// A . + public McpProtocolException(string message, McpErrorCode errorCode) : this(message, null, errorCode) + { + } + + /// + /// Initializes a new instance of the class with a specified error message, inner exception, and JSON-RPC error code. + /// + /// The message that describes the error. + /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. + /// A . + public McpProtocolException(string message, Exception? innerException, McpErrorCode errorCode) : base(message, innerException) + { + ErrorCode = errorCode; + } + + /// + /// Gets the error code associated with this exception. + /// + /// + /// This property contains a standard JSON-RPC error code as defined in the MCP specification. Common error codes include: + /// + /// -32700: Parse error - Invalid JSON received + /// -32600: Invalid request - The JSON is not a valid Request object + /// -32601: Method not found - The method does not exist or is not available + /// -32602: Invalid params - Invalid method parameters + /// -32603: Internal error - Internal JSON-RPC error + /// + /// + public McpErrorCode ErrorCode { get; } = McpErrorCode.InternalError; +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index 749486e4b..a899f3d8e 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -181,11 +181,17 @@ ex is OperationCanceledException && { LogRequestHandlerException(EndpointName, request.Method, ex); - JsonRpcErrorDetail detail = ex is McpException mcpe ? + JsonRpcErrorDetail detail = ex is McpProtocolException mcpProtocolException ? new() { - Code = (int)mcpe.ErrorCode, - Message = mcpe.Message, + Code = (int)mcpProtocolException.ErrorCode, + Message = mcpProtocolException.Message, + } : ex is McpException mcpException ? + new() + { + + Code = (int)McpErrorCode.InternalError, + Message = mcpException.Message, } : new() { @@ -336,7 +342,7 @@ private void HandleMessageWithId(JsonRpcMessage message, JsonRpcMessageWithId me if (!_requestHandlers.TryGetValue(request.Method, out var handler)) { LogNoHandlerFoundForRequest(EndpointName, request.Method); - throw new McpException($"Method '{request.Method}' is not available.", McpErrorCode.MethodNotFound); + throw new McpProtocolException($"Method '{request.Method}' is not available.", McpErrorCode.MethodNotFound); } LogRequestHandlerCalled(EndpointName, request.Method); @@ -446,7 +452,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc if (response is JsonRpcError error) { LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code); - throw new McpException($"Request failed (remote): {error.Error.Message}", (McpErrorCode)error.Error.Code); + throw new McpProtocolException($"Request failed (remote): {error.Error.Message}", (McpErrorCode)error.Error.Code); } if (response is JsonRpcResponse success) @@ -640,7 +646,7 @@ private static void AddExceptionTags(ref TagList tags, Activity? activity, Excep } int? intErrorCode = - (int?)((e as McpException)?.ErrorCode) is int errorCode ? errorCode : + (int?)((e as McpProtocolException)?.ErrorCode) is int errorCode ? errorCode : e is JsonException ? (int)McpErrorCode.ParseError : null; diff --git a/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs b/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs index 0bae663ba..0d255c5b0 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs @@ -291,7 +291,7 @@ public async ValueTask> ElicitAsync( /// The type of the schema being built. /// The serializer options to use. /// The built request schema. - /// + /// private static ElicitRequestParams.RequestSchema BuildRequestSchema(Type type, JsonSerializerOptions serializerOptions) { var schema = new ElicitRequestParams.RequestSchema(); @@ -301,7 +301,7 @@ private static ElicitRequestParams.RequestSchema BuildRequestSchema(Type type, J if (typeInfo.Kind != JsonTypeInfoKind.Object) { - throw new McpException($"Type '{type.FullName}' is not supported for elicitation requests."); + throw new McpProtocolException($"Type '{type.FullName}' is not supported for elicitation requests."); } foreach (JsonPropertyInfo pi in typeInfo.Properties) @@ -319,33 +319,33 @@ private static ElicitRequestParams.RequestSchema BuildRequestSchema(Type type, J /// The type to create the schema for. /// The serializer options to use. /// The created primitive schema definition. - /// Thrown when the type is not supported. + /// Thrown when the type is not supported. private static ElicitRequestParams.PrimitiveSchemaDefinition CreatePrimitiveSchema(Type type, JsonSerializerOptions serializerOptions) { if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) { - throw new McpException($"Type '{type.FullName}' is not a supported property type for elicitation requests. Nullable types are not supported."); + throw new McpProtocolException($"Type '{type.FullName}' is not a supported property type for elicitation requests. Nullable types are not supported."); } var typeInfo = serializerOptions.GetTypeInfo(type); if (typeInfo.Kind != JsonTypeInfoKind.None) { - throw new McpException($"Type '{type.FullName}' is not a supported property type for elicitation requests."); + throw new McpProtocolException($"Type '{type.FullName}' is not a supported property type for elicitation requests."); } var jsonElement = AIJsonUtilities.CreateJsonSchema(type, serializerOptions: serializerOptions); if (!TryValidateElicitationPrimitiveSchema(jsonElement, type, out var error)) { - throw new McpException(error); + throw new McpProtocolException(error); } var primitiveSchemaDefinition = jsonElement.Deserialize(McpJsonUtilities.JsonContext.Default.PrimitiveSchemaDefinition); if (primitiveSchemaDefinition is null) - throw new McpException($"Type '{type.FullName}' is not a supported property type for elicitation requests."); + throw new McpProtocolException($"Type '{type.FullName}' is not a supported property type for elicitation requests."); return primitiveSchemaDefinition; } diff --git a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs index c152d3a0a..8e264741b 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs @@ -288,7 +288,7 @@ subscribeHandler is null && unsubscribeHandler is null && resources is null && listResourcesHandler ??= (static async (_, __) => new ListResourcesResult()); listResourceTemplatesHandler ??= (static async (_, __) => new ListResourceTemplatesResult()); - readResourceHandler ??= (static async (request, _) => throw new McpException($"Unknown resource URI: '{request.Params?.Uri}'", McpErrorCode.InvalidParams)); + readResourceHandler ??= (static async (request, _) => throw new McpProtocolException($"Unknown resource URI: '{request.Params?.Uri}'", McpErrorCode.InvalidParams)); subscribeHandler ??= (static async (_, __) => new EmptyResult()); unsubscribeHandler ??= (static async (_, __) => new EmptyResult()); var listChanged = resourcesCapability?.ListChanged; @@ -452,7 +452,7 @@ private void ConfigurePrompts(McpServerOptions options) ServerCapabilities.Prompts = new(); listPromptsHandler ??= (static async (_, __) => new ListPromptsResult()); - getPromptHandler ??= (static async (request, _) => throw new McpException($"Unknown prompt: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); + getPromptHandler ??= (static async (request, _) => throw new McpProtocolException($"Unknown prompt: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); var listChanged = promptsCapability?.ListChanged; // Handle tools provided via DI by augmenting the handlers to incorporate them. @@ -540,7 +540,7 @@ private void ConfigureTools(McpServerOptions options) ServerCapabilities.Tools = new(); listToolsHandler ??= (static async (_, __) => new ListToolsResult()); - callToolHandler ??= (static async (request, _) => throw new McpException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); + callToolHandler ??= (static async (request, _) => throw new McpProtocolException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); var listChanged = toolsCapability?.ListChanged; // Handle tools provided via DI by augmenting the handlers to incorporate them. @@ -580,7 +580,7 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) listToolsHandler = BuildFilterPipeline(listToolsHandler, options.Filters.ListToolsFilters); callToolHandler = BuildFilterPipeline(callToolHandler, options.Filters.CallToolFilters, handler => - (request, cancellationToken) => + async (request, cancellationToken) => { // Initial handler that sets MatchedPrimitive if (request.Params?.Name is { } toolName && tools is not null && @@ -589,37 +589,23 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) request.MatchedPrimitive = tool; } - return handler(request, cancellationToken); - }, handler => - async (request, cancellationToken) => - { - // Final handler that provides exception handling only for tool execution - // Only wrap tool execution in try-catch, not tool resolution - if (request.MatchedPrimitive is McpServerTool) + try { - try - { - return await handler(request, cancellationToken).ConfigureAwait(false); - } - catch (Exception e) when (e is not OperationCanceledException) - { - ToolCallError(request.Params?.Name ?? string.Empty, e); - - string errorMessage = e is McpException ? - $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : - $"An error occurred invoking '{request.Params?.Name}'."; - - return new() - { - IsError = true, - Content = [new TextContentBlock { Text = errorMessage }], - }; - } + return await handler(request, cancellationToken); } - else + catch (Exception e) when (e is not OperationCanceledException and not McpProtocolException) { - // For unmatched tools, let exceptions bubble up as protocol errors - return await handler(request, cancellationToken).ConfigureAwait(false); + ToolCallError(request.Params?.Name ?? string.Empty, e); + + string errorMessage = e is McpException ? + $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : + $"An error occurred invoking '{request.Params?.Name}'."; + + return new() + { + IsError = true, + Content = [new TextContentBlock { Text = errorMessage }], + }; } }); @@ -735,16 +721,10 @@ private void SetHandler( private static McpRequestHandler BuildFilterPipeline( McpRequestHandler baseHandler, List> filters, - McpRequestFilter? initialHandler = null, - McpRequestFilter? finalHandler = null) + McpRequestFilter? initialHandler = null) { var current = baseHandler; - if (finalHandler is not null) - { - current = finalHandler(current); - } - for (int i = filters.Count - 1; i >= 0; i--) { current = filters[i](current); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs index 84d1c1a79..073b2fd18 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs @@ -36,7 +36,7 @@ public async Task Authorize_Tool_RequiresAuthentication() var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => + var exception = await Assert.ThrowsAsync(async () => await client.CallToolAsync( "authorized_tool", new Dictionary { ["message"] = "test" }, @@ -101,7 +101,7 @@ public async Task AuthorizeWithRoles_Tool_RequiresAdminRole() var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => + var exception = await Assert.ThrowsAsync(async () => await client.CallToolAsync( "admin_tool", new Dictionary { ["message"] = "test" }, @@ -188,7 +188,7 @@ public async Task Authorize_Prompt_RequiresAuthentication() var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => + var exception = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "authorized_prompt", new Dictionary { ["message"] = "test" }, @@ -235,7 +235,7 @@ public async Task Authorize_Resource_RequiresAuthentication() var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => + var exception = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( "resource://authorized", cancellationToken: TestContext.Current.CancellationToken)); @@ -277,7 +277,7 @@ public async Task ListTools_WithoutAuthFilters_ThrowsInvalidOperationException() await using var app = await StartServerWithoutAuthFilters(builder => builder.WithTools()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => + var exception = await Assert.ThrowsAsync(async () => await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)); Assert.Equal("Request failed (remote): An error occurred.", exception.Message); @@ -289,21 +289,23 @@ log.Exception is InvalidOperationException && } [Fact] - public async Task CallTool_WithoutAuthFilters_ThrowsInvalidOperationException() + public async Task CallTool_WithoutAuthFilters_ReturnsError() { _mockLoggerProvider.LogMessages.Clear(); await using var app = await StartServerWithoutAuthFilters(builder => builder.WithTools()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => - await client.CallToolAsync( + var toolResult = await client.CallToolAsync( "authorized_tool", new Dictionary { ["message"] = "test" }, - cancellationToken: TestContext.Current.CancellationToken)); + cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.True(toolResult.IsError); + + var errorContent = Assert.IsType(Assert.Single(toolResult.Content)); + Assert.Equal("An error occurred invoking 'authorized_tool'.", errorContent.Text); Assert.Contains(_mockLoggerProvider.LogMessages, log => - log.LogLevel == LogLevel.Warning && + log.LogLevel == LogLevel.Error && log.Exception is InvalidOperationException && log.Exception.Message.Contains("Authorization filter was not invoked for tools/call operation") && log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); @@ -316,7 +318,7 @@ public async Task ListPrompts_WithoutAuthFilters_ThrowsInvalidOperationException await using var app = await StartServerWithoutAuthFilters(builder => builder.WithPrompts()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => + var exception = await Assert.ThrowsAsync(async () => await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken)); Assert.Equal("Request failed (remote): An error occurred.", exception.Message); @@ -334,7 +336,7 @@ public async Task GetPrompt_WithoutAuthFilters_ThrowsInvalidOperationException() await using var app = await StartServerWithoutAuthFilters(builder => builder.WithPrompts()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => + var exception = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "authorized_prompt", new Dictionary { ["message"] = "test" }, @@ -355,7 +357,7 @@ public async Task ListResources_WithoutAuthFilters_ThrowsInvalidOperationExcepti await using var app = await StartServerWithoutAuthFilters(builder => builder.WithResources()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => + var exception = await Assert.ThrowsAsync(async () => await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken)); Assert.Equal("Request failed (remote): An error occurred.", exception.Message); @@ -373,7 +375,7 @@ public async Task ReadResource_WithoutAuthFilters_ThrowsInvalidOperationExceptio await using var app = await StartServerWithoutAuthFilters(builder => builder.WithResources()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => + var exception = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( "resource://authorized", cancellationToken: TestContext.Current.CancellationToken)); @@ -393,7 +395,7 @@ public async Task ListResourceTemplates_WithoutAuthFilters_ThrowsInvalidOperatio await using var app = await StartServerWithoutAuthFilters(builder => builder.WithResources()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => + var exception = await Assert.ThrowsAsync(async () => await client.ListResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken)); Assert.Equal("Request failed (remote): An error occurred.", exception.Message); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 5e3a654f9..78acaeb5e 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -240,7 +240,7 @@ public async Task GetPrompt_Sse_NonExistent_ThrowsException() // act await using var client = await GetClientAsync(); - await Assert.ThrowsAsync(async () => await client.GetPromptAsync("non_existent_prompt", null, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync(async () => await client.GetPromptAsync("non_existent_prompt", null, cancellationToken: TestContext.Current.CancellationToken)); } [Fact] diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index 0e953e4d7..29bb5e6a6 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -2,6 +2,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Primitives; using ModelContextProtocol.Client; +using System.Collections.Concurrent; namespace ModelContextProtocol.AspNetCore.Tests; @@ -148,7 +149,7 @@ public async Task SseMode_Works_WithSseEndpoint() [Fact] public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitialization() { - var protocolVersionHeaderValues = new List(); + var protocolVersionHeaderValues = new ConcurrentQueue(); Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools(); @@ -160,7 +161,7 @@ public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitia { if (!StringValues.IsNullOrEmpty(context.Request.Headers["mcp-protocol-version"])) { - protocolVersionHeaderValues.Add(context.Request.Headers["mcp-protocol-version"]); + protocolVersionHeaderValues.Enqueue(context.Request.Headers["mcp-protocol-version"]); } await next(context); diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 9765ed928..9a54ed71d 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -170,7 +170,7 @@ private static void ConfigureTools(McpServerOptions options, string? cliArg) { if (request.Params?.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) { - throw new McpException("Missing required argument 'message'", McpErrorCode.InvalidParams); + throw new McpProtocolException("Missing required argument 'message'", McpErrorCode.InvalidParams); } return new CallToolResult { @@ -190,7 +190,7 @@ private static void ConfigureTools(McpServerOptions options, string? cliArg) !request.Params.Arguments.TryGetValue("prompt", out var prompt) || !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) { - throw new McpException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); + throw new McpProtocolException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); } var sampleResult = await request.Server.SampleAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.GetRawText())), cancellationToken); @@ -209,7 +209,7 @@ private static void ConfigureTools(McpServerOptions options, string? cliArg) } else { - throw new McpException($"Unknown tool: {request.Params?.Name}", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Unknown tool: {request.Params?.Name}", McpErrorCode.InvalidParams); } }; } @@ -287,7 +287,7 @@ private static void ConfigurePrompts(McpServerOptions options) } else { - throw new McpException($"Unknown prompt: {request.Params?.Name}", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Unknown prompt: {request.Params?.Name}", McpErrorCode.InvalidParams); } return new GetPromptResult @@ -305,7 +305,7 @@ private static void ConfigureLogging(McpServerOptions options) { if (request.Params?.Level is null) { - throw new McpException("Missing required argument 'level'", McpErrorCode.InvalidParams); + throw new McpProtocolException("Missing required argument 'level'", McpErrorCode.InvalidParams); } _minimumLoggingLevel = request.Params.Level; @@ -387,7 +387,7 @@ private static void ConfigureResources(McpServerOptions options) } catch (Exception e) { - throw new McpException($"Invalid cursor: '{request.Params.Cursor}'", e, McpErrorCode.InvalidParams); + throw new McpProtocolException($"Invalid cursor: '{request.Params.Cursor}'", e, McpErrorCode.InvalidParams); } } @@ -409,7 +409,7 @@ private static void ConfigureResources(McpServerOptions options) { if (request.Params?.Uri is null) { - throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); + throw new McpProtocolException("Missing required argument 'uri'", McpErrorCode.InvalidParams); } if (request.Params.Uri.StartsWith("test://dynamic/resource/")) @@ -417,7 +417,7 @@ private static void ConfigureResources(McpServerOptions options) var id = request.Params.Uri.Split('/').LastOrDefault(); if (string.IsNullOrEmpty(id)) { - throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); } return new ReadResourceResult @@ -434,7 +434,7 @@ private static void ConfigureResources(McpServerOptions options) } ResourceContents contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) - ?? throw new McpException($"Resource not found: '{request.Params.Uri}'", McpErrorCode.InvalidParams); + ?? throw new McpProtocolException($"Resource not found: '{request.Params.Uri}'", McpErrorCode.InvalidParams); return new ReadResourceResult { @@ -446,12 +446,12 @@ private static void ConfigureResources(McpServerOptions options) { if (request?.Params?.Uri is null) { - throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); + throw new McpProtocolException("Missing required argument 'uri'", McpErrorCode.InvalidParams); } if (!request.Params.Uri.StartsWith("test://static/resource/") && !request.Params.Uri.StartsWith("test://dynamic/resource/")) { - throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); } _subscribedResources.TryAdd(request.Params.Uri, true); @@ -463,12 +463,12 @@ private static void ConfigureResources(McpServerOptions options) { if (request?.Params?.Uri is null) { - throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); + throw new McpProtocolException("Missing required argument 'uri'", McpErrorCode.InvalidParams); } if (!request.Params.Uri.StartsWith("test://static/resource/") && !request.Params.Uri.StartsWith("test://dynamic/resource/")) { - throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); } _subscribedResources.TryRemove(request.Params.Uri, out _); @@ -509,7 +509,7 @@ private static void ConfigureCompletions(McpServerOptions options) return new CompleteResult { Completion = new() { Values = values, HasMore = false, Total = values.Length } }; default: - throw new McpException($"Unknown reference type: '{request.Params?.Ref.Type}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Unknown reference type: '{request.Params?.Ref.Type}'", McpErrorCode.InvalidParams); } }; } diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index cf78c0896..be117dc21 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -158,13 +158,13 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { if (request.Params is null) { - throw new McpException("Missing required parameter 'name'", McpErrorCode.InvalidParams); + throw new McpProtocolException("Missing required parameter 'name'", McpErrorCode.InvalidParams); } if (request.Params.Name == "echo") { if (request.Params.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) { - throw new McpException("Missing required argument 'message'", McpErrorCode.InvalidParams); + throw new McpProtocolException("Missing required argument 'message'", McpErrorCode.InvalidParams); } return new CallToolResult { @@ -184,7 +184,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st !request.Params.Arguments.TryGetValue("prompt", out var prompt) || !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) { - throw new McpException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); + throw new McpProtocolException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); } var sampleResult = await request.Server.SampleAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.ToString())), cancellationToken); @@ -196,7 +196,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } else { - throw new McpException($"Unknown tool: '{request.Params.Name}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Unknown tool: '{request.Params.Name}'", McpErrorCode.InvalidParams); } }, ListResourceTemplatesHandler = async (request, cancellationToken) => @@ -226,7 +226,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } catch (Exception e) { - throw new McpException($"Invalid cursor: '{requestParams.Cursor}'", e, McpErrorCode.InvalidParams); + throw new McpProtocolException($"Invalid cursor: '{requestParams.Cursor}'", e, McpErrorCode.InvalidParams); } } @@ -248,7 +248,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { if (request.Params?.Uri is null) { - throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); + throw new McpProtocolException("Missing required argument 'uri'", McpErrorCode.InvalidParams); } if (request.Params.Uri.StartsWith("test://dynamic/resource/")) @@ -256,7 +256,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st var id = request.Params.Uri.Split('/').LastOrDefault(); if (string.IsNullOrEmpty(id)) { - throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); } return new ReadResourceResult @@ -273,7 +273,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } ResourceContents? contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) ?? - throw new McpException($"Resource not found: '{request.Params.Uri}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Resource not found: '{request.Params.Uri}'", McpErrorCode.InvalidParams); return new ReadResourceResult { @@ -317,7 +317,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { if (request.Params is null) { - throw new McpException("Missing required parameter 'name'", McpErrorCode.InvalidParams); + throw new McpProtocolException("Missing required parameter 'name'", McpErrorCode.InvalidParams); } List messages = new(); if (request.Params.Name == "simple_prompt") @@ -354,7 +354,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } else { - throw new McpException($"Unknown prompt: {request.Params.Name}", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Unknown prompt: {request.Params.Name}", McpErrorCode.InvalidParams); } return new GetPromptResult diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 20c6f374b..16fad124a 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -193,7 +193,7 @@ public async Task GetPrompt_NonExistent_ThrowsException(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - await Assert.ThrowsAsync(async () => + await Assert.ThrowsAsync(async () => await client.GetPromptAsync("non_existent_prompt", null, cancellationToken: TestContext.Current.CancellationToken)); } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs index 00e67c247..32b588d09 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs @@ -58,7 +58,18 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer var logger = GetLogger(request.Services, "CallToolFilter"); var primitiveId = request.MatchedPrimitive?.Id ?? "unknown"; logger.LogInformation($"CallToolFilter executed for tool: {primitiveId}"); - return await next(request, cancellationToken); + try + { + return await next(request, cancellationToken); + } + catch (Exception ex) + { + return new CallToolResult + { + Content = [new TextContentBlock { Type = "text", Text = $"Error from filter: {ex.Message}" }], + IsError = true + }; + } }) .AddListPromptsFilter((next) => async (request, cancellationToken) => { @@ -162,6 +173,20 @@ public async Task AddCallToolFilter_Logs_When_CallTool_Called() Assert.Equal("CallToolFilter", logMessage.Category); } + [Fact] + public async Task AddCallToolFilter_Catches_Exception_From_Tool() + { + await using McpClient client = await CreateMcpClientForServer(); + + var result = await client.CallToolAsync("throwing_tool_method", cancellationToken: TestContext.Current.CancellationToken); + + Assert.True(result.IsError); + Assert.NotNull(result.Content); + var textContent = Assert.Single(result.Content); + var textBlock = Assert.IsType(textContent); + Assert.Equal("Error from filter: This tool always throws an exception", textBlock.Text); + } + [Fact] public async Task AddListPromptsFilter_Logs_When_ListPrompts_Called() { @@ -286,6 +311,12 @@ public static string TestToolMethod() { return "test result"; } + + [McpServerTool] + public static string ThrowingToolMethod() + { + throw new InvalidOperationException("This tool always throws an exception"); + } } [McpServerPromptType] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 2df57dbf3..8fdeacb9b 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -62,7 +62,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer }; default: - throw new McpException($"Unexpected cursor: '{cursor}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Unexpected cursor: '{cursor}'", McpErrorCode.InvalidParams); } }) .WithGetPromptHandler(async (request, cancellationToken) => @@ -78,7 +78,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer }; default: - throw new McpException($"Unknown prompt '{request.Params?.Name}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Unknown prompt '{request.Params?.Name}'", McpErrorCode.InvalidParams); } }) .WithPrompts(); @@ -190,7 +190,7 @@ public async Task Throws_When_Prompt_Fails() { await using McpClient client = await CreateMcpClientForServer(); - await Assert.ThrowsAsync(async () => await client.GetPromptAsync( + await Assert.ThrowsAsync(async () => await client.GetPromptAsync( nameof(SimplePrompts.ThrowsException), cancellationToken: TestContext.Current.CancellationToken)); } @@ -200,7 +200,7 @@ public async Task Throws_Exception_On_Unknown_Prompt() { await using McpClient client = await CreateMcpClientForServer(); - var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( + var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "NotRegisteredPrompt", cancellationToken: TestContext.Current.CancellationToken)); @@ -212,7 +212,7 @@ public async Task Throws_Exception_Missing_Parameter() { await using McpClient client = await CreateMcpClientForServer(); - var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( + var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "returns_chat_messages", cancellationToken: TestContext.Current.CancellationToken)); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs index 3a8a63e8f..7d037fb2d 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs @@ -62,7 +62,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer }; default: - throw new McpException($"Unexpected cursor: '{cursor}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Unexpected cursor: '{cursor}'", McpErrorCode.InvalidParams); } }) .WithListResourceTemplatesHandler(async (request, cancellationToken) => @@ -91,7 +91,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer }], }; default: - throw new McpException($"Unexpected cursor: '{cursor}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Unexpected cursor: '{cursor}'", McpErrorCode.InvalidParams); } }) .WithReadResourceHandler(async (request, cancellationToken) => @@ -109,7 +109,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer }; } - throw new McpException($"Resource not found: {request.Params?.Uri}"); + throw new McpProtocolException($"Resource not found: {request.Params?.Uri}"); }) .WithResources(); } @@ -235,7 +235,7 @@ public async Task Throws_When_Resource_Fails() { await using McpClient client = await CreateMcpClientForServer(); - await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( + await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( $"resource://mcp/{nameof(SimpleResources.ThrowsException)}", cancellationToken: TestContext.Current.CancellationToken)); } @@ -245,7 +245,7 @@ public async Task Throws_Exception_On_Unknown_Resource() { await using McpClient client = await CreateMcpClientForServer(); - var e = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( + var e = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( "test:///NotRegisteredResource", cancellationToken: TestContext.Current.CancellationToken)); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 97fd3e330..79bf6fe50 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -90,7 +90,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer }; default: - throw new McpException($"Unexpected cursor: '{cursor}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Unexpected cursor: '{cursor}'", McpErrorCode.InvalidParams); } }) .WithCallToolHandler(async (request, cancellationToken) => @@ -106,7 +106,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer }; default: - throw new McpException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams); + throw new McpProtocolException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams); } }) .WithTools(serializerOptions: BuilderToolsJsonContext.Default.Options); @@ -388,7 +388,7 @@ public async Task Throws_Exception_On_Unknown_Tool() { await using McpClient client = await CreateMcpClientForServer(); - var e = await Assert.ThrowsAsync(async () => await client.CallToolAsync( + var e = await Assert.ThrowsAsync(async () => await client.CallToolAsync( "NotRegisteredTool", cancellationToken: TestContext.Current.CancellationToken)); diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index 8996b9962..b0e7b6d07 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -87,7 +87,7 @@ public async Task Session_FailedToolCall() await RunConnected(async (client, server) => { await client.CallToolAsync("Throw", cancellationToken: TestContext.Current.CancellationToken); - await Assert.ThrowsAsync(async () => await client.CallToolAsync("does-not-exist", cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync(async () => await client.CallToolAsync("does-not-exist", cancellationToken: TestContext.Current.CancellationToken)); }, new List()); } diff --git a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTypedTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTypedTests.cs index 47da166ca..55e32f4ae 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTypedTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTypedTests.cs @@ -248,7 +248,7 @@ public async Task Elicit_Typed_With_Unsupported_Property_Type_Throws() }, }); - var ex = await Assert.ThrowsAsync(async() => + var ex = await Assert.ThrowsAsync(async() => await client.CallToolAsync("TestElicitationUnsupportedType", cancellationToken: TestContext.Current.CancellationToken)); Assert.Contains(typeof(UnsupportedForm.Nested).FullName!, ex.Message); @@ -270,7 +270,7 @@ public async Task Elicit_Typed_With_Nullable_Property_Type_Throws() } }); - var ex = await Assert.ThrowsAsync(async () => + var ex = await Assert.ThrowsAsync(async () => await client.CallToolAsync("TestElicitationNullablePropertyForm", cancellationToken: TestContext.Current.CancellationToken)); } @@ -290,7 +290,7 @@ public async Task Elicit_Typed_With_NonObject_Generic_Type_Throws() } }); - var ex = await Assert.ThrowsAsync(async () => + var ex = await Assert.ThrowsAsync(async () => await client.CallToolAsync("TestElicitationNonObjectGenericType", cancellationToken: TestContext.Current.CancellationToken)); Assert.Contains(typeof(string).FullName!, ex.Message); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 40461d415..4352570e7 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -532,6 +532,110 @@ public async Task Can_Handle_Call_Tool_Requests_Throws_Exception_If_No_Handler_A await Succeeds_Even_If_No_Handler_Assigned(new ServerCapabilities { Tools = new() }, RequestMethods.ToolsCall, "CallTool handler not configured"); } + [Fact] + public async Task Can_Handle_Call_Tool_Requests_With_McpException() + { + const string errorMessage = "Tool execution failed with detailed error"; + await Can_Handle_Requests( + new ServerCapabilities + { + Tools = new() + }, + method: RequestMethods.ToolsCall, + configureOptions: options => + { + options.Handlers.CallToolHandler = async (request, ct) => + { + throw new McpException(errorMessage); + }; + options.Handlers.ListToolsHandler = (request, ct) => throw new NotImplementedException(); + }, + assertResult: (_, response) => + { + var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); + Assert.NotNull(result); + Assert.True(result.IsError); + Assert.NotEmpty(result.Content); + var textContent = Assert.IsType(result.Content[0]); + Assert.Contains(errorMessage, textContent.Text); + }); + } + + [Fact] + public async Task Can_Handle_Call_Tool_Requests_With_Plain_Exception() + { + await Can_Handle_Requests( + new ServerCapabilities + { + Tools = new() + }, + method: RequestMethods.ToolsCall, + configureOptions: options => + { + options.Handlers.CallToolHandler = async (request, ct) => + { + throw new InvalidOperationException("This sensitive message should not be exposed"); + }; + options.Handlers.ListToolsHandler = (request, ct) => throw new NotImplementedException(); + }, + assertResult: (_, response) => + { + var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); + Assert.NotNull(result); + Assert.True(result.IsError); + Assert.NotEmpty(result.Content); + var textContent = Assert.IsType(result.Content[0]); + // Should be a generic error message, not the actual exception message + Assert.DoesNotContain("sensitive", textContent.Text, StringComparison.OrdinalIgnoreCase); + Assert.Contains("An error occurred", textContent.Text); + }); + } + + [Fact] + public async Task Can_Handle_Call_Tool_Requests_With_McpProtocolException() + { + const string errorMessage = "Invalid tool parameters"; + const McpErrorCode errorCode = McpErrorCode.InvalidParams; + + await using var transport = new TestServerTransport(); + var options = CreateOptions(new ServerCapabilities { Tools = new() }); + options.Handlers.CallToolHandler = async (request, ct) => + { + throw new McpProtocolException(errorMessage, errorCode); + }; + options.Handlers.ListToolsHandler = (request, ct) => throw new NotImplementedException(); + + await using var server = McpServer.Create(transport, options, LoggerFactory); + + var runTask = server.RunAsync(TestContext.Current.CancellationToken); + + var receivedMessage = new TaskCompletionSource(); + + transport.OnMessageSent = (message) => + { + if (message is JsonRpcError error && error.Id.ToString() == "55") + receivedMessage.SetResult(error); + }; + + await transport.SendMessageAsync( + new JsonRpcRequest + { + Method = RequestMethods.ToolsCall, + Id = new RequestId(55) + }, + TestContext.Current.CancellationToken + ); + + var error = await receivedMessage.Task.WaitAsync(TimeSpan.FromSeconds(5), TestContext.Current.CancellationToken); + Assert.NotNull(error); + Assert.NotNull(error.Error); + Assert.Equal((int)errorCode, error.Error.Code); + Assert.Equal(errorMessage, error.Error.Message); + + await transport.DisposeAsync(); + await runTask; + } + private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, string method, Action? configureOptions, Action assertResult) { await using var transport = new TestServerTransport();