diff --git a/src/client/cross-spawn.test.ts b/src/client/cross-spawn.test.ts index ca2a5005c..4a213a96e 100644 --- a/src/client/cross-spawn.test.ts +++ b/src/client/cross-spawn.test.ts @@ -13,8 +13,9 @@ describe('StdioClientTransport using cross-spawn', () => { mockSpawn.mockImplementation(() => { const mockProcess: { on: jest.Mock; - stdin?: { on: jest.Mock; write: jest.Mock }; - stdout?: { on: jest.Mock }; + once: jest.Mock; + stdin?: { on: jest.Mock; write: jest.Mock; off: jest.Mock }; + stdout?: { on: jest.Mock; off: jest.Mock }; stderr?: null; } = { on: jest.fn((event: string, callback: () => void) => { @@ -23,12 +24,20 @@ describe('StdioClientTransport using cross-spawn', () => { } return mockProcess; }), + once: jest.fn((event: string, callback: () => void) => { + if (event === 'spawn') { + callback(); + } + return mockProcess; + }), stdin: { on: jest.fn(), - write: jest.fn().mockReturnValue(true) + write: jest.fn().mockReturnValue(true), + off: jest.fn() }, stdout: { - on: jest.fn() + on: jest.fn(), + off: jest.fn() }, stderr: null }; @@ -106,13 +115,16 @@ describe('StdioClientTransport using cross-spawn', () => { // get the mock process object const mockProcess: { on: jest.Mock; + once: jest.Mock; stdin: { on: jest.Mock; write: jest.Mock; once: jest.Mock; + off: jest.Mock; }; stdout: { on: jest.Mock; + off: jest.Mock; }; stderr: null; } = { @@ -122,13 +134,21 @@ describe('StdioClientTransport using cross-spawn', () => { } return mockProcess; }), + once: jest.fn((event: string, callback: () => void) => { + if (event === 'spawn') { + callback(); + } + return mockProcess; + }), stdin: { on: jest.fn(), write: jest.fn().mockReturnValue(true), - once: jest.fn() + once: jest.fn(), + off: jest.fn() }, stdout: { - on: jest.fn() + on: jest.fn(), + off: jest.fn() }, stderr: null }; diff --git a/src/client/stdio.ts b/src/client/stdio.ts index d62a3aeb6..df4d725a4 100644 --- a/src/client/stdio.ts +++ b/src/client/stdio.ts @@ -95,6 +95,8 @@ export class StdioClientTransport implements Transport { private _readBuffer: ReadBuffer = new ReadBuffer(); private _serverParams: StdioServerParameters; private _stderrStream: PassThrough | null = null; + private _onServerDataHandler?: (chunk: Buffer) => void; + private _onServerErrorHandler?: (error: Error) => void; onclose?: () => void; onerror?: (error: Error) => void; @@ -131,6 +133,18 @@ export class StdioClientTransport implements Transport { cwd: this._serverParams.cwd }); + this._onServerDataHandler = (chunk: Buffer) => { + this._readBuffer.append(chunk); + this.processReadBuffer(); + }; + this._onServerErrorHandler = (error: Error) => { + this.onerror?.(error); + }; + + this._process.stdout?.on('data', this._onServerDataHandler); + this._process.stdout?.on('error', this._onServerErrorHandler); + this._process.stdin?.on('error', this._onServerErrorHandler); + this._process.on('error', error => { if (error.name === 'AbortError') { // Expected when close() is called. @@ -141,29 +155,15 @@ export class StdioClientTransport implements Transport { reject(error); this.onerror?.(error); }); - - this._process.on('spawn', () => { - resolve(); - }); - - this._process.on('close', _code => { + this._process.once('spawn', () => resolve()); + this._process.once('close', _code => { + if (this._process) { + this.cleanupListeners(this._process); + } this._process = undefined; this.onclose?.(); }); - this._process.stdin?.on('error', error => { - this.onerror?.(error); - }); - - this._process.stdout?.on('data', chunk => { - this._readBuffer.append(chunk); - this.processReadBuffer(); - }); - - this._process.stdout?.on('error', error => { - this.onerror?.(error); - }); - if (this._stderrStream && this._process.stderr) { this._process.stderr.pipe(this._stderrStream); } @@ -209,7 +209,20 @@ export class StdioClientTransport implements Transport { } } + private cleanupListeners(process: ChildProcess) { + if (this._onServerDataHandler) { + process.stdout?.off('data', this._onServerDataHandler); + } + if (this._onServerErrorHandler) { + process.stdout?.off('error', this._onServerErrorHandler); + process.stdin?.off('error', this._onServerErrorHandler); + } + } + async close(): Promise { + if (this._process) { + this.cleanupListeners(this._process); + } this._abortController.abort(); this._process = undefined; this._readBuffer.clear();