From 3e4308f1ef5e0bafb2a2c06edc3e6b538a89bb4a Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Sun, 30 Mar 2025 22:28:00 -0400 Subject: [PATCH 1/2] Improve progress reporting - Enable injecting an `IProgress<>` into a tool call to report progress when the client has supplied a progress token - Add a custom type for ProgressToken and ProgressNotification to map to the schema - Simplify RequestId to match ProgressToken's shape - Serialize StdioClientTransport's SendMessageAsync implementation - Add strongly-typed names for all request methods - Change all request IDs created by McpJsonRpcEndpoint to include a guid for that endpoint --- src/ModelContextProtocol/Client/McpClient.cs | 34 ++--- .../Client/McpClientExtensions.cs | 32 ++--- src/ModelContextProtocol/NopProgress.cs | 12 ++ .../ProgressNotificationValue.cs | 14 ++ .../Protocol/Messages/NotificationMethods.cs | 23 ++++ .../Protocol/Messages/OperationNames.cs | 43 ------- .../Protocol/Messages/ProgressNotification.cs | 112 ++++++++++++++++ .../Protocol/Messages/ProgressToken.cs | 102 +++++++++++++++ .../Protocol/Messages/RequestId.cs | 120 ++++++++++-------- .../Protocol/Messages/RequestIdConverter.cs | 37 ------ .../Protocol/Messages/RequestMethods.cs | 82 ++++++++++++ .../Transport/SseClientSessionTransport.cs | 2 +- .../Transport/StdioClientStreamTransport.cs | 3 + .../Protocol/Types/ListRootsRequestParams.cs | 6 +- .../Protocol/Types/RequestParamsMetadata.cs | 3 +- .../Server/AIFunctionMcpServerTool.cs | 22 ++++ src/ModelContextProtocol/Server/McpServer.cs | 28 ++-- .../Server/McpServerExtensions.cs | 4 +- src/ModelContextProtocol/Shared/McpSession.cs | 12 +- src/ModelContextProtocol/TokenProgress.cs | 31 +++++ .../Utils/Json/McpJsonUtilities.cs | 1 + .../McpServerBuilderExtensionsToolsTests.cs | 68 +++++++++- .../Server/McpServerTests.cs | 42 +++--- .../Transport/SseClientTransportTests.cs | 6 +- .../Transport/StdioServerTransportTests.cs | 10 +- .../Utils/InMemoryTestSseServer.cs | 6 +- .../Utils/TestServerTransport.cs | 4 +- 27 files changed, 631 insertions(+), 228 deletions(-) create mode 100644 src/ModelContextProtocol/NopProgress.cs create mode 100644 src/ModelContextProtocol/ProgressNotificationValue.cs delete mode 100644 src/ModelContextProtocol/Protocol/Messages/OperationNames.cs create mode 100644 src/ModelContextProtocol/Protocol/Messages/ProgressNotification.cs create mode 100644 src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs delete mode 100644 src/ModelContextProtocol/Protocol/Messages/RequestIdConverter.cs create mode 100644 src/ModelContextProtocol/Protocol/Messages/RequestMethods.cs create mode 100644 src/ModelContextProtocol/TokenProgress.cs diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index b774389f1..f2e0fa5f5 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -42,7 +42,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp } SetRequestHandler( - "sampling/createMessage", + RequestMethods.SamplingCreateMessage, (request, ct) => samplingHandler(request, ct)); } @@ -54,7 +54,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp } SetRequestHandler( - "roots/list", + RequestMethods.RootsList, (request, ct) => rootsHandler(request, ct)); } } @@ -89,21 +89,21 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); initializationCts.CancelAfter(_options.InitializationTimeout); - try - { - // Send initialize request - var initializeResponse = await SendRequestAsync( - new JsonRpcRequest + try + { + // Send initialize request + var initializeResponse = await SendRequestAsync( + new JsonRpcRequest + { + Method = RequestMethods.Initialize, + Params = new InitializeRequestParams() { - Method = "initialize", - Params = new InitializeRequestParams() - { - ProtocolVersion = _options.ProtocolVersion, - Capabilities = _options.Capabilities ?? new ClientCapabilities(), - ClientInfo = _options.ClientInfo, - } - }, - initializationCts.Token).ConfigureAwait(false); + ProtocolVersion = _options.ProtocolVersion, + Capabilities = _options.Capabilities ?? new ClientCapabilities(), + ClientInfo = _options.ClientInfo + } + }, + initializationCts.Token).ConfigureAwait(false); // Store server information _logger.ServerCapabilitiesReceived(EndpointName, @@ -123,7 +123,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) // Send initialized notification await SendMessageAsync( - new JsonRpcNotification { Method = "notifications/initialized" }, + new JsonRpcNotification { Method = NotificationMethods.InitializedNotification }, initializationCts.Token).ConfigureAwait(false); } catch (OperationCanceledException) when (initializationCts.IsCancellationRequested) diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index 436adb264..5227285a1 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -41,7 +41,7 @@ public static Task PingAsync(this IMcpClient client, CancellationToken cancellat Throw.IfNull(client); return client.SendRequestAsync( - CreateRequest("ping", null), + CreateRequest(RequestMethods.Ping, null), cancellationToken); } @@ -61,7 +61,7 @@ public static async Task> ListToolsAsync( do { var toolResults = await client.SendRequestAsync( - CreateRequest("tools/list", CreateCursorDictionary(cursor)), + CreateRequest(RequestMethods.ToolsList, CreateCursorDictionary(cursor)), cancellationToken).ConfigureAwait(false); tools ??= new List(toolResults.Tools.Count); @@ -96,7 +96,7 @@ public static async IAsyncEnumerable EnumerateToolsAsync( do { var toolResults = await client.SendRequestAsync( - CreateRequest("tools/list", CreateCursorDictionary(cursor)), + CreateRequest(RequestMethods.ToolsList, CreateCursorDictionary(cursor)), cancellationToken).ConfigureAwait(false); foreach (var tool in toolResults.Tools) @@ -126,7 +126,7 @@ public static async Task> ListPromptsAsync( do { var promptResults = await client.SendRequestAsync( - CreateRequest("prompts/list", CreateCursorDictionary(cursor)), + CreateRequest(RequestMethods.PromptsList, CreateCursorDictionary(cursor)), cancellationToken).ConfigureAwait(false); if (prompts is null) @@ -164,7 +164,7 @@ public static async IAsyncEnumerable EnumeratePromptsAsync( do { var promptResults = await client.SendRequestAsync( - CreateRequest("prompts/list", CreateCursorDictionary(cursor)), + CreateRequest(RequestMethods.PromptsList, CreateCursorDictionary(cursor)), cancellationToken).ConfigureAwait(false); foreach (var prompt in promptResults.Prompts) @@ -192,7 +192,7 @@ public static Task GetPromptAsync( Throw.IfNullOrWhiteSpace(name); return client.SendRequestAsync( - CreateRequest("prompts/get", CreateParametersDictionary(name, arguments)), + CreateRequest(RequestMethods.PromptsGet, CreateParametersDictionary(name, arguments)), cancellationToken); } @@ -213,7 +213,7 @@ public static async Task> ListResourceTemplatesAsync( do { var templateResults = await client.SendRequestAsync( - CreateRequest("resources/templates/list", CreateCursorDictionary(cursor)), + CreateRequest(RequestMethods.ResourcesTemplatesList, CreateCursorDictionary(cursor)), cancellationToken).ConfigureAwait(false); if (templates is null) @@ -251,7 +251,7 @@ public static async IAsyncEnumerable EnumerateResourceTemplate do { var templateResults = await client.SendRequestAsync( - CreateRequest("resources/templates/list", CreateCursorDictionary(cursor)), + CreateRequest(RequestMethods.ResourcesTemplatesList, CreateCursorDictionary(cursor)), cancellationToken).ConfigureAwait(false); foreach (var template in templateResults.ResourceTemplates) @@ -281,7 +281,7 @@ public static async Task> ListResourcesAsync( do { var resourceResults = await client.SendRequestAsync( - CreateRequest("resources/list", CreateCursorDictionary(cursor)), + CreateRequest(RequestMethods.ResourcesList, CreateCursorDictionary(cursor)), cancellationToken).ConfigureAwait(false); if (resources is null) @@ -319,7 +319,7 @@ public static async IAsyncEnumerable EnumerateResourcesAsync( do { var resourceResults = await client.SendRequestAsync( - CreateRequest("resources/list", CreateCursorDictionary(cursor)), + CreateRequest(RequestMethods.ResourcesList, CreateCursorDictionary(cursor)), cancellationToken).ConfigureAwait(false); foreach (var resource in resourceResults.Resources) @@ -345,7 +345,7 @@ public static Task ReadResourceAsync( Throw.IfNullOrWhiteSpace(uri); return client.SendRequestAsync( - CreateRequest("resources/read", new() { ["uri"] = uri }), + CreateRequest(RequestMethods.ResourcesRead, new() { ["uri"] = uri }), cancellationToken); } @@ -369,7 +369,7 @@ public static Task GetCompletionAsync(this IMcpClient client, Re } return client.SendRequestAsync( - CreateRequest("completion/complete", new() + CreateRequest(RequestMethods.CompletionComplete, new() { ["ref"] = reference, ["argument"] = new Argument { Name = argumentName, Value = argumentValue } @@ -389,7 +389,7 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, Throw.IfNullOrWhiteSpace(uri); return client.SendRequestAsync( - CreateRequest("resources/subscribe", new() { ["uri"] = uri }), + CreateRequest(RequestMethods.ResourcesSubscribe, new() { ["uri"] = uri }), cancellationToken); } @@ -405,7 +405,7 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u Throw.IfNullOrWhiteSpace(uri); return client.SendRequestAsync( - CreateRequest("resources/unsubscribe", new() { ["uri"] = uri }), + CreateRequest(RequestMethods.ResourcesUnsubscribe, new() { ["uri"] = uri }), cancellationToken); } @@ -424,7 +424,7 @@ public static Task CallToolAsync( Throw.IfNull(toolName); return client.SendRequestAsync( - CreateRequest("tools/call", CreateParametersDictionary(toolName, arguments)), + CreateRequest(RequestMethods.ToolsCall, CreateParametersDictionary(toolName, arguments)), cancellationToken); } @@ -560,7 +560,7 @@ public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, C Throw.IfNull(client); return client.SendRequestAsync( - CreateRequest("logging/setLevel", new() { ["level"] = level }), + CreateRequest(RequestMethods.LoggingSetLevel, new() { ["level"] = level }), cancellationToken); } diff --git a/src/ModelContextProtocol/NopProgress.cs b/src/ModelContextProtocol/NopProgress.cs new file mode 100644 index 000000000..182ab9734 --- /dev/null +++ b/src/ModelContextProtocol/NopProgress.cs @@ -0,0 +1,12 @@ +namespace ModelContextProtocol; + +/// Provides an that's a nop. +internal sealed class NullProgress : IProgress +{ + public static NullProgress Instance { get; } = new(); + + /// + public void Report(ProgressNotificationValue value) + { + } +} diff --git a/src/ModelContextProtocol/ProgressNotificationValue.cs b/src/ModelContextProtocol/ProgressNotificationValue.cs new file mode 100644 index 000000000..7aa8c9e9a --- /dev/null +++ b/src/ModelContextProtocol/ProgressNotificationValue.cs @@ -0,0 +1,14 @@ +namespace ModelContextProtocol; + +/// Provides a progress value that can be sent using . +public record struct ProgressNotificationValue +{ + /// Gets or sets the progress thus far. + public required float Progress { get; init; } + + /// Gets or sets the total number of items to process (or total progress required), if known. + public float? Total { get; init; } + + /// Gets or sets an optional message describing the current progress. + public string? Message { get; init; } +} diff --git a/src/ModelContextProtocol/Protocol/Messages/NotificationMethods.cs b/src/ModelContextProtocol/Protocol/Messages/NotificationMethods.cs index 521d3eb33..8fb50c3bc 100644 --- a/src/ModelContextProtocol/Protocol/Messages/NotificationMethods.cs +++ b/src/ModelContextProtocol/Protocol/Messages/NotificationMethods.cs @@ -34,4 +34,27 @@ public static class NotificationMethods /// Sent by the server when a log message is generated. /// public const string LoggingMessageNotification = "notifications/message"; + + /// + /// Sent from the client to the server after initialization has finished. + /// + public const string InitializedNotification = "notifications/initialized"; + + /// + /// Sent to inform the receiver of a progress update for a long-running request. + /// + public const string ProgressNotification = "notifications/progress"; + + /// + /// Sent by either side to indicate that it is cancelling a previously-issued request. + /// + /// + /// The request SHOULD still be in-flight, but due to communication latency, it is always possible that this notification + /// MAY arrive after the request has already finished. + /// + /// This notification indicates that the result will be unused, so any associated processing SHOULD cease. + /// + /// A client MUST NOT attempt to cancel its `initialize` request.". + /// + public const string CancelledNotification = "notifications/cancelled"; } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Messages/OperationNames.cs b/src/ModelContextProtocol/Protocol/Messages/OperationNames.cs deleted file mode 100644 index 76f22d438..000000000 --- a/src/ModelContextProtocol/Protocol/Messages/OperationNames.cs +++ /dev/null @@ -1,43 +0,0 @@ -namespace ModelContextProtocol.Protocol.Messages; - -/// Provides names of standard operations for use with registering handlers. -/// -/// These values should not be inspected or relied on for their exact values. -/// They serve only as opaque keys. They will be stable for the lifetime of a process -/// but may change between versions of this library. -/// -public static class OperationNames -{ - /// Gets the name of the sampling operation. - public static string Sampling { get; } = "operation/sampling"; - - /// Gets the name of the roots operation. - public static string Roots { get; } = "operation/roots"; - - /// Gets the name of the list tools operation. - public static string ListTools { get; } = "operation/listTools"; - - /// Gets the name of the call tool operation. - public static string CallTool { get; } = "operation/callTool"; - - /// Gets the name of the list prompts operation. - public static string ListPrompts { get; } = "operation/listPrompts"; - - /// Gets the name of the get prompt operation. - public static string GetPrompt { get; } = "operation/getPrompt"; - - /// Gets the name of the list resources operation. - public static string ListResources { get; } = "operation/listResources"; - - /// Gets the name of the read resource operation. - public static string ReadResource { get; } = "operation/readResource"; - - /// Gets the name of the get completion operation. - public static string GetCompletion { get; } = "operation/getCompletion"; - - /// Gets the name of the subscribe to resources operation. - public static string SubscribeToResources { get; } = "operation/subscribeToResources"; - - /// Gets the name of the subscribe to resources operation. - public static string UnsubscribeFromResources { get; } = "operation/unsubscribeFromResources"; -} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Messages/ProgressNotification.cs b/src/ModelContextProtocol/Protocol/Messages/ProgressNotification.cs new file mode 100644 index 000000000..5b351009a --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Messages/ProgressNotification.cs @@ -0,0 +1,112 @@ +using System.ComponentModel; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Messages; + +/// +/// An out-of-band notification used to inform the receiver of a progress update for a long-running request. +/// See the schema for details +/// +[JsonConverter(typeof(Converter))] +public class ProgressNotification +{ + /// + /// The progress token which was given in the initial request, used to associate this notification with the request that is proceeding. + /// + public required ProgressToken ProgressToken { get; init; } + + /// + /// The progress thus far. This should increase every time progress is made, even if the total is unknown. + /// + public required ProgressNotificationValue Progress { get; init; } + + /// Provides a for . + [EditorBrowsable(EditorBrowsableState.Never)] + public sealed class Converter : JsonConverter + { + /// + public override ProgressNotification? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + ProgressToken? progressToken = null; + float? progress = null; + float? total = null; + string? message = null; + + while (reader.Read() && reader.TokenType != JsonTokenType.EndObject) + { + if (reader.TokenType == JsonTokenType.PropertyName) + { + var propertyName = reader.GetString(); + reader.Read(); + switch (propertyName) + { + case "progressToken": + if (JsonSerializer.Deserialize(ref reader, options.GetTypeInfo(typeof(ProgressToken))) is not ProgressToken token) + { + throw new JsonException("Invalid value for 'progressToken'."); + } + progressToken = token; + break; + + case "progress": + progress = reader.GetSingle(); + break; + + case "total": + total = reader.GetSingle(); + break; + + case "message": + message = reader.GetString(); + break; + } + } + } + + if (progress is null) + { + throw new JsonException("Missing required property 'progress'."); + } + + if (progressToken is null) + { + throw new JsonException("Missing required property 'progressToken'."); + } + + return new ProgressNotification + { + ProgressToken = progressToken.GetValueOrDefault(), + Progress = new ProgressNotificationValue() + { + Progress = progress.GetValueOrDefault(), + Total = total, + Message = message + } + }; + } + + /// + public override void Write(Utf8JsonWriter writer, ProgressNotification value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + writer.WritePropertyName("progressToken"); + JsonSerializer.Serialize(writer, value.ProgressToken, options.GetTypeInfo(typeof(ProgressToken))); + + writer.WriteNumber("progress", value.Progress.Progress); + + if (value.Progress.Total is { } total) + { + writer.WriteNumber("total", total); + } + + if (value.Progress.Message is { } message) + { + writer.WriteString("message", message); + } + + writer.WriteEndObject(); + } + } +} diff --git a/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs b/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs new file mode 100644 index 000000000..b3f4a8a93 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs @@ -0,0 +1,102 @@ +using ModelContextProtocol.Utils; +using System.ComponentModel; +using System.Globalization; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Messages; + +/// +/// Represents a progress token, which can be either a string or an integer. +/// +[JsonConverter(typeof(Converter))] +public readonly struct ProgressToken : IEquatable +{ + /// The id, either a string or a boxed long or null. + private readonly object? _id; + + /// Initializes a new instance of the with a specified value. + /// The required ID value. + public ProgressToken(string value) + { + Throw.IfNull(value); + _id = value; + } + + /// Initializes a new instance of the with a specified value. + /// The required ID value. + public ProgressToken(long value) + { + // Box the long. Progress tokens are almost always strings in practice, so this should be rare. + _id = value; + } + + /// Gets whether the identifier is uninitialized. + public bool IsDefault => _id is null; + + /// + public override string? ToString() => + _id is string stringValue ? stringValue : + _id is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) : + null; + + /// + /// Compares this ProgressToken to another ProgressToken. + /// + public bool Equals(ProgressToken other) => Equals(_id, other._id); + + /// + public override bool Equals(object? obj) => obj is ProgressToken other && Equals(other); + + /// + public override int GetHashCode() => _id?.GetHashCode() ?? 0; + + /// + /// Compares two ProgressTokens for equality. + /// + public static bool operator ==(ProgressToken left, ProgressToken right) => left.Equals(right); + + /// + /// Compares two ProgressTokens for inequality. + /// + public static bool operator !=(ProgressToken left, ProgressToken right) => !left.Equals(right); + + /// + /// JSON converter for ProgressToken that handles both string and number values. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public sealed class Converter : JsonConverter + { + /// + public override ProgressToken Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return reader.TokenType switch + { + JsonTokenType.String => new(reader.GetString()!), + JsonTokenType.Number => new(reader.GetInt64()), + _ => throw new JsonException("progressToken must be a string or an integer"), + }; + } + + /// + public override void Write(Utf8JsonWriter writer, ProgressToken value, JsonSerializerOptions options) + { + Throw.IfNull(writer); + + switch (value._id) + { + case string str: + writer.WriteStringValue(str); + return; + + case long longValue: + writer.WriteNumberValue(longValue); + return; + + case null: + writer.WriteStringValue(string.Empty); + return; + } + } + } +} diff --git a/src/ModelContextProtocol/Protocol/Messages/RequestId.cs b/src/ModelContextProtocol/Protocol/Messages/RequestId.cs index b6bdbb02b..f2dc12edd 100644 --- a/src/ModelContextProtocol/Protocol/Messages/RequestId.cs +++ b/src/ModelContextProtocol/Protocol/Messages/RequestId.cs @@ -1,76 +1,55 @@ -using System.Text.Json.Serialization; +using ModelContextProtocol.Utils; +using System.ComponentModel; +using System.Globalization; +using System.Text.Json; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Messages; /// -/// Represents a JSON-RPC request identifier which can be either a string or a number. +/// Represents a JSON-RPC request identifier, which can be either a string or an integer. /// -[JsonConverter(typeof(RequestIdConverter))] +[JsonConverter(typeof(Converter))] public readonly struct RequestId : IEquatable { - private readonly object _value; + /// The id, either a string or a boxed long or null. + private readonly object? _id; - private RequestId(object value) + /// Initializes a new instance of the with a specified value. + /// The required ID value. + public RequestId(string value) { - _value = value; + Throw.IfNull(value); + _id = value; } - /// - /// Creates a new RequestId from a string. - /// - /// The Id - /// Wrapped Id object - public static RequestId FromString(string value) => new(value); - - /// - /// Creates a new RequestId from a number. - /// - /// The Id - /// Wrapped Id object - public static RequestId FromNumber(long value) => new(value); - - /// - /// Checks if the RequestId is a string. - /// - public bool IsString => _value is string; - - /// - /// Checks if the RequestId is a number. - /// - public bool IsNumber => _value is long; - - /// - /// Checks if the request id is valid (has a value) - /// - public bool IsValid => _value != null; - - /// - /// Gets the RequestId as a string. - /// - /// Thrown if the RequestId is not a string" - public string AsString => _value as string ?? throw new InvalidOperationException("RequestId is not a string"); + /// Initializes a new instance of the with a specified value. + /// The required ID value. + public RequestId(long value) + { + // Box the long. Request IDs are almost always strings in practice, so this should be rare. + _id = value; + } - /// - /// Gets the RequestId as a number. - /// - /// Thrown if the RequestId is not a number"" - public long AsNumber => _value is long number ? number : throw new InvalidOperationException("RequestId is not a number"); + /// Gets whether the identifier is uninitialized. + public bool IsDefault => _id is null; - /// - /// Returns the string representation of the RequestId. Will box the value if it is a number. - /// - public override string ToString() => _value.ToString() ?? ""; + /// + public override string ToString() => + _id is string stringValue ? stringValue : + _id is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) : + string.Empty; /// /// Compares this RequestId to another RequestId. /// - public bool Equals(RequestId other) => _value.Equals(other._value); + public bool Equals(RequestId other) => Equals(_id, other._id); /// public override bool Equals(object? obj) => obj is RequestId other && Equals(other); /// - public override int GetHashCode() => _value.GetHashCode(); + public override int GetHashCode() => _id?.GetHashCode() ?? 0; /// /// Compares two RequestIds for equality. @@ -81,4 +60,43 @@ private RequestId(object value) /// Compares two RequestIds for inequality. /// public static bool operator !=(RequestId left, RequestId right) => !left.Equals(right); + + /// + /// JSON converter for RequestId that handles both string and number values. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public sealed class Converter: JsonConverter + { + /// + public override RequestId Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return reader.TokenType switch + { + JsonTokenType.String => new(reader.GetString()!), + JsonTokenType.Number => new(reader.GetInt64()), + _ => throw new JsonException("requestId must be a string or an integer"), + }; + } + + /// + public override void Write(Utf8JsonWriter writer, RequestId value, JsonSerializerOptions options) + { + Throw.IfNull(writer); + + switch (value._id) + { + case string str: + writer.WriteStringValue(str); + return; + + case long longValue: + writer.WriteNumberValue(longValue); + return; + + case null: + writer.WriteStringValue(string.Empty); + return; + } + } + } } diff --git a/src/ModelContextProtocol/Protocol/Messages/RequestIdConverter.cs b/src/ModelContextProtocol/Protocol/Messages/RequestIdConverter.cs deleted file mode 100644 index 2d67c3127..000000000 --- a/src/ModelContextProtocol/Protocol/Messages/RequestIdConverter.cs +++ /dev/null @@ -1,37 +0,0 @@ -using ModelContextProtocol.Utils; -using System.Text.Json; -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol.Messages; - -/// -/// JSON converter for RequestId that handles both string and number values. -/// -public class RequestIdConverter : JsonConverter -{ - /// - public override RequestId Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - return reader.TokenType switch - { - JsonTokenType.String => RequestId.FromString(reader.GetString()!), - JsonTokenType.Number => RequestId.FromNumber(reader.GetInt64()), - _ => throw new JsonException("RequestId must be either a string or a number"), - }; - } - - /// - public override void Write(Utf8JsonWriter writer, RequestId value, JsonSerializerOptions options) - { - Throw.IfNull(writer); - - if (value.IsString) - { - writer.WriteStringValue(value.AsString); - } - else - { - writer.WriteNumberValue(value.AsNumber); - } - } -} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Messages/RequestMethods.cs b/src/ModelContextProtocol/Protocol/Messages/RequestMethods.cs new file mode 100644 index 000000000..3d0eb39b8 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Messages/RequestMethods.cs @@ -0,0 +1,82 @@ +namespace ModelContextProtocol.Protocol.Messages; + +/// +/// Provides names for request methods used in the Model Context Protocol (MCP). +/// +public static class RequestMethods +{ + /// + /// Sent from the client to request a list of tools the server has. + /// + public const string ToolsList = "tools/list"; + + /// + /// Used by the client to invoke a tool provided by the server. + /// + public const string ToolsCall = "tools/call"; + + /// + /// Sent from the client to request a list of prompts and prompt templates the server has. + /// + public const string PromptsList = "prompts/list"; + + /// + /// Used by the client to get a prompt provided by the server. + /// + public const string PromptsGet = "prompts/get"; + + /// + /// Sent from the client to request a list of resources the server has. + /// + public const string ResourcesList = "resources/list"; + + /// + /// Sent from the client to the server, to read a specific resource URI. + /// + public const string ResourcesRead = "resources/read"; + + /// + /// Sent from the client to request a list of resource templates the server has. + /// + public const string ResourcesTemplatesList = "resources/templates/list"; + + /// + /// Sent from the client to request resources/updated notifications from the server whenever a particular resource changes. + /// + public const string ResourcesSubscribe = "resources/subscribe"; + + /// + /// Sent from the client to request cancellation of resources/updated notifications from the server. + /// + public const string ResourcesUnsubscribe = "resources/unsubscribe"; + + /// + /// Sent from the server to request a list of root URIs from the client. + /// + public const string RootsList = "roots/list"; + + /// + /// A ping, issued by either the server or the client, to check that the other party is still alive. + /// + public const string Ping = "ping"; + + /// + /// A request from the client to the server, to enable or adjust logging. + /// + public const string LoggingSetLevel = "logging/setLevel"; + + /// + /// A request from the client to the server, to ask for completion options. + /// + public const string CompletionComplete = "completion/complete"; + + /// + /// A request from the server to sample an LLM via the client. + /// + public const string SamplingCreateMessage = "sampling/createMessage"; + + /// + /// This request is sent from the client to the server when it first connects, asking it to begin initialization. + /// + public const string Initialize = "initialize"; +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index ac96dc02f..9d3e24c53 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -118,7 +118,7 @@ public override async Task SendMessageAsync( var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); // Check if the message was an initialize request - if (message is JsonRpcRequest request && request.Method == "initialize") + if (message is JsonRpcRequest request && request.Method == RequestMethods.Initialize) { // If the response is not a JSON-RPC response, it is an SSE message if (responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs index fc213474e..2f3cf4204 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs @@ -21,6 +21,7 @@ internal sealed class StdioClientStreamTransport : TransportBase private readonly ILogger _logger; private readonly JsonSerializerOptions _jsonOptions; private readonly DataReceivedEventHandler _logProcessErrors; + private readonly SemaphoreSlim _sendLock = new(1, 1); private Process? _process; private Task? _readTask; private CancellationTokenSource? _shutdownCts; @@ -150,6 +151,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) /// public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { + using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); + if (!IsConnected || _process?.HasExited == true) { _logger.TransportNotConnected(EndpointName); diff --git a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs index 23dbfd6a1..dae1b75c1 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs @@ -1,4 +1,6 @@ -namespace ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Protocol.Messages; + +namespace ModelContextProtocol.Protocol.Types; /// /// A request from the server to get a list of root URIs from the client. @@ -10,5 +12,5 @@ public class ListRootsRequestParams /// Optional progress token for out-of-band progress notifications. /// [System.Text.Json.Serialization.JsonPropertyName("progressToken")] - public string? ProgressToken { get; init; } + public ProgressToken? ProgressToken { get; init; } } diff --git a/src/ModelContextProtocol/Protocol/Types/RequestParamsMetadata.cs b/src/ModelContextProtocol/Protocol/Types/RequestParamsMetadata.cs index 151064e4d..a4c3ff531 100644 --- a/src/ModelContextProtocol/Protocol/Types/RequestParamsMetadata.cs +++ b/src/ModelContextProtocol/Protocol/Types/RequestParamsMetadata.cs @@ -1,3 +1,4 @@ +using ModelContextProtocol.Protocol.Messages; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; @@ -11,5 +12,5 @@ public class RequestParamsMetadata /// If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. /// [JsonPropertyName("progressToken")] - public object ProgressToken { get; set; } = default!; + public ProgressToken? ProgressToken { get; set; } = default!; } diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index f65837710..47b8514de 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Shared; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Diagnostics.CodeAnalysis; @@ -98,6 +99,27 @@ private static TemporaryAIFunctionFactoryOptions CreateAIFunctionFactoryOptions( }; } + if (pi.ParameterType == typeof(IProgress)) + { + // Bind IProgress to the progress token in the request, + // if there is one. If we can't get one, return a nop progress. + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + { + var requestContent = GetRequestContext(args); + if (requestContent?.Server is { } server && + requestContent?.Params?.Meta?.ProgressToken is { } progressToken) + { + return new TokenProgress(server, progressToken); + } + + return NullProgress.Instance; + }, + }; + } + // We assume that if the services used to create the tool support a particular type, // so too do the services associated with the server. This is the same basic assumption // made in ASP.NET. diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index f2ed6e3ed..161736f7a 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -78,7 +78,7 @@ private McpServer(McpServerOptions options, ILoggerFactory? loggerFactory, IServ }); }; - AddNotificationHandler("notifications/initialized", _ => + AddNotificationHandler(NotificationMethods.InitializedNotification, _ => { if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) { @@ -176,13 +176,13 @@ public override async ValueTask DisposeUnsynchronizedAsync() private void SetPingHandler() { - SetRequestHandler("ping", + SetRequestHandler(RequestMethods.Ping, (request, _) => Task.FromResult(new PingResult())); } private void SetInitializeHandler(McpServerOptions options) { - SetRequestHandler("initialize", + SetRequestHandler(RequestMethods.Initialize, (request, _) => { ClientCapabilities = request?.Capabilities ?? new(); @@ -205,7 +205,7 @@ private void SetInitializeHandler(McpServerOptions options) private void SetCompletionHandler(McpServerOptions options) { // This capability is not optional, so return an empty result if there is no handler. - SetRequestHandler("completion/complete", + SetRequestHandler(RequestMethods.CompletionComplete, options.GetCompletionHandler is { } handler ? (request, ct) => handler(new(this, request), ct) : (request, ct) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } })); @@ -229,11 +229,11 @@ private void SetResourcesHandler(McpServerOptions options) listResourcesHandler ??= (static (_, _) => Task.FromResult(new ListResourcesResult())); - SetRequestHandler("resources/list", (request, ct) => listResourcesHandler(new(this, request), ct)); - SetRequestHandler("resources/read", (request, ct) => readResourceHandler(new(this, request), ct)); + SetRequestHandler(RequestMethods.ResourcesList, (request, ct) => listResourcesHandler(new(this, request), ct)); + SetRequestHandler(RequestMethods.ResourcesRead, (request, ct) => readResourceHandler(new(this, request), ct)); listResourceTemplatesHandler ??= (static (_, _) => Task.FromResult(new ListResourceTemplatesResult())); - SetRequestHandler("resources/templates/list", (request, ct) => listResourceTemplatesHandler(new(this, request), ct)); + SetRequestHandler(RequestMethods.ResourcesTemplatesList, (request, ct) => listResourceTemplatesHandler(new(this, request), ct)); if (resourcesCapability.Subscribe is not true) { @@ -247,8 +247,8 @@ private void SetResourcesHandler(McpServerOptions options) throw new McpServerException("Resources capability was enabled with subscribe support, but SubscribeToResources and/or UnsubscribeFromResources handlers were not specified."); } - SetRequestHandler("resources/subscribe", (request, ct) => subscribeHandler(new(this, request), ct)); - SetRequestHandler("resources/unsubscribe", (request, ct) => unsubscribeHandler(new(this, request), ct)); + SetRequestHandler(RequestMethods.ResourcesSubscribe, (request, ct) => subscribeHandler(new(this, request), ct)); + SetRequestHandler(RequestMethods.ResourcesUnsubscribe, (request, ct) => unsubscribeHandler(new(this, request), ct)); } private void SetPromptsHandler(McpServerOptions options) @@ -264,8 +264,8 @@ private void SetPromptsHandler(McpServerOptions options) throw new McpServerException("Prompts capability was enabled, but ListPrompts and/or GetPrompt handlers were not specified."); } - SetRequestHandler("prompts/list", (request, ct) => listPromptsHandler(new(this, request), ct)); - SetRequestHandler("prompts/get", (request, ct) => getPromptHandler(new(this, request), ct)); + SetRequestHandler(RequestMethods.PromptsList, (request, ct) => listPromptsHandler(new(this, request), ct)); + SetRequestHandler(RequestMethods.PromptsGet, (request, ct) => getPromptHandler(new(this, request), ct)); } private void SetToolsHandler(McpServerOptions options) @@ -363,8 +363,8 @@ private void SetToolsHandler(McpServerOptions options) } } - SetRequestHandler("tools/list", (request, ct) => listToolsHandler(new(this, request), ct)); - SetRequestHandler("tools/call", (request, ct) => callToolHandler(new(this, request), ct)); + SetRequestHandler(RequestMethods.ToolsList, (request, ct) => listToolsHandler(new(this, request), ct)); + SetRequestHandler(RequestMethods.ToolsCall, (request, ct) => callToolHandler(new(this, request), ct)); } private void SetSetLoggingLevelHandler(McpServerOptions options) @@ -379,6 +379,6 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) throw new McpServerException("Logging capability was enabled, but SetLoggingLevelHandler was not specified."); } - SetRequestHandler("logging/setLevel", (request, ct) => setLoggingLevelHandler(new(this, request), ct)); + SetRequestHandler(RequestMethods.LoggingSetLevel, (request, ct) => setLoggingLevelHandler(new(this, request), ct)); } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 7bff56642..3b541ec80 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -26,7 +26,7 @@ public static Task RequestSamplingAsync( } return server.SendRequestAsync( - new JsonRpcRequest { Method = "sampling/createMessage", Params = request }, + new JsonRpcRequest { Method = RequestMethods.SamplingCreateMessage, Params = request }, cancellationToken); } @@ -165,7 +165,7 @@ public static Task RequestRootsAsync( } return server.SendRequestAsync( - new JsonRpcRequest { Method = "roots/list", Params = request }, + new JsonRpcRequest { Method = RequestMethods.RootsList, Params = request }, cancellationToken); } diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index 187bd8dba..831d40c42 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -23,8 +23,9 @@ internal sealed class McpSession : IDisposable private readonly ConcurrentDictionary> _pendingRequests = []; private readonly JsonSerializerOptions _jsonOptions; private readonly ILogger _logger; - - private int _nextRequestId; + + private readonly string _id = Guid.NewGuid().ToString("N"); + private long _nextRequestId; /// /// Initializes a new instance of the class. @@ -141,7 +142,7 @@ private async Task HandleNotification(JsonRpcNotification notification) private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId messageWithId) { - if (!messageWithId.Id.IsValid) + if (messageWithId.Id.IsDefault) { _logger.RequestHasInvalidId(EndpointName); } @@ -212,7 +213,10 @@ public async Task SendRequestAsync(JsonRpcRequest request, Can } // Set request ID - request.Id = RequestId.FromNumber(Interlocked.Increment(ref _nextRequestId)); + if (request.Id.IsDefault) + { + request.Id = new RequestId($"{_id}-{Interlocked.Increment(ref _nextRequestId)}"); + } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); _pendingRequests[request.Id] = tcs; diff --git a/src/ModelContextProtocol/TokenProgress.cs b/src/ModelContextProtocol/TokenProgress.cs new file mode 100644 index 000000000..7cc97236a --- /dev/null +++ b/src/ModelContextProtocol/TokenProgress.cs @@ -0,0 +1,31 @@ +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Server; +using ModelContextProtocol.Shared; + +namespace ModelContextProtocol; + +/// +/// Provides an tied to a specific progress token and that will issue +/// progress notifications to the supplied endpoint. +/// +internal sealed class TokenProgress(IMcpServer server, ProgressToken progressToken) : IProgress +{ + /// + public void Report(ProgressNotificationValue value) + { + _ = server.SendMessageAsync(new JsonRpcNotification() + { + Method = NotificationMethods.ProgressNotification, + Params = new ProgressNotification() + { + ProgressToken = progressToken, + Progress = new() + { + Progress = value.Progress, + Total = value.Total, + Message = value.Message, + }, + }, + }, CancellationToken.None); + } +} diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs index d8513ef85..68d13eb2b 100644 --- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs @@ -142,6 +142,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(ListToolsResult))] [JsonSerializable(typeof(LoggingMessageNotificationParams))] [JsonSerializable(typeof(PingResult))] + [JsonSerializable(typeof(ProgressNotification))] [JsonSerializable(typeof(ReadResourceRequestParams))] [JsonSerializable(typeof(ReadResourceResult))] [JsonSerializable(typeof(ResourceUpdatedNotificationParams))] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 51da7d976..aeeb9007a 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -6,9 +6,11 @@ using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Transport; using ModelContextProtocol.Tests.Utils; +using System.Collections.Concurrent; using System.ComponentModel; using System.IO.Pipelines; using System.Text.Json; @@ -90,7 +92,7 @@ public async Task Can_List_Registered_Tools() IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(11, tools.Count); + Assert.Equal(12, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); Assert.Equal("Echo", echoTool.Name); @@ -137,7 +139,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T cancellationToken: TestContext.Current.CancellationToken)) { var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(11, tools.Count); + Assert.Equal(12, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); Assert.Equal("Echo", echoTool.Name); @@ -164,10 +166,10 @@ public async Task Can_Be_Notified_Of_Tool_Changes() IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(11, tools.Count); + Assert.Equal(12, tools.Count); Channel listChanged = Channel.CreateUnbounded(); - client.AddNotificationHandler("notifications/tools/list_changed", notification => + client.AddNotificationHandler(NotificationMethods.ToolListChangedNotification, notification => { listChanged.Writer.TryWrite(notification); return Task.CompletedTask; @@ -185,7 +187,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() await notificationRead; tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(12, tools.Count); + Assert.Equal(13, tools.Count); Assert.Contains(tools, t => t.Name == "NewTool"); notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); @@ -194,7 +196,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() await notificationRead; tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(11, tools.Count); + Assert.Equal(12, tools.Count); Assert.DoesNotContain(tools, t => t.Name == "NewTool"); } @@ -516,6 +518,49 @@ public void Create_ExtractsToolAnnotations_SomeSet() Assert.Null(annotations.ReadOnlyHint); } + [Fact] + public async Task HandlesIProgressParameter() + { + ConcurrentQueue notifications = new(); + + IMcpClient client = await CreateMcpClientForServer(); + client.AddNotificationHandler(NotificationMethods.ProgressNotification, notification => + { + ProgressNotification pn = JsonSerializer.Deserialize((JsonElement)notification.Params!)!; + notifications.Enqueue(pn); + return Task.CompletedTask; + }); + + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + Assert.NotNull(tools); + Assert.NotEmpty(tools); + + McpClientTool progressTool = tools.First(t => t.Name == nameof(EchoTool.SendsProgressNotifications)); + + var result = await client.SendRequestAsync(new JsonRpcRequest() + { + Method = RequestMethods.ToolsCall, + Params = new CallToolRequestParams() + { + Name = progressTool.ProtocolTool.Name, + Meta = new() { ProgressToken = new("abc123") }, + }, + }, TestContext.Current.CancellationToken); + + Assert.Contains("done", JsonSerializer.Serialize(result)); + SpinWait.SpinUntil(() => notifications.Count == 10, TimeSpan.FromSeconds(10)); + + ProgressNotification[] array = notifications.OrderBy(n => n.Progress.Progress).ToArray(); + Assert.Equal(10, array.Length); + for (int i = 0; i < array.Length; i++) + { + Assert.Equal("abc123", array[i].ProgressToken.ToString()); + Assert.Equal(i, array[i].Progress.Progress); + Assert.Equal(10, array[i].Progress.Total); + Assert.Equal($"Progress {i}", array[i].Progress.Message); + } + } + [McpServerToolType] public sealed class EchoTool(ObjectWithId objectFromDI) { @@ -583,6 +628,17 @@ public static string EchoComplex(ComplexObject complex) [McpServerTool] public string GetCtorParameter() => $"{_randomValue}:{objectFromDI.Id}"; + + [McpServerTool] + public string SendsProgressNotifications(IProgress progress) + { + for (int i = 0; i < 10; i++) + { + progress.Report(new() { Progress = i, Total = 10, Message = $"Progress {i}" }); + } + + return "done"; + } } [McpServerToolType] diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index f7455291a..30a32257f 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -129,7 +129,7 @@ public async Task RequestSamplingAsync_Should_SendRequest() Assert.NotNull(result); Assert.NotEmpty(transport.SentMessages); Assert.IsType(transport.SentMessages[0]); - Assert.Equal("sampling/createMessage", ((JsonRpcRequest)transport.SentMessages[0]).Method); + Assert.Equal(RequestMethods.SamplingCreateMessage, ((JsonRpcRequest)transport.SentMessages[0]).Method); await transport.DisposeAsync(); await runTask; @@ -162,7 +162,7 @@ public async Task RequestRootsAsync_Should_SendRequest() Assert.NotNull(result); Assert.NotEmpty(transport.SentMessages); Assert.IsType(transport.SentMessages[0]); - Assert.Equal("roots/list", ((JsonRpcRequest)transport.SentMessages[0]).Method); + Assert.Equal(RequestMethods.RootsList, ((JsonRpcRequest)transport.SentMessages[0]).Method); await transport.DisposeAsync(); await runTask; @@ -185,7 +185,7 @@ public async Task Can_Handle_Ping_Requests() { await Can_Handle_Requests( serverCapabilities: null, - method: "ping", + method: RequestMethods.Ping, configureOptions: null, assertResult: response => { @@ -198,7 +198,7 @@ public async Task Can_Handle_Initialize_Requests() { await Can_Handle_Requests( serverCapabilities: null, - method: "initialize", + method: RequestMethods.Initialize, configureOptions: null, assertResult: response => { @@ -216,7 +216,7 @@ public async Task Can_Handle_Completion_Requests() { await Can_Handle_Requests( serverCapabilities: null, - method: "completion/complete", + method: RequestMethods.CompletionComplete, configureOptions: null, assertResult: response => { @@ -235,7 +235,7 @@ public async Task Can_Handle_Completion_Requests_With_Handler() { await Can_Handle_Requests( serverCapabilities: null, - method: "completion/complete", + method: RequestMethods.CompletionComplete, configureOptions: options => { options.GetCompletionHandler = (request, ct) => @@ -287,7 +287,7 @@ await Can_Handle_Requests( ReadResourceHandler = (request, ct) => throw new NotImplementedException(), } }, - "resources/templates/list", + RequestMethods.ResourcesTemplatesList, configureOptions: null, assertResult: response => { @@ -318,7 +318,7 @@ await Can_Handle_Requests( ReadResourceHandler = (request, ct) => throw new NotImplementedException(), } }, - "resources/list", + RequestMethods.ResourcesList, configureOptions: null, assertResult: response => { @@ -334,7 +334,7 @@ await Can_Handle_Requests( [Fact] public async Task Can_Handle_Resources_List_Requests_Throws_Exception_If_No_Handler_Assigned() { - await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Resources = new() }, "resources/list", "ListResources handler not configured"); + await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Resources = new() }, RequestMethods.ResourcesList, "ListResources handler not configured"); } [Fact] @@ -355,7 +355,7 @@ await Can_Handle_Requests( ListResourcesHandler = (request, ct) => throw new NotImplementedException(), } }, - method: "resources/read", + method: RequestMethods.ResourcesRead, configureOptions: null, assertResult: response => { @@ -373,7 +373,7 @@ await Can_Handle_Requests( [Fact] public async Task Can_Handle_Resources_Read_Requests_Throws_Exception_If_No_Handler_Assigned() { - await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Resources = new() }, "resources/read", "ReadResource handler not configured"); + await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Resources = new() }, RequestMethods.ResourcesRead, "ReadResource handler not configured"); } [Fact] @@ -394,7 +394,7 @@ await Can_Handle_Requests( GetPromptHandler = (request, ct) => throw new NotImplementedException(), }, }, - method: "prompts/list", + method: RequestMethods.PromptsList, configureOptions: null, assertResult: response => { @@ -410,7 +410,7 @@ await Can_Handle_Requests( [Fact] public async Task Can_Handle_List_Prompts_Requests_Throws_Exception_If_No_Handler_Assigned() { - await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Prompts = new() }, "prompts/list", "ListPrompts handler not configured"); + await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Prompts = new() }, RequestMethods.PromptsList, "ListPrompts handler not configured"); } [Fact] @@ -425,7 +425,7 @@ await Can_Handle_Requests( ListPromptsHandler = (request, ct) => throw new NotImplementedException(), } }, - method: "prompts/get", + method: RequestMethods.PromptsGet, configureOptions: null, assertResult: response => { @@ -439,7 +439,7 @@ await Can_Handle_Requests( [Fact] public async Task Can_Handle_Get_Prompts_Requests_Throws_Exception_If_No_Handler_Assigned() { - await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Prompts = new() }, "prompts/get", "GetPrompt handler not configured"); + await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Prompts = new() }, RequestMethods.PromptsGet, "GetPrompt handler not configured"); } [Fact] @@ -460,7 +460,7 @@ await Can_Handle_Requests( CallToolHandler = (request, ct) => throw new NotImplementedException(), } }, - method: "tools/list", + method: RequestMethods.ToolsList, configureOptions: null, assertResult: response => { @@ -475,7 +475,7 @@ await Can_Handle_Requests( [Fact] public async Task Can_Handle_List_Tools_Requests_Throws_Exception_If_No_Handler_Assigned() { - await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Tools = new() }, "tools/list", "ListTools handler not configured"); + await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Tools = new() }, RequestMethods.ToolsList, "ListTools handler not configured"); } [Fact] @@ -496,7 +496,7 @@ await Can_Handle_Requests( ListToolsHandler = (request, ct) => throw new NotImplementedException(), } }, - method: "tools/call", + method: RequestMethods.ToolsCall, configureOptions: null, assertResult: response => { @@ -511,7 +511,7 @@ await Can_Handle_Requests( [Fact] public async Task Can_Handle_Call_Tool_Requests_Throws_Exception_If_No_Handler_Assigned() { - await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Tools = new() }, "tools/call", "CallTool handler not configured"); + await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Tools = new() }, RequestMethods.ToolsCall, "CallTool handler not configured"); } private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, string method, Action? configureOptions, Action assertResult) @@ -528,7 +528,7 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s transport.OnMessageSent = (message) => { - if (message is JsonRpcResponse response && response.Id.AsNumber == 55) + if (message is JsonRpcResponse response && response.Id.ToString() == "55") receivedMessage.SetResult(response); }; @@ -536,7 +536,7 @@ await transport.SendMessageAsync( new JsonRpcRequest { Method = method, - Id = RequestId.FromNumber(55) + Id = new RequestId(55) } ); diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index ca4a53634..43f818676 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -136,7 +136,7 @@ public async Task SendMessageAsync_Handles_Accepted_Response() }; await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); - await session.SendMessageAsync(new JsonRpcRequest() { Method = "initialize", Id = RequestId.FromNumber(44) }, CancellationToken.None); + await session.SendMessageAsync(new JsonRpcRequest() { Method = RequestMethods.Initialize, Id = new RequestId(44) }, CancellationToken.None); Assert.True(true); } @@ -181,7 +181,7 @@ public async Task SendMessageAsync_Handles_Accepted_Json_RPC_Response() await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); - await session.SendMessageAsync(new JsonRpcRequest() { Method = "initialize", Id = RequestId.FromNumber(44) }, CancellationToken.None); + await session.SendMessageAsync(new JsonRpcRequest() { Method = RequestMethods.Initialize, Id = new RequestId(44) }, CancellationToken.None); Assert.True(true); eventSourcePipe.Writer.Complete(); } @@ -214,7 +214,7 @@ public async Task ReceiveMessagesAsync_Handles_Messages() Assert.True(session.MessageReader.TryRead(out var message)); Assert.NotNull(message); Assert.IsType(message); - Assert.Equal("44", ((JsonRpcRequest)message).Id.AsString); + Assert.Equal("44", ((JsonRpcRequest)message).Id.ToString()); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 2801d1339..c1df1887b 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -77,7 +77,7 @@ public async Task SendMessageAsync_Should_Send_Message() // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); - var message = new JsonRpcRequest { Method = "test", Id = RequestId.FromNumber(44) }; + var message = new JsonRpcRequest { Method = "test", Id = new RequestId(44) }; await transport.SendMessageAsync(message, TestContext.Current.CancellationToken); @@ -100,7 +100,7 @@ public async Task DisposeAsync_Should_Dispose_Resources() [Fact] public async Task ReadMessagesAsync_Should_Read_Messages() { - var message = new JsonRpcRequest { Method = "test", Id = RequestId.FromNumber(44) }; + var message = new JsonRpcRequest { Method = "test", Id = new RequestId(44) }; var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions); // Use a reader that won't terminate @@ -125,7 +125,7 @@ public async Task ReadMessagesAsync_Should_Read_Messages() Assert.True(transport.MessageReader.TryPeek(out var readMessage)); Assert.NotNull(readMessage); Assert.IsType(readMessage); - Assert.Equal(44, ((JsonRpcRequest)readMessage).Id.AsNumber); + Assert.Equal("44", ((JsonRpcRequest)readMessage).Id.ToString()); } [Fact] @@ -158,7 +158,7 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() var chineseMessage = new JsonRpcRequest { Method = "test", - Id = RequestId.FromNumber(44), + Id = new RequestId(44), Params = new Dictionary { ["text"] = JsonSerializer.SerializeToElement(chineseText) @@ -180,7 +180,7 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() var emojiMessage = new JsonRpcRequest { Method = "test", - Id = RequestId.FromNumber(45), + Id = new RequestId(45), Params = new Dictionary { ["text"] = JsonSerializer.SerializeToElement(emojiText) diff --git a/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs b/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs index 453664b95..0bdfde192 100644 --- a/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs +++ b/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs @@ -191,7 +191,7 @@ private async Task HandlePostMessageAsync(HttpListenerContext context, Cancellat string content = await reader.ReadToEndAsync(cancellationToken); var jsonRpcNotification = JsonSerializer.Deserialize(content); - if (jsonRpcNotification != null && jsonRpcNotification.Method != "initialize") + if (jsonRpcNotification != null && jsonRpcNotification.Method != RequestMethods.Initialize) { // Test server so just ignore notifications @@ -209,7 +209,7 @@ private async Task HandlePostMessageAsync(HttpListenerContext context, Cancellat if (jsonRpcRequest != null) { - if (jsonRpcRequest.Method == "initialize") + if (jsonRpcRequest.Method == RequestMethods.Initialize) { await HandleInitializationRequest(response, jsonRpcRequest); } @@ -266,7 +266,7 @@ private static async Task SendJsonRpcErrorAsync(HttpListenerResponse response, R { var errorResponse = new JsonRpcError { - Id = id ?? RequestId.FromString("error"), + Id = id ?? new RequestId("error"), JsonRpc = "2.0", Error = new JsonRpcErrorDetail { diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs index 16c9ac86f..33a133616 100644 --- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs @@ -39,9 +39,9 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca SentMessages.Add(message); if (message is JsonRpcRequest request) { - if (request.Method == "roots/list") + if (request.Method == RequestMethods.RootsList) await ListRoots(request, cancellationToken); - else if (request.Method == "sampling/createMessage") + else if (request.Method == RequestMethods.SamplingCreateMessage) await Sampling(request, cancellationToken); else await WriteMessageAsync(request, cancellationToken); From 6330dd0456eed00173696152b3280abbb0b93252 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 31 Mar 2025 09:22:51 -0400 Subject: [PATCH 2/2] Address PR feedback --- .../Protocol/Messages/ProgressNotification.cs | 6 +----- src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs | 2 +- src/ModelContextProtocol/Protocol/Messages/RequestId.cs | 2 +- .../Configuration/McpServerBuilderExtensionsToolsTests.cs | 2 +- .../Transport/SseClientTransportTests.cs | 2 +- 5 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/ModelContextProtocol/Protocol/Messages/ProgressNotification.cs b/src/ModelContextProtocol/Protocol/Messages/ProgressNotification.cs index 5b351009a..8509076f1 100644 --- a/src/ModelContextProtocol/Protocol/Messages/ProgressNotification.cs +++ b/src/ModelContextProtocol/Protocol/Messages/ProgressNotification.cs @@ -42,11 +42,7 @@ public sealed class Converter : JsonConverter switch (propertyName) { case "progressToken": - if (JsonSerializer.Deserialize(ref reader, options.GetTypeInfo(typeof(ProgressToken))) is not ProgressToken token) - { - throw new JsonException("Invalid value for 'progressToken'."); - } - progressToken = token; + progressToken = (ProgressToken)JsonSerializer.Deserialize(ref reader, options.GetTypeInfo(typeof(ProgressToken)))!; break; case "progress": diff --git a/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs b/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs index b3f4a8a93..6183eb92e 100644 --- a/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs +++ b/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs @@ -36,7 +36,7 @@ public ProgressToken(long value) /// public override string? ToString() => - _id is string stringValue ? stringValue : + _id is string stringValue ? $"\"{stringValue}\"" : _id is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) : null; diff --git a/src/ModelContextProtocol/Protocol/Messages/RequestId.cs b/src/ModelContextProtocol/Protocol/Messages/RequestId.cs index f2dc12edd..e6fc74418 100644 --- a/src/ModelContextProtocol/Protocol/Messages/RequestId.cs +++ b/src/ModelContextProtocol/Protocol/Messages/RequestId.cs @@ -36,7 +36,7 @@ public RequestId(long value) /// public override string ToString() => - _id is string stringValue ? stringValue : + _id is string stringValue ? $"\"{stringValue}\"" : _id is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) : string.Empty; diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index aeeb9007a..baeb2c7c4 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -554,7 +554,7 @@ public async Task HandlesIProgressParameter() Assert.Equal(10, array.Length); for (int i = 0; i < array.Length; i++) { - Assert.Equal("abc123", array[i].ProgressToken.ToString()); + Assert.Equal("\"abc123\"", array[i].ProgressToken.ToString()); Assert.Equal(i, array[i].Progress.Progress); Assert.Equal(10, array[i].Progress.Total); Assert.Equal($"Progress {i}", array[i].Progress.Message); diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 43f818676..8e64dd3b2 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -214,7 +214,7 @@ public async Task ReceiveMessagesAsync_Handles_Messages() Assert.True(session.MessageReader.TryRead(out var message)); Assert.NotNull(message); Assert.IsType(message); - Assert.Equal("44", ((JsonRpcRequest)message).Id.ToString()); + Assert.Equal("\"44\"", ((JsonRpcRequest)message).Id.ToString()); } [Fact]