diff --git a/docs/docfx/articles/rate-limiting.md b/docs/docfx/articles/rate-limiting.md new file mode 100644 index 000000000..39a488c48 --- /dev/null +++ b/docs/docfx/articles/rate-limiting.md @@ -0,0 +1,85 @@ +# Rate Limiting + +## Introduction +The reverse proxy can be used to rate-limit requests before they are proxied to the destination servers. This can reduce load on the destination servers, add a layer of protection, and ensure consistent policies are implemented across your applications. + +> This feature is only available when using .NET 7.0 or later + +## Defaults + +No rate limiting is performed on requests unless enabled in the route or application configuration. However, the Rate Limiting middleware (`app.UseRateLimiter()`) can apply a default limiter applied to all routes, and this doesn't require any opt-in from the config. + +Example: +```c# +builder.Services.AddRateLimiter(options => options.GlobalLimiter = globalLimiter); +``` + +## Configuration +Rate Limiter policies can be specified per route via [RouteConfig.RateLimiterPolicy](xref:Yarp.ReverseProxy.Configuration.RouteConfig) and can be bound from the `Routes` sections of the config file. As with other route properties, this can be modified and reloaded without restarting the proxy. Policy names are case insensitive. + +Example: +```JSON +{ + "ReverseProxy": { + "Routes": { + "route1" : { + "ClusterId": "cluster1", + "RateLimiterPolicy": "customPolicy", + "Match": { + "Hosts": [ "localhost" ] + }, + } + }, + "Clusters": { + "cluster1": { + "Destinations": { + "cluster1/destination1": { + "Address": "https://localhost:10001/" + } + } + } + } + } +} +``` + +[RateLimiter policies](https://learn.microsoft.com/aspnet/core/performance/rate-limit) are an ASP.NET Core concept that the proxy utilizes. The proxy provides the above configuration to specify a policy per route and the rest is handled by existing ASP.NET Core rate limiting middleware. + +RateLimiter policies can be configured in Startup.ConfigureServices as follows: +```c# +public void ConfigureServices(IServiceCollection services) +{ + services.AddRateLimiter(options => + { + options.AddFixedWindowLimiter("customPolicy", opt => + { + opt.PermitLimit = 4; + opt.Window = TimeSpan.FromSeconds(12); + opt.QueueProcessingOrder = QueueProcessingOrder.OldestFirst; + opt.QueueLimit = 2; + }); + }); +} +``` + +In Startup.Configure add the RateLimiter middleware between Routing and Endpoints. + +```c# +public void Configure(IApplicationBuilder app) +{ + app.UseRouting(); + + app.UseRateLimiter(); + + app.UseEndpoints(endpoints => + { + endpoints.MapReverseProxy(); + }); +} +``` + +See the [Rate Limiting](https://learn.microsoft.com/aspnet/core/performance/rate-limit) docs for setting up your preferred kind of rate limiting. + +### Disable Rate Limiting + +Specifying the value `disable` in a route's `RateLimiterPolicy` parameter means the rate limiter middleware will not apply any policies to this route, even the default policy. diff --git a/docs/docfx/articles/toc.yml b/docs/docfx/articles/toc.yml index da16a848a..38a1eb346 100644 --- a/docs/docfx/articles/toc.yml +++ b/docs/docfx/articles/toc.yml @@ -20,6 +20,8 @@ href: header-routing.md - name: Authentication and Authorization href: authn-authz.md +- name: Rate Limiting + href: rate-limiting.md - name: Cross-Origin Requests (CORS) href: cors.md - name: Session Affinity diff --git a/samples/KubernetesIngress.Sample/README.md b/samples/KubernetesIngress.Sample/README.md index 43b317136..dce8af4b1 100644 --- a/samples/KubernetesIngress.Sample/README.md +++ b/samples/KubernetesIngress.Sample/README.md @@ -42,6 +42,7 @@ metadata: namespace: default annotations: yarp.ingress.kubernetes.io/authorization-policy: authzpolicy + yarp.ingress.kubernetes.io/rate-limiter-policy: ratelimiterpolicy yarp.ingress.kubernetes.io/transforms: | - PathRemovePrefix: "/apis" yarp.ingress.kubernetes.io/route-headers: | @@ -73,6 +74,7 @@ The table below lists the available annotations. |Annotation|Data Type| |---|---| |yarp.ingress.kubernetes.io/authorization-policy|string| +|yarp.ingress.kubernetes.io/rate-limiter-policy|string| |yarp.ingress.kubernetes.io/backend-protocol|string| |yarp.ingress.kubernetes.io/cors-policy|string| |yarp.ingress.kubernetes.io/health-check|[ActivateHealthCheckConfig](https://microsoft.github.io/reverse-proxy/api/Yarp.ReverseProxy.Configuration.ActiveHealthCheckConfig.html)| @@ -90,6 +92,12 @@ See https://microsoft.github.io/reverse-proxy/articles/authn-authz.html for a li `yarp.ingress.kubernetes.io/authorization-policy: anonymous` +#### RateLimiter Policy + +See https://microsoft.github.io/reverse-proxy/articles/rate-limiting.html for a list of available policies, or how to add your own custom policies. + +`yarp.ingress.kubernetes.io/rate-limiter-policy: mypolicy` + #### Backend Protocol Specifies the protocol of the backend service. Defaults to http. diff --git a/src/Kubernetes.Controller/Converters/YarpIngressOptions.cs b/src/Kubernetes.Controller/Converters/YarpIngressOptions.cs index 164ecfbd4..6d2957698 100644 --- a/src/Kubernetes.Controller/Converters/YarpIngressOptions.cs +++ b/src/Kubernetes.Controller/Converters/YarpIngressOptions.cs @@ -11,6 +11,9 @@ internal sealed class YarpIngressOptions public bool Https { get; set; } public List> Transforms { get; set; } public string AuthorizationPolicy { get; set; } +#if NET7_0_OR_GREATER + public string RateLimiterPolicy { get; set; } +#endif public SessionAffinityConfig SessionAffinity { get; set; } public HttpClientConfig HttpClientConfig { get; set; } public string LoadBalancingPolicy { get; set; } @@ -38,4 +41,4 @@ public RouteHeader ToRouteHeader() IsCaseSensitive = IsCaseSensitive }; } -} \ No newline at end of file +} diff --git a/src/Kubernetes.Controller/Converters/YarpParser.cs b/src/Kubernetes.Controller/Converters/YarpParser.cs index 4e3092c1e..013373abb 100644 --- a/src/Kubernetes.Controller/Converters/YarpParser.cs +++ b/src/Kubernetes.Controller/Converters/YarpParser.cs @@ -104,6 +104,9 @@ private static void HandleIngressRulePath(YarpIngressContext ingressContext, V1S RouteId = $"{ingressContext.Ingress.Metadata.Name}.{ingressContext.Ingress.Metadata.NamespaceProperty}:{host}{path.Path}", Transforms = ingressContext.Options.Transforms, AuthorizationPolicy = ingressContext.Options.AuthorizationPolicy, +#if NET7_0_OR_GREATER + RateLimiterPolicy = ingressContext.Options.RateLimiterPolicy, +#endif CorsPolicy = ingressContext.Options.CorsPolicy, Metadata = ingressContext.Options.RouteMetadata, Order = ingressContext.Options.RouteOrder, @@ -171,16 +174,22 @@ private static YarpIngressOptions HandleAnnotations(YarpIngressContext context, if (annotations.TryGetValue("yarp.ingress.kubernetes.io/backend-protocol", out var http)) { - options.Https = http.Equals("https", StringComparison.OrdinalIgnoreCase); + options.Https = http.Equals("https", StringComparison.OrdinalIgnoreCase); } if (annotations.TryGetValue("yarp.ingress.kubernetes.io/transforms", out var transforms)) { - options.Transforms = YamlDeserializer.Deserialize>>(transforms); + options.Transforms = YamlDeserializer.Deserialize>>(transforms); } if (annotations.TryGetValue("yarp.ingress.kubernetes.io/authorization-policy", out var authorizationPolicy)) { options.AuthorizationPolicy = authorizationPolicy; } +#if NET7_0_OR_GREATER + if (annotations.TryGetValue("yarp.ingress.kubernetes.io/rate-limiter-policy", out var rateLimiterPolicy)) + { + options.RateLimiterPolicy = rateLimiterPolicy; + } +#endif if (annotations.TryGetValue("yarp.ingress.kubernetes.io/cors-policy", out var corsPolicy)) { options.CorsPolicy = corsPolicy; diff --git a/src/ReverseProxy/Configuration/ConfigProvider/ConfigurationConfigProvider.cs b/src/ReverseProxy/Configuration/ConfigProvider/ConfigurationConfigProvider.cs index 5198143c0..def838ded 100644 --- a/src/ReverseProxy/Configuration/ConfigProvider/ConfigurationConfigProvider.cs +++ b/src/ReverseProxy/Configuration/ConfigProvider/ConfigurationConfigProvider.cs @@ -147,6 +147,9 @@ private static RouteConfig CreateRoute(IConfigurationSection section) MaxRequestBodySize = section.ReadInt64(nameof(RouteConfig.MaxRequestBodySize)), ClusterId = section[nameof(RouteConfig.ClusterId)], AuthorizationPolicy = section[nameof(RouteConfig.AuthorizationPolicy)], +#if NET7_0_OR_GREATER + RateLimiterPolicy = section[nameof(RouteConfig.RateLimiterPolicy)], +#endif CorsPolicy = section[nameof(RouteConfig.CorsPolicy)], Metadata = section.GetSection(nameof(RouteConfig.Metadata)).ReadStringDictionary(), Transforms = CreateTransforms(section.GetSection(nameof(RouteConfig.Transforms))), diff --git a/src/ReverseProxy/Configuration/ConfigValidator.cs b/src/ReverseProxy/Configuration/ConfigValidator.cs index 72a24cecd..b06cc277f 100644 --- a/src/ReverseProxy/Configuration/ConfigValidator.cs +++ b/src/ReverseProxy/Configuration/ConfigValidator.cs @@ -29,6 +29,7 @@ internal sealed class ConfigValidator : IConfigValidator private readonly ITransformBuilder _transformBuilder; private readonly IAuthorizationPolicyProvider _authorizationPolicyProvider; + private readonly IYarpRateLimiterPolicyProvider _rateLimiterPolicyProvider; private readonly ICorsPolicyProvider _corsPolicyProvider; private readonly IDictionary _loadBalancingPolicies; private readonly IDictionary _affinityFailurePolicies; @@ -40,6 +41,7 @@ internal sealed class ConfigValidator : IConfigValidator public ConfigValidator(ITransformBuilder transformBuilder, IAuthorizationPolicyProvider authorizationPolicyProvider, + IYarpRateLimiterPolicyProvider rateLimiterPolicyProvider, ICorsPolicyProvider corsPolicyProvider, IEnumerable loadBalancingPolicies, IEnumerable affinityFailurePolicies, @@ -50,6 +52,7 @@ public ConfigValidator(ITransformBuilder transformBuilder, { _transformBuilder = transformBuilder ?? throw new ArgumentNullException(nameof(transformBuilder)); _authorizationPolicyProvider = authorizationPolicyProvider ?? throw new ArgumentNullException(nameof(authorizationPolicyProvider)); + _rateLimiterPolicyProvider = rateLimiterPolicyProvider ?? throw new ArgumentNullException(nameof(rateLimiterPolicyProvider)); _corsPolicyProvider = corsPolicyProvider ?? throw new ArgumentNullException(nameof(corsPolicyProvider)); _loadBalancingPolicies = loadBalancingPolicies?.ToDictionaryByUniqueId(p => p.Name) ?? throw new ArgumentNullException(nameof(loadBalancingPolicies)); _affinityFailurePolicies = affinityFailurePolicies?.ToDictionaryByUniqueId(p => p.Name) ?? throw new ArgumentNullException(nameof(affinityFailurePolicies)); @@ -72,6 +75,9 @@ public async ValueTask> ValidateRouteAsync(RouteConfig route) errors.AddRange(_transformBuilder.ValidateRoute(route)); await ValidateAuthorizationPolicyAsync(errors, route.AuthorizationPolicy, route.RouteId); +#if NET7_0_OR_GREATER + await ValidateRateLimiterPolicyAsync(errors, route.RateLimiterPolicy, route.RouteId); +#endif await ValidateCorsPolicyAsync(errors, route.CorsPolicy, route.RouteId); if (route.Match is null) @@ -287,6 +293,35 @@ private async ValueTask ValidateAuthorizationPolicyAsync(IList errors } } + private async ValueTask ValidateRateLimiterPolicyAsync(IList errors, string? rateLimiterPolicyName, string routeId) + { + if (string.IsNullOrEmpty(rateLimiterPolicyName)) + { + return; + } + + try + { + var policy = await _rateLimiterPolicyProvider.GetPolicyAsync(rateLimiterPolicyName); + + if (policy is null) + { + errors.Add(new ArgumentException($"RateLimiter policy '{rateLimiterPolicyName}' not found for route '{routeId}'.")); + return; + } + + if (string.Equals(RateLimitingConstants.Default, rateLimiterPolicyName, StringComparison.OrdinalIgnoreCase) + || string.Equals(RateLimitingConstants.Disable, rateLimiterPolicyName, StringComparison.OrdinalIgnoreCase)) + { + errors.Add(new ArgumentException($"The application has registered a RateLimiter policy named '{rateLimiterPolicyName}' that conflicts with the reserved RateLimiter policy name used on this route. The registered policy name needs to be changed for this route to function.")); + } + } + catch (Exception ex) + { + errors.Add(new ArgumentException($"Unable to retrieve the RateLimiter policy '{rateLimiterPolicyName}' for route '{routeId}'.", ex)); + } + } + private async ValueTask ValidateCorsPolicyAsync(IList errors, string? corsPolicyName, string routeId) { if (string.IsNullOrEmpty(corsPolicyName)) diff --git a/src/ReverseProxy/Configuration/IYarpRateLimiterPolicyProvider.cs b/src/ReverseProxy/Configuration/IYarpRateLimiterPolicyProvider.cs new file mode 100644 index 000000000..7a1315af1 --- /dev/null +++ b/src/ReverseProxy/Configuration/IYarpRateLimiterPolicyProvider.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#if NET7_0_OR_GREATER +using System; +using System.Collections; +using System.Reflection; +using Microsoft.AspNetCore.RateLimiting; +using Microsoft.Extensions.Options; +#endif + +using System.Threading.Tasks; + +namespace Yarp.ReverseProxy.Configuration; + +// TODO: update or remove this once AspNetCore provides a mechanism to validate the RateLimiter policies https://github.com/dotnet/aspnetcore/issues/45684 + + +internal interface IYarpRateLimiterPolicyProvider +{ + ValueTask GetPolicyAsync(string policyName); +} + +internal class YarpRateLimiterPolicyProvider : IYarpRateLimiterPolicyProvider +{ +#if NET7_0_OR_GREATER + private readonly RateLimiterOptions _rateLimiterOptions; + + private readonly IDictionary _policyMap, _unactivatedPolicyMap; + + public YarpRateLimiterPolicyProvider(IOptions rateLimiterOptions) + { + _rateLimiterOptions = rateLimiterOptions?.Value ?? throw new ArgumentNullException(nameof(rateLimiterOptions)); + + var type = typeof(RateLimiterOptions); + var flags = BindingFlags.Instance | BindingFlags.NonPublic; + _policyMap = type.GetProperty("PolicyMap", flags)?.GetValue(_rateLimiterOptions, null) as IDictionary + ?? throw new NotSupportedException("This version of YARP is incompatible with the current version of ASP.NET Core."); + _unactivatedPolicyMap = type.GetProperty("UnactivatedPolicyMap", flags)?.GetValue(_rateLimiterOptions, null) as IDictionary + ?? throw new NotSupportedException("This version of YARP is incompatible with the current version of ASP.NET Core."); + } + + public ValueTask GetPolicyAsync(string policyName) + { + return ValueTask.FromResult(_policyMap[policyName] ?? _unactivatedPolicyMap[policyName]); + } +#else + public ValueTask GetPolicyAsync(string policyName) + { + return default; + } +#endif +} diff --git a/src/ReverseProxy/Configuration/RateLimitingConstants.cs b/src/ReverseProxy/Configuration/RateLimitingConstants.cs new file mode 100644 index 000000000..49a5cd1d2 --- /dev/null +++ b/src/ReverseProxy/Configuration/RateLimitingConstants.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Yarp.ReverseProxy.Configuration; + +internal static class RateLimitingConstants +{ + internal const string Default = "Default"; + internal const string Disable = "Disable"; +} diff --git a/src/ReverseProxy/Configuration/RouteConfig.cs b/src/ReverseProxy/Configuration/RouteConfig.cs index c773b95d3..0ca059266 100644 --- a/src/ReverseProxy/Configuration/RouteConfig.cs +++ b/src/ReverseProxy/Configuration/RouteConfig.cs @@ -43,7 +43,15 @@ public sealed record RouteConfig /// Set to "Anonymous" to disable all authorization checks for this route. /// public string? AuthorizationPolicy { get; init; } - +#if NET7_0_OR_GREATER + /// + /// The name of the RateLimiterPolicy to apply to this route. + /// If not set then only the GlobalLimiter will apply. + /// Set to "Disable" to disable rate limiting for this route. + /// Set to "Default" or leave empty to use the global rate limits, if any. + /// + public string? RateLimiterPolicy { get; init; } +#endif /// /// The name of the CorsPolicy to apply to this route. /// If not set then the route won't be automatically matched for cors preflight requests. @@ -79,6 +87,9 @@ public bool Equals(RouteConfig? other) && string.Equals(RouteId, other.RouteId, StringComparison.OrdinalIgnoreCase) && string.Equals(ClusterId, other.ClusterId, StringComparison.OrdinalIgnoreCase) && string.Equals(AuthorizationPolicy, other.AuthorizationPolicy, StringComparison.OrdinalIgnoreCase) +#if NET7_0_OR_GREATER + && string.Equals(RateLimiterPolicy, other.RateLimiterPolicy, StringComparison.OrdinalIgnoreCase) +#endif && string.Equals(CorsPolicy, other.CorsPolicy, StringComparison.OrdinalIgnoreCase) && Match == other.Match && CaseSensitiveEqualHelper.Equals(Metadata, other.Metadata) @@ -87,13 +98,19 @@ public bool Equals(RouteConfig? other) public override int GetHashCode() { - return HashCode.Combine(Order, - RouteId?.GetHashCode(StringComparison.OrdinalIgnoreCase), - ClusterId?.GetHashCode(StringComparison.OrdinalIgnoreCase), - AuthorizationPolicy?.GetHashCode(StringComparison.OrdinalIgnoreCase), - CorsPolicy?.GetHashCode(StringComparison.OrdinalIgnoreCase), - Match, - CaseSensitiveEqualHelper.GetHashCode(Metadata), - CaseSensitiveEqualHelper.GetHashCode(Transforms)); + // HashCode.Combine(...) takes only 8 arguments + var hash = new HashCode(); + hash.Add(Order); + hash.Add(RouteId?.GetHashCode(StringComparison.OrdinalIgnoreCase)); + hash.Add(ClusterId?.GetHashCode(StringComparison.OrdinalIgnoreCase)); + hash.Add(AuthorizationPolicy?.GetHashCode(StringComparison.OrdinalIgnoreCase)); +#if NET7_0_OR_GREATER + hash.Add(RateLimiterPolicy?.GetHashCode(StringComparison.OrdinalIgnoreCase)); +#endif + hash.Add(CorsPolicy?.GetHashCode(StringComparison.OrdinalIgnoreCase)); + hash.Add(Match); + hash.Add(CaseSensitiveEqualHelper.GetHashCode(Metadata)); + hash.Add(CaseSensitiveEqualHelper.GetHashCode(Transforms)); + return hash.ToHashCode(); } } diff --git a/src/ReverseProxy/Management/IReverseProxyBuilderExtensions.cs b/src/ReverseProxy/Management/IReverseProxyBuilderExtensions.cs index 158831d1d..6fb05e439 100644 --- a/src/ReverseProxy/Management/IReverseProxyBuilderExtensions.cs +++ b/src/ReverseProxy/Management/IReverseProxyBuilderExtensions.cs @@ -22,6 +22,7 @@ internal static class IReverseProxyBuilderExtensions { public static IReverseProxyBuilder AddConfigBuilder(this IReverseProxyBuilder builder) { + builder.Services.TryAddSingleton(); builder.Services.TryAddSingleton(); builder.Services.TryAddSingleton(); builder.AddTransformFactory(); diff --git a/src/ReverseProxy/Routing/ProxyEndpointFactory.cs b/src/ReverseProxy/Routing/ProxyEndpointFactory.cs index 34da22d89..2ab8aeae8 100644 --- a/src/ReverseProxy/Routing/ProxyEndpointFactory.cs +++ b/src/ReverseProxy/Routing/ProxyEndpointFactory.cs @@ -9,17 +9,24 @@ using Microsoft.AspNetCore.Cors; using Microsoft.AspNetCore.Cors.Infrastructure; using Microsoft.AspNetCore.Http; +#if NET7_0_OR_GREATER +using Microsoft.AspNetCore.RateLimiting; +#endif using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Routing.Patterns; using Yarp.ReverseProxy.Model; using CorsConstants = Yarp.ReverseProxy.Configuration.CorsConstants; using AuthorizationConstants = Yarp.ReverseProxy.Configuration.AuthorizationConstants; +using RateLimitingConstants = Yarp.ReverseProxy.Configuration.RateLimitingConstants; namespace Yarp.ReverseProxy.Routing; internal sealed class ProxyEndpointFactory { private static readonly IAuthorizeData _defaultAuthorization = new AuthorizeAttribute(); +#if NET7_0_OR_GREATER + private static readonly DisableRateLimitingAttribute _disableRateLimit = new(); +#endif private static readonly IEnableCorsAttribute _defaultCors = new EnableCorsAttribute(); private static readonly IDisableCorsAttribute _disableCors = new DisableCorsAttribute(); private static readonly IAllowAnonymous _allowAnonymous = new AllowAnonymousAttribute(); @@ -110,6 +117,21 @@ public Endpoint CreateEndpoint(RouteModel route, IReadOnlyList() { { "f", "f1" } } }; var g = a with { Order = null }; var h = a with { RouteId = "h" }; +#if NET7_0_OR_GREATER + var i = a with { RateLimiterPolicy = "i" }; +#endif Assert.False(a.Equals(b)); Assert.False(a.Equals(c)); @@ -122,6 +134,9 @@ public void Equals_Negative() Assert.False(a.Equals(f)); Assert.False(a.Equals(g)); Assert.False(a.Equals(h)); +#if NET7_0_OR_GREATER + Assert.False(a.Equals(i)); +#endif } [Fact] @@ -136,6 +151,9 @@ public void RouteConfig_CanBeJsonSerialized() var route1 = new RouteConfig() { AuthorizationPolicy = "a", +#if NET7_0_OR_GREATER + RateLimiterPolicy = "rl", +#endif ClusterId = "c", CorsPolicy = "co", Match = new RouteMatch() diff --git a/test/ReverseProxy.Tests/Configuration/YarpRateLimiterPolicyProviderTests.cs b/test/ReverseProxy.Tests/Configuration/YarpRateLimiterPolicyProviderTests.cs new file mode 100644 index 000000000..a00aea558 --- /dev/null +++ b/test/ReverseProxy.Tests/Configuration/YarpRateLimiterPolicyProviderTests.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#if NET7_0_OR_GREATER +using System; +using System.Threading.Tasks; +using System.Threading.RateLimiting; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.RateLimiting; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Yarp.ReverseProxy.Configuration; + +public class YarpRateLimiterPolicyProviderTests +{ + [Fact] + public async Task GetPolicyAsync_Works() + { + var services = new ServiceCollection(); + + services.AddRateLimiter(options => + { + options.AddFixedWindowLimiter("customPolicy", opt => + { + opt.PermitLimit = 4; + opt.Window = TimeSpan.FromSeconds(12); + opt.QueueProcessingOrder = QueueProcessingOrder.OldestFirst; + opt.QueueLimit = 2; + }); + }); + + services.AddReverseProxy(); + var provider = services.BuildServiceProvider(); + var rateLimiterPolicyProvider = provider.GetRequiredService(); + Assert.Null(await rateLimiterPolicyProvider.GetPolicyAsync("anotherPolicy")); + Assert.NotNull(await rateLimiterPolicyProvider.GetPolicyAsync("customPolicy")); + } +} +#endif diff --git a/test/ReverseProxy.Tests/Routing/ProxyEndpointFactoryTests.cs b/test/ReverseProxy.Tests/Routing/ProxyEndpointFactoryTests.cs index ddfaee8fc..c1c3a88ab 100644 --- a/test/ReverseProxy.Tests/Routing/ProxyEndpointFactoryTests.cs +++ b/test/ReverseProxy.Tests/Routing/ProxyEndpointFactoryTests.cs @@ -5,6 +5,9 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Cors; using Microsoft.AspNetCore.Cors.Infrastructure; +#if NET7_0_OR_GREATER +using Microsoft.AspNetCore.RateLimiting; +#endif using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Routing.Patterns; using Microsoft.Extensions.DependencyInjection; @@ -321,6 +324,101 @@ public void AddEndpoint_NoAuth_Works() Assert.Null(routeEndpoint.Metadata.GetMetadata()); } +#if NET7_0_OR_GREATER + [Fact] + public void AddEndpoint_DefaultRateLimiter_Works() + { + var services = CreateServices(); + var factory = services.GetRequiredService(); + factory.SetProxyPipeline(context => Task.CompletedTask); + + var route = new RouteConfig + { + RouteId = "route1", + RateLimiterPolicy = "defaulT", + Order = 12, + Match = new RouteMatch(), + }; + var cluster = new ClusterState("cluster1"); + var routeState = new RouteState("route1"); + + var (routeEndpoint, _) = CreateEndpoint(factory, routeState, route, cluster); + + Assert.Null(routeEndpoint.Metadata.GetMetadata()); + Assert.Null(routeEndpoint.Metadata.GetMetadata()); + } + + [Fact] + public void AddEndpoint_CustomRateLimiter_Works() + { + var services = CreateServices(); + var factory = services.GetRequiredService(); + factory.SetProxyPipeline(context => Task.CompletedTask); + + var route = new RouteConfig + { + RouteId = "route1", + RateLimiterPolicy = "custom", + Order = 12, + Match = new RouteMatch(), + }; + var cluster = new ClusterState("cluster1"); + var routeState = new RouteState("route1"); + + var (routeEndpoint, _) = CreateEndpoint(factory, routeState, route, cluster); + + var attribute = routeEndpoint.Metadata.GetMetadata(); + Assert.NotNull(attribute); + Assert.Equal("custom", attribute.PolicyName); + Assert.Null(routeEndpoint.Metadata.GetMetadata()); + } + + [Fact] + public void AddEndpoint_DisableRateLimiter_Works() + { + var services = CreateServices(); + var factory = services.GetRequiredService(); + factory.SetProxyPipeline(context => Task.CompletedTask); + + var route = new RouteConfig + { + RouteId = "route1", + RateLimiterPolicy = "disAble", + Order = 12, + Match = new RouteMatch(), + }; + var cluster = new ClusterState("cluster1"); + var routeState = new RouteState("route1"); + + var (routeEndpoint, _) = CreateEndpoint(factory, routeState, route, cluster); + + Assert.NotNull(routeEndpoint.Metadata.GetMetadata()); + Assert.Null(routeEndpoint.Metadata.GetMetadata()); + } + + [Fact] + public void AddEndpoint_NoRateLimiter_Works() + { + var services = CreateServices(); + var factory = services.GetRequiredService(); + factory.SetProxyPipeline(context => Task.CompletedTask); + + var route = new RouteConfig + { + RouteId = "route1", + Order = 12, + Match = new RouteMatch(), + }; + var cluster = new ClusterState("cluster1"); + var routeState = new RouteState("route1"); + + var (routeEndpoint, _) = CreateEndpoint(factory, routeState, route, cluster); + + Assert.Null(routeEndpoint.Metadata.GetMetadata()); + Assert.Null(routeEndpoint.Metadata.GetMetadata()); + } +#endif + [Fact] public void AddEndpoint_DefaultCors_Works() {