Skip to content

Commit d4e1645

Browse files
committed
PR feedback
1 parent bd480b9 commit d4e1645

File tree

8 files changed

+203
-38
lines changed

8 files changed

+203
-38
lines changed

src/CORS/samples/SampleDestination/Startup.cs

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System.Net;
5+
using System.Text;
56
using System.Threading.Tasks;
67
using Microsoft.AspNetCore.Builder;
78
using Microsoft.AspNetCore.Hosting;
@@ -57,15 +58,13 @@ public void ConfigureServices(IServiceCollection services)
5758

5859
public void Configure(IApplicationBuilder app, IHostingEnvironment env)
5960
{
60-
var sampleMiddleware = new SampleMiddleware(context => Task.CompletedTask);
61-
6261
app.UseEndpointRouting(routing =>
6362
{
64-
routing.Map("/allow-origin", sampleMiddleware.Invoke).RequireCors("AllowOrigin");
65-
routing.Map("/allow-header-method", sampleMiddleware.Invoke).RequireCors("AllowHeaderMethod");
66-
routing.Map("/allow-credentials", sampleMiddleware.Invoke).RequireCors("AllowCredentials");
67-
routing.Map("/exposed-header", sampleMiddleware.Invoke).RequireCors("ExposedHeader");
68-
routing.Map("/allow-all", sampleMiddleware.Invoke).RequireCors("AllowAll");
63+
routing.Map("/allow-origin", HandleRequest).WithCorsPolicy("AllowOrigin");
64+
routing.Map("/allow-header-method", HandleRequest).WithCorsPolicy("AllowHeaderMethod");
65+
routing.Map("/allow-credentials", HandleRequest).WithCorsPolicy("AllowCredentials");
66+
routing.Map("/exposed-header", HandleRequest).WithCorsPolicy("ExposedHeader");
67+
routing.Map("/allow-all", HandleRequest).WithCorsPolicy("AllowAll");
6968
});
7069

7170
app.UseCors();
@@ -77,5 +76,19 @@ public void Configure(IApplicationBuilder app, IHostingEnvironment env)
7776
await context.Response.WriteAsync("Hello World!");
7877
});
7978
}
79+
80+
private Task HandleRequest(HttpContext context)
81+
{
82+
var content = Encoding.UTF8.GetBytes("Hello world");
83+
84+
context.Response.Headers["X-AllowedHeader"] = "Test-Value";
85+
context.Response.Headers["X-DisallowedHeader"] = "Test-Value";
86+
87+
context.Response.ContentType = "text/plain; charset=utf-8";
88+
context.Response.ContentLength = content.Length;
89+
context.Response.Body.Write(content, 0, content.Length);
90+
91+
return Task.CompletedTask;
92+
}
8093
}
8194
}

src/CORS/src/Microsoft.AspNetCore.Cors/Infrastructure/CorsEndpointConventionBuilderExtensions.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
// Copyright (c) .NET Foundation. All rights reserved.
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

4+
using System;
45
using Microsoft.AspNetCore.Cors;
56
using Microsoft.AspNetCore.Routing;
6-
using System;
7-
using System.Linq;
87

