Skip to content

Add backpressure when rapidly creating new stateful Streamable HTTP sessions without closing them #677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 67 additions & 9 deletions src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
using ModelContextProtocol.AspNetCore.Stateless;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using System.Diagnostics;
using System.Security.Claims;
using System.Threading;

namespace ModelContextProtocol.AspNetCore;

internal sealed class HttpMcpSession<TTransport>(
string sessionId,
TTransport transport,
UserIdClaim? userId,
TimeProvider timeProvider) : IAsyncDisposable
TimeProvider timeProvider,
SemaphoreSlim? idleSessionSemaphore = null) : IAsyncDisposable
where TTransport : ITransport
{
private int _referenceCount;
private int _getRequestStarted;
private CancellationTokenSource _disposeCts = new();
private bool _isDisposed;

private readonly SemaphoreSlim? _idleSessionSemaphore = idleSessionSemaphore;
private readonly CancellationTokenSource _disposeCts = new();
private readonly object _referenceCountLock = new();

public string Id { get; } = sessionId;
public TTransport Transport { get; } = transport;
Expand All @@ -30,16 +37,43 @@ internal sealed class HttpMcpSession<TTransport>(
public IMcpServer? Server { get; set; }
public Task? ServerRunTask { get; set; }

public IDisposable AcquireReference()
public IAsyncDisposable AcquireReference()
{
Interlocked.Increment(ref _referenceCount);
// We don't do idle tracking for stateless sessions, so we don't need to acquire a reference.
if (_idleSessionSemaphore is null)
{
return new NoopDisposable();
}

lock (_referenceCountLock)
{
if (!_isDisposed && ++_referenceCount == 1)
{
// Non-idle sessions should not prevent session creation.
_idleSessionSemaphore.Release();
}
}

return new UnreferenceDisposable(this);
}

public bool TryStartGetRequest() => Interlocked.Exchange(ref _getRequestStarted, 1) == 0;

public async ValueTask DisposeAsync()
{
bool shouldReleaseIdleSessionSemaphore;

lock (_referenceCountLock)
{
if (_isDisposed)
{
return;
}

_isDisposed = true;
shouldReleaseIdleSessionSemaphore = _referenceCount == 0;
}

try
{
await _disposeCts.CancelAsync();
Expand All @@ -65,21 +99,45 @@ public async ValueTask DisposeAsync()
{
await Transport.DisposeAsync();
_disposeCts.Dispose();

// If the session was disposed while it was inactive, we need to release the semaphore
// to allow new sessions to be created.
if (_idleSessionSemaphore is not null && shouldReleaseIdleSessionSemaphore)
{
_idleSessionSemaphore.Release();
}
}
}
}

public bool HasSameUserId(ClaimsPrincipal user)
=> UserIdClaim == StreamableHttpHandler.GetUserIdClaim(user);
public bool HasSameUserId(ClaimsPrincipal user) => UserIdClaim == StreamableHttpHandler.GetUserIdClaim(user);

private sealed class UnreferenceDisposable(HttpMcpSession<TTransport> session) : IDisposable
private sealed class UnreferenceDisposable(HttpMcpSession<TTransport> session) : IAsyncDisposable
{
public void Dispose()
public async ValueTask DisposeAsync()
{
if (Interlocked.Decrement(ref session._referenceCount) == 0)
Debug.Assert(session._idleSessionSemaphore is not null, "Only StreamableHttpHandler should call AcquireReference.");

bool shouldMarkSessionIdle;

lock (session._referenceCountLock)
{
shouldMarkSessionIdle = !session._isDisposed && --session._referenceCount == 0;
}

if (shouldMarkSessionIdle)
{
session.LastActivityTicks = session.TimeProvider.GetTimestamp();

// Acquire semaphore when session becomes inactive (reference count goes to 0) to slow
// down session creation until idle sessions are disposed by the background service.
await session._idleSessionSemaphore.WaitAsync();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this wait forever since there is no cancellationToken?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory no, since the session is already marked as idle by virtue of the reference count going to zero, the corresponding call to Release() to unblock this should not be inhibited.

For this to wait, that means we are at least at 110% of the MaxIdleSessionCount. Either idle tracking background service will eventually dispose enough sessions to get it below 110% and call Release() or another request will reactivate the session and call Release(). And any sessions waiting for this call to WaitAsync() to complete are already eligible for pruning if they haven't been reactivated.

}
}
}

private sealed class NoopDisposable : IAsyncDisposable
{
public ValueTask DisposeAsync() => ValueTask.CompletedTask;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ public class HttpServerTransportOptions
/// Past this limit, the server will log a critical error and terminate the oldest idle sessions even if they have not reached
/// their <see cref="IdleTimeout"/> until the idle session count is below this limit. Clients that keep their session open by
/// keeping a GET request open will not count towards this limit.
/// Defaults to 100,000 sessions.
/// Defaults to 10,000 sessions.
/// </remarks>
public int MaxIdleSessionCount { get; set; } = 100_000;
public int MaxIdleSessionCount { get; set; } = 10_000;

/// <summary>
/// Used for testing the <see cref="IdleTimeout"/>.
Expand Down
23 changes: 20 additions & 3 deletions src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ internal sealed class StreamableHttpHandler(

public ConcurrentDictionary<string, HttpMcpSession<StreamableHttpServerTransport>> Sessions { get; } = new(StringComparer.Ordinal);

// Semaphore to control session creation backpressure when there are too many idle sessions
// Initial and max count is 10% more than MaxIdleSessionCount (or 100 more if that's higher)
private readonly SemaphoreSlim _idleSessionSemaphore = CreateIdleSessionSemaphore(httpServerTransportOptions.Value);

public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value;

private IDataProtector Protector { get; } = dataProtection.CreateProtector("Microsoft.AspNetCore.StreamableHttpHandler.StatelessSessionId");
Expand Down Expand Up @@ -58,7 +62,7 @@ await WriteJsonRpcErrorAsync(context,

try
{
using var _ = session.AcquireReference();
await using var _ = session.AcquireReference();

InitializeSseResponse(context);
var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted);
Expand Down Expand Up @@ -106,7 +110,7 @@ await WriteJsonRpcErrorAsync(context,
return;
}

using var _ = session.AcquireReference();
await using var _ = session.AcquireReference();
InitializeSseResponse(context);

// We should flush headers to indicate a 200 success quickly, because the initialization response
Expand Down Expand Up @@ -184,6 +188,11 @@ private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>> StartNewS

if (!HttpServerTransportOptions.Stateless)
{
// Acquire semaphore before creating stateful sessions to create backpressure.
// This semaphore represents "slots" for idle sessions, and we may need to wait on the
// IdleTrackingBackgroundService to dispose idle sessions before continuing.
await _idleSessionSemaphore.WaitAsync(context.RequestAborted);

sessionId = MakeNewSessionId();
transport = new()
{
Expand Down Expand Up @@ -248,7 +257,8 @@ private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>> CreateSes
context.Features.Set(server);

var userIdClaim = statelessId?.UserIdClaim ?? GetUserIdClaim(context.User);
var session = new HttpMcpSession<StreamableHttpServerTransport>(sessionId, transport, userIdClaim, HttpServerTransportOptions.TimeProvider)
var semaphore = HttpServerTransportOptions.Stateless ? null : _idleSessionSemaphore;
var session = new HttpMcpSession<StreamableHttpServerTransport>(sessionId, transport, userIdClaim, HttpServerTransportOptions.TimeProvider, semaphore)
{
Server = server,
};
Expand Down Expand Up @@ -337,6 +347,13 @@ private static bool MatchesApplicationJsonMediaType(MediaTypeHeaderValue acceptH
private static bool MatchesTextEventStreamMediaType(MediaTypeHeaderValue acceptHeaderValue)
=> acceptHeaderValue.MatchesMediaType("text/event-stream");

private static SemaphoreSlim CreateIdleSessionSemaphore(HttpServerTransportOptions options)
{
var maxIdleSessionCount = options.MaxIdleSessionCount;
var semaphoreCount = Math.Max(maxIdleSessionCount + 100, (int)(maxIdleSessionCount * 1.1));
return new SemaphoreSlim(semaphoreCount, semaphoreCount);
}

private sealed class HttpDuplexPipe(HttpContext context) : IDuplexPipe
{
public PipeReader Input => context.Request.BodyReader;
Expand Down
2 changes: 1 addition & 1 deletion src/ModelContextProtocol.Core/McpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ private static TimeSpan GetElapsed(long startingTimestamp) =>
[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} method '{Method}' request handler failed.")]
private partial void LogRequestHandlerException(string endpointName, string method, Exception exception);

[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received request for unknown request ID '{RequestId}'.")]
[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received message for unknown request ID '{RequestId}'.")]
private partial void LogNoRequestFoundForMessageWithId(string endpointName, RequestId requestId);

[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}': {ErrorMessage} ({ErrorCode}).")]
Expand Down
Loading