diff --git a/libraries/src/AWS.Lambda.Powertools.Logging/Internal/LoggingLambdaContext.cs b/libraries/src/AWS.Lambda.Powertools.Logging/Internal/LoggingLambdaContext.cs index a8846b155..17c5a3a8c 100644 --- a/libraries/src/AWS.Lambda.Powertools.Logging/Internal/LoggingLambdaContext.cs +++ b/libraries/src/AWS.Lambda.Powertools.Logging/Internal/LoggingLambdaContext.cs @@ -74,24 +74,21 @@ public static bool Extract(AspectEventArgs args) return false; var index = Array.FindIndex(args.Method.GetParameters(), p => p.ParameterType == typeof(ILambdaContext)); - if (index >= 0) - { - var x = (ILambdaContext)args.Args[index]; - - Instance = new LoggingLambdaContext - { - AwsRequestId = x.AwsRequestId, - FunctionName = x.FunctionName, - FunctionVersion = x.FunctionVersion, - InvokedFunctionArn = x.InvokedFunctionArn, - LogGroupName = x.LogGroupName, - LogStreamName = x.LogStreamName, - MemoryLimitInMB = x.MemoryLimitInMB - }; - return true; - } + if (index < 0 || args.Args[index] == null || args.Args[index] is not ILambdaContext) return false; + + var x = (ILambdaContext)args.Args[index]; - return false; + Instance = new LoggingLambdaContext + { + AwsRequestId = x.AwsRequestId, + FunctionName = x.FunctionName, + FunctionVersion = x.FunctionVersion, + InvokedFunctionArn = x.InvokedFunctionArn, + LogGroupName = x.LogGroupName, + LogStreamName = x.LogStreamName, + MemoryLimitInMB = x.MemoryLimitInMB + }; + return true; } /// diff --git a/libraries/tests/AWS.Lambda.Powertools.Logging.Tests/Context/LambdaContextTest.cs b/libraries/tests/AWS.Lambda.Powertools.Logging.Tests/Context/LambdaContextTest.cs index feb9283e9..31e980ba8 100644 --- a/libraries/tests/AWS.Lambda.Powertools.Logging.Tests/Context/LambdaContextTest.cs +++ b/libraries/tests/AWS.Lambda.Powertools.Logging.Tests/Context/LambdaContextTest.cs @@ -56,6 +56,33 @@ public void Extract_WhenHasLambdaContextArgument_InitializesLambdaContextInfo() Assert.Null(LoggingLambdaContext.Instance); } + [Fact] + public void Extract_When_LambdaContext_Is_Null_But_Not_First_Parameter_Returns_False() + { + // Arrange + ILambdaContext lambdaContext = null; + var args = Substitute.For(); + var method = Substitute.For(); + var parameter1 = Substitute.For(); + var parameter2 = Substitute.For(); + + // Setup parameters + parameter1.ParameterType.Returns(typeof(string)); + parameter2.ParameterType.Returns(typeof(ILambdaContext)); + + // Setup method + method.GetParameters().Returns(new[] { parameter1, parameter2 }); + + // Setup args + args.Method = method; + args.Args = new object[] { "requestContext", lambdaContext }; + + // Act && Assert + LoggingLambdaContext.Clear(); + Assert.Null(LoggingLambdaContext.Instance); + Assert.False(LoggingLambdaContext.Extract(args)); + } + [Fact] public void Extract_When_Args_Null_Returns_False() {