diff --git a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs index b946fab9490f..4ea62842e762 100644 --- a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs @@ -81,22 +81,53 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider } else if (p.CustomAttributes.Any()) { + var markedParameter = false; foreach (var attribute in p.GetCustomAttributes(true)) { if (attribute is IFromServiceMetadata) { - return MarkServiceParameter(index); + ThrowIfMarked(markedParameter); + markedParameter = true; + MarkServiceParameter(index); } else if (attribute is FromKeyedServicesAttribute keyedServicesAttribute) { - if (serviceProviderIsService is IServiceProviderIsKeyedService keyedServiceProvider && - keyedServiceProvider.IsKeyedService(GetServiceType(p.ParameterType), keyedServicesAttribute.Key)) + ThrowIfMarked(markedParameter); + markedParameter = true; + + if (serviceProviderIsService is IServiceProviderIsKeyedService keyedServiceProvider) + { + if (keyedServiceProvider.IsKeyedService(GetServiceType(p.ParameterType), keyedServicesAttribute.Key)) + { + KeyedServiceKeys ??= new List<(int, object)>(); + KeyedServiceKeys.Add((index, keyedServicesAttribute.Key)); + MarkServiceParameter(index); + } + else + { + throw new InvalidOperationException($"'{p.ParameterType}' is not in DI as a keyed service."); + } + } + else { - KeyedServiceKeys ??= new List<(int, object)>(); - KeyedServiceKeys.Add((index, keyedServicesAttribute.Key)); - return MarkServiceParameter(index); + throw new InvalidOperationException($"This service provider doesn't support keyed services."); } } + + void ThrowIfMarked(bool marked) + { + if (marked) + { + throw new InvalidOperationException( + $"{methodExecutor.MethodInfo.DeclaringType?.Name}.{methodExecutor.MethodInfo.Name}: The {nameof(FromKeyedServicesAttribute)} is not supported on parameters that are also annotated with {nameof(IFromServiceMetadata)}."); + } + } + } + + if (markedParameter) + { + // If the parameter is marked because of being a service, we don't want to consider it for method parameters during deserialization + return false; } } else if (serviceProviderIsService?.IsService(GetServiceType(p.ParameterType)) == true) diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index ca0c77daba92..7e1855a1dd23 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -1363,12 +1363,6 @@ public async Task ServicesAndParams(int value, [FromService] Service1 servi return total + value; } - public int MultipleSameKeyedServices([FromKeyedServices("service1")] Service1 service, [FromKeyedServices("service1")] Service1 service2) - { - Assert.Same(service, service2); - return 445; - } - public int ServiceWithoutAttribute(Service1 service) { return 1; @@ -1391,6 +1385,15 @@ public async Task Stream(ChannelReader channelReader) await channelReader.ReadAsync(); } } +} + +public class KeyedServicesHub : TestHub +{ + public int MultipleSameKeyedServices([FromKeyedServices("service1")] Service1 service, [FromKeyedServices("service1")] Service1 service2) + { + Assert.Same(service, service2); + return 445; + } public int KeyedService([FromKeyedServices("service1")] Service1 service) { @@ -1414,6 +1417,13 @@ public int MultipleKeyedServices([FromKeyedServices("service1")] Service1 servic } } +public class BadServicesHub : Hub +{ + public void BadMethod([FromKeyedServices("service1")] [FromService] Service1 service) + { + } +} + public class TooManyParamsHub : Hub { public void ManyParams(int a1, string a2, bool a3, float a4, string a5, int a6, int a7, int a8, int a9, int a10, int a11, diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index fc57c512353c..870efbcd23b2 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -4927,13 +4927,14 @@ public async Task KeyedServiceResolvedIfInDI() }); provider.AddKeyedScoped("service1"); + provider.AddKeyedScoped("service2"); }); - var connectionHandler = serviceProvider.GetService>(); + var connectionHandler = serviceProvider.GetService>(); using (var client = new TestClient()) { var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); - var res = await client.InvokeAsync(nameof(ServicesHub.KeyedService)).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(KeyedServicesHub.KeyedService)).DefaultTimeout(); Assert.Equal(43L, res.Result); } } @@ -4949,13 +4950,14 @@ public async Task HubMethodCanInjectKeyedServiceWithOtherParameters() }); provider.AddKeyedScoped("service1"); + provider.AddKeyedScoped("service2"); }); - var connectionHandler = serviceProvider.GetService>(); + var connectionHandler = serviceProvider.GetService>(); using (var client = new TestClient()) { var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); - var res = await client.InvokeAsync(nameof(ServicesHub.KeyedServiceWithParam), 91).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(KeyedServicesHub.KeyedServiceWithParam), 91).DefaultTimeout(); Assert.Equal(1183L, res.Result); } } @@ -4971,14 +4973,15 @@ public async Task HubMethodCanInjectKeyedServiceWithNonKeyedService() }); provider.AddKeyedScoped("service1"); + provider.AddKeyedScoped("service2"); provider.AddScoped(); }); - var connectionHandler = serviceProvider.GetService>(); + var connectionHandler = serviceProvider.GetService>(); using (var client = new TestClient()) { var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); - var res = await client.InvokeAsync(nameof(ServicesHub.KeyedServiceNonKeyedService)).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(KeyedServicesHub.KeyedServiceNonKeyedService)).DefaultTimeout(); Assert.Equal(11L, res.Result); } } @@ -4996,12 +4999,12 @@ public async Task MultipleKeyedServicesResolved() provider.AddKeyedScoped("service1"); provider.AddKeyedScoped("service2"); }); - var connectionHandler = serviceProvider.GetService>(); + var connectionHandler = serviceProvider.GetService>(); using (var client = new TestClient()) { var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); - var res = await client.InvokeAsync(nameof(ServicesHub.MultipleKeyedServices)).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(KeyedServicesHub.MultipleKeyedServices)).DefaultTimeout(); Assert.Equal(45L, res.Result); } } @@ -5017,19 +5020,20 @@ public async Task MultipleKeyedServicesWithSameNameResolved() }); provider.AddKeyedScoped("service1"); + provider.AddKeyedScoped("service2"); }); - var connectionHandler = serviceProvider.GetService>(); + var connectionHandler = serviceProvider.GetService>(); using (var client = new TestClient()) { var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); - var res = await client.InvokeAsync(nameof(ServicesHub.MultipleSameKeyedServices)).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(KeyedServicesHub.MultipleSameKeyedServices)).DefaultTimeout(); Assert.Equal(445L, res.Result); } } [Fact] - public async Task KeyedServiceNotResolvedIfNotInDI() + public void KeyedServiceNotResolvedIfNotInDI() { var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => { @@ -5038,14 +5042,24 @@ public async Task KeyedServiceNotResolvedIfNotInDI() options.EnableDetailedErrors = true; }); }); - var connectionHandler = serviceProvider.GetService>(); + var ex = Assert.Throws(() => serviceProvider.GetService>()); + Assert.Equal("'Microsoft.AspNetCore.SignalR.Tests.Service1' is not in DI as a keyed service.", ex.Message); + } - using (var client = new TestClient()) + [Fact] + public void KeyedServiceAndFromServiceOnSameParameterInvalidWithKeyedServiceInDI() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); - var res = await client.InvokeAsync(nameof(ServicesHub.KeyedService)).DefaultTimeout(); - Assert.Equal("Failed to invoke 'KeyedService' due to an error on the server. InvalidDataException: Invocation provides 0 argument(s) but target expects 1.", res.Error); - } + provider.AddSignalR(options => + { + options.EnableDetailedErrors = true; + }); + + provider.AddKeyedScoped("service1"); + }); + var ex = Assert.Throws(() => serviceProvider.GetService>()); + Assert.Equal("BadServicesHub.BadMethod: The FromKeyedServicesAttribute is not supported on parameters that are also annotated with IFromServiceMetadata.", ex.Message); } [Fact]