diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 04820dde3..64b10d6de 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -23,9 +23,9 @@ internal sealed class StreamableHttpHandler( ILoggerFactory loggerFactory, IServiceProvider applicationServices) { - private static JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); - private static MediaTypeHeaderValue ApplicationJsonMediaType = new("application/json"); - private static MediaTypeHeaderValue TextEventStreamMediaType = new("text/event-stream"); + private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); + private static readonly MediaTypeHeaderValue s_applicationJsonMediaType = new("application/json"); + private static readonly MediaTypeHeaderValue s_textEventStreamMediaType = new("text/event-stream"); public ConcurrentDictionary> Sessions { get; } = new(StringComparer.Ordinal); @@ -36,7 +36,7 @@ public async Task HandlePostRequestAsync(HttpContext context) // so we have to do this manually. The spec doesn't mandate that servers MUST reject these requests, // but it's probably good to at least start out trying to be strict. var acceptHeaders = context.Request.GetTypedHeaders().Accept; - if (!acceptHeaders.Contains(ApplicationJsonMediaType) || !acceptHeaders.Contains(TextEventStreamMediaType)) + if (!acceptHeaders.Contains(s_applicationJsonMediaType) || !acceptHeaders.Contains(s_textEventStreamMediaType)) { await WriteJsonRpcErrorAsync(context, "Not Acceptable: Client must accept both application/json and text/event-stream", @@ -64,7 +64,7 @@ await WriteJsonRpcErrorAsync(context, public async Task HandleGetRequestAsync(HttpContext context) { var acceptHeaders = context.Request.GetTypedHeaders().Accept; - if (!acceptHeaders.Contains(TextEventStreamMediaType)) + if (!acceptHeaders.Contains(s_textEventStreamMediaType)) { await WriteJsonRpcErrorAsync(context, "Not Acceptable: Client must accept text/event-stream", diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index 5d952f8a6..c083764a3 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -87,40 +87,17 @@ public override async Task SendMessageAsync( messageId = messageWithId.Id.ToString(); } - var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint) + using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint) { Content = content, }; - CopyAdditionalHeaders(httpRequestMessage.Headers); + StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders); var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); response.EnsureSuccessStatusCode(); var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); - // Check if the message was an initialize request - if (message is JsonRpcRequest request && request.Method == RequestMethods.Initialize) - { - // If the response is not a JSON-RPC response, it is an SSE message - if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) - { - LogAcceptedPost(Name, messageId); - // The response will arrive as an SSE message - } - else - { - JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse) ?? - throw new InvalidOperationException("Failed to initialize client"); - - LogTransportReceivedMessage(Name, messageId); - await WriteMessageAsync(initializeResponse, cancellationToken).ConfigureAwait(false); - LogTransportMessageWritten(Name, messageId); - } - - return; - } - - // Otherwise, check if the response was accepted (the response will come as an SSE message) if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) { LogAcceptedPost(Name, messageId); @@ -177,17 +154,13 @@ public override async ValueTask DisposeAsync() } } - internal Uri? MessageEndpoint => _messageEndpoint; - - internal SseClientTransportOptions Options => _options; - private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) { try { using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); - CopyAdditionalHeaders(request.Headers); + StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders); using var response = await _httpClient.SendAsync( request, @@ -251,15 +224,7 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation return; } - string messageId = "(no id)"; - if (message is JsonRpcMessageWithId messageWithId) - { - messageId = messageWithId.Id.ToString(); - } - - LogTransportReceivedMessage(Name, messageId); await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); - LogTransportMessageWritten(Name, messageId); } catch (JsonException ex) { @@ -290,20 +255,6 @@ private void HandleEndpointEvent(string data) _connectionEstablished.TrySetResult(true); } - private void CopyAdditionalHeaders(HttpRequestHeaders headers) - { - if (_options.AdditionalHeaders is not null) - { - foreach (var header in _options.AdditionalHeaders) - { - if (!headers.TryAddWithoutValidation(header.Key, header.Value)) - { - throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(SseClientTransportOptions.AdditionalHeaders)}."); - } - } - } - } - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} accepted SSE transport POST for message ID '{MessageId}'.")] private partial void LogAcceptedPost(string endpointName, string messageId); diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index 832d67275..1b2865572 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -57,6 +57,11 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { + if (_options.UseStreamableHttp) + { + return new StreamableHttpClientSessionTransport(_options, _httpClient, _loggerFactory, Name); + } + var sessionTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, Name); try diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs index 0a36a15f9..b83204ae5 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs @@ -30,13 +30,20 @@ public required Uri Endpoint } } + /// + /// Gets or sets a value indicating whether to use "Streamable HTTP" for the transport rather than "HTTP with SSE". Defaults to false. + /// Streamable HTTP transport specification. + /// HTTP with SSE transport specification. + /// + public bool UseStreamableHttp { get; init; } + /// /// Gets a transport identifier used for logging purposes. /// public string? Name { get; init; } /// - /// Gets or sets a timeout used to establish the initial connection to the SSE server. + /// Gets or sets a timeout used to establish the initial connection to the SSE server. Defaults to 30 seconds. /// /// /// This timeout controls how long the client waits for: diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs index 6fcdf0a8b..1abc05b57 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs @@ -146,15 +146,7 @@ private async Task ProcessMessageAsync(string line, CancellationToken cancellati var message = (JsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))); if (message != null) { - string messageId = "(no id)"; - if (message is JsonRpcMessageWithId messageWithId) - { - messageId = messageWithId.Id.ToString(); - } - - LogTransportReceivedMessage(Name, messageId); await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); - LogTransportMessageWritten(Name, messageId); } else { diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs index acf18984a..fba41782f 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs @@ -111,15 +111,7 @@ private async Task ReadMessagesAsync() { if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) is JsonRpcMessage message) { - string messageId = "(no id)"; - if (message is JsonRpcMessageWithId messageWithId) - { - messageId = messageWithId.Id.ToString(); - } - - LogTransportReceivedMessage(Name, messageId); await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false); - LogTransportMessageWritten(Name, messageId); } else { diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs new file mode 100644 index 000000000..7697c28e0 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpClientSessionTransport.cs @@ -0,0 +1,241 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Net.Http.Headers; +using System.Net.ServerSentEvents; +using System.Text.Json; + +#if NET +using System.Net.Http.Json; +#else +using System.Text; +#endif + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// The Streamable HTTP client transport implementation +/// +internal sealed partial class StreamableHttpClientSessionTransport : TransportBase +{ + private static readonly MediaTypeWithQualityHeaderValue s_applicationJsonMediaType = new("application/json"); + private static readonly MediaTypeWithQualityHeaderValue s_textEventStreamMediaType = new("text/event-stream"); + + private readonly HttpClient _httpClient; + private readonly SseClientTransportOptions _options; + private readonly CancellationTokenSource _connectionCts; + private readonly ILogger _logger; + + private string? _mcpSessionId; + private Task? _getReceiveTask; + + public StreamableHttpClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName) + : base(endpointName, loggerFactory) + { + Throw.IfNull(transportOptions); + Throw.IfNull(httpClient); + + _options = transportOptions; + _httpClient = httpClient; + _connectionCts = new CancellationTokenSource(); + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + + // We connect with the initialization request with the MCP transport. This means that any errors won't be observed + // until the first call to SendMessageAsync. Fortunately, that happens internally in McpClientFactory.ConnectAsync + // so we still throw any connection-related Exceptions from there and never expose a pre-connected client to the user. + SetConnected(true); + } + + /// + public override async Task SendMessageAsync( + JsonRpcMessage message, + CancellationToken cancellationToken = default) + { + using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token); + cancellationToken = sendCts.Token; + +#if NET + using var content = JsonContent.Create(message, McpJsonUtilities.DefaultOptions.GetTypeInfo()); +#else + using var content = new StringContent( + JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), + Encoding.UTF8, + "application/json" + ); +#endif + + using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _options.Endpoint) + { + Content = content, + Headers = + { + Accept = { s_applicationJsonMediaType, s_textEventStreamMediaType }, + }, + }; + + CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId); + using var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + + response.EnsureSuccessStatusCode(); + + var rpcRequest = message as JsonRpcRequest; + JsonRpcMessage? rpcResponseCandidate = null; + + if (response.Content.Headers.ContentType?.MediaType == "application/json") + { + var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + rpcResponseCandidate = await ProcessMessageAsync(responseContent, cancellationToken).ConfigureAwait(false); + } + else if (response.Content.Headers.ContentType?.MediaType == "text/event-stream") + { + using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken); + rpcResponseCandidate = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false); + } + + if (rpcRequest is null) + { + return; + } + + if (rpcResponseCandidate is not JsonRpcMessageWithId messageWithId || messageWithId.Id != rpcRequest.Id) + { + throw new McpException($"Streamable HTTP POST response completed without a reply to request with ID: {rpcRequest.Id}"); + } + + if (rpcRequest.Method == RequestMethods.Initialize && rpcResponseCandidate is JsonRpcResponse) + { + // We've successfully initialized! Copy session-id and start GET request if any. + if (response.Headers.TryGetValues("mcp-session-id", out var sessionIdValues)) + { + _mcpSessionId = sessionIdValues.FirstOrDefault(); + } + + _getReceiveTask = ReceiveUnsolicitedMessagesAsync(); + } + } + + public override async ValueTask DisposeAsync() + { + try + { + await _connectionCts.CancelAsync().ConfigureAwait(false); + + try + { + if (_getReceiveTask != null) + { + await _getReceiveTask.ConfigureAwait(false); + } + } + catch (OperationCanceledException) + { + } + finally + { + _connectionCts.Dispose(); + } + } + finally + { + SetConnected(false); + } + } + + private async Task ReceiveUnsolicitedMessagesAsync() + { + // Send a GET request to handle any unsolicited messages not sent over a POST response. + using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint); + request.Headers.Accept.Add(s_textEventStreamMediaType); + CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, _mcpSessionId); + + using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _connectionCts.Token).ConfigureAwait(false); + + if (!response.IsSuccessStatusCode) + { + // Server support for the GET request is optional. If it fails, we don't care. It just means we won't receive unsolicited messages. + return; + } + + using var responseStream = await response.Content.ReadAsStreamAsync(_connectionCts.Token).ConfigureAwait(false); + await ProcessSseResponseAsync(responseStream, relatedRpcRequest: null, _connectionCts.Token).ConfigureAwait(false); + } + + private async Task ProcessSseResponseAsync(Stream responseStream, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken) + { + await foreach (SseItem sseEvent in SseParser.Create(responseStream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) + { + if (sseEvent.EventType != "message") + { + continue; + } + + var message = await ProcessMessageAsync(sseEvent.Data, cancellationToken).ConfigureAwait(false); + + // The server SHOULD end the response here anyway, but we won't leave it to chance. This transport makes + // a GET request for any notifications that might need to be sent after the completion of each POST. + if (message is JsonRpcMessageWithId messageWithId && relatedRpcRequest?.Id == messageWithId.Id) + { + return messageWithId; + } + } + + return null; + } + + private async Task ProcessMessageAsync(string data, CancellationToken cancellationToken) + { + try + { + var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); + if (message is null) + { + LogTransportMessageParseUnexpectedTypeSensitive(Name, data); + return null; + } + + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + return message; + } + catch (JsonException ex) + { + LogJsonException(ex, data); + } + + return null; + } + + private void LogJsonException(JsonException ex, string data) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogTransportMessageParseFailedSensitive(Name, data, ex); + } + else + { + LogTransportMessageParseFailed(Name, ex); + } + } + + internal static void CopyAdditionalHeaders(HttpRequestHeaders headers, Dictionary? additionalHeaders, string? sessionId = null) + { + if (sessionId is not null) + { + headers.Add("mcp-session-id", sessionId); + } + + if (additionalHeaders is null) + { + return; + } + + foreach (var header in additionalHeaders) + { + if (!headers.TryAddWithoutValidation(header.Key, header.Value)) + { + throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(SseClientTransportOptions.AdditionalHeaders)}."); + } + } + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs index 2bd8f2784..29d062677 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamableHttpPostTransport.cs @@ -61,7 +61,7 @@ public async ValueTask DisposeAsync() { yield return message; - if (message.Data is JsonRpcResponse response) + if (message.Data is JsonRpcMessageWithId response) { if (_pendingRequests.Remove(response.Id) && _pendingRequests.Count == 0) { diff --git a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs index af9cdaefd..1af8d91fb 100644 --- a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs +++ b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs @@ -37,7 +37,7 @@ protected TransportBase(string name, ILoggerFactory? loggerFactory) _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions { SingleReader = true, - SingleWriter = true, + SingleWriter = false, }); } @@ -76,6 +76,12 @@ protected async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken throw new InvalidOperationException("Transport is not connected"); } + if (_logger.IsEnabled(LogLevel.Debug)) + { + var messageId = (message as JsonRpcMessageWithId)?.Id.ToString() ?? "(no id)"; + LogTransportReceivedMessage(Name, messageId); + } + await _messageChannel.Writer.WriteAsync(message, cancellationToken).ConfigureAwait(false); } @@ -115,9 +121,6 @@ protected void SetConnected(bool isConnected) [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} transport received message with ID '{MessageId}'.")] private protected partial void LogTransportReceivedMessage(string endpointName, string messageId); - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} transport sent message with ID '{MessageId}'.")] - private protected partial void LogTransportMessageWritten(string endpointName, string messageId); - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} transport received unexpected message. Message: '{Message}'.")] private protected partial void LogTransportMessageParseUnexpectedTypeSensitive(string endpointName, string message); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs new file mode 100644 index 000000000..57a6c6ad9 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -0,0 +1,273 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Tests.Utils; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public abstract class HttpServerIntegrationTests : LoggedTest, IClassFixture +{ + protected readonly SseServerIntegrationTestFixture _fixture; + + public HttpServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + _fixture = fixture; + _fixture.Initialize(testOutputHelper, ClientTransportOptions); + } + + public override void Dispose() + { + _fixture.TestCompleted(); + base.Dispose(); + } + + protected abstract SseClientTransportOptions ClientTransportOptions { get; } + + private Task GetClientAsync(McpClientOptions? options = null) + { + return _fixture.ConnectMcpClientAsync(options, LoggerFactory); + } + + [Fact] + public async Task ConnectAndPing_Sse_TestServer() + { + // Arrange + + // Act + await using var client = await GetClientAsync(); + await client.PingAsync(TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(client); + } + + [Fact] + public async Task Connect_TestServer_ShouldProvideServerFields() + { + // Arrange + + // Act + await using var client = await GetClientAsync(); + + // Assert + Assert.NotNull(client.ServerCapabilities); + Assert.NotNull(client.ServerInfo); + } + + [Fact] + public async Task ListTools_Sse_TestServer() + { + // arrange + + // act + await using var client = await GetClientAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // assert + Assert.NotNull(tools); + } + + [Fact] + public async Task CallTool_Sse_EchoServer() + { + // arrange + + // act + await using var client = await GetClientAsync(); + var result = await client.CallToolAsync( + "echo", + new Dictionary + { + ["message"] = "Hello MCP!" + }, + cancellationToken: TestContext.Current.CancellationToken + ); + + // assert + Assert.NotNull(result); + Assert.False(result.IsError); + var textContent = Assert.Single(result.Content, c => c.Type == "text"); + Assert.Equal("Echo: Hello MCP!", textContent.Text); + } + + [Fact] + public async Task ListResources_Sse_TestServer() + { + // arrange + + // act + await using var client = await GetClientAsync(); + + IList allResources = await client.ListResourcesAsync(TestContext.Current.CancellationToken); + + // The everything server provides 100 test resources + Assert.Equal(100, allResources.Count); + } + + [Fact] + public async Task ReadResource_Sse_TextResource() + { + // arrange + + // act + await using var client = await GetClientAsync(); + // Odd numbered resources are text in the everything server (despite the docs saying otherwise) + // 1 is index 0, which is "even" in the 0-based index + // We copied this oddity to the test server + var result = await client.ReadResourceAsync("test://static/resource/1", TestContext.Current.CancellationToken); + + Assert.NotNull(result); + Assert.Single(result.Contents); + + TextResourceContents textContent = Assert.IsType(result.Contents[0]); + Assert.NotNull(textContent.Text); + } + + [Fact] + public async Task ReadResource_Sse_BinaryResource() + { + // arrange + + // act + await using var client = await GetClientAsync(); + // Even numbered resources are binary in the everything server (despite the docs saying otherwise) + // 2 is index 1, which is "odd" in the 0-based index + // We copied this oddity to the test server + var result = await client.ReadResourceAsync("test://static/resource/2", TestContext.Current.CancellationToken); + + Assert.NotNull(result); + Assert.Single(result.Contents); + + BlobResourceContents blobContent = Assert.IsType(result.Contents[0]); + Assert.NotNull(blobContent.Blob); + } + + [Fact] + public async Task ListPrompts_Sse_TestServer() + { + // arrange + + // act + await using var client = await GetClientAsync(); + var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); + + // assert + Assert.NotNull(prompts); + Assert.NotEmpty(prompts); + // We could add specific assertions for the known prompts + Assert.Contains(prompts, p => p.Name == "simple_prompt"); + Assert.Contains(prompts, p => p.Name == "complex_prompt"); + } + + [Fact] + public async Task GetPrompt_Sse_SimplePrompt() + { + // arrange + + // act + await using var client = await GetClientAsync(); + var result = await client.GetPromptAsync("simple_prompt", null, cancellationToken: TestContext.Current.CancellationToken); + + // assert + Assert.NotNull(result); + Assert.NotEmpty(result.Messages); + } + + [Fact] + public async Task GetPrompt_Sse_ComplexPrompt() + { + // arrange + + // act + await using var client = await GetClientAsync(); + var arguments = new Dictionary + { + { "temperature", "0.7" }, + { "style", "formal" } + }; + var result = await client.GetPromptAsync("complex_prompt", arguments, cancellationToken: TestContext.Current.CancellationToken); + + // assert + Assert.NotNull(result); + Assert.NotEmpty(result.Messages); + } + + [Fact] + public async Task GetPrompt_Sse_NonExistent_ThrowsException() + { + // arrange + + // act + await using var client = await GetClientAsync(); + await Assert.ThrowsAsync(() => + client.GetPromptAsync("non_existent_prompt", null, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task Sampling_Sse_TestServer() + { + // arrange + // Set up the sampling handler + int samplingHandlerCalls = 0; +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + McpClientOptions options = new(); + options.Capabilities = new(); + options.Capabilities.Sampling ??= new(); + options.Capabilities.Sampling.SamplingHandler = async (_, _, _) => + { + samplingHandlerCalls++; + return new CreateMessageResult + { + Model = "test-model", + Role = Role.Assistant, + Content = new Content + { + Type = "text", + Text = "Test response" + } + }; + }; + await using var client = await GetClientAsync(options); +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + + // Call the server's sampleLLM tool which should trigger our sampling handler + var result = await client.CallToolAsync("sampleLLM", new Dictionary + { + ["prompt"] = "Test prompt", + ["maxTokens"] = 100 + }, + cancellationToken: TestContext.Current.CancellationToken); + + // assert + Assert.NotNull(result); + var textContent = Assert.Single(result.Content); + Assert.Equal("text", textContent.Type); + Assert.False(string.IsNullOrEmpty(textContent.Text)); + } + + [Fact] + public async Task CallTool_Sse_EchoServer_Concurrently() + { + await using var client1 = await GetClientAsync(); + await using var client2 = await GetClientAsync(); + + for (int i = 0; i < 4; i++) + { + var client = (i % 2 == 0) ? client1 : client2; + var result = await client.CallToolAsync( + "echo", + new Dictionary + { + ["message"] = $"Hello MCP! {i}" + }, + cancellationToken: TestContext.Current.CancellationToken + ); + + Assert.NotNull(result); + Assert.False(result.IsError); + var textContent = Assert.Single(result.Content, c => c.Type == "text"); + Assert.Equal($"Echo: Hello MCP! {i}", textContent.Text); + } + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs new file mode 100644 index 000000000..d385623a2 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs @@ -0,0 +1,68 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class MapMcpSseTests(ITestOutputHelper outputHelper) : MapMcpTests(outputHelper) +{ + protected override bool UseStreamableHttp => false; + + [Theory] + [InlineData("/a", "/a/sse")] + [InlineData("/a/", "/a/sse")] + [InlineData("/a/b", "/a/b/sse")] + public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePattern, string requestPath) + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new() + { + Name = "TestCustomRouteServer", + Version = "1.0.0", + }; + }).WithHttpTransport(); + await using var app = Builder.Build(); + + app.MapMcp(routePattern); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var mcpClient = await ConnectAsync(requestPath); + + Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name); + } + + [Fact] + public async Task Can_UseHttpContextAccessor_InTool() + { + Builder.Services.AddMcpServer().WithHttpTransport().WithTools(); + + Builder.Services.AddHttpContextAccessor(); + + await using var app = Builder.Build(); + + app.Use(next => + { + return async context => + { + context.User = CreateUser("TestUser"); + await next(context); + }; + }); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var mcpClient = await ConnectAsync(); + + var response = await mcpClient.CallToolAsync( + "EchoWithUserName", + new Dictionary() { ["message"] = "Hello world!" }, + cancellationToken: TestContext.Current.CancellationToken); + + var content = Assert.Single(response.Content); + Assert.Equal("TestUser: Hello world!", content.Text); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs new file mode 100644 index 000000000..30632a8e6 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -0,0 +1,36 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class MapMcpStreamableHttpTests(ITestOutputHelper outputHelper) : MapMcpTests(outputHelper) +{ + protected override bool UseStreamableHttp => true; + + [Theory] + [InlineData("/a", "/a")] + [InlineData("/a", "/a/")] + [InlineData("/a/", "/a/")] + [InlineData("/a/", "/a")] + [InlineData("/a/b", "/a/b")] + public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePattern, string requestPath) + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new() + { + Name = "TestCustomRouteServer", + Version = "1.0.0", + }; + }).WithHttpTransport(); + await using var app = Builder.Build(); + + app.MapMcp(routePattern); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var mcpClient = await ConnectAsync(requestPath); + + Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index f11a0a51a..70b028e22 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -11,13 +11,18 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +public abstract class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) { - private async Task ConnectAsync(string? path = "/sse") + protected abstract bool UseStreamableHttp { get; } + + protected async Task ConnectAsync(string? path = null) { + path ??= UseStreamableHttp ? "/" : "/sse"; + var sseClientTransportOptions = new SseClientTransportOptions() { Endpoint = new Uri($"http://localhost{path}"), + UseStreamableHttp = UseStreamableHttp, }; await using var transport = new SseClientTransport(sseClientTransportOptions, HttpClient, LoggerFactory); return await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -56,65 +61,6 @@ public async Task Allows_Customizing_Route(string pattern) Assert.Equal($"data: {pattern}/message", dataLine[..dataLine.IndexOf('?')]); } - [Theory] - [InlineData("/a", "/a/sse")] - [InlineData("/a/", "/a/sse")] - [InlineData("/a/b", "/a/b/sse")] - public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePattern, string requestPath) - { - Builder.Services.AddMcpServer(options => - { - options.ServerInfo = new() - { - Name = "TestCustomRouteServer", - Version = "1.0.0", - }; - }).WithHttpTransport(); - await using var app = Builder.Build(); - - app.MapMcp(routePattern); - - await app.StartAsync(TestContext.Current.CancellationToken); - - var mcpClient = await ConnectAsync(requestPath); - - Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name); - } - - [Fact] - public async Task Can_UseHttpContextAccessor_InTool() - { - Builder.Services.AddMcpServer().WithHttpTransport().WithTools(); - - Builder.Services.AddHttpContextAccessor(); - - await using var app = Builder.Build(); - - app.Use(next => - { - return async context => - { - context.User = CreateUser("TestUser"); - await next(context); - }; - }); - - app.MapMcp(); - - await app.StartAsync(TestContext.Current.CancellationToken); - - var mcpClient = await ConnectAsync(); - - var response = await mcpClient.CallToolAsync( - "EchoWithUserName", - new Dictionary() { ["message"] = "Hello world!" }, - cancellationToken: TestContext.Current.CancellationToken); - - var content = Assert.Single(response.Content); - Assert.Equal("TestUser: Hello world!", content.Text); - } - - [Fact] public async Task Messages_FromNewUser_AreRejected() { @@ -144,13 +90,13 @@ public async Task Messages_FromNewUser_AreRejected() Assert.Equal(HttpStatusCode.Forbidden, httpRequestException.StatusCode); } - private ClaimsPrincipal CreateUser(string name) + protected ClaimsPrincipal CreateUser(string name) => new ClaimsPrincipal(new ClaimsIdentity( [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name)], "TestAuthType", "name", "role")); [McpServerToolType] - private class EchoHttpContextUserTools(IHttpContextAccessor contextAccessor) + protected class EchoHttpContextUserTools(IHttpContextAccessor contextAccessor) { [McpServerTool, Description("Echoes the input back to the client with their user name.")] public string EchoWithUserName(string message) diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs index 41b8d8fa7..5553163f5 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs @@ -18,10 +18,9 @@ public class SseServerIntegrationTestFixture : IAsyncDisposable // multiple tests, so this dispatches the output to the current test. private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new(); - private SseClientTransportOptions DefaultTransportOptions { get; } = new() + private SseClientTransportOptions DefaultTransportOptions { get; set; } = new() { - Endpoint = new Uri("http://localhost/sse"), - Name = "TestServer", + Endpoint = new("http://localhost/"), }; public SseServerIntegrationTestFixture() @@ -37,8 +36,9 @@ public SseServerIntegrationTestFixture() HttpClient = new HttpClient(socketsHttpHandler) { - BaseAddress = DefaultTransportOptions.Endpoint, + BaseAddress = new("http://localhost/"), }; + _serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _inMemoryTransport, _stopCts.Token); } @@ -53,9 +53,10 @@ public Task ConnectMcpClientAsync(McpClientOptions? options, ILogger TestContext.Current.CancellationToken); } - public void Initialize(ITestOutputHelper output) + public void Initialize(ITestOutputHelper output, SseClientTransportOptions clientTransportOptions) { _delegatingTestOutputHelper.CurrentTestOutputHelper = output; + DefaultTransportOptions = clientTransportOptions; } public void TestCompleted() diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs index 10a6316a9..ee1834a67 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs @@ -1,279 +1,23 @@ -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol.Types; -using ModelContextProtocol.Tests.Utils; +using ModelContextProtocol.Protocol.Transport; using System.Net; using System.Text; namespace ModelContextProtocol.AspNetCore.Tests; -public class SseServerIntegrationTests : LoggedTest, IClassFixture -{ - private readonly SseServerIntegrationTestFixture _fixture; - - public SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { - _fixture = fixture; - _fixture.Initialize(testOutputHelper); - } - - public override void Dispose() - { - _fixture.TestCompleted(); - base.Dispose(); - } - - private Task GetClientAsync(McpClientOptions? options = null) - { - return _fixture.ConnectMcpClientAsync(options, LoggerFactory); - } - - [Fact] - public async Task ConnectAndPing_Sse_TestServer() - { - // Arrange - - // Act - await using var client = await GetClientAsync(); - await client.PingAsync(TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(client); - } - - [Fact] - public async Task Connect_TestServer_ShouldProvideServerFields() - { - // Arrange - - // Act - await using var client = await GetClientAsync(); - - // Assert - Assert.NotNull(client.ServerCapabilities); - Assert.NotNull(client.ServerInfo); - } - - [Fact] - public async Task ListTools_Sse_TestServer() - { - // arrange - - // act - await using var client = await GetClientAsync(); - var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - - // assert - Assert.NotNull(tools); - } - - [Fact] - public async Task CallTool_Sse_EchoServer() - { - // arrange - - // act - await using var client = await GetClientAsync(); - var result = await client.CallToolAsync( - "echo", - new Dictionary - { - ["message"] = "Hello MCP!" - }, - cancellationToken: TestContext.Current.CancellationToken - ); - - // assert - Assert.NotNull(result); - Assert.False(result.IsError); - var textContent = Assert.Single(result.Content, c => c.Type == "text"); - Assert.Equal("Echo: Hello MCP!", textContent.Text); - } - - [Fact] - public async Task ListResources_Sse_TestServer() - { - // arrange - - // act - await using var client = await GetClientAsync(); - - IList allResources = await client.ListResourcesAsync(TestContext.Current.CancellationToken); - - // The everything server provides 100 test resources - Assert.Equal(100, allResources.Count); - } - - [Fact] - public async Task ReadResource_Sse_TextResource() - { - // arrange - - // act - await using var client = await GetClientAsync(); - // Odd numbered resources are text in the everything server (despite the docs saying otherwise) - // 1 is index 0, which is "even" in the 0-based index - // We copied this oddity to the test server - var result = await client.ReadResourceAsync("test://static/resource/1", TestContext.Current.CancellationToken); - - Assert.NotNull(result); - Assert.Single(result.Contents); - - TextResourceContents textContent = Assert.IsType(result.Contents[0]); - Assert.NotNull(textContent.Text); - } - - [Fact] - public async Task ReadResource_Sse_BinaryResource() - { - // arrange - - // act - await using var client = await GetClientAsync(); - // Even numbered resources are binary in the everything server (despite the docs saying otherwise) - // 2 is index 1, which is "odd" in the 0-based index - // We copied this oddity to the test server - var result = await client.ReadResourceAsync("test://static/resource/2", TestContext.Current.CancellationToken); - - Assert.NotNull(result); - Assert.Single(result.Contents); - - BlobResourceContents blobContent = Assert.IsType(result.Contents[0]); - Assert.NotNull(blobContent.Blob); - } - - [Fact] - public async Task ListPrompts_Sse_TestServer() - { - // arrange - - // act - await using var client = await GetClientAsync(); - var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); - - // assert - Assert.NotNull(prompts); - Assert.NotEmpty(prompts); - // We could add specific assertions for the known prompts - Assert.Contains(prompts, p => p.Name == "simple_prompt"); - Assert.Contains(prompts, p => p.Name == "complex_prompt"); - } - - [Fact] - public async Task GetPrompt_Sse_SimplePrompt() - { - // arrange - - // act - await using var client = await GetClientAsync(); - var result = await client.GetPromptAsync("simple_prompt", null, cancellationToken: TestContext.Current.CancellationToken); - - // assert - Assert.NotNull(result); - Assert.NotEmpty(result.Messages); - } - - [Fact] - public async Task GetPrompt_Sse_ComplexPrompt() - { - // arrange +public class SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) + : HttpServerIntegrationTests(fixture, testOutputHelper) - // act - await using var client = await GetClientAsync(); - var arguments = new Dictionary - { - { "temperature", "0.7" }, - { "style", "formal" } - }; - var result = await client.GetPromptAsync("complex_prompt", arguments, cancellationToken: TestContext.Current.CancellationToken); - - // assert - Assert.NotNull(result); - Assert.NotEmpty(result.Messages); - } - - [Fact] - public async Task GetPrompt_Sse_NonExistent_ThrowsException() - { - // arrange - - // act - await using var client = await GetClientAsync(); - await Assert.ThrowsAsync(() => - client.GetPromptAsync("non_existent_prompt", null, cancellationToken: TestContext.Current.CancellationToken)); - } - - [Fact] - public async Task Sampling_Sse_TestServer() - { - // arrange - // Set up the sampling handler - int samplingHandlerCalls = 0; -#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - McpClientOptions options = new(); - options.Capabilities = new(); - options.Capabilities.Sampling ??= new(); - options.Capabilities.Sampling.SamplingHandler = async (_, _, _) => - { - samplingHandlerCalls++; - return new CreateMessageResult - { - Model = "test-model", - Role = Role.Assistant, - Content = new Content - { - Type = "text", - Text = "Test response" - } - }; - }; - await using var client = await GetClientAsync(options); -#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously - - // Call the server's sampleLLM tool which should trigger our sampling handler - var result = await client.CallToolAsync("sampleLLM", new Dictionary - { - ["prompt"] = "Test prompt", - ["maxTokens"] = 100 - }, - cancellationToken: TestContext.Current.CancellationToken); - - // assert - Assert.NotNull(result); - var textContent = Assert.Single(result.Content); - Assert.Equal("text", textContent.Type); - Assert.False(string.IsNullOrEmpty(textContent.Text)); - } - - [Fact] - public async Task CallTool_Sse_EchoServer_Concurrently() +{ + protected override SseClientTransportOptions ClientTransportOptions => new() { - await using var client1 = await GetClientAsync(); - await using var client2 = await GetClientAsync(); - - for (int i = 0; i < 4; i++) - { - var client = (i % 2 == 0) ? client1 : client2; - var result = await client.CallToolAsync( - "echo", - new Dictionary - { - ["message"] = $"Hello MCP! {i}" - }, - cancellationToken: TestContext.Current.CancellationToken - ); - - Assert.NotNull(result); - Assert.False(result.IsError); - var textContent = Assert.Single(result.Content, c => c.Type == "text"); - Assert.Equal($"Echo: Hello MCP! {i}", textContent.Text); - } - } + Endpoint = new Uri("http://localhost/sse"), + Name = "TestServer", + }; [Fact] public async Task EventSourceResponse_Includes_ExpectedHeaders() { - using var sseResponse = await _fixture.HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + using var sseResponse = await _fixture.HttpClient.GetAsync("/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); sseResponse.EnsureSuccessStatusCode(); @@ -289,7 +33,7 @@ public async Task EventSourceStream_Includes_MessageEventType() { // Simulate our own MCP client handshake using a plain HttpClient so we can look for "event: message" // in the raw SSE response stream which is not exposed by the real MCP client. - await using var sseResponse = await _fixture.HttpClient.GetStreamAsync("", TestContext.Current.CancellationToken); + await using var sseResponse = await _fixture.HttpClient.GetStreamAsync("/sse", TestContext.Current.CancellationToken); using var streamReader = new StreamReader(sseResponse); var endpointEvent = await streamReader.ReadLineAsync(TestContext.Current.CancellationToken); @@ -305,7 +49,7 @@ public async Task EventSourceStream_Includes_MessageEventType() """; using (var initializeRequestBody = new StringContent(initializeRequest, Encoding.UTF8, "application/json")) { - var response = await _fixture.HttpClient.PostAsync(messageEndpoint, initializeRequestBody, TestContext.Current.CancellationToken); + using var response = await _fixture.HttpClient.PostAsync(messageEndpoint, initializeRequestBody, TestContext.Current.CancellationToken); Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); } @@ -314,7 +58,7 @@ public async Task EventSourceStream_Includes_MessageEventType() """; using (var initializedNotificationBody = new StringContent(initializedNotification, Encoding.UTF8, "application/json")) { - var response = await _fixture.HttpClient.PostAsync(messageEndpoint, initializedNotificationBody, TestContext.Current.CancellationToken); + using var response = await _fixture.HttpClient.PostAsync(messageEndpoint, initializedNotificationBody, TestContext.Current.CancellationToken); Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs new file mode 100644 index 000000000..5f126227c --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -0,0 +1,167 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Json; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils.Json; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class StreamableHttpClientConformanceTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable +{ + private WebApplication? _app; + + private async Task StartAsync() + { + Builder.Services.Configure(options => + { + options.SerializerOptions.TypeInfoResolverChain.Add(McpJsonUtilities.DefaultOptions.TypeInfoResolver!); + }); + _app = Builder.Build(); + + var echoTool = McpServerTool.Create(Echo, new() + { + Services = _app.Services, + }); + + _app.MapPost("/mcp", async (JsonRpcMessage message) => + { + if (message is not JsonRpcRequest request) + { + // Ignore all non-request notifications. + return Results.Accepted(); + } + + if (request.Method == "initialize") + { + return Results.Json(new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(new InitializeResult + { + ProtocolVersion = "2024-11-05", + Capabilities = new() + { + Tools = new(), + }, + ServerInfo = new Implementation + { + Name = "my-mcp", + Version = "0.0.1", + }, + }, McpJsonUtilities.DefaultOptions) + }); + } + + if (request.Method == "tools/list") + { + return Results.Json(new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(new ListToolsResult + { + Tools = [echoTool.ProtocolTool] + }, McpJsonUtilities.DefaultOptions), + }); + } + + if (request.Method == "tools/call") + { + var parameters = JsonSerializer.Deserialize(request.Params, GetJsonTypeInfo()); + Assert.NotNull(parameters?.Arguments); + + return Results.Json(new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(new CallToolResponse() + { + Content = [new() { Text = parameters.Arguments["message"].ToString() }], + }, McpJsonUtilities.DefaultOptions), + }); + } + + throw new Exception("Unexpected message!"); + }); + + await _app.StartAsync(TestContext.Current.CancellationToken); + } + + [Fact] + public async Task CanCallToolOnSessionlessStreamableHttpServer() + { + await StartAsync(); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new("http://localhost/mcp"), + UseStreamableHttp = true, + }, HttpClient, LoggerFactory); + + await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var echoTool = Assert.Single(tools); + Assert.Equal("echo", echoTool.Name); + await CallEchoAndValidateAsync(echoTool); + } + + + [Fact] + public async Task CanCallToolConcurrently() + { + await StartAsync(); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new("http://localhost/mcp"), + UseStreamableHttp = true, + }, HttpClient, LoggerFactory); + + await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var echoTool = Assert.Single(tools); + Assert.Equal("echo", echoTool.Name); + + var echoTasks = new Task[100]; + for (int i = 0; i < echoTasks.Length; i++) + { + echoTasks[i] = CallEchoAndValidateAsync(echoTool); + } + + await Task.WhenAll(echoTasks); + } + + private static async Task CallEchoAndValidateAsync(McpClientTool echoTool) + { + var response = await echoTool.CallAsync(new Dictionary() { ["message"] = "Hello world!" }, cancellationToken: TestContext.Current.CancellationToken); + Assert.NotNull(response); + var content = Assert.Single(response.Content); + Assert.Equal("text", content.Type); + Assert.Equal("Hello world!", content.Text); + } + + public async ValueTask DisposeAsync() + { + if (_app is not null) + { + await _app.DisposeAsync(); + } + base.Dispose(); + } + + private static JsonTypeInfo GetJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); + + [McpServerTool(Name = "echo")] + private static string Echo(string message) + { + return message; + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs similarity index 98% rename from tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs rename to tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs index fa3f8fe07..56e25936d 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs @@ -18,7 +18,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public class StreamableHttpTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable +public class StreamableHttpServerConformanceTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable { private static McpServerTool[] Tools { get; } = [ McpServerTool.Create(EchoAsync), @@ -35,7 +35,7 @@ private async Task StartAsync() { options.ServerInfo = new Implementation { - Name = nameof(StreamableHttpTests), + Name = nameof(StreamableHttpServerConformanceTests), Version = "73", }; }).WithTools(Tools).WithHttpTransport(); @@ -563,7 +563,7 @@ private string CallToolWithProgressToken(string toolName, string arguments = "{} private static InitializeResult AssertServerInfo(JsonRpcResponse rpcResponse) { var initializeResult = AssertType(rpcResponse.Result); - Assert.Equal(nameof(StreamableHttpTests), initializeResult.ServerInfo.Name); + Assert.Equal(nameof(StreamableHttpServerConformanceTests), initializeResult.ServerInfo.Name); Assert.Equal("73", initializeResult.ServerInfo.Version); return initializeResult; } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs new file mode 100644 index 000000000..9d3048929 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -0,0 +1,64 @@ +using ModelContextProtocol.Protocol.Transport; +using System.Net; +using System.Text; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) + : HttpServerIntegrationTests(fixture, testOutputHelper) + +{ + private const string InitializeRequest = """ + {"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"IntegrationTestClient","version":"1.0.0"}}} + """; + + protected override SseClientTransportOptions ClientTransportOptions => new() + { + Endpoint = new Uri("http://localhost/"), + Name = "TestServer", + UseStreamableHttp = true, + }; + + [Fact] + public async Task EventSourceResponse_Includes_ExpectedHeaders() + { + using var initializeRequestBody = new StringContent(InitializeRequest, Encoding.UTF8, "application/json"); + using var postRequest = new HttpRequestMessage(HttpMethod.Post, "/") + { + Headers = + { + Accept = { new("application/json"), new("text/event-stream")} + }, + Content = initializeRequestBody, + }; + using var sseResponse = await _fixture.HttpClient.SendAsync(postRequest, TestContext.Current.CancellationToken); + + sseResponse.EnsureSuccessStatusCode(); + + Assert.Equal("text/event-stream", sseResponse.Content.Headers.ContentType?.MediaType); + Assert.Equal("identity", sseResponse.Content.Headers.ContentEncoding.ToString()); + Assert.NotNull(sseResponse.Headers.CacheControl); + Assert.True(sseResponse.Headers.CacheControl.NoStore); + Assert.True(sseResponse.Headers.CacheControl.NoCache); + } + + [Fact] + public async Task EventSourceStream_Includes_MessageEventType() + { + using var initializeRequestBody = new StringContent(InitializeRequest, Encoding.UTF8, "application/json"); + using var postRequest = new HttpRequestMessage(HttpMethod.Post, "/") + { + Headers = + { + Accept = { new("application/json"), new("text/event-stream")} + }, + Content = initializeRequestBody, + }; + using var sseResponse = await _fixture.HttpClient.SendAsync(postRequest, TestContext.Current.CancellationToken); + using var sseResponseStream = await sseResponse.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); + using var streamReader = new StreamReader(sseResponseStream); + + var messageEvent = await streamReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.Equal("event: message", messageEvent); + } +} diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 1f74a9565..baf22f3d6 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -120,52 +120,6 @@ public async Task SendMessageAsync_Handles_Accepted_Response() Assert.True(true); } - [Fact] - public async Task SendMessageAsync_Handles_Accepted_Json_RPC_Response() - { - using var mockHttpHandler = new MockHttpHandler(); - using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); - - var eventSourcePipe = new Pipe(); - var eventSourceData = "event: endpoint\r\ndata: /sseendpoint\r\n\r\n"u8; - eventSourceData.CopyTo(eventSourcePipe.Writer.GetSpan(eventSourceData.Length)); - eventSourcePipe.Writer.Advance(eventSourceData.Length); - await eventSourcePipe.Writer.FlushAsync(TestContext.Current.CancellationToken); - - var firstCall = true; - mockHttpHandler.RequestHandler = (request) => - { - if (request.Method == HttpMethod.Post && request.RequestUri?.AbsoluteUri == "http://localhost:8080/sseendpoint") - { - return Task.FromResult(new HttpResponseMessage - { - StatusCode = HttpStatusCode.OK, - Content = new StringContent("{\"jsonrpc\":\"2.0\", \"id\": \"44\", \"result\": null}") - }); - } - else - { - if (!firstCall) - throw new IOException("Abort"); - else - firstCall = false; - - return Task.FromResult(new HttpResponseMessage - { - StatusCode = HttpStatusCode.OK, - Content = new StreamContent(eventSourcePipe.Reader.AsStream()), - }); - } - }; - - await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); - - await session.SendMessageAsync(new JsonRpcRequest() { Method = RequestMethods.Initialize, Id = new RequestId(44) }, CancellationToken.None); - Assert.True(true); - eventSourcePipe.Writer.Complete(); - } - [Fact] public async Task ReceiveMessagesAsync_Handles_Messages() {