Skip to content

Commit 94e317f

Browse files
committed
PR feedback
1 parent d4e1645 commit 94e317f

File tree

2 files changed

+88
-4
lines changed

2 files changed

+88
-4
lines changed

src/CORS/src/Microsoft.AspNetCore.Cors/Infrastructure/CorsMiddleware.cs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ public Task Invoke(HttpContext context, ICorsPolicyProvider corsPolicyProvider)
125125

126126
private async Task InvokeCore(HttpContext context, ICorsPolicyProvider corsPolicyProvider)
127127
{
128+
// CORS policy resolution rules:
129+
//
130+
// 1. If there is an endpoint with IDisableCorsAttribute then CORS is not run
131+
// 2. If there is an endpoint with IEnableCorsAttribute that has a policy name then
132+
// fetch policy by name, prioritizing it above policy on middleware
133+
// 3. If there is no policy on middleware then use name on middleware
134+
128135
var endpoint = context.GetEndpoint();
129136

130137
// Get the most significant CORS metadata for the endpoint
@@ -136,13 +143,22 @@ private async Task InvokeCore(HttpContext context, ICorsPolicyProvider corsPolic
136143
return;
137144
}
138145

139-
string policyName = _corsPolicyName;
140-
if (corsMetadata is IEnableCorsAttribute enableCorsAttribute)
146+
var corsPolicy = _policy;
147+
var policyName = _corsPolicyName;
148+
if (corsMetadata is IEnableCorsAttribute enableCorsAttribute &&
149+
enableCorsAttribute.PolicyName != null)
141150
{
151+
// If a policy name has been provided on the endpoint metadata then prioritizing it above the static middleware policy
142152
policyName = enableCorsAttribute.PolicyName;
153+
corsPolicy = null;
154+
}
155+
156+
if (corsPolicy == null)
157+
{
158+
// Resolve policy by name if the local policy is not being used
159+
corsPolicy = await corsPolicyProvider.GetPolicyAsync(context, policyName);
143160
}
144161

145-
var corsPolicy = _policy ?? await corsPolicyProvider.GetPolicyAsync(context, policyName);
146162
if (corsPolicy == null)
147163
{
148164
Logger?.NoCorsPolicyFound();

src/CORS/test/Microsoft.AspNetCore.Cors.Test/CorsMiddlewareTests.cs

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ public async Task Invoke_HasEndpointWithNoMetadata_RunsCors()
597597
}
598598

599599
[Fact]
600-
public async Task Invoke_HasEndpointWithEnableMetadata_RunsCorsWithPolicyName()
600+
public async Task Invoke_HasEndpointWithEnableMetadata_MiddlewareHasPolicyName_RunsCorsWithPolicyName()
601601
{
602602
// Arrange
603603
var corsService = Mock.Of<ICorsService>();
@@ -626,6 +626,74 @@ public async Task Invoke_HasEndpointWithEnableMetadata_RunsCorsWithPolicyName()
626626
Times.Once);
627627
}
628628

629+
[Fact]
630+
public async Task Invoke_HasEndpointWithEnableMetadata_MiddlewareHasPolicy_RunsCorsWithPolicyName()
631+
{
632+
// Arrange
633+
var policy = new CorsPolicyBuilder().Build();
634+
var corsService = Mock.Of<ICorsService>();
635+
var mockProvider = new Mock<ICorsPolicyProvider>();
636+
var loggerFactory = NullLoggerFactory.Instance;
637+
mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
638+
.Returns(Task.FromResult<CorsPolicy>(null))
639+
.Verifiable();
640+
641+
var middleware = new CorsMiddleware(
642+
Mock.Of<RequestDelegate>(),
643+
corsService,
644+
policy,
645+
loggerFactory);
646+
647+
var httpContext = new DefaultHttpContext();
648+
httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName")), "Test endpoint"));
649+
httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
650+
651+
// Act
652+
await middleware.Invoke(httpContext, mockProvider.Object);
653+
654+
// Assert
655+
mockProvider.Verify(
656+
o => o.GetPolicyAsync(It.IsAny<HttpContext>(), "MetadataPolicyName"),
657+
Times.Once);
658+
}
659+
660+
[Fact]
661+
public async Task Invoke_HasEndpointWithEnableMetadataWithNoName_RunsCorsWithStaticPolicy()
662+
{
663+
// Arrange
664+
var policy = new CorsPolicyBuilder().Build();
665+
var mockCorsService = new Mock<ICorsService>();
666+
var mockProvider = new Mock<ICorsPolicyProvider>();
667+
var loggerFactory = NullLoggerFactory.Instance;
668+
mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
669+
.Returns(Task.FromResult<CorsPolicy>(null))
670+
.Verifiable();
671+
mockCorsService.Setup(o => o.EvaluatePolicy(It.IsAny<HttpContext>(), It.IsAny<CorsPolicy>()))
672+
.Returns(new CorsResult())
673+
.Verifiable();
674+
675+
var middleware = new CorsMiddleware(
676+
Mock.Of<RequestDelegate>(),
677+
mockCorsService.Object,
678+
policy,
679+
loggerFactory);
680+
681+
var httpContext = new DefaultHttpContext();
682+
httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute()), "Test endpoint"));
683+
httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
684+
685+
// Act
686+
await middleware.Invoke(httpContext, mockProvider.Object);
687+
688+
// Assert
689+
mockProvider.Verify(
690+
o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()),
691+
Times.Never);
692+
mockCorsService.Verify(
693+
o => o.EvaluatePolicy(It.IsAny<HttpContext>(), policy),
694+
Times.Once);
695+
}
696+
629697
[Fact]
630698
public async Task Invoke_HasEndpointWithDisableMetadata_SkipCors()
631699
{

0 commit comments

Comments
 (0)