diff --git a/src/SignalR/server/Core/src/IHubClients`T.cs b/src/SignalR/server/Core/src/IHubClients`T.cs index 0dee6f33b19a..06dce7a609cd 100644 --- a/src/SignalR/server/Core/src/IHubClients`T.cs +++ b/src/SignalR/server/Core/src/IHubClients`T.cs @@ -14,7 +14,7 @@ public interface IHubClients /// /// The connection ID. /// A client caller. - T Single(string connectionId) => throw new NotImplementedException(); + T Single(string connectionId) => Client(connectionId); /// /// Gets a that can be used to invoke methods on all clients connected to the hub. diff --git a/src/SignalR/server/Core/src/Internal/HubClients.cs b/src/SignalR/server/Core/src/Internal/HubClients.cs index 9c87ef5388d4..9ddaee52ebc2 100644 --- a/src/SignalR/server/Core/src/Internal/HubClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubClients.cs @@ -15,7 +15,7 @@ public HubClients(HubLifetimeManager lifetimeManager) public ISingleClientProxy Single(string connectionId) { - return new SingleClientProxyWithInvoke(_lifetimeManager, connectionId); + return new SingleClientProxy(_lifetimeManager, connectionId); } public IClientProxy All { get; } diff --git a/src/SignalR/server/Core/src/Internal/HubClients`T.cs b/src/SignalR/server/Core/src/Internal/HubClients`T.cs index 4e5a1e45ec21..e168174b6464 100644 --- a/src/SignalR/server/Core/src/Internal/HubClients`T.cs +++ b/src/SignalR/server/Core/src/Internal/HubClients`T.cs @@ -15,11 +15,6 @@ public HubClients(HubLifetimeManager lifetimeManager) public T All { get; } - public T Single(string connectionId) - { - return TypedClientBuilder.Build(new SingleClientProxyWithInvoke(_lifetimeManager, connectionId)); - } - public T AllExcept(IReadOnlyList excludedConnectionIds) { return TypedClientBuilder.Build(new AllClientsExceptProxy(_lifetimeManager, excludedConnectionIds)); diff --git a/src/SignalR/server/Core/src/Internal/Proxies.cs b/src/SignalR/server/Core/src/Internal/Proxies.cs index 440a0504064f..46f42beead3c 100644 --- a/src/SignalR/server/Core/src/Internal/Proxies.cs +++ b/src/SignalR/server/Core/src/Internal/Proxies.cs @@ -122,23 +122,6 @@ public Task SendCoreAsync(string method, object?[] args, CancellationToken cance } } -internal sealed class SingleClientProxy : IClientProxy where THub : Hub -{ - private readonly string _connectionId; - private readonly HubLifetimeManager _lifetimeManager; - - public SingleClientProxy(HubLifetimeManager lifetimeManager, string connectionId) - { - _lifetimeManager = lifetimeManager; - _connectionId = connectionId; - } - - public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) - { - return _lifetimeManager.SendConnectionAsync(_connectionId, method, args, cancellationToken); - } -} - internal sealed class MultipleClientProxy : IClientProxy where THub : Hub { private readonly HubLifetimeManager _lifetimeManager; @@ -156,12 +139,12 @@ public Task SendCoreAsync(string method, object?[] args, CancellationToken cance } } -internal sealed class SingleClientProxyWithInvoke : ISingleClientProxy where THub : Hub +internal sealed class SingleClientProxy : ISingleClientProxy where THub : Hub { private readonly string _connectionId; private readonly HubLifetimeManager _lifetimeManager; - public SingleClientProxyWithInvoke(HubLifetimeManager lifetimeManager, string connectionId) + public SingleClientProxy(HubLifetimeManager lifetimeManager, string connectionId) { _lifetimeManager = lifetimeManager; _connectionId = connectionId; diff --git a/src/SignalR/server/Core/src/Internal/TypedHubClients.cs b/src/SignalR/server/Core/src/Internal/TypedHubClients.cs index 3a8e68ae3853..fa7023ac332b 100644 --- a/src/SignalR/server/Core/src/Internal/TypedHubClients.cs +++ b/src/SignalR/server/Core/src/Internal/TypedHubClients.cs @@ -12,7 +12,7 @@ public TypedHubClients(IHubCallerClients dynamicContext) _hubClients = dynamicContext; } - public T Single(string connectionId) => TypedClientBuilder.Build(_hubClients.Single(connectionId)); + public T Client(string connectionId) => TypedClientBuilder.Build(_hubClients.Single(connectionId)); public T All => TypedClientBuilder.Build(_hubClients.All); @@ -22,11 +22,6 @@ public TypedHubClients(IHubCallerClients dynamicContext) public T AllExcept(IReadOnlyList excludedConnectionIds) => TypedClientBuilder.Build(_hubClients.AllExcept(excludedConnectionIds)); - public T Client(string connectionId) - { - return TypedClientBuilder.Build(_hubClients.Client(connectionId)); - } - public T Group(string groupName) { return TypedClientBuilder.Build(_hubClients.Group(groupName)); diff --git a/src/SignalR/server/SignalR/test/ClientProxyTests.cs b/src/SignalR/server/SignalR/test/ClientProxyTests.cs index ede98b1c505a..784fc98bf01f 100644 --- a/src/SignalR/server/SignalR/test/ClientProxyTests.cs +++ b/src/SignalR/server/SignalR/test/ClientProxyTests.cs @@ -212,7 +212,7 @@ public async Task SingleClientProxyWithInvoke_ThrowsNotSupported() { var hubLifetimeManager = new EmptyHubLifetimeManager(); - var proxy = new SingleClientProxyWithInvoke(hubLifetimeManager, ""); + var proxy = new SingleClientProxy(hubLifetimeManager, ""); var ex = await Assert.ThrowsAsync(async () => await proxy.InvokeAsync("method")).DefaultTimeout(); Assert.Equal("EmptyHubLifetimeManager`1 does not support client return values.", ex.Message); } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index 75d7191c302a..f72cd15f3b06 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -440,7 +440,7 @@ public Task SendToCaller(string message) } } -public class HubT : Hub +public class HubT : Hub { public override Task OnConnectedAsync() { @@ -524,9 +524,15 @@ public Task SendToCaller(string message) { return Clients.Caller.Send(message); } + + public async Task GetClientResultThreeWays(int singleValue, int clientValue, int callerValue) => + new ClientResults( + await Clients.Single(Context.ConnectionId).GetClientResult(singleValue), + await Clients.Client(Context.ConnectionId).GetClientResult(clientValue), + await Clients.Caller.GetClientResult(callerValue)); } -public interface Test +public interface ITest { Task Send(string message); Task Broadcast(string message); @@ -534,6 +540,8 @@ public interface Test Task GetClientResult(int value); } +public record ClientResults(int SingleResult, int ClientResult, int CallerResult); + public class OnConnectedThrowsHub : Hub { public override Task OnConnectedAsync() diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs index 06d929c910e8..8fdf84cc68e5 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs @@ -145,26 +145,75 @@ public async Task CanUseClientResultsWithIHubContextT() var connectionHandler = serviceProvider.GetService>(); using var client = new TestClient(); + var connectionId = client.Connection.ConnectionId; var connectionHandlerTask = await client.ConnectAsync(connectionHandler); // Wait for a connection, or for the endpoint to fail. await client.Connected.OrThrowIfOtherFails(connectionHandlerTask).DefaultTimeout(); - var context = serviceProvider.GetRequiredService>(); - var resultTask = context.Clients.Single(client.Connection.ConnectionId).GetClientResult(1); + var context = serviceProvider.GetRequiredService>(); - var message = await client.ReadAsync().DefaultTimeout(); - var invocation = Assert.IsType(message); + async Task AssertClientResult(Task resultTask) + { + var message = await client.ReadAsync().DefaultTimeout(); + var invocation = Assert.IsType(message); - Assert.Single(invocation.Arguments); - Assert.Equal(1L, invocation.Arguments[0]); - Assert.Equal("GetClientResult", invocation.Target); + Assert.Single(invocation.Arguments); + Assert.Equal(1L, invocation.Arguments[0]); + Assert.Equal("GetClientResult", invocation.Target); - await client.SendHubMessageAsync(CompletionMessage.WithResult(invocation.InvocationId, 2)).DefaultTimeout(); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocation.InvocationId, 2)).DefaultTimeout(); - var result = await resultTask.DefaultTimeout(); - Assert.Equal(2, result); + var result = await resultTask.DefaultTimeout(); + Assert.Equal(2, result); + } + + await AssertClientResult(context.Clients.Single(connectionId).GetClientResult(1)); + await AssertClientResult(context.Clients.Client(connectionId).GetClientResult(1)); } } + + [Fact] + public async Task CanReturnClientResultToTypedHubThreeWays() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations + builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using var client = new TestClient(invocationBinder: new GetClientResultThreeWaysInvocationBinder()); + + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var invocationId = await client.SendHubMessageAsync(new InvocationMessage( + invocationId: "1", + nameof(HubT.GetClientResultThreeWays), + new object[] { 5, 6, 7 })).DefaultTimeout(); + + // Send back "value + 4" to all three invocations. + for (int i = 0; i < 3; i++) + { + // Hub asks client for a result, this is an invocation message with an ID. + var invocationMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocationMessage.InvocationId); + var res = 4 + (int)invocationMessage.Arguments[0]; + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage.InvocationId, res)).DefaultTimeout(); + } + + var completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(new ClientResults(9, 10, 11), completion.Result); + } + } + + private class GetClientResultThreeWaysInvocationBinder : IInvocationBinder + { + public IReadOnlyList GetParameterTypes(string methodName) => new[] { typeof(int) }; + public Type GetReturnType(string invocationId) => typeof(ClientResults); + public Type GetStreamItemType(string streamId) => throw new NotImplementedException(); + } }