diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 1d037b988..fb5ecd130 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -71,6 +71,44 @@ describe("protocol tests", () => { jest.useRealTimers(); }); + test("should not reset timeout when resetTimeoutOnProgress is false", async () => { + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + resetTimeoutOnProgress: false, + onprogress: onProgressMock, + }); + + jest.advanceTimersByTime(800); + + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: 50, + total: 100, + }, + }); + } + await Promise.resolve(); + + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 50, + total: 100, + }); + + jest.advanceTimersByTime(201); + + await expect(requestPromise).rejects.toThrow("Request timed out"); + }); + test("should reset timeout when progress notification is received", async () => { await protocol.connect(transport); const request = { method: "example", params: {} }; diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index a5b6ad51e..a6e47184b 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -103,6 +103,7 @@ type TimeoutInfo = { startTime: number; timeout: number; maxTotalTimeout?: number; + resetTimeoutOnProgress: boolean; onTimeout: () => void; }; @@ -184,13 +185,15 @@ export abstract class Protocol< messageId: number, timeout: number, maxTotalTimeout: number | undefined, - onTimeout: () => void + onTimeout: () => void, + resetTimeoutOnProgress: boolean = false ) { this._timeoutInfo.set(messageId, { timeoutId: setTimeout(onTimeout, timeout), startTime: Date.now(), timeout, maxTotalTimeout, + resetTimeoutOnProgress, onTimeout }); } @@ -369,7 +372,9 @@ export abstract class Protocol< } const responseHandler = this._responseHandlers.get(messageId); - if (this._timeoutInfo.has(messageId) && responseHandler) { + const timeoutInfo = this._timeoutInfo.get(messageId); + + if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) { try { this._resetTimeout(messageId); } catch (error) { @@ -531,7 +536,7 @@ export abstract class Protocol< { timeout } )); - this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler); + this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); this._transport.send(jsonrpcRequest).catch((error) => { this._cleanupTimeout(messageId);