From 7f6c0465f51ec44c06aa08281f3d5aa1af7a0c26 Mon Sep 17 00:00:00 2001 From: Tuomas Date: Mon, 17 Feb 2025 15:57:02 -0600 Subject: [PATCH 1/3] Add timeout reset on progress notifications --- src/shared/protocol.test.ts | 176 ++++++++++++++++++++++++++++++++++++ src/shared/protocol.ts | 166 ++++++++++++++++++++++++---------- 2 files changed, 292 insertions(+), 50 deletions(-) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 3073d0af4..bf0357b2a 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -62,6 +62,182 @@ describe("protocol tests", () => { await transport.close(); expect(oncloseMock).toHaveBeenCalled(); }); + + test("should reset timeout when progress notification is received", async () => { + jest.useFakeTimers(); + + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, // Increased timeout for more reliable testing + resetTimeoutOnProgress: true, + onprogress: onProgressMock, + }); + + // Advance time close to timeout + jest.advanceTimersByTime(800); + + // Send progress notification + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: 50, + total: 100, + }, + }); + } + + // Run all pending promises to ensure progress handler is called + await Promise.resolve(); + + // Verify progress handler was called + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 50, + total: 100, + }); + + // Send success response + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + id: 0, + result: { result: "success" }, + }); + } + + // Run all pending promises + await Promise.resolve(); + + await expect(requestPromise).resolves.toEqual({ result: "success" }); + + jest.useRealTimers(); + }); + + test("should respect maxTotalTimeout", async () => { + jest.useFakeTimers(); + + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + maxTotalTimeout: 100, + resetTimeoutOnProgress: true, + onprogress: onProgressMock, + }); + + // Advance time beyond maxTotalTimeout + jest.advanceTimersByTime(150); + + // Send progress notification after maxTotalTimeout + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: 50, + total: 100, + }, + }); + } + + await expect(requestPromise).rejects.toThrow("Maximum total timeout exceeded"); + expect(onProgressMock).not.toHaveBeenCalled(); + + jest.useRealTimers(); + }); + + test("should timeout if no progress received within timeout period", async () => { + jest.useFakeTimers(); + + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + + const requestPromise = protocol.request(request, mockSchema, { + timeout: 100, + resetTimeoutOnProgress: true, + }); + + // Advance time beyond timeout + jest.advanceTimersByTime(101); + + await expect(requestPromise).rejects.toThrow("Request timed out"); + + jest.useRealTimers(); + }); + + test("should handle multiple progress notifications correctly", async () => { + jest.useFakeTimers(); + + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + resetTimeoutOnProgress: true, + onprogress: onProgressMock, + }); + + // Simulate multiple progress updates + for (let i = 1; i <= 3; i++) { + // Advance close to timeout + jest.advanceTimersByTime(800); + + // Send progress notification + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: i * 25, + total: 100, + }, + }); + } + + // Verify progress handler was called + await Promise.resolve(); + expect(onProgressMock).toHaveBeenNthCalledWith(i, { + progress: i * 25, + total: 100, + }); + } + + // Send success response + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + id: 0, + result: { result: "success" }, + }); + } + + await Promise.resolve(); + await expect(requestPromise).resolves.toEqual({ result: "success" }); + + jest.useRealTimers(); + }); }); describe("mergeCapabilities", () => { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index a4f211c67..7e58cbe15 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -64,6 +64,20 @@ export type RequestOptions = { * If not specified, `DEFAULT_REQUEST_TIMEOUT_MSEC` will be used as the timeout. */ timeout?: number; + + /** + * If true, receiving a progress notification will reset the request timeout. + * This is useful for long-running operations that send periodic progress updates. + * Default: false + */ + resetTimeoutOnProgress?: boolean; + + /** + * Maximum total time (in milliseconds) to wait for a response, even if progress notifications are received. + * Only used when resetTimeoutOnProgress is true. + * If not specified, there is no maximum total timeout. + */ + maxTotalTimeout?: number; }; /** @@ -76,6 +90,17 @@ export type RequestHandlerExtra = { signal: AbortSignal; }; +/** + * Information about a request's timeout state + */ +type TimeoutInfo = { + timeoutId: ReturnType; + startTime: number; + timeout: number; + maxTotalTimeout?: number; + onTimeout: () => void; +}; + /** * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. @@ -105,6 +130,7 @@ export abstract class Protocol< (response: JSONRPCResponse | Error) => void > = new Map(); private _progressHandlers: Map = new Map(); + private _timeoutInfo: Map = new Map(); /** * Callback for when the connection is closed for any reason. @@ -149,6 +175,49 @@ export abstract class Protocol< ); } + private _setupTimeout( + messageId: number, + timeout: number, + maxTotalTimeout: number | undefined, + onTimeout: () => void + ) { + this._timeoutInfo.set(messageId, { + timeoutId: setTimeout(onTimeout, timeout), + startTime: Date.now(), + timeout, + maxTotalTimeout, + onTimeout + }); + } + + private _resetTimeout(messageId: number, cancel: (reason: unknown) => void): boolean { + const info = this._timeoutInfo.get(messageId); + if (!info) return false; + + const totalElapsed = Date.now() - info.startTime; + if (info.maxTotalTimeout && totalElapsed >= info.maxTotalTimeout) { + this._timeoutInfo.delete(messageId); + cancel(new McpError( + ErrorCode.RequestTimeout, + "Maximum total timeout exceeded", + { maxTotalTimeout: info.maxTotalTimeout, totalElapsed } + )); + return false; + } + + clearTimeout(info.timeoutId); + info.timeoutId = setTimeout(info.onTimeout, info.timeout); + return true; + } + + private _cleanupTimeout(messageId: number) { + const info = this._timeoutInfo.get(messageId); + if (info) { + clearTimeout(info.timeoutId); + this._timeoutInfo.delete(messageId); + } + } + /** * Attaches to the given transport, starts it, and starts listening for messages. * @@ -281,22 +350,27 @@ export abstract class Protocol< private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; - const handler = this._progressHandlers.get(Number(progressToken)); - if (handler === undefined) { - this._onerror( - new Error( - `Received a progress notification for an unknown token: ${JSON.stringify(notification)}`, - ), - ); + const messageId = Number(progressToken); + + const handler = this._progressHandlers.get(messageId); + if (!handler) { + this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); return; } + const responseHandler = this._responseHandlers.get(messageId); + if (this._timeoutInfo.has(messageId) && responseHandler) { + if (!this._resetTimeout(messageId, (reason) => responseHandler(reason as Error))) { + return; + } + } + handler(params); } private _onresponse(response: JSONRPCResponse | JSONRPCError): void { - const messageId = response.id; - const handler = this._responseHandlers.get(Number(messageId)); + const messageId = Number(response.id); + const handler = this._responseHandlers.get(messageId); if (handler === undefined) { this._onerror( new Error( @@ -306,8 +380,10 @@ export abstract class Protocol< return; } - this._responseHandlers.delete(Number(messageId)); - this._progressHandlers.delete(Number(messageId)); + this._responseHandlers.delete(messageId); + this._progressHandlers.delete(messageId); + this._cleanupTimeout(messageId); + if ("result" in response) { handler(response); } else { @@ -393,32 +469,10 @@ export abstract class Protocol< }; } - let timeoutId: ReturnType | undefined = undefined; - - this._responseHandlers.set(messageId, (response) => { - if (timeoutId !== undefined) { - clearTimeout(timeoutId); - } - - if (options?.signal?.aborted) { - return; - } - - if (response instanceof Error) { - return reject(response); - } - - try { - const result = resultSchema.parse(response.result); - resolve(result); - } catch (error) { - reject(error); - } - }); - const cancel = (reason: unknown) => { this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); + this._cleanupTimeout(messageId); this._transport ?.send({ @@ -436,30 +490,42 @@ export abstract class Protocol< reject(reason); }; - options?.signal?.addEventListener("abort", () => { - if (timeoutId !== undefined) { - clearTimeout(timeoutId); + this._responseHandlers.set(messageId, (response) => { + if (options?.signal?.aborted) { + return; + } + + if (response instanceof Error) { + return reject(response); + } + + try { + const result = resultSchema.parse(response.result); + resolve(result); + } catch (error) { + reject(error); } + }); + options?.signal?.addEventListener("abort", () => { cancel(options?.signal?.reason); }); const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; - timeoutId = setTimeout( - () => - cancel( - new McpError(ErrorCode.RequestTimeout, "Request timed out", { - timeout, - }), - ), - timeout, - ); + const timeoutHandler = () => cancel(new McpError( + ErrorCode.RequestTimeout, + "Request timed out", + { timeout } + )); + + if (options?.resetTimeoutOnProgress) { + this._setupTimeout(messageId, timeout, options.maxTotalTimeout, timeoutHandler); + } else { + this._setupTimeout(messageId, timeout, undefined, timeoutHandler); + } this._transport.send(jsonrpcRequest).catch((error) => { - if (timeoutId !== undefined) { - clearTimeout(timeoutId); - } - + this._cleanupTimeout(messageId); reject(error); }); }); From 28415d3917c1d31a13adeeaaeab5420602a2a8b6 Mon Sep 17 00:00:00 2001 From: Tuomas Date: Wed, 19 Feb 2025 12:23:16 -0600 Subject: [PATCH 2/3] Refactor progress notification timeout handling in protocol --- src/shared/protocol.test.ts | 276 +++++++++++++++++------------------- src/shared/protocol.ts | 24 ++-- 2 files changed, 140 insertions(+), 160 deletions(-) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index bf0357b2a..1d037b988 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -63,180 +63,160 @@ describe("protocol tests", () => { expect(oncloseMock).toHaveBeenCalled(); }); - test("should reset timeout when progress notification is received", async () => { - jest.useFakeTimers(); - - await protocol.connect(transport); - const request = { method: "example", params: {} }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), + describe("progress notification timeout behavior", () => { + beforeEach(() => { + jest.useFakeTimers(); }); - - const onProgressMock = jest.fn(); - const requestPromise = protocol.request(request, mockSchema, { - timeout: 1000, // Increased timeout for more reliable testing - resetTimeoutOnProgress: true, - onprogress: onProgressMock, + afterEach(() => { + jest.useRealTimers(); }); - // Advance time close to timeout - jest.advanceTimersByTime(800); - - // Send progress notification - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - method: "notifications/progress", - params: { - progressToken: 0, - progress: 50, - total: 100, - }, + test("should reset timeout when progress notification is received", async () => { + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), }); - } - - // Run all pending promises to ensure progress handler is called - await Promise.resolve(); - - // Verify progress handler was called - expect(onProgressMock).toHaveBeenCalledWith({ - progress: 50, - total: 100, - }); - - // Send success response - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - id: 0, - result: { result: "success" }, + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + resetTimeoutOnProgress: true, + onprogress: onProgressMock, }); - } - - // Run all pending promises - await Promise.resolve(); - - await expect(requestPromise).resolves.toEqual({ result: "success" }); - - jest.useRealTimers(); - }); - - test("should respect maxTotalTimeout", async () => { - jest.useFakeTimers(); - - await protocol.connect(transport); - const request = { method: "example", params: {} }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - - const onProgressMock = jest.fn(); - const requestPromise = protocol.request(request, mockSchema, { - timeout: 1000, - maxTotalTimeout: 100, - resetTimeoutOnProgress: true, - onprogress: onProgressMock, - }); - - // Advance time beyond maxTotalTimeout - jest.advanceTimersByTime(150); - - // Send progress notification after maxTotalTimeout - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - method: "notifications/progress", - params: { - progressToken: 0, - progress: 50, - total: 100, - }, + jest.advanceTimersByTime(800); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: 50, + total: 100, + }, + }); + } + await Promise.resolve(); + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 50, + total: 100, }); - } - - await expect(requestPromise).rejects.toThrow("Maximum total timeout exceeded"); - expect(onProgressMock).not.toHaveBeenCalled(); - - jest.useRealTimers(); - }); - - test("should timeout if no progress received within timeout period", async () => { - jest.useFakeTimers(); - - await protocol.connect(transport); - const request = { method: "example", params: {} }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - - const requestPromise = protocol.request(request, mockSchema, { - timeout: 100, - resetTimeoutOnProgress: true, - }); - - // Advance time beyond timeout - jest.advanceTimersByTime(101); - - await expect(requestPromise).rejects.toThrow("Request timed out"); - - jest.useRealTimers(); - }); - - test("should handle multiple progress notifications correctly", async () => { - jest.useFakeTimers(); - - await protocol.connect(transport); - const request = { method: "example", params: {} }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - - const onProgressMock = jest.fn(); - const requestPromise = protocol.request(request, mockSchema, { - timeout: 1000, - resetTimeoutOnProgress: true, - onprogress: onProgressMock, + jest.advanceTimersByTime(800); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + id: 0, + result: { result: "success" }, + }); + } + await Promise.resolve(); + await expect(requestPromise).resolves.toEqual({ result: "success" }); }); - // Simulate multiple progress updates - for (let i = 1; i <= 3; i++) { - // Advance close to timeout - jest.advanceTimersByTime(800); + test("should respect maxTotalTimeout", async () => { + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + maxTotalTimeout: 150, + resetTimeoutOnProgress: true, + onprogress: onProgressMock, + }); - // Send progress notification + // First progress notification should work + jest.advanceTimersByTime(80); if (transport.onmessage) { transport.onmessage({ jsonrpc: "2.0", method: "notifications/progress", params: { progressToken: 0, - progress: i * 25, + progress: 50, total: 100, }, }); } - - // Verify progress handler was called await Promise.resolve(); - expect(onProgressMock).toHaveBeenNthCalledWith(i, { - progress: i * 25, + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 50, total: 100, }); - } + jest.advanceTimersByTime(80); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: 75, + total: 100, + }, + }); + } + await expect(requestPromise).rejects.toThrow("Maximum total timeout exceeded"); + expect(onProgressMock).toHaveBeenCalledTimes(1); + }); - // Send success response - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - id: 0, - result: { result: "success" }, + test("should timeout if no progress received within timeout period", async () => { + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), }); - } + const requestPromise = protocol.request(request, mockSchema, { + timeout: 100, + resetTimeoutOnProgress: true, + }); + jest.advanceTimersByTime(101); + await expect(requestPromise).rejects.toThrow("Request timed out"); + }); - await Promise.resolve(); - await expect(requestPromise).resolves.toEqual({ result: "success" }); + test("should handle multiple progress notifications correctly", async () => { + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + resetTimeoutOnProgress: true, + onprogress: onProgressMock, + }); - jest.useRealTimers(); + // Simulate multiple progress updates + for (let i = 1; i <= 3; i++) { + jest.advanceTimersByTime(800); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: i * 25, + total: 100, + }, + }); + } + await Promise.resolve(); + expect(onProgressMock).toHaveBeenNthCalledWith(i, { + progress: i * 25, + total: 100, + }); + } + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + id: 0, + result: { result: "success" }, + }); + } + await Promise.resolve(); + await expect(requestPromise).resolves.toEqual({ result: "success" }); + }); }); }); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 7e58cbe15..a0f3d0751 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -73,8 +73,8 @@ export type RequestOptions = { resetTimeoutOnProgress?: boolean; /** - * Maximum total time (in milliseconds) to wait for a response, even if progress notifications are received. - * Only used when resetTimeoutOnProgress is true. + * Maximum total time (in milliseconds) to wait for a response. + * If exceeded, an McpError with code `RequestTimeout` will be raised, regardless of progress notifications. * If not specified, there is no maximum total timeout. */ maxTotalTimeout?: number; @@ -190,19 +190,18 @@ export abstract class Protocol< }); } - private _resetTimeout(messageId: number, cancel: (reason: unknown) => void): boolean { + private _resetTimeout(messageId: number): boolean { const info = this._timeoutInfo.get(messageId); if (!info) return false; const totalElapsed = Date.now() - info.startTime; if (info.maxTotalTimeout && totalElapsed >= info.maxTotalTimeout) { this._timeoutInfo.delete(messageId); - cancel(new McpError( + throw new McpError( ErrorCode.RequestTimeout, "Maximum total timeout exceeded", { maxTotalTimeout: info.maxTotalTimeout, totalElapsed } - )); - return false; + ); } clearTimeout(info.timeoutId); @@ -360,7 +359,12 @@ export abstract class Protocol< const responseHandler = this._responseHandlers.get(messageId); if (this._timeoutInfo.has(messageId) && responseHandler) { - if (!this._resetTimeout(messageId, (reason) => responseHandler(reason as Error))) { + try { + if (!this._resetTimeout(messageId)) { + return; + } + } catch (error) { + responseHandler(error as Error); return; } } @@ -518,11 +522,7 @@ export abstract class Protocol< { timeout } )); - if (options?.resetTimeoutOnProgress) { - this._setupTimeout(messageId, timeout, options.maxTotalTimeout, timeoutHandler); - } else { - this._setupTimeout(messageId, timeout, undefined, timeoutHandler); - } + this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler); this._transport.send(jsonrpcRequest).catch((error) => { this._cleanupTimeout(messageId); From b09240bc1082a30785f056162da8fa3a57b794f6 Mon Sep 17 00:00:00 2001 From: Tuomas Date: Thu, 20 Feb 2025 15:36:57 -0600 Subject: [PATCH 3/3] only return early from progress notification on actual timeout errors --- src/shared/protocol.ts | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index a0f3d0751..97213bf0c 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -360,9 +360,7 @@ export abstract class Protocol< const responseHandler = this._responseHandlers.get(messageId); if (this._timeoutInfo.has(messageId) && responseHandler) { try { - if (!this._resetTimeout(messageId)) { - return; - } + this._resetTimeout(messageId); } catch (error) { responseHandler(error as Error); return;