From 16ba027ec86f3723705820c039f1b2d4d48df37c Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 22 Oct 2025 12:35:02 -0700 Subject: [PATCH 01/18] Add types for tasks --- src/types.ts | 144 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 135 insertions(+), 9 deletions(-) diff --git a/src/types.ts b/src/types.ts index e6d3fe46e..e3aff3831 100644 --- a/src/types.ts +++ b/src/types.ts @@ -5,6 +5,9 @@ export const LATEST_PROTOCOL_VERSION = '2025-06-18'; export const DEFAULT_NEGOTIATED_PROTOCOL_VERSION = '2025-03-26'; export const SUPPORTED_PROTOCOL_VERSIONS = [LATEST_PROTOCOL_VERSION, '2025-03-26', '2024-11-05', '2024-10-07']; +export const TASK_META_KEY = 'modelcontextprotocol.io/task'; +export const RELATED_TASK_META_KEY = 'modelcontextprotocol.io/related-task'; + /* JSON-RPC types */ export const JSONRPC_VERSION = '2.0'; @@ -18,12 +21,46 @@ export const ProgressTokenSchema = z.union([z.string(), z.number().int()]); */ export const CursorSchema = z.string(); +/** + * Task creation metadata, used to ask that the server create a task to represent a request. + */ +export const TaskRequestMetadataSchema = z + .object({ + /** + * The task ID to use as a reference to the created task. + */ + taskId: z.string(), + + /** + * Time in milliseconds to ask to keep task results available after completion. Only used with taskId. + */ + keepAlive: z.number().optional() + }) + .passthrough(); + +/** + * Task association metadata, used to signal which task a message originated from. + */ +export const RelatedTaskMetadataSchema = z + .object({ + taskId: z.string() + }) + .passthrough(); + const RequestMetaSchema = z .object({ /** * If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. */ - progressToken: z.optional(ProgressTokenSchema) + progressToken: z.optional(ProgressTokenSchema), + /** + * If specified, the caller is requesting that the receiver create a task to represent the request. + */ + [TASK_META_KEY]: z.optional(TaskRequestMetadataSchema), + /** + * If specified, this request is related to the provided task. + */ + [RELATED_TASK_META_KEY]: z.optional(RelatedTaskMetadataSchema) }) .passthrough(); @@ -44,7 +81,16 @@ const BaseNotificationParamsSchema = z * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) * for notes on _meta usage. */ - _meta: z.optional(z.object({}).passthrough()) + _meta: z.optional( + z + .object({ + /** + * If specified, this request is related to the provided task. + */ + [RELATED_TASK_META_KEY]: z.optional(RelatedTaskMetadataSchema) + }) + .passthrough() + ) }) .passthrough(); @@ -59,7 +105,16 @@ export const ResultSchema = z * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) * for notes on _meta usage. */ - _meta: z.optional(z.object({}).passthrough()) + _meta: z.optional( + z + .object({ + /** + * If specified, this request is related to the provided task. + */ + [RELATED_TASK_META_KEY]: z.optional(RelatedTaskMetadataSchema) + }) + .passthrough() + ) }) .passthrough(); @@ -440,6 +495,51 @@ export const ProgressNotificationSchema = NotificationSchema.extend({ }) }); +/* Tasks */ +/** + * A pollable state object associated with a request. + */ +export const TaskSchema = z.object({ + taskId: z.string(), + status: z.enum(['submitted', 'working', 'completed', 'failed', 'cancelled', 'unknown']), + keepAlive: z.union([z.number(), z.null()]), + pollFrequency: z.optional(z.number()), + error: z.optional(z.string()) +}); + +/** + * An out-of-band notification used to inform the receiver of a task being created. + */ +export const TaskCreatedNotificationSchema = NotificationSchema.extend({ + method: z.literal('notifications/tasks/created'), + params: BaseNotificationParamsSchema +}); + +/** + * A request to get the state of a specific task. + */ +export const GetTaskRequestSchema = RequestSchema.extend({ + method: z.literal('tasks/get'), + params: BaseRequestParamsSchema.extend({ + taskId: z.string() + }) +}); + +/** + * The response to a tasks/get request. + */ +export const GetTaskResultSchema = ResultSchema.merge(TaskSchema); + +/** + * A request to get the result of a specific task. + */ +export const GetTaskPayloadRequestSchema = RequestSchema.extend({ + method: z.literal('tasks/result'), + params: BaseRequestParamsSchema.extend({ + taskId: z.string() + }) +}); + /* Pagination */ export const PaginatedRequestSchema = RequestSchema.extend({ params: BaseRequestParamsSchema.extend({ @@ -1416,20 +1516,36 @@ export const ClientRequestSchema = z.union([ SubscribeRequestSchema, UnsubscribeRequestSchema, CallToolRequestSchema, - ListToolsRequestSchema + ListToolsRequestSchema, + GetTaskRequestSchema, + GetTaskPayloadRequestSchema ]); export const ClientNotificationSchema = z.union([ CancelledNotificationSchema, ProgressNotificationSchema, InitializedNotificationSchema, - RootsListChangedNotificationSchema + RootsListChangedNotificationSchema, + TaskCreatedNotificationSchema ]); -export const ClientResultSchema = z.union([EmptyResultSchema, CreateMessageResultSchema, ElicitResultSchema, ListRootsResultSchema]); +export const ClientResultSchema = z.union([ + EmptyResultSchema, + CreateMessageResultSchema, + ElicitResultSchema, + ListRootsResultSchema, + GetTaskResultSchema +]); /* Server messages */ -export const ServerRequestSchema = z.union([PingRequestSchema, CreateMessageRequestSchema, ElicitRequestSchema, ListRootsRequestSchema]); +export const ServerRequestSchema = z.union([ + PingRequestSchema, + CreateMessageRequestSchema, + ElicitRequestSchema, + ListRootsRequestSchema, + GetTaskRequestSchema, + GetTaskPayloadRequestSchema +]); export const ServerNotificationSchema = z.union([ CancelledNotificationSchema, @@ -1438,7 +1554,8 @@ export const ServerNotificationSchema = z.union([ ResourceUpdatedNotificationSchema, ResourceListChangedNotificationSchema, ToolListChangedNotificationSchema, - PromptListChangedNotificationSchema + PromptListChangedNotificationSchema, + TaskCreatedNotificationSchema ]); export const ServerResultSchema = z.union([ @@ -1451,7 +1568,8 @@ export const ServerResultSchema = z.union([ ListResourceTemplatesResultSchema, ReadResourceResultSchema, CallToolResultSchema, - ListToolsResultSchema + ListToolsResultSchema, + GetTaskResultSchema ]); export class McpError extends Error { @@ -1550,6 +1668,14 @@ export type PingRequest = Infer; export type Progress = Infer; export type ProgressNotification = Infer; +/* Tasks */ +export type Task = Infer; +export type TaskRequestMetadata = Infer; +export type TaskCreatedNotification = Infer; +export type GetTaskRequest = Infer; +export type GetTaskResult = Infer; +export type GetTaskPayloadRequest = Infer; + /* Pagination */ export type PaginatedRequest = Infer; export type PaginatedResult = Infer; From ecef231013498fdb903ce2d753fa5bdbcea57bbd Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 22 Oct 2025 14:49:03 -0700 Subject: [PATCH 02/18] Implement PendingRequest and basic task API --- src/shared/protocol.ts | 143 +++++++++++++++++++++++++++++++++++++---- src/shared/request.ts | 63 ++++++++++++++++++ 2 files changed, 194 insertions(+), 12 deletions(-) create mode 100644 src/shared/request.ts diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 5cb969418..e447a371c 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -3,6 +3,9 @@ import { CancelledNotificationSchema, ClientCapabilities, ErrorCode, + GetTaskRequest, + GetTaskResultSchema, + GetTaskPayloadRequest, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, @@ -17,16 +20,22 @@ import { Progress, ProgressNotification, ProgressNotificationSchema, + RELATED_TASK_META_KEY, Request, RequestId, Result, ServerCapabilities, RequestMeta, MessageExtraInfo, - RequestInfo + RequestInfo, + TaskCreatedNotificationSchema, + TASK_META_KEY, + GetTaskResult, + TaskRequestMetadata } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; +import { PendingRequest } from './request.js'; /** * Callback for progress notifications. @@ -93,6 +102,11 @@ export type RequestOptions = { * If not specified, there is no maximum total timeout. */ maxTotalTimeout?: number; + + /** + * If provided, augments the request with task metadata to enable call-now, fetch-later execution patterns. + */ + task?: TaskRequestMetadata; } & TransportSendOptions; /** @@ -108,7 +122,11 @@ export type NotificationOptions = { /** * Extra data given to request handlers. */ -export type RequestHandlerExtra = { +export type RequestHandlerExtra< + SendRequestT extends Request, + SendNotificationT extends Notification, + SendResultT extends Result = Result +> = { /** * An abort signal used to communicate if the request was cancelled from the sender's side. */ @@ -152,7 +170,7 @@ export type RequestHandlerExtra>(request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>; + sendRequest: >(request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>; }; /** @@ -176,7 +194,7 @@ export abstract class Protocol) => Promise + (request: JSONRPCRequest, extra: RequestHandlerExtra) => Promise > = new Map(); private _requestHandlerAbortControllers: Map = new Map(); private _notificationHandlers: Map Promise> = new Map(); @@ -184,6 +202,8 @@ export abstract class Protocol = new Map(); private _timeoutInfo: Map = new Map(); private _pendingDebouncedNotifications = new Set(); + private _pendingTaskCreations: Map void; reject: (reason: Error) => void }> = new Map(); + private _requestIdToTaskId: Map = new Map(); /** * Callback for when the connection is closed for any reason. @@ -202,7 +222,10 @@ export abstract class Protocol) => Promise; + fallbackRequestHandler?: ( + request: JSONRPCRequest, + extra: RequestHandlerExtra + ) => Promise; /** * A handler to invoke for any notification types that do not have their own handler installed. @@ -219,6 +242,17 @@ export abstract class Protocol { + const taskId = notification.params?._meta?.[RELATED_TASK_META_KEY]?.taskId; + if (taskId) { + const resolver = this._pendingTaskCreations.get(taskId); + if (resolver) { + resolver.resolve(); + this._pendingTaskCreations.delete(taskId); + } + } + }); + this.setRequestHandler( PingRequestSchema, // Automatic pong by default. @@ -310,10 +344,19 @@ export abstract class Protocol = { + const fullExtra: RequestHandlerExtra = { signal: abortController.signal, sessionId: capturedTransport?.sessionId, _meta: request.params?._meta, @@ -444,6 +487,17 @@ export abstract class Protocol>(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { - const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; + beginRequest>( + request: SendRequestT, + resultSchema: T, + options?: RequestOptions + ): PendingRequest { + const { relatedRequestId, resumptionToken, onresumptiontoken, task } = options ?? {}; + const { taskId, keepAlive } = task ?? {}; - return new Promise((resolve, reject) => { + const promise = new Promise>((resolve, reject) => { if (!this._transport) { reject(new Error('Not connected')); return; @@ -522,6 +581,21 @@ export abstract class Protocol { this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); @@ -575,6 +649,48 @@ export abstract class Protocol>(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { + return this.beginRequest(request, resultSchema, options).result(); + } + + /** + * Waits for a task creation notification with the given taskId. + * Returns a promise that resolves when the notifications/tasks/created notification is received, + * or rejects if the task is cleaned up (e.g., connection closed or request completed). + */ + waitForTaskCreation(taskId: string): Promise { + return new Promise((resolve, reject) => { + this._pendingTaskCreations.set(taskId, { resolve, reject }); + }); + } + + /** + * Gets the current status of a task. + */ + async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { + // @ts-expect-error SendRequestT cannot directly contain GetTaskRequest, but we ensure all type instantiations contain it anyways + return this.request({ method: 'tasks/get', params }, GetTaskResultSchema, options); + } + + /** + * Retrieves the result of a completed task. + */ + async getTaskResult>( + params: GetTaskPayloadRequest['params'], + resultSchema: T, + options?: RequestOptions + ): Promise> { + // @ts-expect-error SendRequestT cannot directly contain GetTaskPayloadRequest, but we ensure all type instantiations contain it anyways + return this.request({ method: 'tasks/result', params }, resultSchema, options); } /** @@ -644,7 +760,10 @@ export abstract class Protocol >( requestSchema: T, - handler: (request: z.infer, extra: RequestHandlerExtra) => SendResultT | Promise + handler: ( + request: z.infer, + extra: RequestHandlerExtra + ) => SendResultT | Promise ): void { const method = requestSchema.shape.method.value; this.assertRequestHandlerCapability(method); diff --git a/src/shared/request.ts b/src/shared/request.ts new file mode 100644 index 000000000..eda46ccd1 --- /dev/null +++ b/src/shared/request.ts @@ -0,0 +1,63 @@ +import { ZodType } from 'zod'; +import { Protocol } from './protocol.js'; +import { Request, Notification, Result, Task, GetTaskResult } from '../types.js'; + +const DEFAULT_POLLING_INTERNAL = 5000; + +export interface TaskHandlerOptions { + onTaskStatus: (task: Task) => Promise; +} + +export class PendingRequest { + constructor( + readonly protocol: Protocol, + readonly resultHandle: Promise, + readonly resultSchema: ZodType, + readonly taskId?: string + ) {} + + /** + * Waits for a result, calling onTaskStatus if provided and a task was created. + */ + async result(options?: Partial): Promise { + if (!options?.onTaskStatus || !this.taskId) { + // No task listener or task ID provided, just block for the result + return await this.resultHandle; + } + + // Whichever is successful first (or a failure if all fail) is returned. + return Promise.allSettled([ + this.resultHandle, + (async () => { + // Blocks for a notifications/tasks/created with the provided task ID + await this.protocol.waitForTaskCreation(this.taskId!); + return await this.taskHandler(options as TaskHandlerOptions); + })() + ]).then(([result, task]) => { + if (result.status === 'fulfilled') { + return result.value; + } else if (task.status === 'fulfilled') { + return task.value; + } + + const errors: unknown[] = [result.reason, task.reason]; + throw new Error(`Both request and task handler failed: ${errors.map(e => `${e}`).join(', ')}`); + }); + } + + /** + * Encapsulates polling for a result, calling onTaskStatus after querying the task. + */ + private async taskHandler({ onTaskStatus }: TaskHandlerOptions): Promise { + // Poll for completion + let task: GetTaskResult; + do { + task = await this.protocol.getTask({ taskId: this.taskId! }); + await onTaskStatus(task); + await new Promise(resolve => setTimeout(resolve, task.pollFrequency ?? DEFAULT_POLLING_INTERNAL)); + } while (!(['complete', 'failed', 'cancelled', 'unknown'] as (typeof task.status)[]).includes(task.status)); + + // Process result + return await this.protocol.getTaskResult({ taskId: this.taskId! }, this.resultSchema); + } +} From 41f212486061259b7ce48f0f9d07aec958554a95 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 22 Oct 2025 14:54:12 -0700 Subject: [PATCH 03/18] Implement RelatedTask metadata sends --- src/shared/protocol.ts | 64 ++++++++++++++++++++++++++++++++++++++---- src/types.ts | 1 + 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index e447a371c..3f528e454 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -31,7 +31,8 @@ import { TaskCreatedNotificationSchema, TASK_META_KEY, GetTaskResult, - TaskRequestMetadata + TaskRequestMetadata, + RelatedTaskMetadata } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; @@ -107,6 +108,11 @@ export type RequestOptions = { * If provided, augments the request with task metadata to enable call-now, fetch-later execution patterns. */ task?: TaskRequestMetadata; + + /** + * If provided, associates this request with a related task. + */ + relatedTask?: RelatedTaskMetadata; } & TransportSendOptions; /** @@ -117,6 +123,11 @@ export type NotificationOptions = { * May be used to indicate to the transport which incoming request to associate this outgoing notification with. */ relatedRequestId?: RequestId; + + /** + * If provided, associates this notification with a related task. + */ + relatedTask?: RelatedTaskMetadata; }; /** @@ -548,7 +559,7 @@ export abstract class Protocol { - const { relatedRequestId, resumptionToken, onresumptiontoken, task } = options ?? {}; + const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {}; const { taskId, keepAlive } = task ?? {}; const promise = new Promise>((resolve, reject) => { @@ -596,6 +607,17 @@ export abstract class Protocol { this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); @@ -705,8 +727,9 @@ export abstract class Protocol this._onerror(error)); @@ -741,11 +779,25 @@ export abstract class Protocol; /* Tasks */ export type Task = Infer; export type TaskRequestMetadata = Infer; +export type RelatedTaskMetadata = Infer; export type TaskCreatedNotification = Infer; export type GetTaskRequest = Infer; export type GetTaskResult = Infer; From a8fabb61e194d1d1f1d1184e05b3da1167220efe Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 22 Oct 2025 15:45:11 -0700 Subject: [PATCH 04/18] Implement task state management --- src/shared/protocol.ts | 122 +++++++++++++++++++++++++++++++++++++++-- src/shared/request.ts | 8 +-- src/shared/task.ts | 51 +++++++++++++++++ src/types.ts | 6 +- 4 files changed, 174 insertions(+), 13 deletions(-) create mode 100644 src/shared/task.ts diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 3f528e454..292d8c902 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -4,8 +4,10 @@ import { ClientCapabilities, ErrorCode, GetTaskRequest, + GetTaskRequestSchema, GetTaskResultSchema, GetTaskPayloadRequest, + GetTaskPayloadRequestSchema, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, @@ -31,12 +33,13 @@ import { TaskCreatedNotificationSchema, TASK_META_KEY, GetTaskResult, - TaskRequestMetadata, + TaskMetadata, RelatedTaskMetadata } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; import { PendingRequest } from './request.js'; +import { TaskStore } from './task.js'; /** * Callback for progress notifications. @@ -62,6 +65,11 @@ export type ProtocolOptions = { * e.g., ['notifications/tools/list_changed'] */ debouncedNotificationMethods?: string[]; + /** + * Optional task storage implementation. If provided, the implementation will automatically + * handle task creation, status tracking, and result storage. + */ + taskStore?: TaskStore; }; /** @@ -107,7 +115,7 @@ export type RequestOptions = { /** * If provided, augments the request with task metadata to enable call-now, fetch-later execution patterns. */ - task?: TaskRequestMetadata; + task?: TaskMetadata; /** * If provided, associates this request with a related task. @@ -215,6 +223,7 @@ export abstract class Protocol(); private _pendingTaskCreations: Map void; reject: (reason: Error) => void }> = new Map(); private _requestIdToTaskId: Map = new Map(); + private _taskStore?: TaskStore; /** * Callback for when the connection is closed for any reason. @@ -245,8 +254,7 @@ export abstract class Protocol { - const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); - controller?.abort(notification.params.reason); + this._oncancel(notification); }); this.setNotificationHandler(ProgressNotificationSchema, notification => { @@ -269,6 +277,65 @@ export abstract class Protocol ({}) as SendResultT ); + + // Install task handlers if TaskStore is provided + this._taskStore = _options?.taskStore; + if (this._taskStore) { + this.setRequestHandler(GetTaskRequestSchema, async request => { + const task = await this._taskStore!.getTask(request.params.taskId); + if (!task) { + throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + } + + // @ts-expect-error SendResultT cannot contain GetTaskResult, but we include it in our derived types everywhere else + return { + ...task, + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: request.params.taskId + } + } + } as SendResultT; + }); + + this.setRequestHandler(GetTaskPayloadRequestSchema, async request => { + const task = await this._taskStore!.getTask(request.params.taskId); + if (!task) { + throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + } + + if (task.status !== 'completed') { + throw new McpError(ErrorCode.InvalidParams, `Cannot retrieve result: Task status is '${task.status}', not 'completed'`); + } + + const result = await this._taskStore!.getTaskResult(request.params.taskId); + return { + ...result, + _meta: { + ...result._meta, + [RELATED_TASK_META_KEY]: { + taskId: request.params.taskId + } + } + } as SendResultT; + }); + } + } + + private async _oncancel(notification: z.infer): Promise { + // Handle request cancellation + const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); + controller?.abort(notification.params.reason); + + // If this request had a task, mark it as cancelled in storage + const taskId = this._requestIdToTaskId.get(Number(notification.params.requestId)); + if (taskId && this._taskStore) { + try { + await this._taskStore.updateTaskStatus(taskId, 'cancelled'); + } catch (error) { + this._onerror(new Error(`Failed to cancel task ${taskId}: ${error}`)); + } + } } private _setupTimeout( @@ -429,16 +496,59 @@ export abstract class Protocol handler(request, fullExtra)) .then( - result => { + async result => { if (abortController.signal.aborted) { return; } - return capturedTransport?.send({ + // If this request asked for task creation, create the task and send notification + const taskMetadata = request.params?._meta?.[TASK_META_KEY]; + if (taskMetadata && this._taskStore) { + const task = await this._taskStore!.getTask(taskMetadata.taskId); + if (task) { + throw new McpError(ErrorCode.InvalidParams, `Task ID already exists: ${taskMetadata.taskId}`); + } + + try { + await this._taskStore.createTask(taskMetadata, request.id, { + method: request.method, + params: request.params + }); + + // Send task created notification + await this.notification( + { + method: 'notifications/tasks/created', + params: { + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: taskMetadata.taskId + } + } + } + } as SendNotificationT, + { relatedRequestId: request.id } + ); + } catch (error) { + this._onerror(new Error(`Failed to create task: ${error}`)); + } + } + + // Send the response + await capturedTransport?.send({ result, jsonrpc: '2.0', id: request.id }); + + // Store the result if this was a task-based request + if (taskMetadata && this._taskStore) { + try { + await this._taskStore.storeTaskResult(taskMetadata.taskId, result); + } catch (error) { + this._onerror(new Error(`Failed to store task result: ${error}`)); + } + } }, error => { if (abortController.signal.aborted) { diff --git a/src/shared/request.ts b/src/shared/request.ts index eda46ccd1..26186f2d9 100644 --- a/src/shared/request.ts +++ b/src/shared/request.ts @@ -31,7 +31,7 @@ export class PendingRequest { // Blocks for a notifications/tasks/created with the provided task ID await this.protocol.waitForTaskCreation(this.taskId!); - return await this.taskHandler(options as TaskHandlerOptions); + return await this.taskHandler(this.taskId!, options as TaskHandlerOptions); })() ]).then(([result, task]) => { if (result.status === 'fulfilled') { @@ -48,16 +48,16 @@ export class PendingRequest { + private async taskHandler(taskId: string, { onTaskStatus }: TaskHandlerOptions): Promise { // Poll for completion let task: GetTaskResult; do { - task = await this.protocol.getTask({ taskId: this.taskId! }); + task = await this.protocol.getTask({ taskId: taskId }); await onTaskStatus(task); await new Promise(resolve => setTimeout(resolve, task.pollFrequency ?? DEFAULT_POLLING_INTERNAL)); } while (!(['complete', 'failed', 'cancelled', 'unknown'] as (typeof task.status)[]).includes(task.status)); // Process result - return await this.protocol.getTaskResult({ taskId: this.taskId! }, this.resultSchema); + return await this.protocol.getTaskResult({ taskId: taskId }, this.resultSchema); } } diff --git a/src/shared/task.ts b/src/shared/task.ts new file mode 100644 index 000000000..0a4b52560 --- /dev/null +++ b/src/shared/task.ts @@ -0,0 +1,51 @@ +import { Task, TaskMetadata, Request, RequestId, Result } from '../types.js'; + +/** + * Interface for storing and retrieving task state and results. + * + * Similar to Transport, this allows pluggable task storage implementations + * (in-memory, database, distributed cache, etc.). + */ +export interface TaskStore { + /** + * Creates a new task with the given metadata and original request. + * + * @param task - The task creation metadata from the request + * @param requestId - The JSON-RPC request ID + * @param request - The original request that triggered task creation + */ + createTask(task: TaskMetadata, requestId: RequestId, request: Request): Promise; + + /** + * Gets the current status of a task. + * + * @param taskId - The task identifier + * @returns The task state including status, keepAlive, pollFrequency, and optional error + */ + getTask(taskId: string): Promise; + + /** + * Stores the result of a completed task. + * + * @param taskId - The task identifier + * @param result - The result to store + */ + storeTaskResult(taskId: string, result: Result): Promise; + + /** + * Retrieves the stored result of a task. + * + * @param taskId - The task identifier + * @returns The stored result + */ + getTaskResult(taskId: string): Promise; + + /** + * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). + * + * @param taskId - The task identifier + * @param status - The new status + * @param error - Optional error message if status is 'failed' or 'cancelled' + */ + updateTaskStatus(taskId: string, status: Task['status'], error?: string): Promise; +} diff --git a/src/types.ts b/src/types.ts index 419310f63..23caa5f41 100644 --- a/src/types.ts +++ b/src/types.ts @@ -24,7 +24,7 @@ export const CursorSchema = z.string(); /** * Task creation metadata, used to ask that the server create a task to represent a request. */ -export const TaskRequestMetadataSchema = z +export const TaskMetadataSchema = z .object({ /** * The task ID to use as a reference to the created task. @@ -56,7 +56,7 @@ const RequestMetaSchema = z /** * If specified, the caller is requesting that the receiver create a task to represent the request. */ - [TASK_META_KEY]: z.optional(TaskRequestMetadataSchema), + [TASK_META_KEY]: z.optional(TaskMetadataSchema), /** * If specified, this request is related to the provided task. */ @@ -1670,7 +1670,7 @@ export type ProgressNotification = Infer; /* Tasks */ export type Task = Infer; -export type TaskRequestMetadata = Infer; +export type TaskMetadata = Infer; export type RelatedTaskMetadata = Infer; export type TaskCreatedNotification = Infer; export type GetTaskRequest = Infer; From b3420b3725d6d6db98d760162ddb1c9dd941b27f Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 22 Oct 2025 15:59:52 -0700 Subject: [PATCH 05/18] Attach related task metadata to request handler --- src/shared/protocol.ts | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 292d8c902..8ee3d493b 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -481,12 +481,19 @@ export abstract class Protocol = { signal: abortController.signal, sessionId: capturedTransport?.sessionId, _meta: request.params?._meta, - sendNotification: notification => this.notification(notification, { relatedRequestId: request.id }), - sendRequest: (r, resultSchema, options?) => this.request(r, resultSchema, { ...options, relatedRequestId: request.id }), + sendNotification: async notification => { + const relatedTask = taskMetadata ? { taskId: taskMetadata.taskId } : undefined; + await this.notification(notification, { relatedRequestId: request.id, relatedTask }); + }, + sendRequest: async (r, resultSchema, options?) => { + const relatedTask = taskMetadata ? { taskId: taskMetadata.taskId } : undefined; + return await this.request(r, resultSchema, { ...options, relatedRequestId: request.id, relatedTask }); + }, authInfo: extra?.authInfo, requestId: request.id, requestInfo: extra?.requestInfo @@ -502,7 +509,6 @@ export abstract class Protocol Date: Thu, 23 Oct 2025 12:26:50 -0700 Subject: [PATCH 06/18] Create task before calling handler --- src/client/index.ts | 19 +++++++++ src/shared/protocol.ts | 90 ++++++++++++++++++++++++++---------------- src/shared/request.ts | 35 ++++++++++------ src/shared/task.ts | 11 ++++++ 4 files changed, 108 insertions(+), 47 deletions(-) diff --git a/src/client/index.ts b/src/client/index.ts index 856eb18e5..b000088e6 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,5 +1,6 @@ import { mergeCapabilities, Protocol, ProtocolOptions, RequestOptions } from '../shared/protocol.js'; import { Transport } from '../shared/transport.js'; +import { PendingRequest } from '../shared/request.js'; import { CallToolRequest, CallToolResultSchema, @@ -326,6 +327,24 @@ export class Client< return this.request({ method: 'resources/unsubscribe', params }, EmptyResultSchema, options); } + /** + * Begins a tool call and returns a PendingRequest for granular control over task-based execution. + * + * This is useful when you want to create a task for a long-running tool call and poll for results later. + */ + beginCallTool( + params: CallToolRequest['params'], + resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, + options?: RequestOptions + ): PendingRequest { + return this.beginRequest({ method: 'tools/call', params }, resultSchema, options); + } + + /** + * Calls a tool and waits for the result. Automatically validates structured output if the tool has an outputSchema. + * + * For task-based execution with granular control, use beginCallTool() instead. + */ async callTool( params: CallToolRequest['params'], resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 8ee3d493b..6bbdbf84d 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -501,44 +501,59 @@ export abstract class Protocol handler(request, fullExtra)) - .then( - async result => { - if (abortController.signal.aborted) { - return; + .then(async () => { + // If this request asked for task creation, create the task and send notification + if (taskMetadata && this._taskStore) { + const task = await this._taskStore!.getTask(taskMetadata.taskId); + if (task) { + throw new McpError(ErrorCode.InvalidParams, `Task ID already exists: ${taskMetadata.taskId}`); } - // If this request asked for task creation, create the task and send notification - if (taskMetadata && this._taskStore) { - const task = await this._taskStore!.getTask(taskMetadata.taskId); - if (task) { - throw new McpError(ErrorCode.InvalidParams, `Task ID already exists: ${taskMetadata.taskId}`); - } - - try { - await this._taskStore.createTask(taskMetadata, request.id, { - method: request.method, - params: request.params - }); - - // Send task created notification - await this.notification( - { - method: 'notifications/tasks/created', - params: { - _meta: { - [RELATED_TASK_META_KEY]: { - taskId: taskMetadata.taskId - } + try { + await this._taskStore.createTask(taskMetadata, request.id, { + method: request.method, + params: request.params + }); + + // Send task created notification + await this.notification( + { + method: 'notifications/tasks/created', + params: { + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: taskMetadata.taskId } } - } as SendNotificationT, - { relatedRequestId: request.id } - ); + } + } as SendNotificationT, + { relatedRequestId: request.id } + ); + } catch (error) { + throw new McpError(ErrorCode.InternalError, `Failed to create task: ${taskMetadata.taskId}`); + } + } + }) + .then(async () => { + // If this request had a task, mark it as working + if (taskMetadata && this._taskStore) { + try { + await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'working'); + } catch (error) { + try { + await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'failed', 'Failed to mark task as working'); } catch (error) { - this._onerror(new Error(`Failed to create task: ${error}`)); + throw new McpError(ErrorCode.InternalError, `Failed to mark task as working: ${error}`); } } + } + }) + .then(() => handler(request, fullExtra)) + .then( + async result => { + if (abortController.signal.aborted) { + return; + } // Send the response await capturedTransport?.send({ @@ -552,7 +567,7 @@ export abstract class Protocol>((resolve, reject) => { + // For tasks, create an advance promise for the creation notification to avoid + // race conditions with installing this callback. + const taskCreated = taskId ? this.waitForTaskCreation(taskId) : Promise.resolve(); + + // Send the request + const result = new Promise>((resolve, reject) => { if (!this._transport) { reject(new Error('Not connected')); return; @@ -788,7 +808,7 @@ export abstract class Protocol { + private waitForTaskCreation(taskId: string): Promise { return new Promise((resolve, reject) => { this._pendingTaskCreations.set(taskId, { resolve, reject }); }); diff --git a/src/shared/request.ts b/src/shared/request.ts index 26186f2d9..fae4f1332 100644 --- a/src/shared/request.ts +++ b/src/shared/request.ts @@ -1,16 +1,21 @@ import { ZodType } from 'zod'; import { Protocol } from './protocol.js'; import { Request, Notification, Result, Task, GetTaskResult } from '../types.js'; +import { isTerminal } from './task.js'; const DEFAULT_POLLING_INTERNAL = 5000; +const DEFAULT_HANDLER = () => Promise.resolve(); + export interface TaskHandlerOptions { - onTaskStatus: (task: Task) => Promise; + onTaskCreated: () => Promise | void; + onTaskStatus: (task: Task) => Promise | void; } export class PendingRequest { constructor( readonly protocol: Protocol, + readonly taskCreatedHandle: Promise, readonly resultHandle: Promise, readonly resultSchema: ZodType, readonly taskId?: string @@ -20,24 +25,30 @@ export class PendingRequest): Promise { - if (!options?.onTaskStatus || !this.taskId) { - // No task listener or task ID provided, just block for the result + const { onTaskCreated = DEFAULT_HANDLER, onTaskStatus = DEFAULT_HANDLER } = options ?? {}; + + if (!this.taskId) { + // No task ID provided, just block for the result return await this.resultHandle; } // Whichever is successful first (or a failure if all fail) is returned. return Promise.allSettled([ - this.resultHandle, (async () => { // Blocks for a notifications/tasks/created with the provided task ID - await this.protocol.waitForTaskCreation(this.taskId!); - return await this.taskHandler(this.taskId!, options as TaskHandlerOptions); - })() - ]).then(([result, task]) => { - if (result.status === 'fulfilled') { - return result.value; - } else if (task.status === 'fulfilled') { + await this.taskCreatedHandle; + await onTaskCreated(); + return await this.taskHandler(this.taskId!, { + onTaskCreated, + onTaskStatus + }); + })(), + this.resultHandle + ]).then(([task, result]) => { + if (task.status === 'fulfilled') { return task.value; + } else if (result.status === 'fulfilled') { + return result.value; } const errors: unknown[] = [result.reason, task.reason]; @@ -55,7 +66,7 @@ export class PendingRequest setTimeout(resolve, task.pollFrequency ?? DEFAULT_POLLING_INTERNAL)); - } while (!(['complete', 'failed', 'cancelled', 'unknown'] as (typeof task.status)[]).includes(task.status)); + } while (!isTerminal(task.status)); // Process result return await this.protocol.getTaskResult({ taskId: taskId }, this.resultSchema); diff --git a/src/shared/task.ts b/src/shared/task.ts index 0a4b52560..617ab81aa 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -49,3 +49,14 @@ export interface TaskStore { */ updateTaskStatus(taskId: string, status: Task['status'], error?: string): Promise; } + +/** + * Checks if a task status represents a terminal state. + * Terminal states are those where the task has finished and will not change. + * + * @param status - The task status to check + * @returns True if the status is terminal (completed, failed, cancelled, or unknown) + */ +export function isTerminal(status: Task['status']): boolean { + return status === 'completed' || status === 'failed' || status === 'cancelled' || status === 'unknown'; +} From fcd2882303df6a332b1e4cbaef784db250875f30 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 23 Oct 2025 12:27:13 -0700 Subject: [PATCH 07/18] Create task example --- src/examples/client/simpleStreamableHttp.ts | 75 +++++++++++ src/examples/server/simpleStreamableHttp.ts | 32 ++++- src/examples/shared/inMemoryTaskStore.ts | 142 ++++++++++++++++++++ 3 files changed, 248 insertions(+), 1 deletion(-) create mode 100644 src/examples/shared/inMemoryTaskStore.ts diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index 10f6afcbe..697353ef4 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -58,6 +58,7 @@ function printHelp(): void { console.log(' reconnect - Reconnect to the server'); console.log(' list-tools - List available tools'); console.log(' call-tool [args] - Call a tool with optional JSON arguments'); + console.log(' call-tool-task [args] - Call a tool with task-based execution (example: call-tool-task delay {"duration":3000})'); console.log(' greet [name] - Call the greet tool'); console.log(' multi-greet [name] - Call the multi-greet tool with notifications'); console.log(' collect-info [type] - Test elicitation with collect-user-info tool (contact/preferences/feedback)'); @@ -141,6 +142,23 @@ function commandLoop(): void { break; } + case 'call-tool-task': + if (args.length < 2) { + console.log('Usage: call-tool-task [args]'); + } else { + const toolName = args[1]; + let toolArgs = {}; + if (args.length > 2) { + try { + toolArgs = JSON.parse(args.slice(2).join(' ')); + } catch { + console.log('Invalid JSON arguments. Using empty args.'); + } + } + await callToolTask(toolName, toolArgs); + } + break; + case 'list-prompts': await listPrompts(); break; @@ -777,6 +795,63 @@ async function readResource(uri: string): Promise { } } +async function callToolTask(name: string, args: Record): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + console.log(`Calling tool '${name}' with task-based execution...`); + console.log('Arguments:', args); + + // Use task-based execution - call now, fetch later + const taskId = `task-${Date.now()}`; + console.log(`Task ID: ${taskId}`); + console.log('This will return immediately while processing continues in the background...'); + + try { + // Begin the tool call with task metadata + const pendingRequest = client.beginCallTool( + { + name, + arguments: args + }, + CallToolResultSchema, + { + task: { + taskId, + keepAlive: 60000 // Keep results for 60 seconds + } + } + ); + + console.log('Waiting for task completion...'); + + await pendingRequest.result({ + onTaskCreated: () => { + console.log('Task created successfully'); + }, + onTaskStatus: task => { + console.log(` ${task.status}${task.error ? ` - ${task.error}` : ''}`); + } + }); + + console.log('Task completed! Fetching result...'); + + // Get the actual result + const result = await client.getTaskResult({ taskId }, CallToolResultSchema); + + console.log('Tool result:'); + result.content.forEach(item => { + if (item.type === 'text') { + console.log(` ${item.text}`); + } + }); + } catch (error) { + console.log(`Error with task-based execution: ${error}`); + } +} + async function cleanup(): Promise { if (client && transport) { try { diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 5872cb4ac..966337f45 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -14,6 +14,7 @@ import { ResourceLink } from '../../types.js'; import { InMemoryEventStore } from '../shared/inMemoryEventStore.js'; +import { InMemoryTaskStore } from '../shared/inMemoryTaskStore.js'; import { setupAuthServer } from './demoInMemoryOAuthProvider.js'; import { OAuthMetadata } from 'src/shared/auth.js'; import { checkResourceAllowed } from 'src/shared/auth-utils.js'; @@ -24,6 +25,9 @@ import cors from 'cors'; const useOAuth = process.argv.includes('--oauth'); const strictOAuth = process.argv.includes('--oauth-strict'); +// Create shared task store for demonstration +const taskStore = new InMemoryTaskStore(); + // Create an MCP server with implementation details const getServer = () => { const server = new McpServer( @@ -33,7 +37,10 @@ const getServer = () => { icons: [{ src: './mcp.svg', sizes: ['512x512'], mimeType: 'image/svg+xml' }], websiteUrl: 'https://github.com/modelcontextprotocol/typescript-sdk' }, - { capabilities: { logging: {} } } + { + capabilities: { logging: {} }, + taskStore // Enable task support + } ); // Register a simple tool that returns a greeting @@ -439,6 +446,29 @@ const getServer = () => { } ); + // Register a long-running tool that demonstrates task execution + server.registerTool( + 'delay', + { + title: 'Delay', + description: 'A simple tool that delays for a specified duration, useful for testing task execution', + inputSchema: { + duration: z.number().describe('Duration in milliseconds').default(5000) + } + }, + async ({ duration }): Promise => { + await new Promise(resolve => setTimeout(resolve, duration)); + return { + content: [ + { + type: 'text', + text: `Completed ${duration}ms delay` + } + ] + }; + } + ); + return server; }; diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts new file mode 100644 index 000000000..79d8a05bd --- /dev/null +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -0,0 +1,142 @@ +import { Task, TaskMetadata, Request, RequestId, Result } from '../../types.js'; +import { TaskStore, isTerminal } from '../../shared/task.js'; + +interface StoredTask { + task: Task; + request: Request; + requestId: RequestId; + result?: Result; +} + +/** + * A simple in-memory implementation of TaskStore for demonstration purposes. + * + * This implementation stores all tasks in memory and provides automatic cleanup + * based on the keepAlive duration specified in the task metadata. + * + * Note: This is not suitable for production use as all data is lost on restart. + * For production, consider implementing TaskStore with a database or distributed cache. + */ +export class InMemoryTaskStore implements TaskStore { + private tasks = new Map(); + private cleanupTimers = new Map>(); + + async createTask(metadata: TaskMetadata, requestId: RequestId, request: Request): Promise { + const taskId = metadata.taskId; + + if (this.tasks.has(taskId)) { + throw new Error(`Task with ID ${taskId} already exists`); + } + + const task: Task = { + taskId, + status: 'submitted', + keepAlive: metadata.keepAlive ?? null, + pollFrequency: 500 + }; + + this.tasks.set(taskId, { + task, + request, + requestId + }); + + // Schedule cleanup if keepAlive is specified + if (metadata.keepAlive) { + const timer = setTimeout(() => { + this.tasks.delete(taskId); + this.cleanupTimers.delete(taskId); + }, metadata.keepAlive); + + this.cleanupTimers.set(taskId, timer); + } + } + + async getTask(taskId: string): Promise { + const stored = this.tasks.get(taskId); + return stored ? { ...stored.task } : null; + } + + async storeTaskResult(taskId: string, result: Result): Promise { + const stored = this.tasks.get(taskId); + if (!stored) { + throw new Error(`Task with ID ${taskId} not found`); + } + + stored.result = result; + stored.task.status = 'completed'; + + // Reset cleanup timer to start from now (if keepAlive is set) + if (stored.task.keepAlive) { + const existingTimer = this.cleanupTimers.get(taskId); + if (existingTimer) { + clearTimeout(existingTimer); + } + + const timer = setTimeout(() => { + this.tasks.delete(taskId); + this.cleanupTimers.delete(taskId); + }, stored.task.keepAlive); + + this.cleanupTimers.set(taskId, timer); + } + } + + async getTaskResult(taskId: string): Promise { + const stored = this.tasks.get(taskId); + if (!stored) { + throw new Error(`Task with ID ${taskId} not found`); + } + + if (!stored.result) { + throw new Error(`Task ${taskId} has no result stored`); + } + + return stored.result; + } + + async updateTaskStatus(taskId: string, status: Task['status'], error?: string): Promise { + const stored = this.tasks.get(taskId); + if (!stored) { + throw new Error(`Task with ID ${taskId} not found`); + } + + stored.task.status = status; + if (error) { + stored.task.error = error; + } + + // If task is in a terminal state and has keepAlive, start cleanup timer + if (isTerminal(status) && stored.task.keepAlive) { + const existingTimer = this.cleanupTimers.get(taskId); + if (existingTimer) { + clearTimeout(existingTimer); + } + + const timer = setTimeout(() => { + this.tasks.delete(taskId); + this.cleanupTimers.delete(taskId); + }, stored.task.keepAlive); + + this.cleanupTimers.set(taskId, timer); + } + } + + /** + * Cleanup all timers (useful for testing or graceful shutdown) + */ + cleanup(): void { + for (const timer of this.cleanupTimers.values()) { + clearTimeout(timer); + } + this.cleanupTimers.clear(); + this.tasks.clear(); + } + + /** + * Get all tasks (useful for debugging) + */ + getAllTasks(): Task[] { + return Array.from(this.tasks.values()).map(stored => ({ ...stored.task })); + } +} From c73b10567ca97a8ddc16f9c23dd0b303a33c8ec9 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 23 Oct 2025 12:57:40 -0700 Subject: [PATCH 08/18] Implement input_required status for tasks --- src/examples/client/simpleStreamableHttp.ts | 10 ++++++++-- src/examples/server/simpleStreamableHttp.ts | 19 +++++++++++++------ src/shared/protocol.ts | 11 ++++++++++- src/types.ts | 2 +- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index 697353ef4..0b84cdfa1 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -17,7 +17,8 @@ import { ElicitRequestSchema, ResourceLink, ReadResourceRequest, - ReadResourceResultSchema + ReadResourceResultSchema, + RELATED_TASK_META_KEY } from '../../types.js'; import { getDisplayName } from '../../shared/metadataUtils.js'; import Ajv from 'ajv'; @@ -249,6 +250,7 @@ async function connect(url?: string): Promise { client.setRequestHandler(ElicitRequestSchema, async request => { console.log('\n🔔 Elicitation Request Received:'); console.log(`Message: ${request.params.message}`); + console.log(`Related Task: ${request.params._meta?.[RELATED_TASK_META_KEY]?.taskId}`); console.log('Requested Schema:'); console.log(JSON.stringify(request.params.requestedSchema, null, 2)); @@ -827,12 +829,16 @@ async function callToolTask(name: string, args: Record): Promis console.log('Waiting for task completion...'); + let lastStatus = ''; await pendingRequest.result({ onTaskCreated: () => { console.log('Task created successfully'); }, onTaskStatus: task => { - console.log(` ${task.status}${task.error ? ` - ${task.error}` : ''}`); + if (lastStatus !== task.status) { + console.log(` ${task.status}${task.error ? ` - ${task.error}` : ''}`); + } + lastStatus = task.status; } }); diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 966337f45..ec73a4f02 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -7,6 +7,7 @@ import { getOAuthProtectedResourceMetadataUrl, mcpAuthMetadataRouter } from '../ import { requireBearerAuth } from '../../server/auth/middleware/bearerAuth.js'; import { CallToolResult, + ElicitResultSchema, GetPromptResult, isInitializeRequest, PrimitiveSchemaDefinition, @@ -126,7 +127,7 @@ const getServer = () => { { infoType: z.enum(['contact', 'preferences', 'feedback']).describe('Type of information to collect') }, - async ({ infoType }): Promise => { + async ({ infoType }, extra): Promise => { let message: string; let requestedSchema: { type: 'object'; @@ -221,11 +222,17 @@ const getServer = () => { } try { - // Use the underlying server instance to elicit input from the client - const result = await server.server.elicitInput({ - message, - requestedSchema - }); + // Elicit input from the client + const result = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + message, + requestedSchema + } + }, + ElicitResultSchema + ); if (result.action === 'accept') { return { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 6bbdbf84d..382440fe6 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -492,7 +492,16 @@ export abstract class Protocol { const relatedTask = taskMetadata ? { taskId: taskMetadata.taskId } : undefined; - return await this.request(r, resultSchema, { ...options, relatedRequestId: request.id, relatedTask }); + if (taskMetadata && this._taskStore) { + await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'input_required'); + } + try { + return await this.request(r, resultSchema, { ...options, relatedRequestId: request.id, relatedTask }); + } finally { + if (taskMetadata && this._taskStore) { + await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'working'); + } + } }, authInfo: extra?.authInfo, requestId: request.id, diff --git a/src/types.ts b/src/types.ts index 23caa5f41..24ea5881d 100644 --- a/src/types.ts +++ b/src/types.ts @@ -501,7 +501,7 @@ export const ProgressNotificationSchema = NotificationSchema.extend({ */ export const TaskSchema = z.object({ taskId: z.string(), - status: z.enum(['submitted', 'working', 'completed', 'failed', 'cancelled', 'unknown']), + status: z.enum(['submitted', 'working', 'input_required', 'completed', 'failed', 'cancelled', 'unknown']), keepAlive: z.union([z.number(), z.null()]), pollFrequency: z.optional(z.number()), error: z.optional(z.string()) From b028061b83b060ee0545f92f4e6a7206a356aa25 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 23 Oct 2025 14:38:04 -0700 Subject: [PATCH 09/18] Implement unit tests for task support --- src/examples/shared/inMemoryTaskStore.test.ts | 374 ++++++++++++++ src/server/index.test.ts | 259 ++++++++++ src/shared/protocol.test.ts | 469 ++++++++++++++++++ src/shared/protocol.ts | 40 +- 4 files changed, 1120 insertions(+), 22 deletions(-) create mode 100644 src/examples/shared/inMemoryTaskStore.test.ts diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts new file mode 100644 index 000000000..2e4020a7f --- /dev/null +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -0,0 +1,374 @@ +import { InMemoryTaskStore } from './inMemoryTaskStore.js'; +import { TaskMetadata, Request } from '../../types.js'; + +describe('InMemoryTaskStore', () => { + let store: InMemoryTaskStore; + + beforeEach(() => { + store = new InMemoryTaskStore(); + }); + + afterEach(() => { + store.cleanup(); + }); + + describe('createTask', () => { + it('should create a new task with submitted status', async () => { + const metadata: TaskMetadata = { + taskId: 'task-1', + keepAlive: 60000 + }; + const request: Request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + await store.createTask(metadata, 123, request); + + const task = await store.getTask('task-1'); + expect(task).toBeDefined(); + expect(task?.taskId).toBe('task-1'); + expect(task?.status).toBe('submitted'); + expect(task?.keepAlive).toBe(60000); + expect(task?.pollFrequency).toBe(500); + }); + + it('should create task without keepAlive', async () => { + const metadata: TaskMetadata = { + taskId: 'task-no-keepalive' + }; + const request: Request = { + method: 'tools/call', + params: {} + }; + + await store.createTask(metadata, 456, request); + + const task = await store.getTask('task-no-keepalive'); + expect(task).toBeDefined(); + expect(task?.keepAlive).toBeNull(); + }); + + it('should reject duplicate taskId', async () => { + const metadata: TaskMetadata = { + taskId: 'duplicate-task' + }; + const request: Request = { + method: 'tools/call', + params: {} + }; + + await store.createTask(metadata, 789, request); + + await expect(store.createTask(metadata, 790, request)).rejects.toThrow('Task with ID duplicate-task already exists'); + }); + }); + + describe('getTask', () => { + it('should return null for non-existent task', async () => { + const task = await store.getTask('non-existent'); + expect(task).toBeNull(); + }); + + it('should return task state', async () => { + const metadata: TaskMetadata = { + taskId: 'get-test' + }; + const request: Request = { + method: 'tools/call', + params: {} + }; + + await store.createTask(metadata, 111, request); + await store.updateTaskStatus('get-test', 'working'); + + const task = await store.getTask('get-test'); + expect(task).toBeDefined(); + expect(task?.status).toBe('working'); + }); + }); + + describe('updateTaskStatus', () => { + beforeEach(async () => { + const metadata: TaskMetadata = { + taskId: 'status-test' + }; + await store.createTask(metadata, 222, { + method: 'tools/call', + params: {} + }); + }); + + it('should update task status from submitted to working', async () => { + await store.updateTaskStatus('status-test', 'working'); + + const task = await store.getTask('status-test'); + expect(task?.status).toBe('working'); + }); + + it('should update task status to input_required', async () => { + await store.updateTaskStatus('status-test', 'input_required'); + + const task = await store.getTask('status-test'); + expect(task?.status).toBe('input_required'); + }); + + it('should update task status to completed', async () => { + await store.updateTaskStatus('status-test', 'completed'); + + const task = await store.getTask('status-test'); + expect(task?.status).toBe('completed'); + }); + + it('should update task status to failed with error', async () => { + await store.updateTaskStatus('status-test', 'failed', 'Something went wrong'); + + const task = await store.getTask('status-test'); + expect(task?.status).toBe('failed'); + expect(task?.error).toBe('Something went wrong'); + }); + + it('should update task status to cancelled', async () => { + await store.updateTaskStatus('status-test', 'cancelled'); + + const task = await store.getTask('status-test'); + expect(task?.status).toBe('cancelled'); + }); + + it('should throw if task not found', async () => { + await expect(store.updateTaskStatus('non-existent', 'working')).rejects.toThrow('Task with ID non-existent not found'); + }); + }); + + describe('storeTaskResult', () => { + beforeEach(async () => { + const metadata: TaskMetadata = { + taskId: 'result-test', + keepAlive: 60000 + }; + await store.createTask(metadata, 333, { + method: 'tools/call', + params: {} + }); + }); + + it('should store task result and set status to completed', async () => { + const result = { + content: [{ type: 'text' as const, text: 'Success!' }] + }; + + await store.storeTaskResult('result-test', result); + + const task = await store.getTask('result-test'); + expect(task?.status).toBe('completed'); + + const storedResult = await store.getTaskResult('result-test'); + expect(storedResult).toEqual(result); + }); + + it('should throw if task not found', async () => { + await expect(store.storeTaskResult('non-existent', {})).rejects.toThrow('Task with ID non-existent not found'); + }); + }); + + describe('getTaskResult', () => { + it('should throw if task not found', async () => { + await expect(store.getTaskResult('non-existent')).rejects.toThrow('Task with ID non-existent not found'); + }); + + it('should throw if task has no result stored', async () => { + const metadata: TaskMetadata = { + taskId: 'no-result' + }; + await store.createTask(metadata, 444, { + method: 'tools/call', + params: {} + }); + + await expect(store.getTaskResult('no-result')).rejects.toThrow('Task no-result has no result stored'); + }); + + it('should return stored result', async () => { + const metadata: TaskMetadata = { + taskId: 'with-result' + }; + await store.createTask(metadata, 555, { + method: 'tools/call', + params: {} + }); + + const result = { + content: [{ type: 'text' as const, text: 'Result data' }] + }; + await store.storeTaskResult('with-result', result); + + const retrieved = await store.getTaskResult('with-result'); + expect(retrieved).toEqual(result); + }); + }); + + describe('keepAlive cleanup', () => { + beforeEach(() => { + jest.useFakeTimers(); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('should cleanup task after keepAlive duration', async () => { + const metadata: TaskMetadata = { + taskId: 'cleanup-test', + keepAlive: 1000 + }; + await store.createTask(metadata, 666, { + method: 'tools/call', + params: {} + }); + + // Task should exist initially + let task = await store.getTask('cleanup-test'); + expect(task).toBeDefined(); + + // Fast-forward past keepAlive + jest.advanceTimersByTime(1001); + + // Task should be cleaned up + task = await store.getTask('cleanup-test'); + expect(task).toBeNull(); + }); + + it('should reset cleanup timer when result is stored', async () => { + const metadata: TaskMetadata = { + taskId: 'reset-cleanup', + keepAlive: 1000 + }; + await store.createTask(metadata, 777, { + method: 'tools/call', + params: {} + }); + + // Fast-forward 500ms + jest.advanceTimersByTime(500); + + // Store result (should reset timer) + await store.storeTaskResult('reset-cleanup', { + content: [{ type: 'text' as const, text: 'Done' }] + }); + + // Fast-forward another 500ms (total 1000ms since creation, but timer was reset) + jest.advanceTimersByTime(500); + + // Task should still exist + const task = await store.getTask('reset-cleanup'); + expect(task).toBeDefined(); + + // Fast-forward remaining time + jest.advanceTimersByTime(501); + + // Now task should be cleaned up + const cleanedTask = await store.getTask('reset-cleanup'); + expect(cleanedTask).toBeNull(); + }); + + it('should not cleanup tasks without keepAlive', async () => { + const metadata: TaskMetadata = { + taskId: 'no-cleanup' + }; + await store.createTask(metadata, 888, { + method: 'tools/call', + params: {} + }); + + // Fast-forward a long time + jest.advanceTimersByTime(100000); + + // Task should still exist + const task = await store.getTask('no-cleanup'); + expect(task).toBeDefined(); + }); + + it('should start cleanup timer when task reaches terminal state', async () => { + const metadata: TaskMetadata = { + taskId: 'terminal-cleanup', + keepAlive: 1000 + }; + await store.createTask(metadata, 999, { + method: 'tools/call', + params: {} + }); + + // Task in non-terminal state, fast-forward + jest.advanceTimersByTime(1001); + + // Task should be cleaned up + let task = await store.getTask('terminal-cleanup'); + expect(task).toBeNull(); + + // Create another task + const metadata2: TaskMetadata = { + taskId: 'terminal-cleanup-2', + keepAlive: 2000 + }; + await store.createTask(metadata2, 1000, { + method: 'tools/call', + params: {} + }); + + // Update to terminal state + await store.updateTaskStatus('terminal-cleanup-2', 'completed'); + + // Fast-forward past original keepAlive + jest.advanceTimersByTime(2001); + + // Task should be cleaned up + task = await store.getTask('terminal-cleanup-2'); + expect(task).toBeNull(); + }); + }); + + describe('getAllTasks', () => { + it('should return all tasks', async () => { + await store.createTask({ taskId: 'task-1' }, 1, { + method: 'tools/call', + params: {} + }); + await store.createTask({ taskId: 'task-2' }, 2, { + method: 'tools/call', + params: {} + }); + await store.createTask({ taskId: 'task-3' }, 3, { + method: 'tools/call', + params: {} + }); + + const tasks = store.getAllTasks(); + expect(tasks).toHaveLength(3); + expect(tasks.map(t => t.taskId).sort()).toEqual(['task-1', 'task-2', 'task-3']); + }); + + it('should return empty array when no tasks', () => { + const tasks = store.getAllTasks(); + expect(tasks).toEqual([]); + }); + }); + + describe('cleanup', () => { + it('should clear all timers and tasks', async () => { + await store.createTask({ taskId: 'task-1', keepAlive: 1000 }, 1, { + method: 'tools/call', + params: {} + }); + await store.createTask({ taskId: 'task-2', keepAlive: 2000 }, 2, { + method: 'tools/call', + params: {} + }); + + expect(store.getAllTasks()).toHaveLength(2); + + store.cleanup(); + + expect(store.getAllTasks()).toHaveLength(0); + }); + }); +}); diff --git a/src/server/index.test.ts b/src/server/index.test.ts index d056707fe..6d74707cc 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -21,6 +21,8 @@ import { import { Transport } from '../shared/transport.js'; import { InMemoryTransport } from '../inMemory.js'; import { Client } from '../client/index.js'; +import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; +import { CallToolRequestSchema, CallToolResultSchema } from '../types.js'; test('should accept latest protocol version', async () => { let sendPromiseResolve: (value: unknown) => void; @@ -955,3 +957,260 @@ test('should respect log level for transport with sessionId', async () => { await server.sendLoggingMessage(warningParams, SESSION_ID); expect(clientTransport.onmessage).toHaveBeenCalled(); }); + +describe('Task-based execution', () => { + test('server with TaskStore should handle task-based tool execution', async () => { + const taskStore = new InMemoryTaskStore(); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore + } + ); + + // Set up a tool handler that simulates some work + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + // Simulate some async work + await new Promise(resolve => setTimeout(resolve, 10)); + return { + content: [{ type: 'text', text: 'Tool executed successfully!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Use beginCallTool to create a task + const taskId = 'test-task-123'; + const pendingRequest = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { + taskId, + keepAlive: 60000 + } + }); + + // Wait for the task to complete + await pendingRequest.result(); + + // Verify we can retrieve the task + const task = await client.getTask({ taskId }); + expect(task).toBeDefined(); + expect(task.status).toBe('completed'); + + // Verify we can retrieve the result + const result = await client.getTaskResult({ taskId }, CallToolResultSchema); + expect(result.content).toEqual([{ type: 'text', text: 'Tool executed successfully!' }]); + + // Cleanup + taskStore.cleanup(); + }); + + test('server without TaskStore should reject task-based requests', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + // No taskStore configured + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: 'Success!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to get a task when server doesn't have TaskStore + // The server will return a "Method not found" error + await expect(client.getTask({ taskId: 'non-existent' })).rejects.toThrow('Method not found'); + }); + + test('should automatically attach related-task metadata to nested requests during tool execution', async () => { + const taskStore = new InMemoryTaskStore(); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + // Track the elicitation request to verify related-task metadata + let capturedElicitRequest: z.infer | null = null; + + // Set up client elicitation handler + client.setRequestHandler(ElicitRequestSchema, async request => { + // Capture the request to verify metadata later + capturedElicitRequest = request; + + return { + action: 'accept', + content: { + username: 'test-user' + } + }; + }); + + // Set up server tool that makes a nested elicitation request + server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (request.params.name === 'collect-info') { + // During tool execution, make a nested request to the client using extra.sendRequest + // This should AUTOMATICALLY attach the related-task metadata + const elicitResult = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' } + }, + required: ['username'] + } + } + }, + z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }) + ); + + return { + content: [ + { + type: 'text', + text: `Collected username: ${elicitResult.action === 'accept' && elicitResult.content ? (elicitResult.content as Record).username : 'none'}` + } + ] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'collect-info', + description: 'Collects user info via elicitation', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Call tool WITH task metadata + const taskId = 'test-task-456'; + const pendingRequest = client.beginCallTool({ name: 'collect-info', arguments: {} }, CallToolResultSchema, { + task: { + taskId, + keepAlive: 60000 + } + }); + + // Wait for completion + await pendingRequest.result(); + + // Verify the nested elicitation request received the related-task metadata + expect(capturedElicitRequest).toBeDefined(); + expect(capturedElicitRequest!.params._meta).toBeDefined(); + expect(capturedElicitRequest!.params._meta?.['modelcontextprotocol.io/related-task']).toEqual({ + taskId: 'test-task-456' + }); + + // Verify tool result was correct + const result = await client.getTaskResult({ taskId }, CallToolResultSchema); + expect(result.content).toEqual([ + { + type: 'text', + text: 'Collected username: test-user' + } + ]); + + // Cleanup + taskStore.cleanup(); + }); +}); diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 1c098eafa..4eccfbd91 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -742,3 +742,472 @@ describe('mergeCapabilities', () => { expect(merged).toEqual({}); }); }); + +describe('Task-based execution', () => { + let protocol: Protocol; + let transport: MockTransport; + let sendSpy: jest.SpyInstance; + + beforeEach(() => { + transport = new MockTransport(); + sendSpy = jest.spyOn(transport, 'send'); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })(); + }); + + describe('beginRequest with task metadata', () => { + it('should inject task metadata into _meta field', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + protocol.beginRequest(request, resultSchema, { + task: { + taskId: 'my-task-123', + keepAlive: 30000 + } + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'tools/call', + params: { + name: 'test-tool', + _meta: { + 'modelcontextprotocol.io/task': { + taskId: 'my-task-123', + keepAlive: 30000 + } + } + } + }), + expect.any(Object) + ); + }); + + it('should preserve existing _meta when adding task metadata', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { + name: 'test-tool', + _meta: { + customField: 'customValue' + } + } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + protocol.beginRequest(request, resultSchema, { + task: { + taskId: 'my-task-456' + } + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + params: { + name: 'test-tool', + _meta: { + customField: 'customValue', + 'modelcontextprotocol.io/task': { + taskId: 'my-task-456' + } + } + } + }), + expect.any(Object) + ); + }); + + it('should return PendingRequest object', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + const pendingRequest = protocol.beginRequest(request, resultSchema, { + task: { + taskId: 'my-task-789' + } + }); + + expect(pendingRequest).toBeDefined(); + expect(pendingRequest.taskId).toBe('my-task-789'); + }); + }); + + describe('relatedTask metadata', () => { + it('should inject relatedTask metadata into _meta field', async () => { + await protocol.connect(transport); + + const request = { + method: 'notifications/message', + params: { data: 'test' } + }; + + const resultSchema = z.object({}); + + protocol.beginRequest(request, resultSchema, { + relatedTask: { + taskId: 'parent-task-123' + } + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + params: { + data: 'test', + _meta: { + 'modelcontextprotocol.io/related-task': { + taskId: 'parent-task-123' + } + } + } + }), + expect.any(Object) + ); + }); + + it('should work with notification method', async () => { + await protocol.connect(transport); + + await protocol.notification( + { + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }, + { + relatedTask: { + taskId: 'parent-task-456' + } + } + ); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'notifications/message', + params: { + level: 'info', + data: 'test message', + _meta: { + 'modelcontextprotocol.io/related-task': { + taskId: 'parent-task-456' + } + } + } + }), + expect.any(Object) + ); + }); + }); + + describe('task metadata combination', () => { + it('should combine task, relatedTask, and progress metadata', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + protocol.beginRequest(request, resultSchema, { + task: { + taskId: 'my-task-combined' + }, + relatedTask: { + taskId: 'parent-task' + }, + onprogress: jest.fn() + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + params: { + name: 'test-tool', + _meta: { + 'modelcontextprotocol.io/task': { + taskId: 'my-task-combined' + }, + 'modelcontextprotocol.io/related-task': { + taskId: 'parent-task' + }, + progressToken: expect.any(Number) + } + } + }), + expect.any(Object) + ); + }); + }); + + describe('task status transitions', () => { + it('should transition from submitted to working when handler starts', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async () => ({ + result: 'success' + })); + + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'test/method', + params: { + _meta: { + 'modelcontextprotocol.io/task': { + taskId: 'test-task', + keepAlive: 60000 + } + } + } + }); + + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(mockTaskStore.createTask).toHaveBeenCalledWith({ taskId: 'test-task', keepAlive: 60000 }, 1, { + method: 'test/method', + params: expect.any(Object) + }); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working'); + }); + + it('should transition to input_required during extra.sendRequest', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + }; + + const responsiveTransport = new MockTransport(); + responsiveTransport.send = jest.fn().mockImplementation(async (message: unknown) => { + if ( + typeof message === 'object' && + message !== null && + 'method' in message && + 'id' in message && + message.method === 'nested/request' && + responsiveTransport.onmessage + ) { + setTimeout(() => { + responsiveTransport.onmessage?.({ + jsonrpc: '2.0', + id: (message as { id: number }).id, + result: { nested: 'response' } + }); + }, 5); + } + }); + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(responsiveTransport); + + const capturedUpdateCalls: Array<{ taskId: string; status: string }> = []; + mockTaskStore.updateTaskStatus.mockImplementation((taskId, status) => { + capturedUpdateCalls.push({ taskId, status }); + return Promise.resolve(); + }); + + protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async (_request, extra) => { + await extra.sendRequest({ method: 'nested/request', params: {} }, z.object({ nested: z.string() })); + return { result: 'success' }; + }); + + responsiveTransport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'test/method', + params: { + _meta: { + 'modelcontextprotocol.io/task': { + taskId: 'test-task', + keepAlive: 60000 + } + } + } + }); + + await new Promise(resolve => setTimeout(resolve, 100)); + + expect(capturedUpdateCalls).toContainEqual({ taskId: 'test-task', status: 'working' }); + expect(capturedUpdateCalls).toContainEqual({ taskId: 'test-task', status: 'input_required' }); + + const inputRequiredIndex = capturedUpdateCalls.findIndex(c => c.status === 'input_required'); + const workingCalls = capturedUpdateCalls.filter(c => c.status === 'working'); + expect(workingCalls).toHaveLength(2); + + let workingCount = 0; + const secondWorkingIndex = capturedUpdateCalls.findIndex(c => { + if (c.status === 'working') { + workingCount++; + return workingCount === 2; + } + return false; + }); + expect(secondWorkingIndex).toBeGreaterThan(inputRequiredIndex); + }); + + it('should mark task as completed when storeTaskResult is called', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async () => ({ + result: 'success' + })); + + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'test/method', + params: { + _meta: { + 'modelcontextprotocol.io/task': { + taskId: 'test-task', + keepAlive: 60000 + } + } + } + }); + + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(mockTaskStore.storeTaskResult).toHaveBeenCalledWith('test-task', { result: 'success' }); + }); + + it('should mark task as cancelled when notifications/cancelled is received', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue({ taskId: 'test-task', status: 'working', keepAlive: null, pollFrequency: 500 }), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + void protocol.request({ method: 'test/slow', params: {} }, z.object({ result: z.string() }), { + task: { taskId: 'test-task', keepAlive: 60000 } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + transport.onmessage?.({ + jsonrpc: '2.0', + method: 'notifications/cancelled', + params: { + requestId: 0, + reason: 'User cancelled' + } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'cancelled'); + }); + + it('should mark task as failed when updateTaskStatus to working fails', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockRejectedValueOnce(new Error('Failed to update status')).mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async () => ({ + result: 'success' + })); + + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'test/method', + params: { + _meta: { + 'modelcontextprotocol.io/task': { + taskId: 'test-task', + keepAlive: 60000 + } + } + } + }); + + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working'); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'failed', 'Failed to mark task as working'); + }); + }); +}); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 382440fe6..6d811953a 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -518,29 +518,25 @@ export abstract class Protocol { @@ -548,7 +544,7 @@ export abstract class Protocol Date: Thu, 23 Oct 2025 16:24:41 -0700 Subject: [PATCH 10/18] Add docs for task augmentation --- README.md | 164 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) diff --git a/README.md b/README.md index 92f56786f..47588d600 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ - [Improving Network Efficiency with Notification Debouncing](#improving-network-efficiency-with-notification-debouncing) - [Low-Level Server](#low-level-server) - [Eliciting User Input](#eliciting-user-input) + - [Task-Based Execution](#task-based-execution) - [Writing MCP Clients](#writing-mcp-clients) - [Proxy Authorization Requests Upstream](#proxy-authorization-requests-upstream) - [Backwards Compatibility](#backwards-compatibility) @@ -1301,6 +1302,169 @@ client.setRequestHandler(ElicitRequestSchema, async request => { **Note**: Elicitation requires client support. Clients must declare the `elicitation` capability during initialization. +### Task-Based Execution + +Task-based execution enables "call-now, fetch-later" patterns for long-running operations. This is useful for tools that take significant time to complete, where clients may want to disconnect and check on progress or retrieve results later. + +Common use cases include: + +- Long-running data processing or analysis +- Code migration or refactoring operations +- Complex computational tasks +- Operations that require periodic status updates + +#### Server-Side: Implementing Task Support + +To enable task-based execution, configure your server with a `TaskStore` implementation. The SDK doesn't provide a built-in TaskStore—you'll need to implement one backed by your database of choice: + +```typescript +import { Server } from '@modelcontextprotocol/sdk/server/index.js'; +import { TaskStore } from '@modelcontextprotocol/sdk/shared/task.js'; +import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'; + +// Implement TaskStore backed by your database (e.g., PostgreSQL, Redis, etc.) +class MyTaskStore implements TaskStore { + async createTask(metadata, requestId, request) { + // Store task in your database + } + + async getTask(taskId) { + // Retrieve task from your database + } + + async updateTaskStatus(taskId, status, errorMessage?) { + // Update task status in your database + } + + async storeTaskResult(taskId, result) { + // Store task result in your database + } + + async getTaskResult(taskId) { + // Retrieve task result from your database + } +} + +const taskStore = new MyTaskStore(); + +const server = new Server( + { + name: 'task-enabled-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore // Enable task support + } +); + +// Set up a long-running tool handler as usual +server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'analyze-data') { + // Simulate long-running analysis + await new Promise(resolve => setTimeout(resolve, 30000)); + + return { + content: [ + { + type: 'text', + text: 'Analysis complete!' + } + ] + }; + } + throw new Error('Unknown tool'); +}); + +server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'analyze-data', + description: 'Perform data analysis (long-running)', + inputSchema: { + type: 'object', + properties: { + dataset: { type: 'string' } + } + } + } + ] +})); +``` + +**Note**: See `src/examples/shared/inMemoryTaskStore.ts` in the SDK source for a reference implementation suitable for development and testing. + +#### Client-Side: Using Task-Based Execution + +Clients use `beginCallTool()` to initiate task-based operations. The returned `PendingRequest` object provides automatic polling and status tracking: + +```typescript +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; + +const client = new Client({ + name: 'task-client', + version: '1.0.0' +}); + +// ... connect to server ... + +// Initiate a task-based tool call +const taskId = 'analysis-task-123'; +const pendingRequest = client.beginCallTool( + { + name: 'analyze-data', + arguments: { dataset: 'user-data.csv' } + }, + CallToolResultSchema, + { + task: { + taskId, + keepAlive: 300000 // Keep results for 5 minutes after completion + } + } +); + +// Option 1: Wait for completion with status callbacks +const result = await pendingRequest.result({ + onTaskCreated: () => { + console.log('Task created successfully'); + }, + onTaskStatus: task => { + console.log(`Task status: ${task.status}`); + // Status can be: 'submitted', 'working', 'input_required', 'completed', 'failed', or 'cancelled' + } +}); +console.log('Task completed:', result); + +// Option 2: Fire and forget - disconnect and reconnect later +// (useful when you don't want to wait for long-running tasks) +// Later, after disconnecting and reconnecting to the server: +const taskStatus = await client.getTask({ taskId }); +console.log('Task status:', taskStatus.status); + +if (taskStatus.status === 'completed') { + const taskResult = await client.getTaskResult({ taskId }, CallToolResultSchema); + console.log('Retrieved result after reconnect:', taskResult); +} +``` + +#### Task Status Lifecycle + +Tasks transition through the following states: + +- **submitted**: Task has been created and queued +- **working**: Task is actively being processed +- **input_required**: Task is waiting for additional input (e.g., from elicitation) +- **completed**: Task finished successfully +- **failed**: Task encountered an error +- **cancelled**: Task was cancelled by the client +- **unknown**: Task status could not be determined (terminal state, rarely occurs) + +The `keepAlive` parameter determines how long the server retains task results after completion. This allows clients to retrieve results even after disconnecting and reconnecting. + ### Writing MCP Clients The SDK provides a high-level client interface: From 5dc999f60593132087f2026ba3497aa4941e96eb Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 27 Oct 2025 12:19:35 -0700 Subject: [PATCH 11/18] Implement tasks/list method --- src/examples/shared/inMemoryTaskStore.test.ts | 73 ++++++ src/examples/shared/inMemoryTaskStore.ts | 26 ++ src/shared/protocol.test.ts | 248 +++++++++++++++++- src/shared/protocol.ts | 27 ++ src/shared/task.ts | 8 + src/types.ts | 28 +- 6 files changed, 401 insertions(+), 9 deletions(-) diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts index 2e4020a7f..9c8c7dab0 100644 --- a/src/examples/shared/inMemoryTaskStore.test.ts +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -353,6 +353,79 @@ describe('InMemoryTaskStore', () => { }); }); + describe('listTasks', () => { + it('should return empty list when no tasks', async () => { + const result = await store.listTasks(); + expect(result.tasks).toEqual([]); + expect(result.nextCursor).toBeUndefined(); + }); + + it('should return all tasks when less than page size', async () => { + await store.createTask({ taskId: 'task-1' }, 1, { + method: 'tools/call', + params: {} + }); + await store.createTask({ taskId: 'task-2' }, 2, { + method: 'tools/call', + params: {} + }); + await store.createTask({ taskId: 'task-3' }, 3, { + method: 'tools/call', + params: {} + }); + + const result = await store.listTasks(); + expect(result.tasks).toHaveLength(3); + expect(result.nextCursor).toBeUndefined(); + }); + + it('should paginate when more than page size', async () => { + // Create 15 tasks (page size is 10) + for (let i = 1; i <= 15; i++) { + await store.createTask({ taskId: `task-${i}` }, i, { + method: 'tools/call', + params: {} + }); + } + + // Get first page + const page1 = await store.listTasks(); + expect(page1.tasks).toHaveLength(10); + expect(page1.nextCursor).toBeDefined(); + + // Get second page using cursor + const page2 = await store.listTasks(page1.nextCursor); + expect(page2.tasks).toHaveLength(5); + expect(page2.nextCursor).toBeUndefined(); + }); + + it('should throw error for invalid cursor', async () => { + await store.createTask({ taskId: 'task-1' }, 1, { + method: 'tools/call', + params: {} + }); + + await expect(store.listTasks('non-existent-cursor')).rejects.toThrow('Invalid cursor: non-existent-cursor'); + }); + + it('should continue from cursor correctly', async () => { + // Create tasks with predictable IDs + for (let i = 1; i <= 5; i++) { + await store.createTask({ taskId: `task-${i}` }, i, { + method: 'tools/call', + params: {} + }); + } + + // Get first 3 tasks + const allTaskIds = Array.from(store.getAllTasks().map(t => t.taskId)); + const result = await store.listTasks(allTaskIds[2]); + + // Should get tasks after task-3 + expect(result.tasks).toHaveLength(2); + }); + }); + describe('cleanup', () => { it('should clear all timers and tasks', async () => { await store.createTask({ taskId: 'task-1', keepAlive: 1000 }, 1, { diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts index 79d8a05bd..c9f297c86 100644 --- a/src/examples/shared/inMemoryTaskStore.ts +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -122,6 +122,32 @@ export class InMemoryTaskStore implements TaskStore { } } + async listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }> { + const PAGE_SIZE = 10; + const allTaskIds = Array.from(this.tasks.keys()); + + let startIndex = 0; + if (cursor) { + const cursorIndex = allTaskIds.indexOf(cursor); + if (cursorIndex >= 0) { + startIndex = cursorIndex + 1; + } else { + // Invalid cursor - throw error + throw new Error(`Invalid cursor: ${cursor}`); + } + } + + const pageTaskIds = allTaskIds.slice(startIndex, startIndex + PAGE_SIZE); + const tasks = pageTaskIds.map(taskId => { + const stored = this.tasks.get(taskId)!; + return { ...stored.task }; + }); + + const nextCursor = startIndex + PAGE_SIZE < allTaskIds.length ? pageTaskIds[pageTaskIds.length - 1] : undefined; + + return { tasks, nextCursor }; + } + /** * Cleanup all timers (useful for testing or graceful shutdown) */ diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 4eccfbd91..a84e5a0ec 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -972,7 +972,8 @@ describe('Task-based execution', () => { getTask: jest.fn().mockResolvedValue(null), updateTaskStatus: jest.fn().mockResolvedValue(undefined), storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) }; protocol = new (class extends Protocol { @@ -1016,7 +1017,8 @@ describe('Task-based execution', () => { getTask: jest.fn().mockResolvedValue(null), updateTaskStatus: jest.fn().mockResolvedValue(undefined), storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) }; const responsiveTransport = new MockTransport(); @@ -1098,7 +1100,8 @@ describe('Task-based execution', () => { getTask: jest.fn().mockResolvedValue(null), updateTaskStatus: jest.fn().mockResolvedValue(undefined), storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) }; protocol = new (class extends Protocol { @@ -1138,7 +1141,8 @@ describe('Task-based execution', () => { getTask: jest.fn().mockResolvedValue({ taskId: 'test-task', status: 'working', keepAlive: null, pollFrequency: 500 }), updateTaskStatus: jest.fn().mockResolvedValue(undefined), storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) }; protocol = new (class extends Protocol { @@ -1175,7 +1179,8 @@ describe('Task-based execution', () => { getTask: jest.fn().mockResolvedValue(null), updateTaskStatus: jest.fn().mockRejectedValueOnce(new Error('Failed to update status')).mockResolvedValue(undefined), storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }) + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) }; protocol = new (class extends Protocol { @@ -1210,4 +1215,237 @@ describe('Task-based execution', () => { expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'failed', 'Failed to mark task as working'); }); }); + + describe('listTasks', () => { + it('should handle tasks/list requests and return tasks from TaskStore', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ + tasks: [ + { taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }, + { taskId: 'task-2', status: 'working', keepAlive: 60000, pollFrequency: 1000 } + ], + nextCursor: 'task-2' + }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + // Simulate receiving a tasks/list request + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'tasks/list', + params: {} + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined); + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(1); + expect(sentMessage.result.tasks).toEqual([ + { taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }, + { taskId: 'task-2', status: 'working', keepAlive: 60000, pollFrequency: 1000 } + ]); + expect(sentMessage.result.nextCursor).toBe('task-2'); + expect(sentMessage.result._meta).toEqual({}); + }); + + it('should handle tasks/list requests with cursor for pagination', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ + tasks: [{ taskId: 'task-3', status: 'submitted', keepAlive: null, pollFrequency: 500 }], + nextCursor: undefined + }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + // Simulate receiving a tasks/list request with cursor + transport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/list', + params: { + cursor: 'task-2' + } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.listTasks).toHaveBeenCalledWith('task-2'); + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(2); + expect(sentMessage.result.tasks).toEqual([{ taskId: 'task-3', status: 'submitted', keepAlive: null, pollFrequency: 500 }]); + expect(sentMessage.result.nextCursor).toBeUndefined(); + expect(sentMessage.result._meta).toEqual({}); + }); + + it('should handle tasks/list requests with empty results', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockResolvedValue({ + tasks: [], + nextCursor: undefined + }) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + // Simulate receiving a tasks/list request + transport.onmessage?.({ + jsonrpc: '2.0', + id: 3, + method: 'tasks/list', + params: {} + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined); + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(3); + expect(sentMessage.result.tasks).toEqual([]); + expect(sentMessage.result.nextCursor).toBeUndefined(); + expect(sentMessage.result._meta).toEqual({}); + }); + + it('should return error for invalid cursor', async () => { + const mockTaskStore = { + createTask: jest.fn().mockResolvedValue(undefined), + getTask: jest.fn().mockResolvedValue(null), + updateTaskStatus: jest.fn().mockResolvedValue(undefined), + storeTaskResult: jest.fn().mockResolvedValue(undefined), + getTaskResult: jest.fn().mockResolvedValue({ content: [] }), + listTasks: jest.fn().mockRejectedValue(new Error('Invalid cursor: bad-cursor')) + }; + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + // Simulate receiving a tasks/list request with invalid cursor + transport.onmessage?.({ + jsonrpc: '2.0', + id: 4, + method: 'tasks/list', + params: { + cursor: 'bad-cursor' + } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.listTasks).toHaveBeenCalledWith('bad-cursor'); + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(4); + expect(sentMessage.error).toBeDefined(); + expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.message).toContain('Failed to list tasks'); + expect(sentMessage.error.message).toContain('Invalid cursor'); + }); + + it('should call listTasks method from client side', async () => { + await protocol.connect(transport); + + const listTasksPromise = protocol.listTasks(); + + // Simulate server response + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: sendSpy.mock.calls[0][0].id, + result: { + tasks: [{ taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }], + nextCursor: undefined, + _meta: {} + } + }); + }, 10); + + const result = await listTasksPromise; + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'tasks/list', + params: undefined + }), + expect.any(Object) + ); + expect(result.tasks).toHaveLength(1); + expect(result.tasks[0].taskId).toBe('task-1'); + }); + + it('should call listTasks with cursor from client side', async () => { + await protocol.connect(transport); + + const listTasksPromise = protocol.listTasks({ cursor: 'task-10' }); + + // Simulate server response + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: sendSpy.mock.calls[0][0].id, + result: { + tasks: [{ taskId: 'task-11', status: 'working', keepAlive: 30000, pollFrequency: 1000 }], + nextCursor: 'task-11', + _meta: {} + } + }); + }, 10); + + const result = await listTasksPromise; + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'tasks/list', + params: { cursor: 'task-10' } + }), + expect.any(Object) + ); + expect(result.tasks).toHaveLength(1); + expect(result.tasks[0].taskId).toBe('task-11'); + expect(result.nextCursor).toBe('task-11'); + }); + }); }); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 6d811953a..a25107613 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -8,6 +8,8 @@ import { GetTaskResultSchema, GetTaskPayloadRequest, GetTaskPayloadRequestSchema, + ListTasksRequestSchema, + ListTasksResultSchema, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, @@ -319,6 +321,23 @@ export abstract class Protocol { + try { + const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor); + // @ts-expect-error SendResultT cannot contain ListTasksResult, but we include it in our derived types everywhere else + return { + tasks, + nextCursor, + _meta: {} + } as SendResultT; + } catch (error) { + throw new McpError( + ErrorCode.InvalidParams, + `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}` + ); + } + }); } } @@ -856,6 +875,14 @@ export abstract class Protocol> { + // @ts-expect-error SendRequestT cannot directly contain ListTasksRequest, but we ensure all type instantiations contain it anyways + return this.request({ method: 'tasks/list', params }, ListTasksResultSchema, options); + } + /** * Emits a notification, which is a one-way message that does not expect a response. */ diff --git a/src/shared/task.ts b/src/shared/task.ts index 617ab81aa..fbcd22e82 100644 --- a/src/shared/task.ts +++ b/src/shared/task.ts @@ -48,6 +48,14 @@ export interface TaskStore { * @param error - Optional error message if status is 'failed' or 'cancelled' */ updateTaskStatus(taskId: string, status: Task['status'], error?: string): Promise; + + /** + * Lists tasks, optionally starting from a pagination cursor. + * + * @param cursor - Optional cursor for pagination + * @returns An object containing the tasks array and an optional nextCursor + */ + listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; } /** diff --git a/src/types.ts b/src/types.ts index 24ea5881d..5b528f551 100644 --- a/src/types.ts +++ b/src/types.ts @@ -559,6 +559,20 @@ export const PaginatedResultSchema = ResultSchema.extend({ nextCursor: z.optional(CursorSchema) }); +/** + * A request to list tasks. + */ +export const ListTasksRequestSchema = PaginatedRequestSchema.extend({ + method: z.literal('tasks/list') +}); + +/** + * The response to a tasks/list request. + */ +export const ListTasksResultSchema = PaginatedResultSchema.extend({ + tasks: z.array(TaskSchema) +}); + /* Resources */ /** * The contents of a specific resource or sub-resource. @@ -1518,7 +1532,8 @@ export const ClientRequestSchema = z.union([ CallToolRequestSchema, ListToolsRequestSchema, GetTaskRequestSchema, - GetTaskPayloadRequestSchema + GetTaskPayloadRequestSchema, + ListTasksRequestSchema ]); export const ClientNotificationSchema = z.union([ @@ -1534,7 +1549,8 @@ export const ClientResultSchema = z.union([ CreateMessageResultSchema, ElicitResultSchema, ListRootsResultSchema, - GetTaskResultSchema + GetTaskResultSchema, + ListTasksResultSchema ]); /* Server messages */ @@ -1544,7 +1560,8 @@ export const ServerRequestSchema = z.union([ ElicitRequestSchema, ListRootsRequestSchema, GetTaskRequestSchema, - GetTaskPayloadRequestSchema + GetTaskPayloadRequestSchema, + ListTasksRequestSchema ]); export const ServerNotificationSchema = z.union([ @@ -1569,7 +1586,8 @@ export const ServerResultSchema = z.union([ ReadResourceResultSchema, CallToolResultSchema, ListToolsResultSchema, - GetTaskResultSchema + GetTaskResultSchema, + ListTasksResultSchema ]); export class McpError extends Error { @@ -1676,6 +1694,8 @@ export type TaskCreatedNotification = Infer; export type GetTaskResult = Infer; export type GetTaskPayloadRequest = Infer; +export type ListTasksRequest = Infer; +export type ListTasksResult = Infer; /* Pagination */ export type PaginatedRequest = Infer; From 71a956857edf5f125106deb627fb388caa655b1d Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 29 Oct 2025 14:10:05 -0700 Subject: [PATCH 12/18] Automatically execute tool calls as tasks --- package-lock.json | 32 ++++++++++++++++++------- package.json | 1 + src/client/index.ts | 8 ++++++- src/shared/protocol.test.ts | 48 ++++++++++++++++++++++++++----------- src/shared/protocol.ts | 37 +++++++++++++++++++++++++--- src/shared/request.ts | 5 ++-- 6 files changed, 102 insertions(+), 29 deletions(-) diff --git a/package-lock.json b/package-lock.json index 0f614d70e..8ee31c5d8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,6 +9,7 @@ "version": "1.20.2", "license": "MIT", "dependencies": { + "@lukeed/uuid": "^2.0.1", "ajv": "^8.17.1", "ajv-formats": "^3.0.1", "content-type": "^1.0.5", @@ -52,19 +53,11 @@ "node": ">=18" }, "peerDependencies": { - "@cfworker/json-schema": "^4.1.1", - "ajv": "^8.17.1", - "ajv-formats": "^3.0.1" + "@cfworker/json-schema": "^4.1.1" }, "peerDependenciesMeta": { "@cfworker/json-schema": { "optional": true - }, - "ajv": { - "optional": true - }, - "ajv-formats": { - "optional": true } } }, @@ -1610,6 +1603,27 @@ "@jridgewell/sourcemap-codec": "^1.4.14" } }, + "node_modules/@lukeed/csprng": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@lukeed/csprng/-/csprng-1.1.0.tgz", + "integrity": "sha512-Z7C/xXCiGWsg0KuKsHTKJxbWhpI3Vs5GwLfOean7MGyVFGqdRgBbAjOCh6u4bbjPc/8MJ2pZmK/0DLdCbivLDA==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/@lukeed/uuid": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/@lukeed/uuid/-/uuid-2.0.1.tgz", + "integrity": "sha512-qC72D4+CDdjGqJvkFMMEAtancHUQ7/d/tAiHf64z8MopFDmcrtbcJuerDtFceuAfQJ2pDSfCKCtbqoGBNnwg0w==", + "license": "MIT", + "dependencies": { + "@lukeed/csprng": "^1.1.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/@noble/hashes": { "version": "1.8.0", "resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.8.0.tgz", diff --git a/package.json b/package.json index 22a5b41cc..d7b601af3 100644 --- a/package.json +++ b/package.json @@ -75,6 +75,7 @@ "client": "tsx src/cli.ts client" }, "dependencies": { + "@lukeed/uuid": "^2.0.1", "ajv": "^8.17.1", "ajv-formats": "^3.0.1", "content-type": "^1.0.5", diff --git a/src/client/index.ts b/src/client/index.ts index a4c70c581..da66b1102 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,6 +1,7 @@ import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; import type { Transport } from '../shared/transport.js'; import { PendingRequest } from '../shared/request.js'; +import { v4 as uuidv4 } from '@lukeed/uuid'; import { type CallToolRequest, CallToolResultSchema, @@ -368,7 +369,12 @@ export class Client< resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, options?: RequestOptions ): PendingRequest { - return this.beginRequest({ method: 'tools/call', params }, resultSchema, options); + // Automatically add task metadata if not provided + const optionsWithTask = { + ...options, + task: options?.task ?? { taskId: uuidv4() } + }; + return this.beginRequest({ method: 'tools/call', params }, resultSchema, optionsWithTask); } /** diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index a84e5a0ec..b3f1d4e5e 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -1,5 +1,15 @@ import { ZodType, z } from 'zod'; -import { ClientCapabilities, ErrorCode, McpError, Notification, Request, Result, ServerCapabilities } from '../types.js'; +import { + ClientCapabilities, + ErrorCode, + McpError, + Notification, + RELATED_TASK_META_KEY, + Request, + Result, + ServerCapabilities, + TASK_META_KEY +} from '../types.js'; import { Protocol, mergeCapabilities } from './protocol.js'; import { Transport } from './transport.js'; @@ -784,7 +794,7 @@ describe('Task-based execution', () => { params: { name: 'test-tool', _meta: { - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'my-task-123', keepAlive: 30000 } @@ -824,7 +834,7 @@ describe('Task-based execution', () => { name: 'test-tool', _meta: { customField: 'customValue', - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'my-task-456' } } @@ -879,7 +889,7 @@ describe('Task-based execution', () => { params: { data: 'test', _meta: { - 'modelcontextprotocol.io/related-task': { + [RELATED_TASK_META_KEY]: { taskId: 'parent-task-123' } } @@ -911,7 +921,7 @@ describe('Task-based execution', () => { level: 'info', data: 'test message', _meta: { - 'modelcontextprotocol.io/related-task': { + [RELATED_TASK_META_KEY]: { taskId: 'parent-task-456' } } @@ -950,10 +960,10 @@ describe('Task-based execution', () => { params: { name: 'test-tool', _meta: { - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'my-task-combined' }, - 'modelcontextprotocol.io/related-task': { + [RELATED_TASK_META_KEY]: { taskId: 'parent-task' }, progressToken: expect.any(Number) @@ -994,7 +1004,7 @@ describe('Task-based execution', () => { method: 'test/method', params: { _meta: { - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'test-task', keepAlive: 60000 } @@ -1066,7 +1076,7 @@ describe('Task-based execution', () => { method: 'test/method', params: { _meta: { - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'test-task', keepAlive: 60000 } @@ -1122,7 +1132,7 @@ describe('Task-based execution', () => { method: 'test/method', params: { _meta: { - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'test-task', keepAlive: 60000 } @@ -1201,7 +1211,7 @@ describe('Task-based execution', () => { method: 'test/method', params: { _meta: { - 'modelcontextprotocol.io/task': { + [TASK_META_KEY]: { taskId: 'test-task', keepAlive: 60000 } @@ -1398,7 +1408,11 @@ describe('Task-based execution', () => { result: { tasks: [{ taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }], nextCursor: undefined, - _meta: {} + _meta: { + [TASK_META_KEY]: expect.objectContaining({ + taskId: expect.any(String) + }) + } } }); }, 10); @@ -1429,7 +1443,11 @@ describe('Task-based execution', () => { result: { tasks: [{ taskId: 'task-11', status: 'working', keepAlive: 30000, pollFrequency: 1000 }], nextCursor: 'task-11', - _meta: {} + _meta: { + [TASK_META_KEY]: expect.objectContaining({ + taskId: expect.any(String) + }) + } } }); }, 10); @@ -1439,7 +1457,9 @@ describe('Task-based execution', () => { expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ method: 'tasks/list', - params: { cursor: 'task-10' } + params: { + cursor: 'task-10' + } }), expect.any(Object) ); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index a25107613..40f346edc 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -223,7 +223,7 @@ export abstract class Protocol = new Map(); private _timeoutInfo: Map = new Map(); private _pendingDebouncedNotifications = new Set(); - private _pendingTaskCreations: Map void; reject: (reason: Error) => void }> = new Map(); + private _pendingTaskCreations: Map void; reject: (reason: unknown) => void }> = new Map(); private _requestIdToTaskId: Map = new Map(); private _taskStore?: TaskStore; @@ -400,6 +400,16 @@ export abstract class Protocol>((resolve, reject) => { + const earlyReject = (error: unknown) => { + // Clean up task tracking if we reject before sending + if (taskId) { + const resolver = this._pendingTaskCreations.get(taskId); + resolver?.reject(error); + this._pendingTaskCreations.delete(taskId); + } + reject(error); + }; + if (!this._transport) { - reject(new Error('Not connected')); + earlyReject(new Error('Not connected')); return; } if (this._options?.enforceStrictCapabilities === true) { - this.assertCapabilityForMethod(request.method); + try { + this.assertCapabilityForMethod(request.method); + } catch (e) { + earlyReject(e); + return; + } } options?.signal?.throwIfAborted(); @@ -782,6 +812,7 @@ export abstract class Protocol `${e}`).join(', ')}`); + // Both failed - prefer to throw the result error since it's usually more meaningful + // (e.g., timeout, connection error, etc.) than the task creation failure + throw result.reason; }); } From 2167b437cc6dcf62e7dc644438af446a7c43158d Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Fri, 31 Oct 2025 18:13:03 -0700 Subject: [PATCH 13/18] Implement task API tests on both the client and server --- src/client/index.test.ts | 688 +++++++++++++++++++++++++++++++++++++++ src/server/index.test.ts | 418 ++++++++++++++++++++++++ 2 files changed, 1106 insertions(+) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index de37b2d90..a135a7c14 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -13,6 +13,7 @@ import { ListResourcesRequestSchema, ListToolsRequestSchema, CallToolRequestSchema, + CallToolResultSchema, CreateMessageRequestSchema, ElicitRequestSchema, ListRootsRequestSchema, @@ -21,6 +22,7 @@ import { import { Transport } from '../shared/transport.js'; import { Server } from '../server/index.js'; import { InMemoryTransport } from '../inMemory.js'; +import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; /*** * Test: Initialize with Matching Protocol Version @@ -1239,3 +1241,689 @@ describe('outputSchema validation', () => { ); }); }); + +describe('Task-based execution', () => { + describe('Client calling server', () => { + let serverTaskStore: InMemoryTaskStore; + + beforeEach(() => { + serverTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + serverTaskStore?.cleanup(); + }); + + test('should create task on server via tool call', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: 'Tool executed successfully!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Client creates task on server via tool call + const taskId = 'test-task-create'; + const pendingRequest = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { + taskId, + keepAlive: 60000 + } + }); + + await pendingRequest.result(); + + // Verify task was created successfully + const task = await client.getTask({ taskId }); + expect(task.status).toBe('completed'); + }); + + test('should query task status from server using getTask', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: 'Success!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create a task + const taskId = 'test-task-get'; + const pending = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { taskId, keepAlive: 60000 } + }); + await pending.result(); + + // Query task status + const task = await client.getTask({ taskId }); + expect(task).toBeDefined(); + expect(task.taskId).toBe(taskId); + expect(task.status).toBe('completed'); + }); + + test('should query task result from server using getTaskResult', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: 'Result data!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create a task + const taskId = 'test-task-result'; + const pending = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { taskId, keepAlive: 60000 } + }); + await pending.result(); + + // Query task result + const result = await client.getTaskResult({ taskId }, CallToolResultSchema); + expect(result.content).toEqual([{ type: 'text', text: 'Result data!' }]); + }); + + test('should query task list from server using listTasks', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: 'Success!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks + const taskIds = ['task-list-1', 'task-list-2']; + + for (const taskId of taskIds) { + const pending = client.beginCallTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { taskId, keepAlive: 60000 } + }); + await pending.result(); + } + + // Query task list + const taskList = await client.listTasks(); + expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); + for (const taskId of taskIds) { + expect(taskList.tasks).toContainEqual( + expect.objectContaining({ + taskId, + status: 'completed' + }) + ); + } + }); + }); + + describe('Server calling client', () => { + let clientTaskStore: InMemoryTaskStore; + + beforeEach(() => { + clientTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + clientTaskStore?.cleanup(); + }); + + test('should create task on client via server elicitation', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { + username: 'test-user' + } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Server creates task on client via elicitation + const taskId = 'elicit-task-create'; + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const pendingRequest = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' } + }, + required: ['username'] + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + + await pendingRequest.result(); + + // Verify task was created + const task = await server.getTask({ taskId }); + expect(task.status).toBe('completed'); + }); + + test('should query task status from client using getTask', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'test-user' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create a task on client + const taskId = 'elicit-task-get'; + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const pending = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + await pending.result(); + + // Query task status + const task = await server.getTask({ taskId }); + expect(task).toBeDefined(); + expect(task.taskId).toBe(taskId); + expect(task.status).toBe('completed'); + }); + + test('should query task result from client using getTaskResult', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'result-user' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create a task on client + const taskId = 'elicit-task-result'; + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const pending = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + await pending.result(); + + // Query task result + const result = await server.getTaskResult({ taskId }, ElicitResultSchema); + expect(result.action).toBe('accept'); + expect(result.content).toEqual({ username: 'result-user' }); + }); + + test('should query task list from client using listTasks', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'list-user' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks on client + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const taskIds = ['elicit-list-1', 'elicit-list-2']; + for (const taskId of taskIds) { + const pending = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + await pending.result(); + } + + // Query task list + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); + for (const taskId of taskIds) { + expect(taskList.tasks).toContainEqual( + expect.objectContaining({ + taskId, + status: 'completed' + }) + ); + } + }); + }); + + test('should list tasks from server with pagination', async () => { + const serverTaskStore = new InMemoryTaskStore(); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: `Result for ${request.params.arguments?.id || 'unknown'}` }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks + const taskIds = ['task-1', 'task-2', 'task-3']; + + for (const taskId of taskIds) { + const pending = client.beginCallTool({ name: 'test-tool', arguments: { id: taskId } }, CallToolResultSchema, { + task: { taskId, keepAlive: 60000 } + }); + await pending.result(); + } + + // List all tasks without cursor + const firstPage = await client.listTasks(); + expect(firstPage.tasks.length).toBeGreaterThan(0); + expect(firstPage.tasks.map(t => t.taskId)).toEqual(expect.arrayContaining(taskIds)); + + // If there's a cursor, test pagination + if (firstPage.nextCursor) { + const secondPage = await client.listTasks({ cursor: firstPage.nextCursor }); + expect(secondPage.tasks).toBeDefined(); + } + + serverTaskStore.cleanup(); + }); + + describe('Error scenarios', () => { + let serverTaskStore: InMemoryTaskStore; + let clientTaskStore: InMemoryTaskStore; + + beforeEach(() => { + serverTaskStore = new InMemoryTaskStore(); + clientTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + serverTaskStore?.cleanup(); + clientTaskStore?.cleanup(); + }); + + test('should throw error when querying non-existent task from server', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to get a task that doesn't exist + await expect(client.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + }); + + test('should throw error when querying result of non-existent task from server', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore: serverTaskStore + } + ); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to get result of a task that doesn't exist + await expect(client.getTaskResult({ taskId: 'non-existent-task' }, CallToolResultSchema)).rejects.toThrow(); + }); + + test('should throw error when server queries non-existent task from client', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'test' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to query a task that doesn't exist on client + await expect(server.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + }); + }); +}); diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 988085199..0bf13eaf7 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -1211,4 +1211,422 @@ describe('Task-based execution', () => { // Cleanup taskStore.cleanup(); }); + + describe('Server calling client via elicitation', () => { + let clientTaskStore: InMemoryTaskStore; + + beforeEach(() => { + clientTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + clientTaskStore?.cleanup(); + }); + + test('should create task on client via elicitation', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { + username: 'server-test-user', + confirmed: true + } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Server creates task on client via elicitation + const taskId = 'server-elicit-create'; + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const pendingRequest = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' }, + confirmed: { type: 'boolean' } + }, + required: ['username'] + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + + await pendingRequest.result(); + + // Verify task was created + const task = await server.getTask({ taskId }); + expect(task.status).toBe('completed'); + }); + + test('should query task from client using getTask', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'get-user' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create task + const taskId = 'server-elicit-get'; + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const pending = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + await pending.result(); + + // Query task + const task = await server.getTask({ taskId }); + expect(task).toBeDefined(); + expect(task.taskId).toBe(taskId); + expect(task.status).toBe('completed'); + }); + + test('should query task result from client using getTaskResult', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'result-user', confirmed: true } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create task + const taskId = 'server-elicit-result'; + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const pending = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Provide info', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' }, + confirmed: { type: 'boolean' } + } + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + await pending.result(); + + // Query result + const result = await server.getTaskResult({ taskId }, ElicitResultSchema); + expect(result.action).toBe('accept'); + expect(result.content).toEqual({ username: 'result-user', confirmed: true }); + }); + + test('should query task list from client using listTasks', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'list-user' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks + const ElicitResultSchema = z.object({ + action: z.enum(['accept', 'decline', 'cancel']), + content: z.record(z.unknown()).optional() + }); + + const taskIds = ['server-elicit-list-1', 'server-elicit-list-2']; + for (const taskId of taskIds) { + const pending = server.beginRequest( + { + method: 'elicitation/create', + params: { + message: 'Provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + ElicitResultSchema, + { task: { taskId, keepAlive: 60000 } } + ); + await pending.result(); + } + + // Query task list + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); + for (const taskId of taskIds) { + expect(taskList.tasks).toContainEqual( + expect.objectContaining({ + taskId, + status: 'completed' + }) + ); + } + }); + }); + + test('should handle multiple concurrent task-based tool calls', async () => { + const taskStore = new InMemoryTaskStore(); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore + } + ); + + // Set up a tool handler with variable delay + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'async-tool') { + const delay = (request.params.arguments?.delay as number) || 10; + await new Promise(resolve => setTimeout(resolve, delay)); + return { + content: [{ type: 'text', text: `Completed task ${request.params.arguments?.taskNum || 'unknown'}` }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'async-tool', + description: 'An async test tool', + inputSchema: { + type: 'object', + properties: { + delay: { type: 'number' }, + taskNum: { type: 'number' } + } + } + } + ] + })); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks concurrently + const taskIds = ['concurrent-1', 'concurrent-2', 'concurrent-3', 'concurrent-4']; + const pendingRequests = taskIds.map((taskId, index) => + client.beginCallTool({ name: 'async-tool', arguments: { delay: 10 + index * 5, taskNum: index + 1 } }, CallToolResultSchema, { + task: { taskId, keepAlive: 60000 } + }) + ); + + // Wait for all tasks to complete + await Promise.all(pendingRequests.map(p => p.result())); + + // Verify all tasks completed successfully + for (let i = 0; i < taskIds.length; i++) { + const task = await client.getTask({ taskId: taskIds[i] }); + expect(task.status).toBe('completed'); + expect(task.taskId).toBe(taskIds[i]); + + const result = await client.getTaskResult({ taskId: taskIds[i] }, CallToolResultSchema); + expect(result.content).toEqual([{ type: 'text', text: `Completed task ${i + 1}` }]); + } + + // Verify listTasks returns all tasks + const taskList = await client.listTasks(); + for (const taskId of taskIds) { + expect(taskList.tasks).toContainEqual(expect.objectContaining({ taskId })); + } + + // Cleanup + taskStore.cleanup(); + }); + + describe('Error scenarios', () => { + let taskStore: InMemoryTaskStore; + let clientTaskStore: InMemoryTaskStore; + + beforeEach(() => { + taskStore = new InMemoryTaskStore(); + clientTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + taskStore?.cleanup(); + clientTaskStore?.cleanup(); + }); + + test('should throw error when client queries non-existent task from server', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + }, + taskStore + } + ); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to query a task that doesn't exist + await expect(client.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + }); + + test('should throw error when server queries non-existent task from client', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'test' } + })); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to query a task that doesn't exist on client + await expect(server.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + }); + }); }); From 12d0f66ebd87b346a07251b9d6f9db7f2d0ff041 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Fri, 31 Oct 2025 19:38:40 -0700 Subject: [PATCH 14/18] Make default task polling interval configurable --- src/shared/protocol.ts | 7 ++++++- src/shared/request.ts | 9 ++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 40f346edc..389abfe8f 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -72,6 +72,11 @@ export type ProtocolOptions = { * handle task creation, status tracking, and result storage. */ taskStore?: TaskStore; + /** + * Default polling interval (in milliseconds) for task status checks when no pollFrequency + * is provided by the server. Defaults to 5000ms if not specified. + */ + defaultTaskPollInterval?: number; }; /** @@ -863,7 +868,7 @@ export abstract class Protocol Promise.resolve(); @@ -18,7 +18,8 @@ export class PendingRequest, readonly resultHandle: Promise, readonly resultSchema: ZodType, - readonly taskId?: string + readonly taskId?: string, + readonly defaultTaskPollInterval?: number ) {} /** @@ -66,7 +67,9 @@ export class PendingRequest setTimeout(resolve, task.pollFrequency ?? DEFAULT_POLLING_INTERNAL)); + await new Promise(resolve => + setTimeout(resolve, task.pollFrequency ?? this.defaultTaskPollInterval ?? DEFAULT_TASK_POLLING_INTERVAL) + ); } while (!isTerminal(task.status)); // Process result From bb28ef79808670f5263ac5223e60db51d2a16c27 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 3 Nov 2025 10:52:37 -0800 Subject: [PATCH 15/18] Exclude relatedTask from RequestHandlerExtra --- src/shared/protocol.ts | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 389abfe8f..d3687311d 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -145,6 +145,12 @@ export type NotificationOptions = { relatedTask?: RelatedTaskMetadata; }; +/** + * Options that can be given per request. + */ +// relatedTask is excluded as the SDK controls if this is sent according to if the source is a task. +export type TaskRequestOptions = Omit; + /** * Extra data given to request handlers. */ @@ -196,7 +202,11 @@ export type RequestHandlerExtra< * * This is used by certain transports to correctly associate related messages. */ - sendRequest: >(request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>; + sendRequest: >( + request: SendRequestT, + resultSchema: U, + options?: TaskRequestOptions + ) => Promise>; }; /** From 0bf2b429d27e1a7981a38462dba811a2b0f1ccc9 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 3 Nov 2025 11:10:39 -0800 Subject: [PATCH 16/18] Mark tasks as cancelled only after confirming abort --- src/shared/protocol.test.ts | 290 +++++++++++++++++++++++------------- src/shared/protocol.ts | 77 +++++++--- 2 files changed, 249 insertions(+), 118 deletions(-) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index b3f1d4e5e..76782f3cd 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -6,12 +6,16 @@ import { Notification, RELATED_TASK_META_KEY, Request, + RequestId, Result, ServerCapabilities, - TASK_META_KEY + Task, + TASK_META_KEY, + TaskMetadata } from '../types.js'; import { Protocol, mergeCapabilities } from './protocol.js'; import { Transport } from './transport.js'; +import { TaskStore } from './task.js'; // Mock Transport class class MockTransport implements Transport { @@ -26,6 +30,76 @@ class MockTransport implements Transport { async send(_message: unknown): Promise {} } +function createMockTaskStore(options?: { + onStatus?: (status: Task['status']) => void; + onList?: () => void; +}): TaskStore & { [K in keyof TaskStore]: jest.Mock, Parameters> } { + const tasks: Record = {}; + return { + createTask: jest.fn((taskMetadata: TaskMetadata, _1: RequestId, _2: Request) => { + tasks[taskMetadata.taskId] = { + taskId: taskMetadata.taskId, + status: (taskMetadata.status as Task['status'] | undefined) ?? 'submitted', + keepAlive: taskMetadata.keepAlive ?? null, + pollFrequency: (taskMetadata.pollFrequency as Task['pollFrequency'] | undefined) ?? 1000 + }; + options?.onStatus?.('submitted'); + return Promise.resolve(); + }), + getTask: jest.fn((taskId: string) => { + return Promise.resolve(tasks[taskId] ?? null); + }), + updateTaskStatus: jest.fn((taskId, status, error) => { + const task = tasks[taskId]; + if (task) { + task.status = status; + task.error = error; + options?.onStatus?.(task.status); + } + return Promise.resolve(); + }), + storeTaskResult: jest.fn((taskId: string, result: Result) => { + const task = tasks[taskId]; + if (task) { + task.status = 'completed'; + task.result = result; + options?.onStatus?.('completed'); + } + return Promise.resolve(); + }), + getTaskResult: jest.fn((taskId: string) => { + const task = tasks[taskId]; + if (task?.result) { + return Promise.resolve(task.result); + } + throw new Error('Task result not found'); + }), + listTasks: jest.fn(() => { + const result = { + tasks: Object.values(tasks) + }; + options?.onList?.(); + return Promise.resolve(result); + }) + }; +} + +function createLatch() { + let latch = false; + const waitForLatch = async () => { + while (!latch) { + await new Promise(resolve => setTimeout(resolve, 0)); + } + }; + + return { + releaseLatch: () => { + latch = true; + }, + waitForLatch + }; +} + describe('protocol tests', () => { let protocol: Protocol; let transport: MockTransport; @@ -977,14 +1051,14 @@ describe('Task-based execution', () => { describe('task status transitions', () => { it('should transition from submitted to working when handler starts', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) - }; + const workingProcessed = createLatch(); + const mockTaskStore = createMockTaskStore({ + onStatus: status => { + if (status === 'working') { + workingProcessed.releaseLatch(); + } + } + }); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1012,24 +1086,17 @@ describe('Task-based execution', () => { } }); - await new Promise(resolve => setTimeout(resolve, 50)); + await workingProcessed.waitForLatch(); expect(mockTaskStore.createTask).toHaveBeenCalledWith({ taskId: 'test-task', keepAlive: 60000 }, 1, { method: 'test/method', params: expect.any(Object) }); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working'); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working', undefined); }); it('should transition to input_required during extra.sendRequest', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) - }; + const mockTaskStore = createMockTaskStore(); const responsiveTransport = new MockTransport(); responsiveTransport.send = jest.fn().mockImplementation(async (message: unknown) => { @@ -1105,14 +1172,14 @@ describe('Task-based execution', () => { }); it('should mark task as completed when storeTaskResult is called', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) - }; + const completeProcessed = createLatch(); + const mockTaskStore = createMockTaskStore({ + onStatus: status => { + if (status === 'completed') { + completeProcessed.releaseLatch(); + } + } + }); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1140,20 +1207,20 @@ describe('Task-based execution', () => { } }); - await new Promise(resolve => setTimeout(resolve, 50)); + await completeProcessed.waitForLatch(); expect(mockTaskStore.storeTaskResult).toHaveBeenCalledWith('test-task', { result: 'success' }); }); it('should mark task as cancelled when notifications/cancelled is received', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue({ taskId: 'test-task', status: 'working', keepAlive: null, pollFrequency: 500 }), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) - }; + const cancelProcessed = createLatch(); + const mockTaskStore = createMockTaskStore({ + onStatus: status => { + if (status === 'cancelled') { + cancelProcessed.releaseLatch(); + } + } + }); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1163,35 +1230,50 @@ describe('Task-based execution', () => { await protocol.connect(transport); - void protocol.request({ method: 'test/slow', params: {} }, z.object({ result: z.string() }), { - task: { taskId: 'test-task', keepAlive: 60000 } + const requestInProgress = createLatch(); + const cancelSent = createLatch(); + + protocol.setRequestHandler(z.object({ method: z.literal('test/method'), params: z.any() }), async () => { + requestInProgress.releaseLatch(); + await cancelSent.waitForLatch(); + return { + result: 'success' + }; }); - await new Promise(resolve => setTimeout(resolve, 10)); + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'test/method', + params: { + _meta: { + [TASK_META_KEY]: { + taskId: 'test-task', + keepAlive: 60000 + } + } + } + }); transport.onmessage?.({ jsonrpc: '2.0', method: 'notifications/cancelled', params: { - requestId: 0, + requestId: 1, reason: 'User cancelled' } }); - await new Promise(resolve => setTimeout(resolve, 10)); + await requestInProgress.waitForLatch(); + cancelSent.releaseLatch(); + await cancelProcessed.waitForLatch(); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'cancelled'); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'cancelled', undefined); }); it('should mark task as failed when updateTaskStatus to working fails', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockRejectedValueOnce(new Error('Failed to update status')).mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ tasks: [], nextCursor: undefined }) - }; + const mockTaskStore = createMockTaskStore(); + mockTaskStore.updateTaskStatus.mockRejectedValueOnce(new Error('Failed to update status')).mockResolvedValue(undefined); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1221,27 +1303,42 @@ describe('Task-based execution', () => { await new Promise(resolve => setTimeout(resolve, 50)); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working'); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'working', undefined); expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith('test-task', 'failed', 'Failed to mark task as working'); }); }); describe('listTasks', () => { it('should handle tasks/list requests and return tasks from TaskStore', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ - tasks: [ - { taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }, - { taskId: 'task-2', status: 'working', keepAlive: 60000, pollFrequency: 1000 } - ], - nextCursor: 'task-2' - }) - }; + const listedTasks = createLatch(); + const mockTaskStore = createMockTaskStore({ + onList: () => listedTasks.releaseLatch() + }); + await mockTaskStore.createTask( + { + taskId: 'task-1', + status: 'completed', + pollFrequency: 500 + }, + 1, + { + method: 'test/method', + params: {} + } + ); + await mockTaskStore.createTask( + { + taskId: 'task-2', + status: 'working', + keepAlive: 60000, + pollFrequency: 1000 + }, + 2, + { + method: 'test/method', + params: {} + } + ); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1254,37 +1351,41 @@ describe('Task-based execution', () => { // Simulate receiving a tasks/list request transport.onmessage?.({ jsonrpc: '2.0', - id: 1, + id: 3, method: 'tasks/list', params: {} }); - await new Promise(resolve => setTimeout(resolve, 10)); + await listedTasks.waitForLatch(); expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined); const sentMessage = sendSpy.mock.calls[0][0]; expect(sentMessage.jsonrpc).toBe('2.0'); - expect(sentMessage.id).toBe(1); + expect(sentMessage.id).toBe(3); expect(sentMessage.result.tasks).toEqual([ { taskId: 'task-1', status: 'completed', keepAlive: null, pollFrequency: 500 }, { taskId: 'task-2', status: 'working', keepAlive: 60000, pollFrequency: 1000 } ]); - expect(sentMessage.result.nextCursor).toBe('task-2'); expect(sentMessage.result._meta).toEqual({}); }); it('should handle tasks/list requests with cursor for pagination', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ - tasks: [{ taskId: 'task-3', status: 'submitted', keepAlive: null, pollFrequency: 500 }], - nextCursor: undefined - }) - }; + const listedTasks = createLatch(); + const mockTaskStore = createMockTaskStore({ + onList: () => listedTasks.releaseLatch() + }); + await mockTaskStore.createTask( + { + taskId: 'task-3', + status: 'submitted', + pollFrequency: 500 + }, + 1, + { + method: 'test/method', + params: {} + } + ); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1304,7 +1405,7 @@ describe('Task-based execution', () => { } }); - await new Promise(resolve => setTimeout(resolve, 10)); + await listedTasks.waitForLatch(); expect(mockTaskStore.listTasks).toHaveBeenCalledWith('task-2'); const sentMessage = sendSpy.mock.calls[0][0]; @@ -1316,17 +1417,10 @@ describe('Task-based execution', () => { }); it('should handle tasks/list requests with empty results', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockResolvedValue({ - tasks: [], - nextCursor: undefined - }) - }; + const listedTasks = createLatch(); + const mockTaskStore = createMockTaskStore({ + onList: () => listedTasks.releaseLatch() + }); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} @@ -1344,7 +1438,7 @@ describe('Task-based execution', () => { params: {} }); - await new Promise(resolve => setTimeout(resolve, 10)); + await listedTasks.waitForLatch(); expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined); const sentMessage = sendSpy.mock.calls[0][0]; @@ -1356,14 +1450,8 @@ describe('Task-based execution', () => { }); it('should return error for invalid cursor', async () => { - const mockTaskStore = { - createTask: jest.fn().mockResolvedValue(undefined), - getTask: jest.fn().mockResolvedValue(null), - updateTaskStatus: jest.fn().mockResolvedValue(undefined), - storeTaskResult: jest.fn().mockResolvedValue(undefined), - getTaskResult: jest.fn().mockResolvedValue({ content: [] }), - listTasks: jest.fn().mockRejectedValue(new Error('Invalid cursor: bad-cursor')) - }; + const mockTaskStore = createMockTaskStore(); + mockTaskStore.listTasks.mockRejectedValue(new Error('Invalid cursor: bad-cursor')); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index d3687311d..ed407be56 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -36,12 +36,14 @@ import { TASK_META_KEY, GetTaskResult, TaskMetadata, - RelatedTaskMetadata + RelatedTaskMetadata, + CancelledNotification, + Task } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; import { PendingRequest } from './request.js'; -import { TaskStore } from './task.js'; +import { isTerminal, TaskStore } from './task.js'; /** * Callback for progress notifications. @@ -239,7 +241,7 @@ export abstract class Protocol = new Map(); private _pendingDebouncedNotifications = new Set(); private _pendingTaskCreations: Map void; reject: (reason: unknown) => void }> = new Map(); - private _requestIdToTaskId: Map = new Map(); + private _requestIdToTaskId: Map = new Map(); private _taskStore?: TaskStore; /** @@ -356,16 +358,18 @@ export abstract class Protocol): Promise { + private async _oncancel(notification: CancelledNotification): Promise { // Handle request cancellation const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); controller?.abort(notification.params.reason); + } + private async _postcancel(requestId: RequestId): Promise { // If this request had a task, mark it as cancelled in storage - const taskId = this._requestIdToTaskId.get(Number(notification.params.requestId)); + const taskId = this._requestIdToTaskId.get(requestId); if (taskId && this._taskStore) { try { - await this._taskStore.updateTaskStatus(taskId, 'cancelled'); + await this._setTaskStatus(taskId, 'cancelled'); } catch (error) { this._onerror(new Error(`Failed to cancel task ${taskId}: ${error}`)); } @@ -536,14 +540,16 @@ export abstract class Protocol { const relatedTask = taskMetadata ? { taskId: taskMetadata.taskId } : undefined; - if (taskMetadata && this._taskStore) { - await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'input_required'); + if (taskMetadata) { + // Allow this to throw to the caller (request handler) + await this._setTaskStatus(taskMetadata.taskId, 'input_required'); } try { return await this.request(r, resultSchema, { ...options, relatedRequestId: request.id, relatedTask }); } finally { - if (taskMetadata && this._taskStore) { - await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'working'); + if (taskMetadata) { + // Allow this to throw to the caller (request handler) + await this._setTaskStatus(taskMetadata.taskId, 'working'); } } }, @@ -557,7 +563,7 @@ export abstract class Protocol { // If this request asked for task creation, create the task and send notification if (taskMetadata && this._taskStore) { - const task = await this._taskStore!.getTask(taskMetadata.taskId); + const task = await this._taskStore.getTask(taskMetadata.taskId); if (task) { throw new McpError(ErrorCode.InvalidParams, `Task ID already exists: ${taskMetadata.taskId}`); } @@ -566,6 +572,7 @@ export abstract class Protocol { // If this request had a task, mark it as working - if (taskMetadata && this._taskStore) { + if (taskMetadata) { try { - await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'working'); + await this._setTaskStatus(taskMetadata.taskId, 'working'); } catch { try { - await this._taskStore.updateTaskStatus(taskMetadata.taskId, 'failed', 'Failed to mark task as working'); + await this._setTaskStatus(taskMetadata.taskId, 'failed', 'Failed to mark task as working'); } catch (error) { throw new McpError(ErrorCode.InternalError, `Failed to mark task as working: ${error}`); } @@ -601,6 +608,8 @@ export abstract class Protocol { if (abortController.signal.aborted) { + // Request was cancelled + await this._postcancel(request.id); return; } @@ -620,8 +629,10 @@ export abstract class Protocol { + async error => { if (abortController.signal.aborted) { + // Request was cancelled + await this._postcancel(request.id); return; } @@ -749,7 +760,7 @@ export abstract class Protocol>((resolve, reject) => { @@ -895,12 +906,44 @@ export abstract class Protocol { + private _waitForTaskCreation(taskId: string): Promise { return new Promise((resolve, reject) => { this._pendingTaskCreations.set(taskId, { resolve, reject }); }); } + private async _setTaskStatus( + taskId: string, + status: Status, + errorReason?: ErrorReason + ) { + if (!this._taskStore) { + // No task store configured + return; + } + + try { + // Check the current task status to avoid overwriting terminal states + // as a safeguard for when the TaskStore implementation doesn't try + // to avoid this. + const task = await this._taskStore.getTask(taskId); + if (!task) { + return; + } + + if (isTerminal(task.status)) { + this._onerror( + new Error(`Failed to update status of task "${taskId}" from terminal status "${task.status}" to "${status}"`) + ); + return; + } + + await this._taskStore.updateTaskStatus(taskId, status, errorReason); + } catch (error) { + throw new Error(`Failed to update status of task "${taskId}" to "${status}": ${error}`); + } + } + /** * Gets the current status of a task. */ From 486e8edba532637f2ac50eb68ddcca2b4850ca1d Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 3 Nov 2025 15:02:52 -0800 Subject: [PATCH 17/18] Store task result before attempting to respond to client --- src/shared/protocol.ts | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index ed407be56..bf28afde7 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -613,14 +613,10 @@ export abstract class Protocol { if (abortController.signal.aborted) { From 06db60370c42d4ccc52c7531ec54c68ea7180d28 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Mon, 3 Nov 2025 15:54:44 -0800 Subject: [PATCH 18/18] Allow task polling before creation notification arrives --- src/shared/request.ts | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/shared/request.ts b/src/shared/request.ts index d6d48467f..ddfd1c1c6 100644 --- a/src/shared/request.ts +++ b/src/shared/request.ts @@ -36,13 +36,21 @@ export class PendingRequest { - // Blocks for a notifications/tasks/created with the provided task ID - await this.taskCreatedHandle; - await onTaskCreated(); - return await this.taskHandler(this.taskId!, { + // Start task handler immediately without waiting for creation notification + const taskPromise = this.taskHandler(this.taskId!, { onTaskCreated, onTaskStatus }); + + // Call onTaskCreated callback when notification arrives, but don't block taskHandler + // The promise is tied to the lifecycle of taskPromise, so it won't leak + this.taskCreatedHandle + .then(() => onTaskCreated()) + .catch(() => { + // Silently ignore if notification never arrives or fails + }); + + return await taskPromise; })(), this.resultHandle ]).then(([task, result]) => {