Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,53 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider
}
else if (p.CustomAttributes.Any())
{
var markedParameter = false;
foreach (var attribute in p.GetCustomAttributes(true))
{
if (attribute is IFromServiceMetadata)
{
return MarkServiceParameter(index);
ThrowIfMarked(markedParameter);
markedParameter = true;
MarkServiceParameter(index);
}
else if (attribute is FromKeyedServicesAttribute keyedServicesAttribute)
{
if (serviceProviderIsService is IServiceProviderIsKeyedService keyedServiceProvider &&
keyedServiceProvider.IsKeyedService(GetServiceType(p.ParameterType), keyedServicesAttribute.Key))
ThrowIfMarked(markedParameter);
markedParameter = true;

if (serviceProviderIsService is IServiceProviderIsKeyedService keyedServiceProvider)
{
if (keyedServiceProvider.IsKeyedService(GetServiceType(p.ParameterType), keyedServicesAttribute.Key))
{
KeyedServiceKeys ??= new List<(int, object)>();
KeyedServiceKeys.Add((index, keyedServicesAttribute.Key));
MarkServiceParameter(index);
}
else
{
throw new InvalidOperationException($"'{p.ParameterType}' is not in DI as a keyed service.");
}
}
else
{
KeyedServiceKeys ??= new List<(int, object)>();
KeyedServiceKeys.Add((index, keyedServicesAttribute.Key));
return MarkServiceParameter(index);
throw new InvalidOperationException($"This service provider doesn't support keyed services.");
}
}

void ThrowIfMarked(bool marked)
{
if (marked)
{
throw new InvalidOperationException(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should unify this with the exceptions throw in MVC/minimal APIs?

RE:

throw new NotSupportedException(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That one doesn't mention which method is the culprit, maybe because it can be an unnamed delegate?

I'd like to keep the name for SignalR so it's very clear where the issue is. Maybe just add it at the beginning:
"{methodExecutor.MethodInfo.FullName}: The {nameof(FromKeyedServicesAttribute)} is not supported on parameters that are also annotated with {nameof(IFromServiceMetadata)}."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the exception message

$"{methodExecutor.MethodInfo.DeclaringType?.Name}.{methodExecutor.MethodInfo.Name}: The {nameof(FromKeyedServicesAttribute)} is not supported on parameters that are also annotated with {nameof(IFromServiceMetadata)}.");
}
}
}

if (markedParameter)
{
// If the parameter is marked because of being a service, we don't want to consider it for method parameters during deserialization
return false;
}
}
else if (serviceProviderIsService?.IsService(GetServiceType(p.ParameterType)) == true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1363,12 +1363,6 @@ public async Task<int> ServicesAndParams(int value, [FromService] Service1 servi
return total + value;
}

public int MultipleSameKeyedServices([FromKeyedServices("service1")] Service1 service, [FromKeyedServices("service1")] Service1 service2)
{
Assert.Same(service, service2);
return 445;
}

public int ServiceWithoutAttribute(Service1 service)
{
return 1;
Expand All @@ -1391,6 +1385,15 @@ public async Task Stream(ChannelReader<int> channelReader)
await channelReader.ReadAsync();
}
}
}

public class KeyedServicesHub : TestHub
{
public int MultipleSameKeyedServices([FromKeyedServices("service1")] Service1 service, [FromKeyedServices("service1")] Service1 service2)
{
Assert.Same(service, service2);
return 445;
}

public int KeyedService([FromKeyedServices("service1")] Service1 service)
{
Expand All @@ -1414,6 +1417,13 @@ public int MultipleKeyedServices([FromKeyedServices("service1")] Service1 servic
}
}

public class BadServicesHub : Hub
{
public void BadMethod([FromKeyedServices("service1")] [FromService] Service1 service)
{
}
}

public class TooManyParamsHub : Hub
{
public void ManyParams(int a1, string a2, bool a3, float a4, string a5, int a6, int a7, int a8, int a9, int a10, int a11,
Expand Down
48 changes: 31 additions & 17 deletions src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4927,13 +4927,14 @@ public async Task KeyedServiceResolvedIfInDI()
});

provider.AddKeyedScoped<Service1>("service1");
provider.AddKeyedScoped<Service1>("service2");
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<KeyedServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.KeyedService)).DefaultTimeout();
var res = await client.InvokeAsync(nameof(KeyedServicesHub.KeyedService)).DefaultTimeout();
Assert.Equal(43L, res.Result);
}
}
Expand All @@ -4949,13 +4950,14 @@ public async Task HubMethodCanInjectKeyedServiceWithOtherParameters()
});

