Skip to content

Commit e24a31a

Browse files
committed
Improve progress reporting
- Enable injecting an `IProgress<>` into a tool call to report progress when the client has supplied a progress token - Add a custom type for ProgressToken and ProgressNotification to map to the schema - Simplify RequestId to match ProgressToken's shape - Serialize StdioClientTransport's SendMessageAsync implementation - Add strongly-typed names for all request methods - Change all request IDs created by McpJsonRpcEndpoint to include a guid for that endpoint
1 parent f55302b commit e24a31a

27 files changed

+611
-231
lines changed

src/ModelContextProtocol/Client/McpClient.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer
4141
}
4242

4343
SetRequestHandler<CreateMessageRequestParams, CreateMessageResult>(
44-
"sampling/createMessage",
44+
RequestMethods.SamplingCreateMessage,
4545
(request, ct) => samplingHandler(request, ct));
4646
}
4747

@@ -53,7 +53,7 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer
5353
}
5454

5555
SetRequestHandler<ListRootsRequestParams, ListRootsResult>(
56-
"roots/list",
56+
RequestMethods.RootsList,
5757
(request, ct) => rootsHandler(request, ct));
5858
}
5959
}
@@ -124,7 +124,7 @@ private async Task InitializeAsync(CancellationToken cancellationToken)
124124
var initializeResponse = await SendRequestAsync<InitializeResult>(
125125
new JsonRpcRequest
126126
{
127-
Method = "initialize",
127+
Method = RequestMethods.Initialize,
128128
Params = new InitializeRequestParams()
129129
{
130130
ProtocolVersion = _options.ProtocolVersion,
@@ -152,7 +152,7 @@ private async Task InitializeAsync(CancellationToken cancellationToken)
152152

153153
// Send initialized notification
154154
await SendMessageAsync(
155-
new JsonRpcNotification { Method = "notifications/initialized" },
155+
new JsonRpcNotification { Method = NotificationMethods.InitializedNotification },
156156
initializationCts.Token).ConfigureAwait(false);
157157
}
158158
catch (OperationCanceledException) when (initializationCts.IsCancellationRequested)

src/ModelContextProtocol/Client/McpClientExtensions.cs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public static Task PingAsync(this IMcpClient client, CancellationToken cancellat
4141
Throw.IfNull(client);
4242

4343
return client.SendRequestAsync<dynamic>(
44-
CreateRequest("ping", null),
44+
CreateRequest(RequestMethods.Ping, null),
4545
cancellationToken);
4646
}
4747

@@ -61,7 +61,7 @@ public static async Task<IList<McpClientTool>> ListToolsAsync(
6161
do
6262
{
6363
var toolResults = await client.SendRequestAsync<ListToolsResult>(
64-
CreateRequest("tools/list", CreateCursorDictionary(cursor)),
64+
CreateRequest(RequestMethods.ToolsList, CreateCursorDictionary(cursor)),
6565
cancellationToken).ConfigureAwait(false);
6666

6767
tools ??= new List<McpClientTool>(toolResults.Tools.Count);
@@ -96,7 +96,7 @@ public static async IAsyncEnumerable<McpClientTool> EnumerateToolsAsync(
9696
do
9797
{
9898
var toolResults = await client.SendRequestAsync<ListToolsResult>(
99-
CreateRequest("tools/list", CreateCursorDictionary(cursor)),
99+
CreateRequest(RequestMethods.ToolsList, CreateCursorDictionary(cursor)),
100100
cancellationToken).ConfigureAwait(false);
101101

102102
foreach (var tool in toolResults.Tools)
@@ -126,7 +126,7 @@ public static async Task<IList<Prompt>> ListPromptsAsync(
126126
do
127127
{
128128
var promptResults = await client.SendRequestAsync<ListPromptsResult>(
129-
CreateRequest("prompts/list", CreateCursorDictionary(cursor)),
129+
CreateRequest(RequestMethods.PromptsList, CreateCursorDictionary(cursor)),
130130
cancellationToken).ConfigureAwait(false);
131131

132132
if (prompts is null)
@@ -164,7 +164,7 @@ public static async IAsyncEnumerable<Prompt> EnumeratePromptsAsync(
164164
do
165165
{
166166
var promptResults = await client.SendRequestAsync<ListPromptsResult>(
167-
CreateRequest("prompts/list", CreateCursorDictionary(cursor)),
167+
CreateRequest(RequestMethods.PromptsList, CreateCursorDictionary(cursor)),
168168
cancellationToken).ConfigureAwait(false);
169169

170170
foreach (var prompt in promptResults.Prompts)
@@ -192,7 +192,7 @@ public static Task<GetPromptResult> GetPromptAsync(
192192
Throw.IfNullOrWhiteSpace(name);
193193

194194
return client.SendRequestAsync<GetPromptResult>(
195-
CreateRequest("prompts/get", CreateParametersDictionary(name, arguments)),
195+
CreateRequest(RequestMethods.PromptsGet, CreateParametersDictionary(name, arguments)),
196196
cancellationToken);
197197
}
198198

@@ -213,7 +213,7 @@ public static async Task<IList<ResourceTemplate>> ListResourceTemplatesAsync(
213213
do
214214
{
215215
var templateResults = await client.SendRequestAsync<ListResourceTemplatesResult>(
216-
CreateRequest("resources/templates/list", CreateCursorDictionary(cursor)),
216+
CreateRequest(RequestMethods.ResourcesTemplatesList, CreateCursorDictionary(cursor)),
217217
cancellationToken).ConfigureAwait(false);
218218

219219
if (templates is null)
@@ -251,7 +251,7 @@ public static async IAsyncEnumerable<ResourceTemplate> EnumerateResourceTemplate
251251
do
252252
{
253253
var templateResults = await client.SendRequestAsync<ListResourceTemplatesResult>(
254-
CreateRequest("resources/templates/list", CreateCursorDictionary(cursor)),
254+
CreateRequest(RequestMethods.ResourcesTemplatesList, CreateCursorDictionary(cursor)),
255255
cancellationToken).ConfigureAwait(false);
256256

257257
foreach (var template in templateResults.ResourceTemplates)
@@ -281,7 +281,7 @@ public static async Task<IList<Resource>> ListResourcesAsync(
281281
do
282282
{
283283
var resourceResults = await client.SendRequestAsync<ListResourcesResult>(
284-
CreateRequest("resources/list", CreateCursorDictionary(cursor)),
284+
CreateRequest(RequestMethods.ResourcesList, CreateCursorDictionary(cursor)),
285285
cancellationToken).ConfigureAwait(false);
286286

287287
if (resources is null)
@@ -319,7 +319,7 @@ public static async IAsyncEnumerable<Resource> EnumerateResourcesAsync(
319319
do
320320
{
321321
var resourceResults = await client.SendRequestAsync<ListResourcesResult>(
322-
CreateRequest("resources/list", CreateCursorDictionary(cursor)),
322+
CreateRequest(RequestMethods.ResourcesList, CreateCursorDictionary(cursor)),
323323
cancellationToken).ConfigureAwait(false);
324324

325325
foreach (var resource in resourceResults.Resources)
@@ -345,7 +345,7 @@ public static Task<ReadResourceResult> ReadResourceAsync(
345345
Throw.IfNullOrWhiteSpace(uri);
346346

347347
return client.SendRequestAsync<ReadResourceResult>(
348-
CreateRequest("resources/read", new() { ["uri"] = uri }),
348+
CreateRequest(RequestMethods.ResourcesRead, new() { ["uri"] = uri }),
349349
cancellationToken);
350350
}
351351

@@ -369,7 +369,7 @@ public static Task<CompleteResult> GetCompletionAsync(this IMcpClient client, Re
369369
}
370370

371371
return client.SendRequestAsync<CompleteResult>(
372-
CreateRequest("completion/complete", new()
372+
CreateRequest(RequestMethods.CompletionComplete, new()
373373
{
374374
["ref"] = reference,
375375
["argument"] = new Argument { Name = argumentName, Value = argumentValue }
@@ -389,7 +389,7 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri,
389389
Throw.IfNullOrWhiteSpace(uri);
390390

391391
return client.SendRequestAsync<EmptyResult>(
392-
CreateRequest("resources/subscribe", new() { ["uri"] = uri }),
392+
CreateRequest(RequestMethods.ResourcesSubscribe, new() { ["uri"] = uri }),
393393
cancellationToken);
394394
}
395395

@@ -405,7 +405,7 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u
405405
Throw.IfNullOrWhiteSpace(uri);
406406

407407
return client.SendRequestAsync<EmptyResult>(
408-
CreateRequest("resources/unsubscribe", new() { ["uri"] = uri }),
408+
CreateRequest(RequestMethods.ResourcesUnsubscribe, new() { ["uri"] = uri }),
409409
cancellationToken);
410410
}
411411

@@ -424,7 +424,7 @@ public static Task<CallToolResponse> CallToolAsync(
424424
Throw.IfNull(toolName);
425425

426426
return client.SendRequestAsync<CallToolResponse>(
427-
CreateRequest("tools/call", CreateParametersDictionary(toolName, arguments)),
427+
CreateRequest(RequestMethods.ToolsCall, CreateParametersDictionary(toolName, arguments)),
428428
cancellationToken);
429429
}
430430

@@ -570,7 +570,7 @@ public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, C
570570
Throw.IfNull(client);
571571

572572
return client.SendRequestAsync<EmptyResult>(
573-
CreateRequest("logging/setLevel", new() { ["level"] = level }),
573+
CreateRequest(RequestMethods.LoggingSetLevel, new() { ["level"] = level }),
574574
cancellationToken);
575575
}
576576

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
namespace ModelContextProtocol;
2+
3+
/// <summary>Provides an <see cref="IProgress{ProgressNotificationValue}"/> that's a nop.</summary>
4+
internal sealed class NullProgress : IProgress<ProgressNotificationValue>
5+
{
6+
public static NullProgress Instance { get; } = new();
7+
8+
/// <inheritdoc />
9+
public void Report(ProgressNotificationValue value)
10+
{
11+
}
12+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
namespace ModelContextProtocol;
2+
3+
/// <summary>Provides a progress value that can be sent using <see cref="IProgress{ProgressNotificationValue}"/>.</summary>
4+
public record struct ProgressNotificationValue
5+
{
6+
/// <summary>Gets or sets the progress thus far.</summary>
7+
public required float Progress { get; init; }
8+
9+
/// <summary>Gets or sets the total number of items to process (or total progress required), if known.</summary>
10+
public float? Total { get; init; }
11+
12+
/// <summary>Gets or sets an optional message describing the current progress.</summary>
13+
public string? Message { get; init; }
14+
}

src/ModelContextProtocol/Protocol/Messages/NotificationMethods.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,27 @@ public static class NotificationMethods
3434
/// Sent by the server when a log message is generated.
3535
/// </summary>
3636
public const string LoggingMessageNotification = "notifications/message";
37+
38+
/// <summary>
39+
/// Sent from the client to the server after initialization has finished.
40+
/// </summary>
41+
public const string InitializedNotification = "notifications/initialized";
42+
43+
/// <summary>
44+
/// Sent to inform the receiver of a progress update for a long-running request.
45+
/// </summary>
46+
public const string ProgressNotification = "notifications/progress";
47+
48+
/// <summary>
49+
/// Sent by either side to indicate that it is cancelling a previously-issued request.
50+
/// </summary>
51+
/// <remarks>
52+
/// The request SHOULD still be in-flight, but due to communication latency, it is always possible that this notification
53+
/// MAY arrive after the request has already finished.
54+
///
55+
/// This notification indicates that the result will be unused, so any associated processing SHOULD cease.
56+
///
57+
/// A client MUST NOT attempt to cancel its `initialize` request.".
58+
/// </remarks>
59+
public const string CancelledNotification = "notifications/cancelled";
3760
}

src/ModelContextProtocol/Protocol/Messages/OperationNames.cs

Lines changed: 0 additions & 43 deletions
This file was deleted.
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
using System.ComponentModel;
2+
using System.Text.Json;
3+
using System.Text.Json.Serialization;
4+
5+
namespace ModelContextProtocol.Protocol.Messages;
6+
7+
/// <summary>
8+
/// An out-of-band notification used to inform the receiver of a progress update for a long-running request.
9+
/// <see href="https://github.com/modelcontextprotocol/specification/blob/main/schema/">See the schema for details</see>
10+
/// </summary>
11+
[JsonConverter(typeof(Converter))]
12+
public class ProgressNotification
13+
{
14+
/// <summary>
15+
/// The progress token which was given in the initial request, used to associate this notification with the request that is proceeding.
16+
/// </summary>
17+
public required ProgressToken ProgressToken { get; init; }
18+
19+
/// <summary>
20+
/// The progress thus far. This should increase every time progress is made, even if the total is unknown.
21+
/// </summary>
22+
public required ProgressNotificationValue Progress { get; init; }
23+
24+
/// <summary>Provides a <see cref="JsonConverter"/> for <see cref="ProgressNotification"/>.</summary>
25+
[EditorBrowsable(EditorBrowsableState.Never)]
26+
public sealed class Converter : JsonConverter<ProgressNotification>
27+
{
28+
/// <inheritdoc />
29+
public override ProgressNotification? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
30+
{
31+
ProgressToken? progressToken = null;
32+
float? progress = null;
33+
float? total = null;
34+
string? message = null;
35+
36+
while (reader.Read() && reader.TokenType != JsonTokenType.EndObject)
37+
{
38+
if (reader.TokenType == JsonTokenType.PropertyName)
39+
{
40+
var propertyName = reader.GetString();
41+
reader.Read();
42+
switch (propertyName)
43+
{
44+
case "progressToken":
45+
if (JsonSerializer.Deserialize(ref reader, options.GetTypeInfo(typeof(ProgressToken))) is not ProgressToken token)
46+
{
47+
throw new JsonException("Invalid value for 'progressToken'.");
48+
}
49+
progressToken = token;
50+
break;
51+
52+
case "progress":
53+
progress = reader.GetSingle();
54+
break;
55+
56+
case "total":
57+
total = reader.GetSingle();
58+
break;
59+
60+
case "message":
61+
message = reader.GetString();
62+
break;
63+
}
64+
}
65+
}
66+
67+
if (progress is null)
68+
{
69+
throw new JsonException("Missing required property 'progress'.");
70+
}
71+
72+
if (progressToken is null)
73+
{
74+
throw new JsonException("Missing required property 'progressToken'.");
75+
}
76+
77+
return new ProgressNotification
78+
{
79+
ProgressToken = progressToken.GetValueOrDefault(),
80+
Progress = new ProgressNotificationValue()
81+
{
82+
Progress = progress.GetValueOrDefault(),
83+
Total = total,
84+
Message = message
85+
}
86+
};
87+
}
88+
89+
/// <inheritdoc />
90+
public override void Write(Utf8JsonWriter writer, ProgressNotification value, JsonSerializerOptions options)
91+
{
92+
writer.WriteStartObject();
93+
94+
writer.WritePropertyName("progressToken");
95+
JsonSerializer.Serialize(writer, value.ProgressToken, options.GetTypeInfo(typeof(ProgressToken)));
96+
97+
writer.WriteNumber("progress", value.Progress.Progress);
98+
99+
if (value.Progress.Total is { } total)
100+
{
101+
writer.WriteNumber("total", total);
102+
}
103+
104+
if (value.Progress.Message is { } message)
105+
{
106+
writer.WriteString("message", message);
107+
}
108+
109+
writer.WriteEndObject();
110+
}
111+
}
112+
}

0 commit comments

Comments
 (0)