98
namespace Microsoft.AspNetCore.Builder
109
{
1110
public static class CorsEndpointConventionBuilderExtensions
1211
{
13-
public static IEndpointConventionBuilder RequireCors(this IEndpointConventionBuilder builder, string policyName)
12+
public static IEndpointConventionBuilder WithCorsPolicy(this IEndpointConventionBuilder builder, string policyName)
1413
{
1514
if (builder == null)
1615
{

src/CORS/src/Microsoft.AspNetCore.Cors/Infrastructure/CorsMiddleware.cs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,24 @@ public Task Invoke(HttpContext context, ICorsPolicyProvider corsPolicyProvider)
125125

126126
private async Task InvokeCore(HttpContext context, ICorsPolicyProvider corsPolicyProvider)
127127
{
128-
var corsPolicy = _policy ?? await corsPolicyProvider.GetPolicyAsync(context, ResolveCorsPolicyName(context));
128+
var endpoint = context.GetEndpoint();
129+
130+
// Get the most significant CORS metadata for the endpoint
131+
// For backwards compatibility reasons this is then downcast to Enable/Disable metadata
132+
var corsMetadata = endpoint?.Metadata.GetMetadata<ICorsAttribute>();
133+
if (corsMetadata is IDisableCorsAttribute)
134+
{
135+
await _next(context);
136+
return;
137+
}
138+
139+
string policyName = _corsPolicyName;
140+
if (corsMetadata is IEnableCorsAttribute enableCorsAttribute)
141+
{
142+
policyName = enableCorsAttribute.PolicyName;
143+
}
144+
145+
var corsPolicy = _policy ?? await corsPolicyProvider.GetPolicyAsync(context, policyName);
129146
if (corsPolicy == null)
130147
{
131148
Logger?.NoCorsPolicyFound();
@@ -150,12 +167,6 @@ private async Task InvokeCore(HttpContext context, ICorsPolicyProvider corsPolic
150167
}
151168
}
152169

153-
internal string ResolveCorsPolicyName(HttpContext context)
154-
{
155-
var endpoint = context.GetEndpoint();
156-
return endpoint?.Metadata.GetMetadata<IEnableCorsAttribute>()?.PolicyName ?? _corsPolicyName;
157-
}
158-
159170
private static Task OnResponseStarting(object state)
160171
{
161172
var (middleware, context, result) = (Tuple<CorsMiddleware, HttpContext, CorsResult>)state;
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
namespace Microsoft.AspNetCore.Cors.Infrastructure
5+
{
6+
/// <summary>
7+
/// A marker interface which can be used to identify CORS metdata.
8+
/// </summary>
9+
public interface ICorsAttribute
10+
{
11+
}
12+
}

src/CORS/src/Microsoft.AspNetCore.Cors/Infrastructure/IDisableCorsAttribute.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
66
/// <summary>
77
/// An interface which can be used to identify a type which provides metdata to disable cors for a resource.
88
/// </summary>
9-
public interface IDisableCorsAttribute
9+
public interface IDisableCorsAttribute : ICorsAttribute
1010
{
1111
}
12-
}
12+
}

src/CORS/src/Microsoft.AspNetCore.Cors/Infrastructure/IEnableCorsAttribute.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
66
/// <summary>
77
/// An interface which can be used to identify a type which provides metadata needed for enabling CORS support.
88
/// </summary>
9-
public interface IEnableCorsAttribute
9+
public interface IEnableCorsAttribute : ICorsAttribute
1010
{
1111
/// <summary>
1212
/// The name of the policy which needs to be applied.
1313
/// </summary>
1414
string PolicyName { get; set; }
1515
}
16-
}
16+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using Microsoft.AspNetCore.Builder;
7+
using Microsoft.AspNetCore.Http;
8+
using Microsoft.AspNetCore.Routing;
9+
using Xunit;
10+
11+
namespace Microsoft.AspNetCore.Cors.Infrastructure
12+
{
13+
public class CorsEndpointConventionBuilderExtensionsTests
14+
{
15+
[Fact]
16+
public void WithCorsPolicy_MetadataAdded()
17+
{
18+
// Arrange
19+
var testConventionBuilder = new TestEndpointConventionBuilder();
20+
21+
// Act
22+
testConventionBuilder.WithCorsPolicy("TestPolicyName");
23+
24+
// Assert
25+
var addCorsPolicy = Assert.Single(testConventionBuilder.Conventions);
26+
27+
var endpointModel = new TestEndpointModel();
28+
addCorsPolicy(endpointModel);
29+
var endpoint = endpointModel.Build();
30+
31+
var metadata = endpoint.Metadata.GetMetadata<IEnableCorsAttribute>();
32+
Assert.NotNull(metadata);
33+
Assert.Equal("TestPolicyName", metadata.PolicyName);
34+
}
35+
36+
private class TestEndpointModel : EndpointModel
37+
{
38+
public override Endpoint Build()
39+
{
40+
return new Endpoint(RequestDelegate, new EndpointMetadataCollection(Metadata), DisplayName);
41+
}
42+
}
43+
44+
private class TestEndpointConventionBuilder : IEndpointConventionBuilder
45+
{
46+
public IList<Action<EndpointModel>> Conventions { get; } = new List<Action<EndpointModel>>();
47+
48+
public void Apply(Action<EndpointModel> convention)
49+
{
50+
Conventions.Add(convention);
51+
}
52+
}
53+
}
54+
}

src/CORS/test/Microsoft.AspNetCore.Cors.Test/CorsMiddlewareTests.cs

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -567,47 +567,123 @@ public async Task CorsRequest_SetsResponseHeader_IfExceptionHandlerClearsRespons
567567
}
568568

569569
[Fact]
570-
public void ResolveCorsPolicyName_NoEndpoint_UseDefaultPolicyName()
570+
public async Task Invoke_HasEndpointWithNoMetadata_RunsCors()
571571
{
572572
// Arrange
573-
var middleware = new CorsMiddleware(c => Task.CompletedTask, Mock.Of<ICorsService>(), Mock.Of<ILoggerFactory>(), "DefaultPolicyName");
574-
var context = new DefaultHttpContext();
573+
var corsService = Mock.Of<ICorsService>();
574+
var mockProvider = new Mock<ICorsPolicyProvider>();
575+
var loggerFactory = NullLoggerFactory.Instance;
576+
mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
577+
.Returns(Task.FromResult<CorsPolicy>(null))
578+
.Verifiable();
579+
580+
var middleware = new CorsMiddleware(
581+
Mock.Of<RequestDelegate>(),
582+
corsService,
583+
loggerFactory,
584+
"DefaultPolicyName");
585+
586+
var httpContext = new DefaultHttpContext();
587+
httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, EndpointMetadataCollection.Empty, "Test endpoint"));
588+
httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
575589

576590
// Act
577-
var resolvedPolicyName = middleware.ResolveCorsPolicyName(context);
591+
await middleware.Invoke(httpContext, mockProvider.Object);
578592

579593
// Assert
580-
Assert.Equal("DefaultPolicyName", resolvedPolicyName);
594+
mockProvider.Verify(
595+
o => o.GetPolicyAsync(It.IsAny<HttpContext>(), "DefaultPolicyName"),
596+
Times.Once);
581597
}
582598

