Skip to content
Merged
66 changes: 63 additions & 3 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Linq.Expressions;
using System.Reflection;
using System.Security.Claims;
using System.Text;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Http.Metadata;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -189,6 +190,13 @@ private static Expression[] CreateArguments(ParameterInfo[]? parameters, Factory
args[i] = CreateArgument(parameters[i], factoryContext);
}

if (factoryContext.HasMultipleBodyParameters)
{
var errorMessage = BuildErrorMessageForMultipleBodyParameters(factoryContext);
throw new InvalidOperationException(errorMessage);

}

return args;
}

Expand All @@ -203,6 +211,7 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext

if (parameterCustomAttributes.OfType<IFromRouteMetadata>().FirstOrDefault() is { } routeAttribute)
{
factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.RouteAttribue);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great call to add a separate class for all the constants we're introducing here.

if (factoryContext.RouteParameters is { } routeParams && !routeParams.Contains(parameter.Name, StringComparer.OrdinalIgnoreCase))
{
throw new InvalidOperationException($"{parameter.Name} is not a route paramter.");
Expand All @@ -212,18 +221,22 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext
}
else if (parameterCustomAttributes.OfType<IFromQueryMetadata>().FirstOrDefault() is { } queryAttribute)
{
factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.QueryAttribue);
return BindParameterFromProperty(parameter, QueryExpr, queryAttribute.Name ?? parameter.Name, factoryContext);
}
else if (parameterCustomAttributes.OfType<IFromHeaderMetadata>().FirstOrDefault() is { } headerAttribute)
{
factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.HeaderAttribue);
return BindParameterFromProperty(parameter, HeadersExpr, headerAttribute.Name ?? parameter.Name, factoryContext);
}
else if (parameterCustomAttributes.OfType<IFromBodyMetadata>().FirstOrDefault() is { } bodyAttribute)
{
factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.BodyAttribue);
return BindParameterFromBody(parameter, bodyAttribute.AllowEmpty, factoryContext);
}
else if (parameter.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType)))
{
factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.ServiceAttribue);
return BindParameterFromService(parameter);
}
else if (parameter.ParameterType == typeof(HttpContext))
Expand Down Expand Up @@ -254,18 +267,22 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext
// when RDF.Create is manually invoked.
if (factoryContext.RouteParameters is { } routeParams)
{

if (routeParams.Contains(parameter.Name, StringComparer.OrdinalIgnoreCase))
{
// We're in the fallback case and we have a parameter and route parameter match so don't fallback
// to query string in this case
factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.RouteParameter);
return BindParameterFromProperty(parameter, RouteValuesExpr, parameter.Name, factoryContext);
}
else
{
factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.QueryStringParameter);
return BindParameterFromProperty(parameter, QueryExpr, parameter.Name, factoryContext);
}
}

factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.RouteOrQueryStringParameter);
return BindParameterFromRouteValueOrQueryString(parameter, parameter.Name, factoryContext);
}
else
Expand All @@ -274,10 +291,12 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext
{
if (serviceProviderIsService.IsService(parameter.ParameterType))
{
factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.ServiceParameter);
return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr);
}
}

factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.BodyParameter);
return BindParameterFromBody(parameter, allowEmpty: false, factoryContext);
}
}
Expand Down Expand Up @@ -500,7 +519,6 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall,
return async (target, httpContext) =>
{
object? bodyValue = defaultBodyValue;

var feature = httpContext.Features.Get<IHttpRequestBodyDetectionFeature>();
if (feature?.CanHaveBody == true)
{
Expand All @@ -515,12 +533,12 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall,
}
catch (InvalidDataException ex)
{

Log.RequestBodyInvalidDataException(httpContext, ex);
httpContext.Response.StatusCode = 400;
return;
}
}

await invoker(target, httpContext, bodyValue);
};
}
Expand Down Expand Up @@ -725,7 +743,14 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al
{
if (factoryContext.JsonRequestBodyType is not null)
{
throw new InvalidOperationException("Action cannot have more than one FromBody attribute.");
factoryContext.HasMultipleBodyParameters = true;
var parameterName = parameter.Name;
if (parameterName is not null && factoryContext.TrackedParameters.ContainsKey(parameterName))
{
factoryContext.TrackedParameters.Remove(parameterName);
factoryContext.TrackedParameters.Add(parameterName, "UNKNOWN");

}
}

var nullability = NullabilityContext.Create(parameter);
Expand Down Expand Up @@ -963,6 +988,24 @@ private class FactoryContext
public bool UsingTempSourceString { get; set; }
public List<ParameterExpression> ExtraLocals { get; } = new();
public List<Expression> ParamCheckExpressions { get; } = new();

public Dictionary<string, string> TrackedParameters { get; } = new();
public bool HasMultipleBodyParameters { get; set; }
}

private static class RequestDelegateFactoryConstants
{
public const string RouteAttribue = "Route (Attribute)";
public const string QueryAttribue = "Query (Attribute)";
public const string HeaderAttribue = "Header (Attribute)";
public const string BodyAttribue = "Body (Attribute)";
public const string ServiceAttribue = "Service (Attribute)";
public const string RouteParameter = "Route (Inferred)";
public const string QueryStringParameter = "Query String (Inferred)";
public const string ServiceParameter = "Services (Inferred)";
public const string BodyParameter = "Body (Inferred)";
public const string RouteOrQueryStringParameter = "Route or Query String (Inferred)";

}

private static partial class Log
Expand Down Expand Up @@ -1032,5 +1075,22 @@ private static void SetPlaintextContentType(HttpContext httpContext)
{
httpContext.Response.ContentType ??= "text/plain; charset=utf-8";
}

private static string BuildErrorMessageForMultipleBodyParameters(FactoryContext factoryContext)
{
var errorMessage = new StringBuilder();
errorMessage.Append($"Failure to infer one or more parameters.\n");
errorMessage.Append("Below is the list of parameters that we found: \n\n");
errorMessage.Append($"{"Parameter",-20}|{"Source",-30} \n");
errorMessage.Append("---------------------------------------------------------------------------------\n");

foreach (var kv in factoryContext.TrackedParameters)
{
errorMessage.Append($"{kv.Key,-19} | {kv.Value,-15}\n");
}
errorMessage.Append("\n\n");
errorMessage.Append("Did you mean to register the \"UNKNOWN\" parameters as a Service?\n\n");
return errorMessage.ToString();
}
}
}