diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index b5b147d149e6..a3beea775975 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -172,9 +172,6 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti if (connection is not null) { Log.EstablishedConnection(_logger); - - // Allow the reads to be canceled - connection.Cancellation ??= new CancellationTokenSource(); } } else @@ -198,7 +195,7 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti if (connection.TransportType != HttpTransportType.WebSockets || connection.UseStatefulReconnect) { - if (!await connection.CancelPreviousPoll(context)) + if (connection.ApplicationTask is not null && !await connection.CancelPreviousPoll(context)) { // Connection closed. It's already set the response status code. return; @@ -215,6 +212,9 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti case HttpTransportType.None: break; case HttpTransportType.WebSockets: + // Allow the reads to be canceled + connection.Cancellation ??= new CancellationTokenSource(); + var isReconnect = connection.ApplicationTask is not null; var ws = new WebSocketsServerTransport(options.WebSockets, connection.Application, connection, _loggerFactory); if (!connection.TryActivatePersistentConnection(connectionDelegate, ws, currentRequestTcs.Task, context, _logger)) @@ -376,6 +376,11 @@ private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatche if (error == null) { connection = CreateConnection(options, clientProtocolVersion, useStatefulReconnect); + if (connection.Status == HttpConnectionStatus.Disposed) + { + // Happens if the server is shutting down when a new negotiate request comes in + error = "The connection was closed before negotiation completed."; + } } // Set the Connection ID on the logging scope so that logs from now on will have the diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index ba658aa38ac7..9bb6b90769ed 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -26,6 +26,9 @@ internal sealed partial class HttpConnectionManager private readonly ILogger _connectionLogger; private readonly TimeSpan _disconnectTimeout; private readonly HttpConnectionsMetrics _metrics; + private readonly IHostApplicationLifetime _applicationLifetime; + private readonly Lock _closeLock = new(); + private bool _closed; public HttpConnectionManager(ILoggerFactory loggerFactory, IHostApplicationLifetime appLifetime, IOptions connectionOptions, HttpConnectionsMetrics metrics) { @@ -34,6 +37,7 @@ public HttpConnectionManager(ILoggerFactory loggerFactory, IHostApplicationLifet _nextHeartbeat = new PeriodicTimer(_heartbeatTickRate); _disconnectTimeout = connectionOptions.Value.DisconnectTimeout ?? ConnectionOptionsSetup.DefaultDisconectTimeout; _metrics = metrics; + _applicationLifetime = appLifetime; // Register these last as the callbacks could run immediately appLifetime.ApplicationStarted.Register(Start); @@ -82,6 +86,12 @@ internal HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions _connections.TryAdd(connectionToken, connection); + // If the application is stopping don't allow new connections to be created + if (_applicationLifetime.ApplicationStopping.IsCancellationRequested || _closed) + { + CloseConnections(); + } + return connection; } @@ -184,20 +194,28 @@ public void Scan() public void CloseConnections() { - // Stop firing the timer - _nextHeartbeat.Dispose(); + lock (_closeLock) + { + if (!_closed) + { + // Stop firing the timer + _nextHeartbeat.Dispose(); - var tasks = new List(_connections.Count); + _closed = true; + } - // REVIEW: In the future we can consider a hybrid where we first try to wait for shutdown - // for a certain time frame then after some grace period we shutdown more aggressively - foreach (var c in _connections) - { - // We're shutting down so don't wait for closing the application - tasks.Add(DisposeAndRemoveAsync(c.Value, closeGracefully: false, HttpConnectionStopStatus.AppShutdown)); - } + var tasks = new List(_connections.Count); + + // REVIEW: In the future we can consider a hybrid where we first try to wait for shutdown + // for a certain time frame then after some grace period we shutdown more aggressively + foreach (var c in _connections) + { + // We're shutting down so don't wait for closing the application + tasks.Add(DisposeAndRemoveAsync(c.Value, closeGracefully: false, HttpConnectionStopStatus.AppShutdown)); + } - Task.WaitAll(tasks.ToArray(), TimeSpan.FromSeconds(5)); + Task.WaitAll(tasks.ToArray(), TimeSpan.FromSeconds(5)); + } } internal async Task DisposeAndRemoveAsync(HttpConnectionContext connection, bool closeGracefully, HttpConnectionStopStatus status) diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index 3ba5e15984bf..34dc370e37eb 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -194,6 +194,38 @@ public async Task NoNegotiateVersionInQueryStringThrowsWhenMinProtocolVersionIsS } } + [Fact] + public async Task NegotiateAfterApplicationStoppingReturnsError() + { + using (StartVerifiableLog()) + { + var appLifetime = new TestApplicationLifetime(); + var manager = CreateConnectionManager(LoggerFactory, appLifetime); + appLifetime.Start(); + + appLifetime.StopApplication(); + + var dispatcher = CreateDispatcher(manager, LoggerFactory); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddOptions(); + var ms = new MemoryStream(); + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + context.Response.Body = ms; + context.Request.QueryString = new QueryString(""); + await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions()); + var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); + + var error = negotiateResponse.Value("error"); + Assert.Equal("The connection was closed before negotiation completed.", error); + + var connectionId = negotiateResponse.Value("connectionId"); + Assert.Null(connectionId); + } + } + [Theory] [InlineData(HttpTransportType.LongPolling)] [InlineData(HttpTransportType.ServerSentEvents)] @@ -2517,6 +2549,120 @@ public async Task DisableReconnectDisallowsReplacementConnection() } } + [Fact] + public async Task StatefulReconnectionConnectionClosesOnApplicationStopping() + { + // ReconnectConnectionHandler can throw OperationCanceledException during Pipe.ReadAsync + using (StartVerifiableLog(wc => wc.EventId.Name == "FailedDispose")) + { + var appLifetime = new TestApplicationLifetime(); + var manager = CreateConnectionManager(LoggerFactory, appLifetime); + var options = new HttpConnectionDispatcherOptions() { AllowStatefulReconnects = true }; + options.WebSockets.CloseTimeout = TimeSpan.FromMilliseconds(1); + // pretend negotiate occurred + var connection = manager.CreateConnection(options, negotiateVersion: 1, useStatefulReconnect: true); + connection.TransportType = HttpTransportType.WebSockets; + + var dispatcher = CreateDispatcher(manager, LoggerFactory); + var services = new ServiceCollection(); + + var context = MakeRequest("/foo", connection, services); + SetTransport(context, HttpTransportType.WebSockets); + + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + + var initialWebSocketTask = dispatcher.ExecuteAsync(context, options, app); + +#pragma warning disable CA2252 // This API requires opting into preview features + var reconnectFeature = connection.Features.Get(); +#pragma warning restore CA2252 // This API requires opting into preview features + Assert.NotNull(reconnectFeature); + + var websocketFeature = (TestWebSocketConnectionFeature)context.Features.Get(); + await websocketFeature.Accepted.DefaultTimeout(); + + // Stop application should cause the connection to close and new connection attempts to fail + appLifetime.StopApplication(); + var webSocketMessage = await websocketFeature.Client.GetNextMessageAsync().DefaultTimeout(); + + Assert.Equal(WebSocketCloseStatus.NormalClosure, webSocketMessage.CloseStatus); + + await initialWebSocketTask.DefaultTimeout(); + + // New websocket connection with previous connection token + context = MakeRequest("/foo", connection, services); + SetTransport(context, HttpTransportType.WebSockets); + + await dispatcher.ExecuteAsync(context, options, app).DefaultTimeout(); + + // Should complete immediately with 404 as the connection is closed + Assert.Equal(404, context.Response.StatusCode); + } + } + + [Fact] + public async Task StatefulReconnectionConnectionThatReconnectedClosesOnApplicationStopping() + { + // ReconnectConnectionHandler can throw OperationCanceledException during Pipe.ReadAsync + using (StartVerifiableLog(wc => wc.EventId.Name == "FailedDispose")) + { + var appLifetime = new TestApplicationLifetime(); + var manager = CreateConnectionManager(LoggerFactory, appLifetime); + var options = new HttpConnectionDispatcherOptions() { AllowStatefulReconnects = true }; + options.WebSockets.CloseTimeout = TimeSpan.FromMilliseconds(1); + // pretend negotiate occurred + var connection = manager.CreateConnection(options, negotiateVersion: 1, useStatefulReconnect: true); + connection.TransportType = HttpTransportType.WebSockets; + + var dispatcher = CreateDispatcher(manager, LoggerFactory); + var services = new ServiceCollection(); + + var context = MakeRequest("/foo", connection, services); + SetTransport(context, HttpTransportType.WebSockets); + + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + + var initialWebSocketTask = dispatcher.ExecuteAsync(context, options, app); + +#pragma warning disable CA2252 // This API requires opting into preview features + var reconnectFeature = connection.Features.Get(); +#pragma warning restore CA2252 // This API requires opting into preview features + Assert.NotNull(reconnectFeature); + + var websocketFeature = (TestWebSocketConnectionFeature)context.Features.Get(); + await websocketFeature.Accepted.DefaultTimeout(); + + // New websocket connection with previous connection token + context = MakeRequest("/foo", connection, services); + SetTransport(context, HttpTransportType.WebSockets); + + var secondWebSocketTask = dispatcher.ExecuteAsync(context, options, app).DefaultTimeout(); + + await initialWebSocketTask.DefaultTimeout(); + + // Stop application should cause the connection to close and new connection attempts to fail + appLifetime.StopApplication(); + var webSocketMessage = await websocketFeature.Client.GetNextMessageAsync().DefaultTimeout(); + + Assert.Equal(WebSocketCloseStatus.NormalClosure, webSocketMessage.CloseStatus); + + await secondWebSocketTask.DefaultTimeout(); + + // New websocket connection with previous connection token + context = MakeRequest("/foo", connection, services); + SetTransport(context, HttpTransportType.WebSockets); + + await dispatcher.ExecuteAsync(context, options, app).DefaultTimeout(); + + // Should complete immediately with 404 as the connection is closed + Assert.Equal(404, context.Response.StatusCode); + } + } + private class ControllableMemoryStream : MemoryStream { private readonly SyncPoint _syncPoint; @@ -3766,18 +3912,24 @@ private static void SetTransport(HttpContext context, HttpTransportType transpor } } + private static HttpConnectionManager CreateConnectionManager(ILoggerFactory loggerFactory, IHostApplicationLifetime hostApplicationLifetime) + { + return CreateConnectionManager(loggerFactory, null, null, hostApplicationLifetime); + } + private static HttpConnectionManager CreateConnectionManager(ILoggerFactory loggerFactory, HttpConnectionsMetrics metrics = null) { return CreateConnectionManager(loggerFactory, null, metrics); } - private static HttpConnectionManager CreateConnectionManager(ILoggerFactory loggerFactory, TimeSpan? disconnectTimeout, HttpConnectionsMetrics metrics = null) + private static HttpConnectionManager CreateConnectionManager(ILoggerFactory loggerFactory, TimeSpan? disconnectTimeout, + HttpConnectionsMetrics metrics = null, IHostApplicationLifetime hostApplicationLifetime = null) { var connectionOptions = new ConnectionOptions(); connectionOptions.DisconnectTimeout = disconnectTimeout; return new HttpConnectionManager( loggerFactory ?? new LoggerFactory(), - new EmptyApplicationLifetime(), + hostApplicationLifetime ?? new EmptyApplicationLifetime(), Options.Create(connectionOptions), metrics ?? new HttpConnectionsMetrics(new TestMeterFactory())); } diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs index 529b372c65f9..bbf2998f3a46 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs @@ -425,6 +425,26 @@ public async Task ApplicationLifetimeCanStartBeforeHttpConnectionManagerInitiali } } + [Fact] + public async Task ApplicationLifetimeStoppingApplicationStopsNewIncomingConnections() + { + using (StartVerifiableLog()) + { + var appLifetime = new TestApplicationLifetime(); + var connectionManager = CreateConnectionManager(LoggerFactory, appLifetime); + + appLifetime.Start(); + + appLifetime.StopApplication(); + + var connection = connectionManager.CreateConnection(); + + Assert.Equal(HttpConnectionStatus.Disposed, connection.Status); + var result = await connection.Application.Output.FlushAsync(); + Assert.True(result.IsCompleted); + } + } + private static HttpConnectionManager CreateConnectionManager(ILoggerFactory loggerFactory, IHostApplicationLifetime lifetime = null, HttpConnectionsMetrics metrics = null) { lifetime ??= new EmptyApplicationLifetime();