diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index 7f97e1e29761..c47953e4eebb 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -168,6 +168,14 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti { transport = HttpTransportType.WebSockets; connection = await GetOrCreateConnectionAsync(context, options); + + if (connection is not null) + { + Log.EstablishedConnection(_logger); + + // Allow the reads to be canceled + connection.Cancellation ??= new CancellationTokenSource(); + } } else { diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index 558a974d1830..3389cebfdd9b 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -2842,6 +2842,37 @@ public async Task WebSocketConnectionClosingTriggersConnectionClosedToken() } } + [Fact] + public async Task ServerClosingClosesWebSocketConnection() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var connection = manager.CreateConnection(); + + var dispatcher = CreateDispatcher(manager, LoggerFactory); + var services = new ServiceCollection(); + services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + SetTransport(context, HttpTransportType.WebSockets); + + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1); + + var executeTask = dispatcher.ExecuteAsync(context, options, app); + + // "close" server, since we're not using a server in these tests we just simulate what would be called when the server closes + await connection.DisposeAsync().DefaultTimeout(); + + await connection.ConnectionClosed.WaitForCancellationAsync().DefaultTimeout(); + + await executeTask.DefaultTimeout(); + } + } + public class CustomHttpRequestLifetimeFeature : IHttpRequestLifetimeFeature { public CancellationToken RequestAborted { get; set; }