provider.AddKeyedScoped<Service1>("service1");
provider.AddKeyedScoped<Service1>("service2");
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<KeyedServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.KeyedServiceWithParam), 91).DefaultTimeout();
var res = await client.InvokeAsync(nameof(KeyedServicesHub.KeyedServiceWithParam), 91).DefaultTimeout();
Assert.Equal(1183L, res.Result);
}
}
Expand All @@ -4971,14 +4973,15 @@ public async Task HubMethodCanInjectKeyedServiceWithNonKeyedService()
});

provider.AddKeyedScoped<Service1>("service1");
provider.AddKeyedScoped<Service1>("service2");
provider.AddScoped<Service2>();
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<KeyedServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.KeyedServiceNonKeyedService)).DefaultTimeout();
var res = await client.InvokeAsync(nameof(KeyedServicesHub.KeyedServiceNonKeyedService)).DefaultTimeout();
Assert.Equal(11L, res.Result);
}
}
Expand All @@ -4996,12 +4999,12 @@ public async Task MultipleKeyedServicesResolved()
provider.AddKeyedScoped<Service1>("service1");
provider.AddKeyedScoped<Service1>("service2");
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<KeyedServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.MultipleKeyedServices)).DefaultTimeout();
var res = await client.InvokeAsync(nameof(KeyedServicesHub.MultipleKeyedServices)).DefaultTimeout();
Assert.Equal(45L, res.Result);
}
}
Expand All @@ -5017,19 +5020,20 @@ public async Task MultipleKeyedServicesWithSameNameResolved()
});

provider.AddKeyedScoped<Service1>("service1");
provider.AddKeyedScoped<Service1>("service2");
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<KeyedServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.MultipleSameKeyedServices)).DefaultTimeout();
var res = await client.InvokeAsync(nameof(KeyedServicesHub.MultipleSameKeyedServices)).DefaultTimeout();
Assert.Equal(445L, res.Result);
}
}

[Fact]
public async Task KeyedServiceNotResolvedIfNotInDI()
public void KeyedServiceNotResolvedIfNotInDI()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
Expand All @@ -5038,14 +5042,24 @@ public async Task KeyedServiceNotResolvedIfNotInDI()
options.EnableDetailedErrors = true;
});
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();
var ex = Assert.Throws<InvalidOperationException>(() => serviceProvider.GetService<HubConnectionHandler<KeyedServicesHub>>());
Assert.Equal("'Microsoft.AspNetCore.SignalR.Tests.Service1' is not in DI as a keyed service.", ex.Message);
}

using (var client = new TestClient())
[Fact]
public void KeyedServiceAndFromServiceOnSameParameterInvalidWithKeyedServiceInDI()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.KeyedService)).DefaultTimeout();
Assert.Equal("Failed to invoke 'KeyedService' due to an error on the server. InvalidDataException: Invocation provides 0 argument(s) but target expects 1.", res.Error);
}
provider.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
});

provider.AddKeyedScoped<Service1>("service1");
});
var ex = Assert.Throws<InvalidOperationException>(() => serviceProvider.GetService<HubConnectionHandler<BadServicesHub>>());
Assert.Equal("BadServicesHub.BadMethod: The FromKeyedServicesAttribute is not supported on parameters that are also annotated with IFromServiceMetadata.", ex.Message);
}

[Fact]
Expand Down