From b3146056e2d55defd0fe2126d7105bd89cdf0741 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Mon, 28 Apr 2025 12:59:30 -0700 Subject: [PATCH 1/5] Add stateless Streamable HTTP support - This allows a single MCP session spanning multiple requests to be handled by different servers without sharing state - This does require the servers share data protection keys, but this is standard for ASP.NET Core cookies and antiforgery as well --- .../HttpMcpServerBuilderExtensions.cs | 1 + .../HttpMcpSession.cs | 30 +-- .../HttpServerTransportOptions.cs | 7 + .../McpEndpointRouteBuilderExtensions.cs | 35 ++-- .../SseHandler.cs | 5 +- .../StatelessSessionId.cs | 16 ++ .../StatelessSessionIdJsonContext.cs | 6 + .../StreamableHttpHandler.cs | 194 ++++++++++++++---- .../Protocol/Messages/JsonRpcRequest.cs | 3 +- .../StreamableHttpClientSessionTransport.cs | 2 +- .../Transport/StreamableHttpPostTransport.cs | 42 ++-- .../StreamableHttpServerTransport.cs | 35 +++- src/ModelContextProtocol/Server/McpServer.cs | 25 ++- .../Server/McpServerOptions.cs | 22 ++ .../HttpServerIntegrationTests.cs | 2 + .../MapMcpSseTests.cs | 24 +++ .../MapMcpStatelessTests.cs | 10 + .../MapMcpStreamableHttpTests.cs | 2 +- .../MapMcpTests.cs | 32 +-- .../SseIntegrationTests.cs | 2 +- .../StatelessServerIntegrationTests.cs | 16 ++ .../StatelessServerTests.cs | 69 +++++++ .../StreamableHttpServerConformanceTests.cs | 42 +++- .../StreamableHttpServerIntegrationTests.cs | 1 - .../Program.cs | 24 +++ 25 files changed, 501 insertions(+), 146 deletions(-) create mode 100644 src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs create mode 100644 src/ModelContextProtocol.AspNetCore/StatelessSessionIdJsonContext.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index 8bff45962..a8a63e49e 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -26,6 +26,7 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder builder.Services.TryAddSingleton(); builder.Services.TryAddSingleton(); builder.Services.AddHostedService(); + builder.Services.AddDataProtection(); if (configureOptions is not null) { diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs index 1b854b944..0903dda6e 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs @@ -4,7 +4,11 @@ namespace ModelContextProtocol.AspNetCore; -internal sealed class HttpMcpSession(string sessionId, TTransport transport, ClaimsPrincipal user, TimeProvider timeProvider) : IAsyncDisposable +internal sealed class HttpMcpSession( + string sessionId, + TTransport transport, + (string Type, string Value, string Issuer)? userIdClaim, + TimeProvider timeProvider) : IAsyncDisposable where TTransport : ITransport { private int _referenceCount; @@ -13,7 +17,7 @@ internal sealed class HttpMcpSession(string sessionId, TTransport tr public string Id { get; } = sessionId; public TTransport Transport { get; } = transport; - public (string Type, string Value, string Issuer)? UserIdClaim { get; } = GetUserIdClaim(user); + public (string Type, string Value, string Issuer)? UserIdClaim { get; } = userIdClaim; public CancellationToken SessionClosed => _disposeCts.Token; @@ -63,27 +67,7 @@ public async ValueTask DisposeAsync() } public bool HasSameUserId(ClaimsPrincipal user) - => UserIdClaim == GetUserIdClaim(user); - - // SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims. - // However, we short-circuit unlike antiforgery since we expect to call this to verify MCP messages a lot more frequently than - // verifying antiforgery tokens from
posts. - private static (string Type, string Value, string Issuer)? GetUserIdClaim(ClaimsPrincipal user) - { - if (user?.Identity?.IsAuthenticated != true) - { - return null; - } - - var claim = user.FindFirst(ClaimTypes.NameIdentifier) ?? user.FindFirst("sub") ?? user.FindFirst(ClaimTypes.Upn); - - if (claim is { } idClaim) - { - return (idClaim.Type, idClaim.Value, idClaim.Issuer); - } - - return null; - } + => UserIdClaim == StreamableHttpHandler.GetUserIdClaim(user); private sealed class UnreferenceDisposable(HttpMcpSession session, TimeProvider timeProvider) : IDisposable { diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 4880714c4..df83ff6d8 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -22,6 +22,13 @@ public class HttpServerTransportOptions /// public Func? RunSessionHandler { get; set; } + /// + /// Gets or sets whether the server should run in a stateless mode that does not require all requests for a given session + /// to arrive to the same ASP.NET Core application process. If true, the /sse endpoint will be disabled, and + /// client capabilities will be round-tripped as part of the mcp-session-id header instead of stored in memory. Defaults to false. + /// + public bool Stateless { get; set; } + /// /// Represents the duration of time the server will wait between any active requests before timing out an /// MCP session. This is checked in background every 5 seconds. A client trying to resume a session will diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 0eefa52fb..1e60d2aab 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -35,20 +35,27 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo .WithMetadata(new AcceptsMetadata(["application/json"])) .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])) .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); - streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync) - .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); - streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync); - - // Map legacy HTTP with SSE endpoints. - var sseHandler = endpoints.ServiceProvider.GetRequiredService(); - var sseGroup = mcpGroup.MapGroup("") - .WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}"); - - sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync) - .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); - sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync) - .WithMetadata(new AcceptsMetadata(["application/json"])) - .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); + + if (!streamableHttpHandler.HttpServerTransportOptions.Stateless) + { + // The GET and DELETE endpoints are not mapped in Stateless mode since there's no way to send unsolicited messages + // for the GET to handle, and there is no server-side state for the DELETE to clean up. + streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); + streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync); + + // Map legacy HTTP with SSE endpoints only if not in Stateless mode, because we cannot guarantee the /message requests + // will be handled by the same process as the /sse request. + var sseHandler = endpoints.ServiceProvider.GetRequiredService(); + var sseGroup = mcpGroup.MapGroup("") + .WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}"); + + sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); + sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync) + .WithMetadata(new AcceptsMetadata(["application/json"])) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); + } return mcpGroup; } diff --git a/src/ModelContextProtocol.AspNetCore/SseHandler.cs b/src/ModelContextProtocol.AspNetCore/SseHandler.cs index 36efadef4..cea6817ea 100644 --- a/src/ModelContextProtocol.AspNetCore/SseHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/SseHandler.cs @@ -34,7 +34,10 @@ public async Task HandleSseRequestAsync(HttpContext context) var requestPath = (context.Request.PathBase + context.Request.Path).ToString(); var endpointPattern = requestPath[..(requestPath.LastIndexOf('/') + 1)]; await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}"); - await using var httpMcpSession = new HttpMcpSession(sessionId, transport, context.User, httpMcpServerOptions.Value.TimeProvider); + + var userIdClaim = StreamableHttpHandler.GetUserIdClaim(context.User); + await using var httpMcpSession = new HttpMcpSession(sessionId, transport, userIdClaim, httpMcpServerOptions.Value.TimeProvider); + if (!_sessions.TryAdd(sessionId, httpMcpSession)) { throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); diff --git a/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs b/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs new file mode 100644 index 000000000..73c206e94 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs @@ -0,0 +1,16 @@ +using ModelContextProtocol.Protocol.Types; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.AspNetCore; + +internal class StatelessSessionId +{ + [JsonPropertyName("capabilities")] + public ClientCapabilities? Capabilities { get; init; } + + [JsonPropertyName("clientInfo")] + public Implementation? ClientInfo { get; init; } + + [JsonPropertyName("userIdClaim")] + public (string Type, string Value, string Issuer)? UserIdClaim { get; init; } +} diff --git a/src/ModelContextProtocol.AspNetCore/StatelessSessionIdJsonContext.cs b/src/ModelContextProtocol.AspNetCore/StatelessSessionIdJsonContext.cs new file mode 100644 index 000000000..2690a3b15 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/StatelessSessionIdJsonContext.cs @@ -0,0 +1,6 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.AspNetCore; + +[JsonSerializable(typeof(StatelessSessionId))] +internal sealed partial class StatelessSessionIdJsonContext : JsonSerializerContext; diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 64b10d6de..072ad8512 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -1,4 +1,5 @@ -using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.DataProtection; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.Logging; @@ -11,7 +12,9 @@ using System.Collections.Concurrent; using System.Diagnostics; using System.IO.Pipelines; +using System.Security.Claims; using System.Security.Cryptography; +using System.Text.Json; using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.AspNetCore; @@ -19,16 +22,24 @@ namespace ModelContextProtocol.AspNetCore; internal sealed class StreamableHttpHandler( IOptions mcpServerOptionsSnapshot, IOptionsFactory mcpServerOptionsFactory, - IOptions httpMcpServerOptions, + IOptions httpServerTransportOptions, + IDataProtectionProvider dataProtection, ILoggerFactory loggerFactory, IServiceProvider applicationServices) { + private const string StatelessSessionIdPurpose = "Microsoft.AspNetCore.StreamableHttpHandler.StatelessSessionId"; + private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); + private static readonly MediaTypeHeaderValue s_applicationJsonMediaType = new("application/json"); private static readonly MediaTypeHeaderValue s_textEventStreamMediaType = new("text/event-stream"); public ConcurrentDictionary> Sessions { get; } = new(StringComparer.Ordinal); + public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value; + + private IDataProtector Protector { get; } = dataProtection.CreateProtector(StatelessSessionIdPurpose); + public async Task HandlePostRequestAsync(HttpContext context) { // The Streamable HTTP spec mandates the client MUST accept both application/json and text/event-stream. @@ -50,14 +61,28 @@ await WriteJsonRpcErrorAsync(context, return; } - using var _ = session.AcquireReference(); - InitializeSseResponse(context); - var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted); - if (!wroteResponse) + try + { + using var _ = session.AcquireReference(); + + InitializeSseResponse(context); + var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted); + if (!wroteResponse) + { + // We wound up writing nothing, so there should be no Content-Type response header. + context.Response.Headers.ContentType = (string?)null; + context.Response.StatusCode = StatusCodes.Status202Accepted; + } + } + finally { - // We wound up writing nothing, so there should be no Content-Type response header. - context.Response.Headers.ContentType = (string?)null; - context.Response.StatusCode = StatusCodes.Status202Accepted; + // Stateless sessions are 1:1 with HTTP requests and are outlived by the MCP session tracked by the mcp-session-id. + // Non-stateless sessions are 1:1 with the mcp-session-id and outlive the POST request. + // Non-stateless sessions get disposed by a DELETE request or the IdleTrackingBackgroundService. + if (HttpServerTransportOptions.Stateless) + { + await session.DisposeAsync(); + } } } @@ -108,27 +133,36 @@ public async Task HandleDeleteRequestAsync(HttpContext context) private async ValueTask?> GetSessionAsync(HttpContext context, string sessionId) { - if (Sessions.TryGetValue(sessionId, out var existingSession)) + HttpMcpSession? session; + + if (HttpServerTransportOptions.Stateless) { - if (!existingSession.HasSameUserId(context.User)) - { - await WriteJsonRpcErrorAsync(context, - "Forbidden: The currently authenticated user does not match the user who initiated the session.", - StatusCodes.Status403Forbidden); - return null; - } + var sessionJson = Protector.Unprotect(sessionId); + var statelessSessionId = JsonSerializer.Deserialize(sessionJson, StatelessSessionIdJsonContext.Default.StatelessSessionId); + var transport = new StreamableHttpServerTransport(); + session = await CreateSessionAsync(context, transport, sessionId, statelessSessionId); + } + else if (!Sessions.TryGetValue(sessionId, out session)) + { + // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does. + // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this + // JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound + // https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields + await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, 32001); + return null; + } - context.Response.Headers["mcp-session-id"] = existingSession.Id; - context.Features.Set(existingSession.Server); - return existingSession; + if (!session.HasSameUserId(context.User)) + { + await WriteJsonRpcErrorAsync(context, + "Forbidden: The currently authenticated user does not match the user who initiated the session.", + StatusCodes.Status403Forbidden); + return null; } - // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does. - // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this - // JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound - // https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields - await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, 32001); - return null; + context.Response.Headers["mcp-session-id"] = session.Id; + context.Features.Set(session.Server); + return session; } private async ValueTask?> GetOrCreateSessionAsync(HttpContext context) @@ -137,14 +171,7 @@ await WriteJsonRpcErrorAsync(context, if (string.IsNullOrEmpty(sessionId)) { - var session = await CreateSessionAsync(context); - - if (!Sessions.TryAdd(session.Id, session)) - { - throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); - } - - return session; + return await StartNewSessionAsync(context); } else { @@ -152,29 +179,72 @@ await WriteJsonRpcErrorAsync(context, } } - private async ValueTask> CreateSessionAsync(HttpContext context) + private async ValueTask> StartNewSessionAsync(HttpContext context) { - var sessionId = MakeNewSessionId(); - context.Response.Headers["mcp-session-id"] = sessionId; + string sessionId; + var transport = new StreamableHttpServerTransport(); + + if (!HttpServerTransportOptions.Stateless) + { + sessionId = MakeNewSessionId(); + context.Response.Headers["mcp-session-id"] = sessionId; + } + else + { + // "(uninitialized stateless id)" is not written anywhere. We delay writing th mcp-session-id + // until after we receive the initialize request with the client info we need to serialize. + sessionId = "(uninitialized stateless id)"; + ScheduleStatelessSessionIdWrite(context, transport); + } + + var session = await CreateSessionAsync(context, transport, sessionId); + // The HttpMcpSession is not stored between requests in stateless mode. Instead the session is recreated from the mcp-session-id. + if (!HttpServerTransportOptions.Stateless) + { + if (!Sessions.TryAdd(sessionId, session)) + { + throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); + } + } + + return session; + } + + private async ValueTask> CreateSessionAsync( + HttpContext context, + StreamableHttpServerTransport transport, + string sessionId, + StatelessSessionId? statelessId = null) + { var mcpServerOptions = mcpServerOptionsSnapshot.Value; - if (httpMcpServerOptions.Value.ConfigureSessionOptions is { } configureSessionOptions) + if (statelessId is not null || HttpServerTransportOptions.ConfigureSessionOptions is not null) { mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName); - await configureSessionOptions(context, mcpServerOptions, context.RequestAborted); + + if (statelessId is not null) + { + mcpServerOptions.KnownClientInfo = statelessId.ClientInfo; + mcpServerOptions.KnownClientCapabilities = statelessId.Capabilities; + } + + if (HttpServerTransportOptions.ConfigureSessionOptions is { } configureSessionOptions) + { + await configureSessionOptions(context, mcpServerOptions, context.RequestAborted); + } } - var transport = new StreamableHttpServerTransport(); // Use application instead of request services, because the session will likely outlive the first initialization request. var server = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, applicationServices); context.Features.Set(server); - var session = new HttpMcpSession(sessionId, transport, context.User, httpMcpServerOptions.Value.TimeProvider) + var userIdClaim = statelessId?.UserIdClaim ?? GetUserIdClaim(context.User); + var session = new HttpMcpSession(sessionId, transport, userIdClaim, HttpServerTransportOptions.TimeProvider) { Server = server, }; - var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? RunSessionAsync; + var runSessionAsync = HttpServerTransportOptions.RunSessionHandler ?? RunSessionAsync; session.ServerRunTask = runSessionAsync(context, server, session.SessionClosed); return session; @@ -210,9 +280,49 @@ internal static string MakeNewSessionId() return WebEncoders.Base64UrlEncode(buffer); } + private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttpServerTransport transport) + { + context.Response.OnStarting(() => + { + var statelessId = new StatelessSessionId + { + ClientInfo = transport.ClientInfo, + Capabilities = transport.ClientCapabilities, + UserIdClaim = GetUserIdClaim(context.User), + }; + + var sessionJson = JsonSerializer.Serialize(statelessId, StatelessSessionIdJsonContext.Default.StatelessSessionId); + var sessionId = Protector.Protect(sessionJson); + + context.Response.Headers["mcp-session-id"] = sessionId; + + return Task.CompletedTask; + }); + } + internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted) => session.RunAsync(requestAborted); + // SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims. + // However, we short-circuit unlike antiforgery since we expect to call this to verify MCP messages a lot more frequently than + // verifying antiforgery tokens from posts. + internal static (string Type, string Value, string Issuer)? GetUserIdClaim(ClaimsPrincipal user) + { + if (user?.Identity?.IsAuthenticated != true) + { + return null; + } + + var claim = user.FindFirst(ClaimTypes.NameIdentifier) ?? user.FindFirst("sub") ?? user.FindFirst(ClaimTypes.Upn); + + if (claim is { } idClaim) + { + return (idClaim.Type, idClaim.Value, idClaim.Issuer); + } + + return null; + } + private static JsonTypeInfo GetRequiredJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); private sealed class HttpDuplexPipe(HttpContext context) : IDuplexPipe diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs index ff7a45044..6e356cf26 100644 --- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs +++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs @@ -35,7 +35,8 @@ internal JsonRpcRequest WithId(RequestId id) JsonRpc = JsonRpc, Id = id, Method = Method, - Params = Params + Params = Params, + RelatedTransport = RelatedTransport, }; } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs index 7697c28e0..a442d5b3d 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs @@ -57,7 +57,7 @@ public override async Task SendMessageAsync( cancellationToken = sendCts.Token; #if NET - using var content = JsonContent.Create(message, McpJsonUtilities.DefaultOptions.GetTypeInfo()); + using var content = JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); #else using var content = new StringContent( JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs index 4cdb30b34..e3cdb4040 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs @@ -1,7 +1,5 @@ using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; -using System.Buffers; using System.IO.Pipelines; using System.Net.ServerSentEvents; using System.Runtime.CompilerServices; @@ -14,12 +12,11 @@ namespace ModelContextProtocol.Protocol.Transport; /// Handles processing the request/response body pairs for the Streamable HTTP transport. /// This is typically used via . /// -internal sealed class StreamableHttpPostTransport(ChannelWriter? incomingChannel, IDuplexPipe httpBodies) : ITransport +internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, IDuplexPipe httpBodies) : ITransport { private readonly SseWriter _sseWriter = new(); - private readonly HashSet _pendingRequests = []; + private RequestId _pendingRequest; - // REVIEW: Should we introduce a send-only interface for RelatedTransport? public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.RelatedTransport should only be used for sending messages."); /// @@ -29,15 +26,11 @@ internal sealed class StreamableHttpPostTransport(ChannelWriter? /// public async ValueTask RunAsync(CancellationToken cancellationToken) { - // The incomingChannel is null to handle the potential client GET request to handle unsolicited JsonRpcMessages. - if (incomingChannel is not null) - { - var message = await JsonSerializer.DeserializeAsync(httpBodies.Input.AsStream(), - McpJsonUtilities.JsonContext.Default.JsonRpcMessage, cancellationToken).ConfigureAwait(false); - await OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false); - } + var message = await JsonSerializer.DeserializeAsync(httpBodies.Input.AsStream(), + McpJsonUtilities.JsonContext.Default.JsonRpcMessage, cancellationToken).ConfigureAwait(false); + await OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false); - if (_pendingRequests.Count == 0) + if (_pendingRequest.Id is null) { return false; } @@ -63,13 +56,10 @@ public async ValueTask DisposeAsync() { yield return message; - if (message.Data is JsonRpcMessageWithId response) + if (message.Data is JsonRpcMessageWithId response && response.Id == _pendingRequest) { - if (_pendingRequests.Remove(response.Id) && _pendingRequests.Count == 0) - { - // Complete the SSE response stream now that all pending requests have been processed. - break; - } + // Complete the SSE response stream now that all pending requests have been processed. + break; } } } @@ -83,13 +73,19 @@ private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, Cancella if (message is JsonRpcRequest request) { - _pendingRequests.Add(request.Id); + _pendingRequest = request.Id; + + // Store client capabilities so they can be serialized by "stateless" callers for use in later requests. + if (request.Method == RequestMethods.Initialize) + { + var initializeRequestParams = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); + parentTransport.ClientCapabilities = initializeRequestParams?.Capabilities; + parentTransport.ClientInfo = initializeRequestParams?.ClientInfo; + } } message.RelatedTransport = this; - // Really an assertion. This doesn't get called when incomingChannel is null for GET requests. - Throw.IfNull(incomingChannel); - await incomingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false); + await parentTransport.MessageWriter.WriteAsync(message, cancellationToken).ConfigureAwait(false); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs index aa9e522da..9e8cb1d6b 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Types; using System.IO.Pipelines; using System.Threading.Channels; @@ -36,6 +37,21 @@ public sealed class StreamableHttpServerTransport : ITransport private int _getRequestStarted; + /// + /// Gets the capabilities supported by the client if it was received by . + /// + public ClientCapabilities? ClientCapabilities { get; internal set; } + + /// + /// Gets the version and implementation information of the connected client if it was received by . + /// + public Implementation? ClientInfo { get; internal set; } + + /// + public ChannelReader MessageReader => _incomingChannel.Reader; + + internal ChannelWriter MessageWriter => _incomingChannel.Writer; + /// /// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by /// writing any unsolicited JSON-RPC messages sent via @@ -63,20 +79,17 @@ public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken c /// The duplex pipe facilitates the reading and writing of HTTP request and response data. /// This token allows for the operation to be canceled if needed. /// - /// True, if data was written to the respond body. + /// True, if data was written to the response body. /// False, if nothing was written because the request body did not contain any messages to respond to. /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario. /// public async Task HandlePostRequest(IDuplexPipe httpBodies, CancellationToken cancellationToken) { using var postCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken); - await using var postTransport = new StreamableHttpPostTransport(_incomingChannel.Writer, httpBodies); + await using var postTransport = new StreamableHttpPostTransport(this, httpBodies); return await postTransport.RunAsync(postCts.Token).ConfigureAwait(false); } - /// - public ChannelReader MessageReader => _incomingChannel.Reader; - /// public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { @@ -86,14 +99,20 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can /// public async ValueTask DisposeAsync() { - _disposeCts.Cancel(); try { - await _sseWriter.DisposeAsync().ConfigureAwait(false); + await _disposeCts.CancelAsync(); } finally { - _disposeCts.Dispose(); + try + { + await _sseWriter.DisposeAsync().ConfigureAwait(false); + } + finally + { + _disposeCts.Dispose(); + } } } } diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index ae0e7afc5..4a55a0af5 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -26,7 +26,8 @@ internal sealed class McpServer : McpEndpoint, IMcpServer private readonly EventHandler? _toolsChangedDelegate; private readonly EventHandler? _promptsChangedDelegate; - private string _endpointName; + private readonly string _serverOnlyEndpointName; + private string? _endpointName; private int _started; /// Holds a boxed value for the server. @@ -56,9 +57,13 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? _sessionTransport = transport; ServerOptions = options; Services = serviceProvider; - _endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; + _serverOnlyEndpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; _servicesScopePerRequest = options.ScopeRequests; + ClientCapabilities = options.KnownClientCapabilities; + ClientInfo = options.KnownClientInfo; + UpdateEndpointNameWithClientInfo(); + // Configure all request handlers based on the supplied options. SetInitializeHandler(options); SetToolsHandler(options); @@ -114,7 +119,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? public IServiceProvider? Services { get; } /// - public override string EndpointName => _endpointName; + public override string EndpointName => _endpointName ?? _serverOnlyEndpointName; /// public LoggingLevel? LoggingLevel => _loggingLevel?.Value; @@ -172,8 +177,8 @@ private void SetInitializeHandler(McpServerOptions options) ClientInfo = request?.ClientInfo; // Use the ClientInfo to update the session EndpointName for logging. - _endpointName = $"{_endpointName}, Client ({ClientInfo?.Name} {ClientInfo?.Version})"; - GetSessionOrThrow().EndpointName = _endpointName; + UpdateEndpointNameWithClientInfo(); + GetSessionOrThrow().EndpointName = EndpointName; return new InitializeResult { @@ -551,6 +556,16 @@ private void SetHandler( requestTypeInfo, responseTypeInfo); } + private void UpdateEndpointNameWithClientInfo() + { + if (ClientInfo is null) + { + return; + } + + _endpointName = $"{_serverOnlyEndpointName}, Client ({ClientInfo.Name} {ClientInfo.Version})"; + } + /// Maps a to a . internal static LoggingLevel ToLoggingLevel(LogLevel level) => level switch diff --git a/src/ModelContextProtocol/Server/McpServerOptions.cs b/src/ModelContextProtocol/Server/McpServerOptions.cs index 6880d2f2b..4c820a496 100644 --- a/src/ModelContextProtocol/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol/Server/McpServerOptions.cs @@ -65,4 +65,26 @@ public class McpServerOptions /// handler will be invoked within a new service scope. /// public bool ScopeRequests { get; set; } = true; + + /// + /// Gets or sets preexisting knowledge about the client including its name and version to help support + /// stateless Streamable HTTP servers that encode this knowledge in the mcp-session-id header. + /// + /// + /// + /// When not specified, this information sourced from the client's initialize request. + /// + /// + public Implementation? KnownClientInfo { get; set; } + + /// + /// Gets or sets preexisting knowledge about the client client capabilities to help support + /// stateless Streamable HTTP servers that encode this knowledge in the mcp-session-id header. + /// + /// + /// + /// When not specified, this information sourced from the client's initialize request. + /// + /// + public ClientCapabilities? KnownClientCapabilities { get; set; } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 57a6c6ad9..fe7c9d036 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -207,6 +207,8 @@ await Assert.ThrowsAsync(() => [Fact] public async Task Sampling_Sse_TestServer() { + Assert.SkipWhen(GetType() == typeof(StatelessServerIntegrationTests), "Sampling is not supported in stateless mode."); + // arrange // Set up the sampling handler int samplingHandlerCalls = 0; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs index d385623a2..1d4917bae 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs @@ -8,6 +8,30 @@ public class MapMcpSseTests(ITestOutputHelper outputHelper) : MapMcpTests(output { protected override bool UseStreamableHttp => false; + [Theory] + [InlineData("/mcp")] + [InlineData("/mcp/secondary")] + public async Task Allows_Customizing_Route(string pattern) + { + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless); + await using var app = Builder.Build(); + + app.MapMcp(pattern); + + await app.StartAsync(TestContext.Current.CancellationToken); + + using var response = await HttpClient.GetAsync($"http://localhost{pattern}/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + response.EnsureSuccessStatusCode(); + using var sseStream = await response.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); + using var sseStreamReader = new StreamReader(sseStream, System.Text.Encoding.UTF8); + var eventLine = await sseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); + var dataLine = await sseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.NotNull(eventLine); + Assert.Equal("event: endpoint", eventLine); + Assert.NotNull(dataLine); + Assert.Equal($"data: {pattern}/message", dataLine[..dataLine.IndexOf('?')]); + } + [Theory] [InlineData("/a", "/a/sse")] [InlineData("/a/", "/a/sse")] diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs new file mode 100644 index 000000000..030701c72 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs @@ -0,0 +1,10 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class MapMcpStatelessTests(ITestOutputHelper outputHelper) : MapMcpStreamableHttpTests(outputHelper) +{ + protected override bool UseStreamableHttp => true; + protected override bool Stateless => true; +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index 30632a8e6..0b2f68bbf 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -22,7 +22,7 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat Name = "TestCustomRouteServer", Version = "1.0.0", }; - }).WithHttpTransport(); + }).WithHttpTransport(ConfigureStateless); await using var app = Builder.Build(); app.MapMcp(routePattern); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 70b028e22..dd6540716 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -15,6 +15,13 @@ public abstract class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelI { protected abstract bool UseStreamableHttp { get; } + protected virtual bool Stateless => false; + + protected void ConfigureStateless(HttpServerTransportOptions options) + { + options.Stateless = Stateless; + } + protected async Task ConnectAsync(string? path = null) { path ??= UseStreamableHttp ? "/" : "/sse"; @@ -37,34 +44,11 @@ public async Task MapMcp_ThrowsInvalidOperationException_IfWithHttpTransportIsNo Assert.StartsWith("You must call WithHttpTransport()", exception.Message); } - [Theory] - [InlineData("/mcp")] - [InlineData("/mcp/secondary")] - public async Task Allows_Customizing_Route(string pattern) - { - Builder.Services.AddMcpServer().WithHttpTransport(); - await using var app = Builder.Build(); - - app.MapMcp(pattern); - - await app.StartAsync(TestContext.Current.CancellationToken); - - using var response = await HttpClient.GetAsync($"http://localhost{pattern}/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); - response.EnsureSuccessStatusCode(); - using var sseStream = await response.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); - using var sseStreamReader = new StreamReader(sseStream, System.Text.Encoding.UTF8); - var eventLine = await sseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); - var dataLine = await sseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); - Assert.NotNull(eventLine); - Assert.Equal("event: endpoint", eventLine); - Assert.NotNull(dataLine); - Assert.Equal($"data: {pattern}/message", dataLine[..dataLine.IndexOf('?')]); - } [Fact] public async Task Messages_FromNewUser_AreRejected() { - Builder.Services.AddMcpServer().WithHttpTransport().WithTools(); + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools(); // Add an authentication scheme that will send a 403 Forbidden response. Builder.Services.AddAuthentication().AddBearerToken(); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index b659ff172..7733c836d 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -17,7 +17,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper) { - private SseClientTransportOptions DefaultTransportOptions = new() + private readonly SseClientTransportOptions DefaultTransportOptions = new() { Endpoint = new Uri("http://localhost/sse"), Name = "In-memory Test Server", diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs new file mode 100644 index 000000000..03ceacd71 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs @@ -0,0 +1,16 @@ +using ModelContextProtocol.Protocol.Transport; +using System.Text; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) + : StreamableHttpServerIntegrationTests(fixture, testOutputHelper) + +{ + protected override SseClientTransportOptions ClientTransportOptions => new() + { + Endpoint = new Uri("http://localhost/stateless"), + Name = "TestServer", + UseStreamableHttp = true, + }; +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs new file mode 100644 index 000000000..06bd35f96 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -0,0 +1,69 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Protocol.Types; +using System.Net; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class StatelessServerTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable +{ + private WebApplication? _app; + + private async Task StartAsync() + { + Builder.Services.AddMcpServer(mcpServerOptions => + { + mcpServerOptions.ServerInfo = new Implementation + { + Name = nameof(StreamableHttpServerConformanceTests), + Version = "73", + }; + }).WithHttpTransport(httpServerTransportOptions => + { + httpServerTransportOptions.Stateless = true; + }); + + _app = Builder.Build(); + + _app.MapMcp(); + + await _app.StartAsync(TestContext.Current.CancellationToken); + + HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); + } + + public async ValueTask DisposeAsync() + { + if (_app is not null) + { + await _app.DisposeAsync(); + } + base.Dispose(); + } + + [Fact] + public async Task EnablingStatelessMode_Disables_SseEndpoints() + { + await StartAsync(); + + using var sseResponse = await HttpClient.GetAsync("/sse", TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotFound, sseResponse.StatusCode); + + using var messageResponse = await HttpClient.PostAsync("/message", new StringContent(""), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotFound, messageResponse.StatusCode); + } + + [Fact] + public async Task EnablingStatelessMode_Disables_GetAndDeleteEndpoints() + { + await StartAsync(); + + using var getResponse = await HttpClient.GetAsync("/", TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.MethodNotAllowed, getResponse.StatusCode); + + using var deleteResponse = await HttpClient.DeleteAsync("/", TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.MethodNotAllowed, deleteResponse.StatusCode); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs index 9e5ce6fa5..196b5b612 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs @@ -71,7 +71,6 @@ public async Task NegativeNonInfiniteIdleTimeout_Throws_ArgumentOutOfRangeExcept Assert.Contains("IdleTimeout", ex.Message); } - [Fact] public async Task NegativeMaxIdleSessionCount_Throws_ArgumentOutOfRangeException() { @@ -360,6 +359,47 @@ public async Task Progress_IsReported_InSameSseResponseAsRpcResponse() Assert.Equal(11, currentSseItem); } + [Fact] + public async Task AsyncLocalSetInRunSessionHandlerCallback_Flows_ToAllToolCalls() + { + var asyncLocal = new AsyncLocal(); + var totalSessionCount = 0; + + Builder.Services.AddMcpServer() + .WithHttpTransport(options => + { + options.RunSessionHandler = async (httpContext, mcpServer, cancellationToken) => + { + asyncLocal.Value = $"RunSessionHandler ({totalSessionCount++})"; + await mcpServer.RunAsync(cancellationToken); + }; + }); + + Builder.Services.AddSingleton(McpServerTool.Create([McpServerTool(Name = "async-local-session")] () => asyncLocal.Value)); + + await StartAsync(); + + var firstSessionId = await CallInitializeAndValidateAsync(); + + async Task CallAsyncLocalToolAndValidateAsync(int expectedSessionIndex) + { + var response = await HttpClient.PostAsync("", JsonContent(CallTool("async-local-session")), TestContext.Current.CancellationToken); + var rpcResponse = await AssertSingleSseResponseAsync(response); + var callToolResponse = AssertType(rpcResponse.Result); + var callToolContent = Assert.Single(callToolResponse.Content); + Assert.Equal("text", callToolContent.Type); + Assert.Equal($"RunSessionHandler ({expectedSessionIndex})", callToolContent.Text); + } + + await CallAsyncLocalToolAndValidateAsync(expectedSessionIndex: 0); + + await CallInitializeAndValidateAsync(); + await CallAsyncLocalToolAndValidateAsync(expectedSessionIndex: 1); + + SetSessionId(firstSessionId); + await CallAsyncLocalToolAndValidateAsync(expectedSessionIndex: 0); + } + [Fact] public async Task IdleSessions_ArePruned_AfterIdleTimeout() { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs index 9d3048929..3abb1aa31 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -1,5 +1,4 @@ using ModelContextProtocol.Protocol.Transport; -using System.Net; using System.Text; namespace ModelContextProtocol.AspNetCore.Tests; diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 72a271cf9..88124f9d5 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Utils.Json; using Serilog; +using System.Diagnostics; using System.Text; using System.Text.Json; @@ -378,6 +379,26 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st }; } + private static void HandleStatelessMcp(IApplicationBuilder app) + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddLogging(); + serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService()); + serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService()); + serviceCollection.AddRoutingCore(); + + serviceCollection.AddMcpServer(ConfigureOptions).WithHttpTransport(options => options.Stateless = true); + + var appBuilder = new ApplicationBuilder(serviceCollection.BuildServiceProvider()); + appBuilder.UseRouting(); + appBuilder.UseEndpoints(innerEndpoints => + { + innerEndpoints.MapMcp("/stateless"); + }); + + app.Run(appBuilder.Build()); + } + public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, IConnectionListenerFactory? kestrelTransport = null, CancellationToken cancellationToken = default) { Console.WriteLine("Starting server..."); @@ -419,6 +440,9 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide app.UseRouting(); app.UseEndpoints(_ => { }); + // Handle the /stateless endpoint if no other endpoints have been matched by the call to UseRouting above. + HandleStatelessMcp(app); + app.MapMcp(); await app.RunAsync(cancellationToken); From 8de5d4976c9313498f4c54adbda0bb3e73175d50 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 7 May 2025 14:41:06 -0700 Subject: [PATCH 2/5] Address PR feedback WIP --- src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs | 2 +- .../StreamableHttpHandler.cs | 8 ++++---- src/ModelContextProtocol/Server/McpServerOptions.cs | 6 +++--- tests/ModelContextProtocol.TestSseServer/Program.cs | 1 - 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs b/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs index 73c206e94..fd958e6ba 100644 --- a/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs +++ b/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs @@ -3,7 +3,7 @@ namespace ModelContextProtocol.AspNetCore; -internal class StatelessSessionId +internal sealed class StatelessSessionId { [JsonPropertyName("capabilities")] public ClientCapabilities? Capabilities { get; init; } diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 072ad8512..d917d4c4e 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -191,7 +191,7 @@ private async ValueTask> StartNewS } else { - // "(uninitialized stateless id)" is not written anywhere. We delay writing th mcp-session-id + // "(uninitialized stateless id)" is not written anywhere. We delay writing the mcp-session-id // until after we receive the initialize request with the client info we need to serialize. sessionId = "(uninitialized stateless id)"; ScheduleStatelessSessionIdWrite(context, transport); @@ -199,7 +199,7 @@ private async ValueTask> StartNewS var session = await CreateSessionAsync(context, transport, sessionId); - // The HttpMcpSession is not stored between requests in stateless mode. Instead the session is recreated from the mcp-session-id. + // The HttpMcpSession is not stored between requests in stateless mode. Instead, the session is recreated from the mcp-session-id. if (!HttpServerTransportOptions.Stateless) { if (!Sessions.TryAdd(sessionId, session)) @@ -224,8 +224,8 @@ private async ValueTask> CreateSes if (statelessId is not null) { - mcpServerOptions.KnownClientInfo = statelessId.ClientInfo; - mcpServerOptions.KnownClientCapabilities = statelessId.Capabilities; + mcpServerOptions.KnownClientInfo = statelessId.ClientInfo ?? mcpServerOptions.KnownClientInfo; + mcpServerOptions.KnownClientCapabilities = statelessId.Capabilities ?? mcpServerOptions.KnownClientCapabilities; } if (HttpServerTransportOptions.ConfigureSessionOptions is { } configureSessionOptions) diff --git a/src/ModelContextProtocol/Server/McpServerOptions.cs b/src/ModelContextProtocol/Server/McpServerOptions.cs index 4c820a496..dbc1aec93 100644 --- a/src/ModelContextProtocol/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol/Server/McpServerOptions.cs @@ -72,18 +72,18 @@ public class McpServerOptions /// /// /// - /// When not specified, this information sourced from the client's initialize request. + /// When not specified, this information is sourced from the client's initialize request. /// /// public Implementation? KnownClientInfo { get; set; } /// - /// Gets or sets preexisting knowledge about the client client capabilities to help support + /// Gets or sets preexisting knowledge about the client's capabilities to help support /// stateless Streamable HTTP servers that encode this knowledge in the mcp-session-id header. /// /// /// - /// When not specified, this information sourced from the client's initialize request. + /// When not specified, this information is sourced from the client's initialize request. /// /// public ClientCapabilities? KnownClientCapabilities { get; set; } diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 88124f9d5..5a530478c 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -3,7 +3,6 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Utils.Json; using Serilog; -using System.Diagnostics; using System.Text; using System.Text.Json; From 7392032d816f5aad1b45e5fc7755dae26600e4f1 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Mon, 12 May 2025 19:36:41 -0700 Subject: [PATCH 3/5] Address PR feedback - Don't store client capabilities in stateless session ID since sampling and roots cannot be supported in stateless mode - Immediate throw for unsupported operations in stateless mode - Improve HttpTransportOptions doc comments --- .../HttpServerTransportOptions.cs | 29 ++- .../IdleTrackingBackgroundService.cs | 2 +- .../StatelessSessionId.cs | 3 - .../StreamableHttpHandler.cs | 32 ++-- .../Transport/StreamableHttpPostTransport.cs | 11 +- .../StreamableHttpServerTransport.cs | 22 ++- src/ModelContextProtocol/Server/McpServer.cs | 38 ++-- .../Server/McpServerExtensions.cs | 48 +++-- .../Server/McpServerOptions.cs | 11 -- .../MapMcpSseTests.cs | 34 +--- .../MapMcpStreamableHttpTests.cs | 1 + .../MapMcpTests.cs | 38 +++- .../SseIntegrationTests.cs | 20 +-- .../SseServerIntegrationTests.cs | 2 +- .../StatelessServerIntegrationTests.cs | 2 +- .../StatelessServerTests.cs | 165 +++++++++++++++++- .../StreamableHttpClientConformanceTests.cs | 2 +- .../StreamableHttpServerIntegrationTests.cs | 2 +- .../Program.cs | 4 +- .../Server/McpServerTests.cs | 3 +- 20 files changed, 330 insertions(+), 139 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index df83ff6d8..7741193ea 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -24,25 +24,36 @@ public class HttpServerTransportOptions /// /// Gets or sets whether the server should run in a stateless mode that does not require all requests for a given session - /// to arrive to the same ASP.NET Core application process. If true, the /sse endpoint will be disabled, and - /// client capabilities will be round-tripped as part of the mcp-session-id header instead of stored in memory. Defaults to false. + /// to arrive to the same ASP.NET Core application process. /// + /// + /// If , the "/sse" endpoint will be disabled, and client information will be round-tripped as part + /// of the "mcp-session-id" header instead of stored in memory. Unsolicited server-to-client messages and all server-to-client + /// requests are also unsupported, because any responses may arrive at another ASP.NET Core application process. + /// Client sampling and roots capabilities are also disabled in stateless mode, because the server cannot make requests. + /// Defaults to . + /// public bool Stateless { get; set; } /// - /// Represents the duration of time the server will wait between any active requests before timing out an - /// MCP session. This is checked in background every 5 seconds. A client trying to resume a session will - /// receive a 404 status code and should restart their session. A client can keep their session open by - /// keeping a GET request open. The default value is set to 2 hours. + /// Gets or sets the duration of time the server will wait between any active requests before timing out an MCP session. /// + /// + /// This is checked in background every 5 seconds. A client trying to resume a session will receive a 404 status code + /// and should restart their session. A client can keep their session open by keeping a GET request open. + /// Defaults to 2 hours. + /// public TimeSpan IdleTimeout { get; set; } = TimeSpan.FromHours(2); /// - /// The maximum number of idle sessions to track. This is used to limit the number of sessions that can be idle at once. + /// Gets or sets maximum number of idle sessions to track in memory. This is used to limit the number of sessions that can be idle at once. + /// + /// /// Past this limit, the server will log a critical error and terminate the oldest idle sessions even if they have not reached /// their until the idle session count is below this limit. Clients that keep their session open by - /// keeping a GET request open will not count towards this limit. The default value is set to 100,000 sessions. - /// + /// keeping a GET request open will not count towards this limit. + /// Defaults to 100,000 sessions. + /// public int MaxIdleSessionCount { get; set; } = 100_000; /// diff --git a/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs b/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs index d7c57735a..bb50c91c9 100644 --- a/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs +++ b/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs @@ -12,7 +12,7 @@ internal sealed partial class IdleTrackingBackgroundService( ILogger logger) : BackgroundService { // The compiler will complain about the parameter being unused otherwise despite the source generator. - private ILogger _logger = logger; + private readonly ILogger _logger = logger; protected override async Task ExecuteAsync(CancellationToken stoppingToken) { diff --git a/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs b/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs index fd958e6ba..7d8284c6d 100644 --- a/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs +++ b/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs @@ -5,9 +5,6 @@ namespace ModelContextProtocol.AspNetCore; internal sealed class StatelessSessionId { - [JsonPropertyName("capabilities")] - public ClientCapabilities? Capabilities { get; init; } - [JsonPropertyName("clientInfo")] public Implementation? ClientInfo { get; init; } diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index d917d4c4e..068288544 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -7,6 +7,7 @@ using Microsoft.Net.Http.Headers; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Utils.Json; using System.Collections.Concurrent; @@ -27,8 +28,6 @@ internal sealed class StreamableHttpHandler( ILoggerFactory loggerFactory, IServiceProvider applicationServices) { - private const string StatelessSessionIdPurpose = "Microsoft.AspNetCore.StreamableHttpHandler.StatelessSessionId"; - private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); private static readonly MediaTypeHeaderValue s_applicationJsonMediaType = new("application/json"); @@ -38,7 +37,7 @@ internal sealed class StreamableHttpHandler( public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value; - private IDataProtector Protector { get; } = dataProtection.CreateProtector(StatelessSessionIdPurpose); + private IDataProtector Protector { get; } = dataProtection.CreateProtector("Microsoft.AspNetCore.StreamableHttpHandler.StatelessSessionId"); public async Task HandlePostRequestAsync(HttpContext context) { @@ -139,7 +138,10 @@ public async Task HandleDeleteRequestAsync(HttpContext context) { var sessionJson = Protector.Unprotect(sessionId); var statelessSessionId = JsonSerializer.Deserialize(sessionJson, StatelessSessionIdJsonContext.Default.StatelessSessionId); - var transport = new StreamableHttpServerTransport(); + var transport = new StreamableHttpServerTransport + { + Stateless = true, + }; session = await CreateSessionAsync(context, transport, sessionId, statelessSessionId); } else if (!Sessions.TryGetValue(sessionId, out session)) @@ -148,7 +150,7 @@ public async Task HandleDeleteRequestAsync(HttpContext context) // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this // JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound // https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields - await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, 32001); + await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, -32001); return null; } @@ -182,11 +184,12 @@ await WriteJsonRpcErrorAsync(context, private async ValueTask> StartNewSessionAsync(HttpContext context) { string sessionId; - var transport = new StreamableHttpServerTransport(); + StreamableHttpServerTransport transport; if (!HttpServerTransportOptions.Stateless) { sessionId = MakeNewSessionId(); + transport = new(); context.Response.Headers["mcp-session-id"] = sessionId; } else @@ -194,6 +197,10 @@ private async ValueTask> StartNewS // "(uninitialized stateless id)" is not written anywhere. We delay writing the mcp-session-id // until after we receive the initialize request with the client info we need to serialize. sessionId = "(uninitialized stateless id)"; + transport = new() + { + Stateless = true, + }; ScheduleStatelessSessionIdWrite(context, transport); } @@ -217,6 +224,7 @@ private async ValueTask> CreateSes string sessionId, StatelessSessionId? statelessId = null) { + var mcpServerServices = applicationServices; var mcpServerOptions = mcpServerOptionsSnapshot.Value; if (statelessId is not null || HttpServerTransportOptions.ConfigureSessionOptions is not null) { @@ -224,8 +232,10 @@ private async ValueTask> CreateSes if (statelessId is not null) { - mcpServerOptions.KnownClientInfo = statelessId.ClientInfo ?? mcpServerOptions.KnownClientInfo; - mcpServerOptions.KnownClientCapabilities = statelessId.Capabilities ?? mcpServerOptions.KnownClientCapabilities; + // The session does not outlive the request in stateless mode. + mcpServerServices = context.RequestServices; + mcpServerOptions.ScopeRequests = false; + mcpServerOptions.KnownClientInfo = statelessId.ClientInfo; } if (HttpServerTransportOptions.ConfigureSessionOptions is { } configureSessionOptions) @@ -234,8 +244,7 @@ private async ValueTask> CreateSes } } - // Use application instead of request services, because the session will likely outlive the first initialization request. - var server = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, applicationServices); + var server = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, mcpServerServices); context.Features.Set(server); var userIdClaim = statelessId?.UserIdClaim ?? GetUserIdClaim(context.User); @@ -286,8 +295,7 @@ private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttp { var statelessId = new StatelessSessionId { - ClientInfo = transport.ClientInfo, - Capabilities = transport.ClientCapabilities, + ClientInfo = transport?.InitializeRequest?.ClientInfo, UserIdClaim = GetUserIdClaim(context.User), }; diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs index e3cdb4040..ce6e33ca1 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs @@ -42,6 +42,11 @@ public async ValueTask RunAsync(CancellationToken cancellationToken) public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { + if (parentTransport.Stateless && message is JsonRpcRequest) + { + throw new InvalidOperationException("Server to client requests are not supported in stateless mode."); + } + await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); } @@ -76,11 +81,9 @@ private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, Cancella _pendingRequest = request.Id; // Store client capabilities so they can be serialized by "stateless" callers for use in later requests. - if (request.Method == RequestMethods.Initialize) + if (parentTransport.Stateless && request.Method == RequestMethods.Initialize) { - var initializeRequestParams = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); - parentTransport.ClientCapabilities = initializeRequestParams?.Capabilities; - parentTransport.ClientInfo = initializeRequestParams?.ClientInfo; + parentTransport.InitializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs index 9e8cb1d6b..9ec804536 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpServerTransport.cs @@ -38,14 +38,18 @@ public sealed class StreamableHttpServerTransport : ITransport private int _getRequestStarted; /// - /// Gets the capabilities supported by the client if it was received by . + /// Configures whether the transport should be in stateless mode that does not require all requests for a given session + /// to arrive to the same ASP.NET Core application process. Unsolicited server-to-client messages are not supported in this mode, + /// so calling results in an . + /// Server-to-client requests are also unsupported, because the responses may arrive at another ASP.NET Core application process. + /// Client sampling and roots capabilities are also disabled in stateless mode, because the server cannot make requests. /// - public ClientCapabilities? ClientCapabilities { get; internal set; } + public bool Stateless { get; init; } /// - /// Gets the version and implementation information of the connected client if it was received by . + /// Gets the initialize request if it was received by and is set to . /// - public Implementation? ClientInfo { get; internal set; } + public InitializeRequestParams? InitializeRequest { get; internal set; } /// public ChannelReader MessageReader => _incomingChannel.Reader; @@ -62,6 +66,11 @@ public sealed class StreamableHttpServerTransport : ITransport /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken cancellationToken) { + if (Stateless) + { + throw new InvalidOperationException("GET requests are not supported in stateless mode."); + } + if (Interlocked.Exchange(ref _getRequestStarted, 1) == 1) { throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session."); @@ -93,6 +102,11 @@ public async Task HandlePostRequest(IDuplexPipe httpBodies, CancellationTo /// public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { + if (Stateless) + { + throw new InvalidOperationException("Unsolicited server to client messages are not supported in stateless mode."); + } + await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); } diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 619cfe6fb..d37fcf6f2 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -58,7 +58,6 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? _serverOnlyEndpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; _servicesScopePerRequest = options.ScopeRequests; - ClientCapabilities = options.KnownClientCapabilities; ClientInfo = options.KnownClientInfo; UpdateEndpointNameWithClientInfo(); @@ -80,26 +79,29 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? } // Now that everything has been configured, subscribe to any necessary notifications. - if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) + if (transport is not StreamableHttpServerTransport streamableHttpTransport || streamableHttpTransport.Stateless is false) { - EventHandler changed = (sender, e) => _ = this.SendNotificationAsync(NotificationMethods.ToolListChangedNotification); - tools.Changed += changed; - _disposables.Add(() => tools.Changed -= changed); - } + if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) + { + EventHandler changed = (sender, e) => _ = this.SendNotificationAsync(NotificationMethods.ToolListChangedNotification); + tools.Changed += changed; + _disposables.Add(() => tools.Changed -= changed); + } - if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts) - { - EventHandler changed = (sender, e) => _ = this.SendNotificationAsync(NotificationMethods.PromptListChangedNotification); - prompts.Changed += changed; - _disposables.Add(() => prompts.Changed -= changed); - } + if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts) + { + EventHandler changed = (sender, e) => _ = this.SendNotificationAsync(NotificationMethods.PromptListChangedNotification); + prompts.Changed += changed; + _disposables.Add(() => prompts.Changed -= changed); + } - var resources = ServerOptions.Capabilities?.Resources?.ResourceCollection; - if (resources is not null) - { - EventHandler changed = (sender, e) => _ = this.SendNotificationAsync(NotificationMethods.PromptListChangedNotification); - resources.Changed += changed; - _disposables.Add(() => resources.Changed -= changed); + var resources = ServerOptions.Capabilities?.Resources?.ResourceCollection; + if (resources is not null) + { + EventHandler changed = (sender, e) => _ = this.SendNotificationAsync(NotificationMethods.PromptListChangedNotification); + resources.Changed += changed; + _disposables.Add(() => resources.Changed -= changed); + } } // And initialize the session. diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 9450517c8..be43f6f7a 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -30,14 +30,10 @@ public static class McpServerExtensions /// and token limits. /// public static ValueTask RequestSamplingAsync( - this IMcpServer server, CreateMessageRequestParams request, CancellationToken cancellationToken) + this IMcpServer server, CreateMessageRequestParams request, CancellationToken cancellationToken = default) { Throw.IfNull(server); - - if (server.ClientCapabilities?.Sampling is null) - { - throw new InvalidOperationException("Client does not support sampling."); - } + ThrowIfSamplingUnsupported(server); return server.SendRequestAsync( RequestMethods.SamplingCreateMessage, @@ -163,11 +159,7 @@ public static async Task RequestSamplingAsync( public static IChatClient AsSamplingChatClient(this IMcpServer server) { Throw.IfNull(server); - - if (server.ClientCapabilities?.Sampling is null) - { - throw new InvalidOperationException("Client does not support sampling."); - } + ThrowIfSamplingUnsupported(server); return new SamplingChatClient(server); } @@ -198,14 +190,10 @@ public static ILoggerProvider AsClientLoggerProvider(this IMcpServer server) /// or other structured data sources that the client makes available through the protocol. /// public static ValueTask RequestRootsAsync( - this IMcpServer server, ListRootsRequestParams request, CancellationToken cancellationToken) + this IMcpServer server, ListRootsRequestParams request, CancellationToken cancellationToken = default) { Throw.IfNull(server); - - if (server.ClientCapabilities?.Roots is null) - { - throw new InvalidOperationException("Client does not support roots."); - } + ThrowIfRootsUnsupported(server); return server.SendRequestAsync( RequestMethods.RootsList, @@ -215,6 +203,32 @@ public static ValueTask RequestRootsAsync( cancellationToken: cancellationToken); } + private static void ThrowIfSamplingUnsupported(IMcpServer server) + { + if (server.ClientCapabilities?.Sampling is null) + { + if (server.ServerOptions.KnownClientInfo is not null) + { + throw new InvalidOperationException("Sampling is not supported in stateless mode."); + } + + throw new InvalidOperationException("Client does not support sampling."); + } + } + + private static void ThrowIfRootsUnsupported(IMcpServer server) + { + if (server.ClientCapabilities?.Roots is null) + { + if (server.ServerOptions.KnownClientInfo is not null) + { + throw new InvalidOperationException("Roots are not supported in stateless mode."); + } + + throw new InvalidOperationException("Client does not support roots."); + } + } + /// Provides an implementation that's implemented via client sampling. private sealed class SamplingChatClient(IMcpServer server) : IChatClient { diff --git a/src/ModelContextProtocol/Server/McpServerOptions.cs b/src/ModelContextProtocol/Server/McpServerOptions.cs index dbc1aec93..bae26ca78 100644 --- a/src/ModelContextProtocol/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol/Server/McpServerOptions.cs @@ -76,15 +76,4 @@ public class McpServerOptions /// /// public Implementation? KnownClientInfo { get; set; } - - /// - /// Gets or sets preexisting knowledge about the client's capabilities to help support - /// stateless Streamable HTTP servers that encode this knowledge in the mcp-session-id header. - /// - /// - /// - /// When not specified, this information is sourced from the client's initialize request. - /// - /// - public ClientCapabilities? KnownClientCapabilities { get; set; } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs index 1d4917bae..602aa0c3c 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs @@ -7,6 +7,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; public class MapMcpSseTests(ITestOutputHelper outputHelper) : MapMcpTests(outputHelper) { protected override bool UseStreamableHttp => false; + protected override bool Stateless => false; [Theory] [InlineData("/mcp")] @@ -56,37 +57,4 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name); } - - [Fact] - public async Task Can_UseHttpContextAccessor_InTool() - { - Builder.Services.AddMcpServer().WithHttpTransport().WithTools(); - - Builder.Services.AddHttpContextAccessor(); - - await using var app = Builder.Build(); - - app.Use(next => - { - return async context => - { - context.User = CreateUser("TestUser"); - await next(context); - }; - }); - - app.MapMcp(); - - await app.StartAsync(TestContext.Current.CancellationToken); - - var mcpClient = await ConnectAsync(); - - var response = await mcpClient.CallToolAsync( - "EchoWithUserName", - new Dictionary() { ["message"] = "Hello world!" }, - cancellationToken: TestContext.Current.CancellationToken); - - var content = Assert.Single(response.Content); - Assert.Equal("TestUser: Hello world!", content.Text); - } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index 0b2f68bbf..c987bca90 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -6,6 +6,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; public class MapMcpStreamableHttpTests(ITestOutputHelper outputHelper) : MapMcpTests(outputHelper) { protected override bool UseStreamableHttp => true; + protected override bool Stateless => false; [Theory] [InlineData("/a", "/a")] diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index dd6540716..89bbcb025 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -14,8 +14,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; public abstract class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) { protected abstract bool UseStreamableHttp { get; } - - protected virtual bool Stateless => false; + protected abstract bool Stateless { get; } protected void ConfigureStateless(HttpServerTransportOptions options) { @@ -44,6 +43,41 @@ public async Task MapMcp_ThrowsInvalidOperationException_IfWithHttpTransportIsNo Assert.StartsWith("You must call WithHttpTransport()", exception.Message); } + [Fact] + public async Task Can_UseIHttpContextAccessor_InTool() + { + Assert.SkipWhen(UseStreamableHttp, "IHttpContextAccessor is not currently supported with Streamable HTTP." + + "TODO: Support it in stateless mode by manually capturing and flowing execution context."); + + Builder.Services.AddMcpServer().WithHttpTransport().WithTools(); + + Builder.Services.AddHttpContextAccessor(); + + await using var app = Builder.Build(); + + app.Use(next => + { + return async context => + { + context.User = CreateUser("TestUser"); + await next(context); + }; + }); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var mcpClient = await ConnectAsync(); + + var response = await mcpClient.CallToolAsync( + "EchoWithUserName", + new Dictionary() { ["message"] = "Hello world!" }, + cancellationToken: TestContext.Current.CancellationToken); + + var content = Assert.Single(response.Content); + Assert.Equal("TestUser: Hello world!", content.Text); + } [Fact] public async Task Messages_FromNewUser_AreRejected() diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 7733c836d..340cfdac2 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -20,10 +20,10 @@ public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : Kestr private readonly SseClientTransportOptions DefaultTransportOptions = new() { Endpoint = new Uri("http://localhost/sse"), - Name = "In-memory Test Server", + Name = "In-memory SSE Client", }; - private Task ConnectMcpClient(HttpClient? httpClient = null, SseClientTransportOptions? transportOptions = null) + private Task ConnectMcpClientAsync(HttpClient? httpClient = null, SseClientTransportOptions? transportOptions = null) => McpClientFactory.CreateAsync( new SseClientTransport(transportOptions ?? DefaultTransportOptions, httpClient ?? HttpClient, LoggerFactory), loggerFactory: LoggerFactory, @@ -37,7 +37,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectMcpClient(); + await using var mcpClient = await ConnectMcpClientAsync(); // Send a test message through POST endpoint await mcpClient.SendNotificationAsync("test/message", new Envelope { Message = "Hello, SSE!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: TestContext.Current.CancellationToken); @@ -52,7 +52,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU MapAbsoluteEndpointUriMcp(app); await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectMcpClient(); + await using var mcpClient = await ConnectMcpClientAsync(); // Send a test message through POST endpoint await mcpClient.SendNotificationAsync("test/message", new Envelope { Message = "Hello, SSE!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: TestContext.Current.CancellationToken); @@ -84,7 +84,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectMcpClient(); + await using var mcpClient = await ConnectMcpClientAsync(); mcpClient.RegisterNotificationHandler("test/notification", (args, ca) => { @@ -124,7 +124,7 @@ public async Task AddMcpServer_CanBeCalled_MultipleTimes() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectMcpClient(); + await using var mcpClient = await ConnectMcpClientAsync(); // Options can be lazily initialized, but they must be instantiated by the time an MCP client can finish connecting. // Callbacks can be called multiple times if configureOptionsAsync is configured, because that uses the IOptionsFactory, @@ -184,14 +184,14 @@ public async Task AdditionalHeaders_AreSent_InGetAndPostRequests() var sseOptions = new SseClientTransportOptions() { Endpoint = new Uri("http://localhost/sse"), - Name = "In-memory Test Server", + Name = "In-memory SSE Client", AdditionalHeaders = new() { ["Authorize"] = "Bearer testToken" }, }; - await using var mcpClient = await ConnectMcpClient(transportOptions: sseOptions); + await using var mcpClient = await ConnectMcpClientAsync(transportOptions: sseOptions); Assert.True(wasGetRequest); Assert.True(wasPostRequest); @@ -211,14 +211,14 @@ public async Task EmptyAdditionalHeadersKey_Throws_InvalidOperationException() var sseOptions = new SseClientTransportOptions() { Endpoint = new Uri("http://localhost/sse"), - Name = "In-memory Test Server", + Name = "In-memory SSE Client", AdditionalHeaders = new() { [""] = "" }, }; - var ex = await Assert.ThrowsAsync(() => ConnectMcpClient(transportOptions: sseOptions)); + var ex = await Assert.ThrowsAsync(() => ConnectMcpClientAsync(transportOptions: sseOptions)); Assert.Equal("Failed to add header '' with value '' from AdditionalHeaders.", ex.Message); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs index ee1834a67..7fbea2112 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs @@ -11,7 +11,7 @@ public class SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, protected override SseClientTransportOptions ClientTransportOptions => new() { Endpoint = new Uri("http://localhost/sse"), - Name = "TestServer", + Name = "In-memory SSE Client", }; [Fact] diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs index 03ceacd71..21b9cac6e 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs @@ -10,7 +10,7 @@ public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fix protected override SseClientTransportOptions ClientTransportOptions => new() { Endpoint = new Uri("http://localhost/stateless"), - Name = "TestServer", + Name = "In-memory Streamable HTTP Client", UseStreamableHttp = true, }; } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs index 06bd35f96..5b398b9de 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -1,31 +1,57 @@ using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using System.Diagnostics; using System.Net; namespace ModelContextProtocol.AspNetCore.Tests; +[McpServerToolType] public class StatelessServerTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable { private WebApplication? _app; + private readonly SseClientTransportOptions DefaultTransportOptions = new() + { + Endpoint = new Uri("http://localhost/"), + Name = "In-memory Streamable HTTP Client", + UseStreamableHttp = true, + }; + private async Task StartAsync() { Builder.Services.AddMcpServer(mcpServerOptions => + { + mcpServerOptions.ServerInfo = new Implementation + { + Name = nameof(StreamableHttpServerConformanceTests), + Version = "73", + }; + }) + .WithHttpTransport(httpServerTransportOptions => + { + httpServerTransportOptions.Stateless = true; + }) + .WithTools(); + + Builder.Services.AddScoped(); + + _app = Builder.Build(); + + _app.Use(next => { - mcpServerOptions.ServerInfo = new Implementation + return context => { - Name = nameof(StreamableHttpServerConformanceTests), - Version = "73", + context.RequestServices.GetRequiredService().State = "From request middleware!"; + return next(context); }; - }).WithHttpTransport(httpServerTransportOptions => - { - httpServerTransportOptions.Stateless = true; }); - _app = Builder.Build(); - _app.MapMcp(); await _app.StartAsync(TestContext.Current.CancellationToken); @@ -34,6 +60,11 @@ private async Task StartAsync() HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); } + private Task ConnectMcpClientAsync(McpClientOptions? clientOptions = null) + => McpClientFactory.CreateAsync( + new SseClientTransport(DefaultTransportOptions, HttpClient, LoggerFactory), + clientOptions, LoggerFactory, TestContext.Current.CancellationToken); + public async ValueTask DisposeAsync() { if (_app is not null) @@ -66,4 +97,122 @@ public async Task EnablingStatelessMode_Disables_GetAndDeleteEndpoints() using var deleteResponse = await HttpClient.DeleteAsync("/", TestContext.Current.CancellationToken); Assert.Equal(HttpStatusCode.MethodNotAllowed, deleteResponse.StatusCode); } + + [Fact] + public async Task SamplingRequest_Fails_WithInvalidOperationException() + { + await StartAsync(); + + var mcpClientOptions = new McpClientOptions(); + mcpClientOptions.Capabilities = new(); + mcpClientOptions.Capabilities.Sampling ??= new(); + mcpClientOptions.Capabilities.Sampling.SamplingHandler = (_, _, _) => + { + throw new UnreachableException(); + }; + + await using var client = await ConnectMcpClientAsync(mcpClientOptions); + + var toolResponse = await client.CallToolAsync("testSamplingErrors", cancellationToken: TestContext.Current.CancellationToken); + var toolContent = Assert.Single(toolResponse.Content); + Assert.Equal("Server to client requests are not supported in stateless mode.", toolContent.Text); + } + + [Fact] + public async Task RootsRequest_Fails_WithInvalidOperationException() + { + await StartAsync(); + + var mcpClientOptions = new McpClientOptions(); + mcpClientOptions.Capabilities = new(); + mcpClientOptions.Capabilities.Roots ??= new(); + mcpClientOptions.Capabilities.Roots.RootsHandler = (_, _) => + { + throw new UnreachableException(); + }; + + await using var client = await ConnectMcpClientAsync(mcpClientOptions); + + var toolResponse = await client.CallToolAsync("testRootsErrors", cancellationToken: TestContext.Current.CancellationToken); + var toolContent = Assert.Single(toolResponse.Content); + Assert.Equal("Server to client requests are not supported in stateless mode.", toolContent.Text); + } + + [Fact] + public async Task UnsolicitedNotification_Fails_WithInvalidOperationException() + { + InvalidOperationException? unsolicitedNotificationException = null; + + Builder.Services.AddMcpServer() + .WithHttpTransport(options => + { + options.RunSessionHandler = async (context, server, cancellationToken) => + { + unsolicitedNotificationException = await Assert.ThrowsAsync( + () => server.SendNotificationAsync(NotificationMethods.PromptListChangedNotification, TestContext.Current.CancellationToken)); + + await server.RunAsync(cancellationToken); + }; + }); + + await StartAsync(); + + await using var client = await ConnectMcpClientAsync(); + + Assert.NotNull(unsolicitedNotificationException); + Assert.Equal("Unsolicited server to client messages are not supported in stateless mode.", unsolicitedNotificationException.Message); + } + + [Fact] + public async Task ScopedServices_Resolve_FromRequestScope() + { + await StartAsync(); + + await using var client = await ConnectMcpClientAsync(); + + var toolResponse = await client.CallToolAsync("testScope", cancellationToken: TestContext.Current.CancellationToken); + var toolContent = Assert.Single(toolResponse.Content); + Assert.Equal("From request middleware!", toolContent.Text); + } + + [McpServerTool(Name = "testSamplingErrors")] + public static async Task TestSamplingErrors(IMcpServer server) + { + const string expectedSamplingErrorMessage = "Sampling is not supported in stateless mode."; + + // Even when the client has sampling support, it should not be advertised in stateless mode. + Assert.Null(server.ClientCapabilities); + + var asSamplingChatClientEx = Assert.Throws(() => server.AsSamplingChatClient()); + Assert.Equal(expectedSamplingErrorMessage, asSamplingChatClientEx.Message); + + var requestSamplingEx = await Assert.ThrowsAsync(() => server.RequestSamplingAsync([])); + Assert.Equal(expectedSamplingErrorMessage, requestSamplingEx.Message); + + var ex = await Assert.ThrowsAsync(() => server.SendRequestAsync(new JsonRpcRequest { Method = RequestMethods.SamplingCreateMessage })); + return ex.Message; + } + + [McpServerTool(Name = "testRootsErrors")] + public static async Task TestRootsErrors(IMcpServer server) + { + const string expectedRootsErrorMessage = "Roots are not supported in stateless mode."; + + // Even when the client has roots support, it should not be advertised in stateless mode. + Assert.Null(server.ClientCapabilities); + + var requestRootsEx = Assert.Throws(() => server.RequestRootsAsync(new())); + Assert.Equal(expectedRootsErrorMessage, requestRootsEx.Message); + + var ex = await Assert.ThrowsAsync(() => server.SendRequestAsync(new JsonRpcRequest { Method = RequestMethods.RootsList })); + return ex.Message; + } + + [McpServerTool(Name = "testScope")] + public static string? TestScope(ScopedService scopedService) => scopedService.State; + + public class ScopedService + { + public string? State { get; set; } + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index 5f126227c..2d07577b8 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -31,7 +31,7 @@ private async Task StartAsync() Services = _app.Services, }); - _app.MapPost("/mcp", async (JsonRpcMessage message) => + _app.MapPost("/mcp", (JsonRpcMessage message) => { if (message is not JsonRpcRequest request) { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs index 3abb1aa31..bec382599 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -14,7 +14,7 @@ public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixtur protected override SseClientTransportOptions ClientTransportOptions => new() { Endpoint = new Uri("http://localhost/"), - Name = "TestServer", + Name = "In-memory Streamable HTTP Client", UseStreamableHttp = true, }; diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 5a530478c..bfd85687e 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Utils.Json; using Serilog; +using System.Diagnostics; using System.Text; using System.Text.Json; @@ -16,12 +17,11 @@ private static void ConfigureSerilog(ILoggingBuilder loggingBuilder) { Log.Logger = new LoggerConfiguration() .MinimumLevel.Verbose() // Capture all log levels - .WriteTo.File(Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "logs", "TestServer_.log"), + .WriteTo.File(Path.Combine(AppContext.BaseDirectory, "logs", "TestServer_.log"), rollingInterval: RollingInterval.Day, outputTemplate: "{Timestamp:yyyy-MM-dd HH:mm:ss.fff zzz} [{Level:u3}] {Message:lj}{NewLine}{Exception}") .CreateLogger(); - var logsPath = Path.Combine(AppContext.BaseDirectory, "testserver.log"); loggingBuilder.AddSerilog(); } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 662563520..446154189 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -581,6 +581,8 @@ private sealed class TestServerForIChatClient(bool supportsSampling) : IMcpServe supportsSampling ? new ClientCapabilities { Sampling = new SamplingCapability() } : null; + public McpServerOptions ServerOptions => new(); + public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) { CreateMessageRequestParams? rp = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.DefaultOptions); @@ -617,7 +619,6 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati public ValueTask DisposeAsync() => default; public Implementation? ClientInfo => throw new NotImplementedException(); - public McpServerOptions ServerOptions => throw new NotImplementedException(); public IServiceProvider? Services => throw new NotImplementedException(); public LoggingLevel? LoggingLevel => throw new NotImplementedException(); public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) => From 876bdbb4ffceff4dc1ce7a6acb560f5a31126249 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 14 May 2025 04:35:08 -0700 Subject: [PATCH 4/5] Run Can_UseIHttpContextAccessor_InTool in stateless mode --- src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs | 4 ++-- .../StatelessSessionId.cs | 2 +- src/ModelContextProtocol.AspNetCore/StatelessUserId.cs | 3 +++ .../StreamableHttpHandler.cs | 4 ++-- .../ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs | 9 ++++++--- 5 files changed, 14 insertions(+), 8 deletions(-) create mode 100644 src/ModelContextProtocol.AspNetCore/StatelessUserId.cs diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs index 0903dda6e..174e0148e 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs @@ -7,7 +7,7 @@ namespace ModelContextProtocol.AspNetCore; internal sealed class HttpMcpSession( string sessionId, TTransport transport, - (string Type, string Value, string Issuer)? userIdClaim, + StatelessUserId? userId, TimeProvider timeProvider) : IAsyncDisposable where TTransport : ITransport { @@ -17,7 +17,7 @@ internal sealed class HttpMcpSession( public string Id { get; } = sessionId; public TTransport Transport { get; } = transport; - public (string Type, string Value, string Issuer)? UserIdClaim { get; } = userIdClaim; + public StatelessUserId? UserIdClaim { get; } = userId; public CancellationToken SessionClosed => _disposeCts.Token; diff --git a/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs b/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs index 7d8284c6d..c383b8be0 100644 --- a/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs +++ b/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs @@ -9,5 +9,5 @@ internal sealed class StatelessSessionId public Implementation? ClientInfo { get; init; } [JsonPropertyName("userIdClaim")] - public (string Type, string Value, string Issuer)? UserIdClaim { get; init; } + public StatelessUserId? UserIdClaim { get; init; } } diff --git a/src/ModelContextProtocol.AspNetCore/StatelessUserId.cs b/src/ModelContextProtocol.AspNetCore/StatelessUserId.cs new file mode 100644 index 000000000..a39d7dde5 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/StatelessUserId.cs @@ -0,0 +1,3 @@ +namespace ModelContextProtocol.AspNetCore; + +internal sealed record StatelessUserId(string Type, string value, string Issuer); diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 068288544..584ba8529 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -314,7 +314,7 @@ internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session // SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims. // However, we short-circuit unlike antiforgery since we expect to call this to verify MCP messages a lot more frequently than // verifying antiforgery tokens from posts. - internal static (string Type, string Value, string Issuer)? GetUserIdClaim(ClaimsPrincipal user) + internal static StatelessUserId? GetUserIdClaim(ClaimsPrincipal user) { if (user?.Identity?.IsAuthenticated != true) { @@ -325,7 +325,7 @@ internal static (string Type, string Value, string Issuer)? GetUserIdClaim(Claim if (claim is { } idClaim) { - return (idClaim.Type, idClaim.Value, idClaim.Issuer); + return new(idClaim.Type, idClaim.Value, idClaim.Issuer); } return null; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 89bbcb025..36ee4e8dc 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -46,10 +46,13 @@ public async Task MapMcp_ThrowsInvalidOperationException_IfWithHttpTransportIsNo [Fact] public async Task Can_UseIHttpContextAccessor_InTool() { - Assert.SkipWhen(UseStreamableHttp, "IHttpContextAccessor is not currently supported with Streamable HTTP." + - "TODO: Support it in stateless mode by manually capturing and flowing execution context."); + Assert.SkipWhen(UseStreamableHttp && !Stateless, + """ + IHttpContextAccessor is not currently supported with non-stateless Streamable HTTP. + TODO: Support it in stateless mode by manually capturing and flowing execution context. + """); - Builder.Services.AddMcpServer().WithHttpTransport().WithTools(); + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools(); Builder.Services.AddHttpContextAccessor(); From 68e5eb8f6dee4846b14d3a2c353a2bc026465830 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 14 May 2025 04:40:11 -0700 Subject: [PATCH 5/5] Add internal ModelContextProtocol.AspNetCore.Stateless namespace --- src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs | 7 ++++--- .../{ => Stateless}/StatelessSessionId.cs | 4 ++-- .../{ => Stateless}/StatelessSessionIdJsonContext.cs | 2 +- .../Stateless/UserIdClaim.cs | 3 +++ src/ModelContextProtocol.AspNetCore/StatelessUserId.cs | 3 --- .../StreamableHttpHandler.cs | 3 ++- 6 files changed, 12 insertions(+), 10 deletions(-) rename src/ModelContextProtocol.AspNetCore/{ => Stateless}/StatelessSessionId.cs (71%) rename src/ModelContextProtocol.AspNetCore/{ => Stateless}/StatelessSessionIdJsonContext.cs (76%) create mode 100644 src/ModelContextProtocol.AspNetCore/Stateless/UserIdClaim.cs delete mode 100644 src/ModelContextProtocol.AspNetCore/StatelessUserId.cs diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs index 174e0148e..836dcc50b 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs @@ -1,4 +1,5 @@ -using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.AspNetCore.Stateless; +using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Server; using System.Security.Claims; @@ -7,7 +8,7 @@ namespace ModelContextProtocol.AspNetCore; internal sealed class HttpMcpSession( string sessionId, TTransport transport, - StatelessUserId? userId, + UserIdClaim? userId, TimeProvider timeProvider) : IAsyncDisposable where TTransport : ITransport { @@ -17,7 +18,7 @@ internal sealed class HttpMcpSession( public string Id { get; } = sessionId; public TTransport Transport { get; } = transport; - public StatelessUserId? UserIdClaim { get; } = userId; + public UserIdClaim? UserIdClaim { get; } = userId; public CancellationToken SessionClosed => _disposeCts.Token; diff --git a/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs b/src/ModelContextProtocol.AspNetCore/Stateless/StatelessSessionId.cs similarity index 71% rename from src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs rename to src/ModelContextProtocol.AspNetCore/Stateless/StatelessSessionId.cs index c383b8be0..09eec87e6 100644 --- a/src/ModelContextProtocol.AspNetCore/StatelessSessionId.cs +++ b/src/ModelContextProtocol.AspNetCore/Stateless/StatelessSessionId.cs @@ -1,7 +1,7 @@ using ModelContextProtocol.Protocol.Types; using System.Text.Json.Serialization; -namespace ModelContextProtocol.AspNetCore; +namespace ModelContextProtocol.AspNetCore.Stateless; internal sealed class StatelessSessionId { @@ -9,5 +9,5 @@ internal sealed class StatelessSessionId public Implementation? ClientInfo { get; init; } [JsonPropertyName("userIdClaim")] - public StatelessUserId? UserIdClaim { get; init; } + public UserIdClaim? UserIdClaim { get; init; } } diff --git a/src/ModelContextProtocol.AspNetCore/StatelessSessionIdJsonContext.cs b/src/ModelContextProtocol.AspNetCore/Stateless/StatelessSessionIdJsonContext.cs similarity index 76% rename from src/ModelContextProtocol.AspNetCore/StatelessSessionIdJsonContext.cs rename to src/ModelContextProtocol.AspNetCore/Stateless/StatelessSessionIdJsonContext.cs index 2690a3b15..6963ed609 100644 --- a/src/ModelContextProtocol.AspNetCore/StatelessSessionIdJsonContext.cs +++ b/src/ModelContextProtocol.AspNetCore/Stateless/StatelessSessionIdJsonContext.cs @@ -1,6 +1,6 @@ using System.Text.Json.Serialization; -namespace ModelContextProtocol.AspNetCore; +namespace ModelContextProtocol.AspNetCore.Stateless; [JsonSerializable(typeof(StatelessSessionId))] internal sealed partial class StatelessSessionIdJsonContext : JsonSerializerContext; diff --git a/src/ModelContextProtocol.AspNetCore/Stateless/UserIdClaim.cs b/src/ModelContextProtocol.AspNetCore/Stateless/UserIdClaim.cs new file mode 100644 index 000000000..f18c1c5ff --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Stateless/UserIdClaim.cs @@ -0,0 +1,3 @@ +namespace ModelContextProtocol.AspNetCore.Stateless; + +internal sealed record UserIdClaim(string Type, string Value, string Issuer); diff --git a/src/ModelContextProtocol.AspNetCore/StatelessUserId.cs b/src/ModelContextProtocol.AspNetCore/StatelessUserId.cs deleted file mode 100644 index a39d7dde5..000000000 --- a/src/ModelContextProtocol.AspNetCore/StatelessUserId.cs +++ /dev/null @@ -1,3 +0,0 @@ -namespace ModelContextProtocol.AspNetCore; - -internal sealed record StatelessUserId(string Type, string value, string Issuer); diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 584ba8529..86ea6fbfd 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -5,6 +5,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Net.Http.Headers; +using ModelContextProtocol.AspNetCore.Stateless; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; @@ -314,7 +315,7 @@ internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session // SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims. // However, we short-circuit unlike antiforgery since we expect to call this to verify MCP messages a lot more frequently than // verifying antiforgery tokens from posts. - internal static StatelessUserId? GetUserIdClaim(ClaimsPrincipal user) + internal static UserIdClaim? GetUserIdClaim(ClaimsPrincipal user) { if (user?.Identity?.IsAuthenticated != true) {