diff --git a/src/SignalR/server/Core/src/Internal/HubClients`T.cs b/src/SignalR/server/Core/src/Internal/HubClients`T.cs index e168174b6464..4e5a1e45ec21 100644 --- a/src/SignalR/server/Core/src/Internal/HubClients`T.cs +++ b/src/SignalR/server/Core/src/Internal/HubClients`T.cs @@ -15,6 +15,11 @@ public HubClients(HubLifetimeManager lifetimeManager) public T All { get; } + public T Single(string connectionId) + { + return TypedClientBuilder.Build(new SingleClientProxyWithInvoke(_lifetimeManager, connectionId)); + } + public T AllExcept(IReadOnlyList excludedConnectionIds) { return TypedClientBuilder.Build(new AllClientsExceptProxy(_lifetimeManager, excludedConnectionIds)); diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index db7359370a33..75d7191c302a 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -530,6 +530,8 @@ public interface Test { Task Send(string message); Task Broadcast(string message); + + Task GetClientResult(int value); } public class OnConnectedThrowsHub : Hub diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs index b179f92acbc5..06d929c910e8 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs @@ -103,4 +103,68 @@ public async Task ThrowsWhenParallelHubInvokesNotEnabled() } } } + + [Fact] + public async Task CanUseClientResultsWithIHubContext() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using var client = new TestClient(); + + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + // Wait for a connection, or for the endpoint to fail. + await client.Connected.OrThrowIfOtherFails(connectionHandlerTask).DefaultTimeout(); + + var context = serviceProvider.GetRequiredService>(); + var resultTask = context.Clients.Single(client.Connection.ConnectionId).InvokeAsync("GetClientResult", 1); + + var message = await client.ReadAsync().DefaultTimeout(); + var invocation = Assert.IsType(message); + + Assert.Single(invocation.Arguments); + Assert.Equal(1L, invocation.Arguments[0]); + Assert.Equal("GetClientResult", invocation.Target); + + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocation.InvocationId, 2)).DefaultTimeout(); + + var result = await resultTask.DefaultTimeout(); + Assert.Equal(2, result); + } + } + + [Fact] + public async Task CanUseClientResultsWithIHubContextT() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using var client = new TestClient(); + + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + // Wait for a connection, or for the endpoint to fail. + await client.Connected.OrThrowIfOtherFails(connectionHandlerTask).DefaultTimeout(); + + var context = serviceProvider.GetRequiredService>(); + var resultTask = context.Clients.Single(client.Connection.ConnectionId).GetClientResult(1); + + var message = await client.ReadAsync().DefaultTimeout(); + var invocation = Assert.IsType(message); + + Assert.Single(invocation.Arguments); + Assert.Equal(1L, invocation.Arguments[0]); + Assert.Equal("GetClientResult", invocation.Target); + + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocation.InvocationId, 2)).DefaultTimeout(); + + var result = await resultTask.DefaultTimeout(); + Assert.Equal(2, result); + } + } }