583599
[Fact]
584-
public void ResolveCorsPolicyName_EndpointWithoutMetadata_UseDefaultPolicyName()
600+
public async Task Invoke_HasEndpointWithEnableMetadata_RunsCorsWithPolicyName()
585601
{
586602
// Arrange
587-
var middleware = new CorsMiddleware(c => Task.CompletedTask, Mock.Of<ICorsService>(), Mock.Of<ILoggerFactory>(), "DefaultPolicyName");
588-
var context = new DefaultHttpContext();
589-
context.SetEndpoint(new Endpoint(c => Task.CompletedTask, EndpointMetadataCollection.Empty, "Test endpoint"));
603+
var corsService = Mock.Of<ICorsService>();
604+
var mockProvider = new Mock<ICorsPolicyProvider>();
605+
var loggerFactory = NullLoggerFactory.Instance;
606+
mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
607+
.Returns(Task.FromResult<CorsPolicy>(null))
608+
.Verifiable();
609+
610+
var middleware = new CorsMiddleware(
611+
Mock.Of<RequestDelegate>(),
612+
corsService,
613+
loggerFactory,
614+
"DefaultPolicyName");
615+
616+
var httpContext = new DefaultHttpContext();
617+
httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName")), "Test endpoint"));
618+
httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
590619

591620
// Act
592-
var resolvedPolicyName = middleware.ResolveCorsPolicyName(context);
621+
await middleware.Invoke(httpContext, mockProvider.Object);
593622

594623
// Assert
595-
Assert.Equal("DefaultPolicyName", resolvedPolicyName);
624+
mockProvider.Verify(
625+
o => o.GetPolicyAsync(It.IsAny<HttpContext>(), "MetadataPolicyName"),
626+
Times.Once);
596627
}
597628

598629
[Fact]
599-
public void ResolveCorsPolicyName_EndpointWithMetadata_UseDefaultPolicyName()
630+
public async Task Invoke_HasEndpointWithDisableMetadata_SkipCors()
600631
{
601632
// Arrange
602-
var middleware = new CorsMiddleware(c => Task.CompletedTask, Mock.Of<ICorsService>(), Mock.Of<ILoggerFactory>(), "DefaultPolicyName");
603-
var context = new DefaultHttpContext();
604-
context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName")), "Test endpoint"));
633+
var corsService = Mock.Of<ICorsService>();
634+
var mockProvider = new Mock<ICorsPolicyProvider>();
635+
var loggerFactory = NullLoggerFactory.Instance;
636+
mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
637+
.Returns(Task.FromResult<CorsPolicy>(null))
638+
.Verifiable();
639+
640+
var middleware = new CorsMiddleware(
641+
Mock.Of<RequestDelegate>(),
642+
corsService,
643+
loggerFactory,
644+
"DefaultPolicyName");
645+
646+
var httpContext = new DefaultHttpContext();
647+
httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new DisableCorsAttribute()), "Test endpoint"));
648+
httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
649+
650+
// Act
651+
await middleware.Invoke(httpContext, mockProvider.Object);
652+
653+
// Assert
654+
mockProvider.Verify(
655+
o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()),
656+
Times.Never);
657+
}
658+
659+
[Fact]
660+
public async Task Invoke_HasEndpointWithMutlipleMetadata_SkipCorsBecauseOfMetadataOrder()
661+
{
662+
// Arrange
663+
var corsService = Mock.Of<ICorsService>();
664+
var mockProvider = new Mock<ICorsPolicyProvider>();
665+
var loggerFactory = NullLoggerFactory.Instance;
666+
mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
667+
.Returns(Task.FromResult<CorsPolicy>(null))
668+
.Verifiable();
669+
670+
var middleware = new CorsMiddleware(
671+
Mock.Of<RequestDelegate>(),
672+
corsService,
673+
loggerFactory,
674+
"DefaultPolicyName");
675+
676+
var httpContext = new DefaultHttpContext();
677+
httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName"), new DisableCorsAttribute()), "Test endpoint"));
678+
httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
605679

606680
// Act
607-
var resolvedPolicyName = middleware.ResolveCorsPolicyName(context);
681+
await middleware.Invoke(httpContext, mockProvider.Object);
608682

609683
// Assert
610-
Assert.Equal("MetadataPolicyName", resolvedPolicyName);
684+
mockProvider.Verify(
685+
o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()),
686+
Times.Never);
611687
}
612688
}
613689
}

0 commit comments

Comments
 (0)