diff --git a/packages/tracer/src/middleware/middy.ts b/packages/tracer/src/middleware/middy.ts index 132228390d..1d1a8b63f6 100644 --- a/packages/tracer/src/middleware/middy.ts +++ b/packages/tracer/src/middleware/middy.ts @@ -1,3 +1,4 @@ +import { TRACER_KEY } from '@aws-lambda-powertools/commons/lib/middleware'; import type { Tracer } from '../Tracer'; import type { Segment, Subsegment } from 'aws-xray-sdk-core'; import type { CaptureLambdaHandlerOptions } from '../types'; @@ -40,6 +41,18 @@ const captureLambdaHandler = ( let lambdaSegment: Segment; let handlerSegment: Subsegment; + /** + * Set the cleanup function to be called in case other middlewares return early. + * + * @param request - The request object + */ + const setCleanupFunction = (request: MiddyLikeRequest): void => { + request.internal = { + ...request.internal, + [TRACER_KEY]: close, + }; + }; + const open = (): void => { const segment = target.getSegment(); if (segment === undefined) { @@ -61,9 +74,12 @@ const captureLambdaHandler = ( target.setSegment(lambdaSegment); }; - const captureLambdaHandlerBefore = async (): Promise => { + const captureLambdaHandlerBefore = async ( + request: MiddyLikeRequest + ): Promise => { if (target.isTracingEnabled()) { open(); + setCleanupFunction(request); target.annotateColdStart(); target.addServiceNameAnnotation(); } diff --git a/packages/tracer/tests/unit/middy.test.ts b/packages/tracer/tests/unit/middy.test.ts index f3fc79e1c3..7948817baf 100644 --- a/packages/tracer/tests/unit/middy.test.ts +++ b/packages/tracer/tests/unit/middy.test.ts @@ -3,7 +3,6 @@ * * @group unit/tracer/all */ - import { captureLambdaHandler } from '../../src/middleware/middy'; import middy from '@middy/core'; import { Tracer } from './../../src'; @@ -13,6 +12,7 @@ import { setContextMissingStrategy, Subsegment, } from 'aws-xray-sdk-core'; +import { cleanupMiddlewares } from '@aws-lambda-powertools/commons/lib/middleware'; jest.spyOn(console, 'debug').mockImplementation(() => null); jest.spyOn(console, 'warn').mockImplementation(() => null); @@ -306,4 +306,66 @@ describe('Middy middleware', () => { 'hello-world' ); }); + + test('when enabled, and another middleware returns early, it still closes and restores the segments correctly', async () => { + // Prepare + const tracer = new Tracer(); + const setSegmentSpy = jest + .spyOn(tracer.provider, 'setSegment') + .mockImplementation(() => ({})); + jest.spyOn(tracer, 'annotateColdStart').mockImplementation(() => ({})); + jest + .spyOn(tracer, 'addServiceNameAnnotation') + .mockImplementation(() => ({})); + const facadeSegment1 = new Segment('facade'); + const handlerSubsegment1 = new Subsegment('## index.handlerA'); + jest + .spyOn(facadeSegment1, 'addNewSubsegment') + .mockImplementation(() => handlerSubsegment1); + const facadeSegment2 = new Segment('facade'); + const handlerSubsegment2 = new Subsegment('## index.handlerB'); + jest + .spyOn(facadeSegment2, 'addNewSubsegment') + .mockImplementation(() => handlerSubsegment2); + jest + .spyOn(tracer.provider, 'getSegment') + .mockImplementationOnce(() => facadeSegment1) + .mockImplementationOnce(() => facadeSegment2); + const myCustomMiddleware = (): middy.MiddlewareObj => { + const before = async ( + request: middy.Request + ): Promise => { + // Return early on the second invocation + if (request.event.idx === 1) { + // Cleanup Powertools resources + await cleanupMiddlewares(request); + + // Then return early + return 'foo'; + } + }; + + return { + before, + }; + }; + const handler = middy((): void => { + console.log('Hello world!'); + }) + .use(captureLambdaHandler(tracer, { captureResponse: false })) + .use(myCustomMiddleware()); + + // Act + await handler({ idx: 0 }, context); + await handler({ idx: 1 }, context); + + // Assess + // Check that the subsegments are closed + expect(handlerSubsegment1.isClosed()).toBe(true); + expect(handlerSubsegment2.isClosed()).toBe(true); + // Check that the segments are restored + expect(setSegmentSpy).toHaveBeenCalledTimes(4); + expect(setSegmentSpy).toHaveBeenNthCalledWith(2, facadeSegment1); + expect(setSegmentSpy).toHaveBeenNthCalledWith(4, facadeSegment2); + }); });