Skip to content

Commit ce420f8

Browse files
DRAFT: SEP-1034: Default values for Elicitation Schemas (#1096)
1 parent 2da89db commit ce420f8

File tree

4 files changed

+212
-22
lines changed

4 files changed

+212
-22
lines changed

src/client/index.ts

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,41 @@ import {
3636
SUPPORTED_PROTOCOL_VERSIONS,
3737
type SubscribeRequest,
3838
type Tool,
39-
type UnsubscribeRequest
39+
type UnsubscribeRequest,
40+
ElicitResultSchema,
41+
ElicitRequestSchema
4042
} from '../types.js';
4143
import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js';
4244
import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js';
45+
import { ZodLiteral, ZodObject, z } from 'zod';
46+
import type { RequestHandlerExtra } from '../shared/protocol.js';
47+
48+
/**
49+
* Elicitation default application helper. Applies defaults to the data based on the schema.
50+
*
51+
* @param schema - The schema to apply defaults to.
52+
* @param data - The data to apply defaults to.
53+
*/
54+
function applyElicitationDefaults(schema: JsonSchemaType | undefined, data: unknown): void {
55+
if (!schema || data === null || typeof data !== 'object') return;
56+
57+
// Handle object properties
58+
if (schema.type === 'object' && schema.properties && typeof schema.properties === 'object') {
59+
const obj = data as Record<string, unknown>;
60+
const props = schema.properties as Record<string, JsonSchemaType & { default?: unknown }>;
61+
for (const key of Object.keys(props)) {
62+
const propSchema = props[key];
63+
// If missing or explicitly undefined, apply default if present
64+
if (obj[key] === undefined && Object.prototype.hasOwnProperty.call(propSchema, 'default')) {
65+
obj[key] = propSchema.default;
66+
}
67+
// Recurse into existing nested objects/arrays
68+
if (obj[key] !== undefined) {
69+
applyElicitationDefaults(propSchema, obj[key]);
70+
}
71+
}
72+
}
73+
}
4374

4475
export type ClientOptions = ProtocolOptions & {
4576
/**
@@ -141,6 +172,64 @@ export class Client<
141172
this._capabilities = mergeCapabilities(this._capabilities, capabilities);
142173
}
143174

175+
/**
176+
* Override request handler registration to enforce client-side validation for elicitation.
177+
*/
178+
public override setRequestHandler<
179+
T extends ZodObject<{
180+
method: ZodLiteral<string>;
181+
}>
182+
>(
183+
requestSchema: T,
184+
handler: (
185+
request: z.infer<T>,
186+
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
187+
) => ClientResult | ResultT | Promise<ClientResult | ResultT>
188+
): void {
189+
const method = requestSchema.shape.method.value;
190+
if (method === 'elicitation/create') {
191+
const wrappedHandler = async (
192+
request: z.infer<T>,
193+
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
194+
): Promise<ClientResult | ResultT> => {
195+
const validatedRequest = ElicitRequestSchema.safeParse(request);
196+
if (!validatedRequest.success) {
197+
throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation request: ${validatedRequest.error.message}`);
198+
}
199+
200+
const result = await Promise.resolve(handler(request, extra));
201+
202+
const validationResult = ElicitResultSchema.safeParse(result);
203+
if (!validationResult.success) {
204+
throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation result: ${validationResult.error.message}`);
205+
}
206+
207+
const validatedResult = validationResult.data;
208+
209+
if (
210+
this._capabilities.elicitation?.applyDefaults &&
211+
validatedResult.action === 'accept' &&
212+
validatedResult.content &&
213+
validatedRequest.data.params.requestedSchema
214+
) {
215+
try {
216+
applyElicitationDefaults(validatedRequest.data.params.requestedSchema, validatedResult.content);
217+
} catch {
218+
// gracefully ignore errors in default application
219+
}
220+
}
221+
222+
return validatedResult;
223+
};
224+
225+
// Install the wrapped handler
226+
return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler);
227+
}
228+
229+
// Non-elicitation handlers use default behavior
230+
return super.setRequestHandler(requestSchema, handler);
231+
}
232+
144233
protected assertCapability(capability: keyof ServerCapabilities, method: string): void {
145234
if (!this._serverCapabilities?.[capability]) {
146235
throw new Error(`Server does not support ${capability} (required for ${method})`);

src/server/elicitation.test.ts

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,72 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
609609
});
610610
});
611611

612+
test(`${validatorName}: should default missing fields from schema defaults`, async () => {
613+
const server = new Server(
614+
{ name: 'test-server', version: '1.0.0' },
615+
{
616+
capabilities: {},
617+
jsonSchemaValidator: validatorProvider
618+
}
619+
);
620+
621+
const client = new Client(
622+
{ name: 'test-client', version: '1.0.0' },
623+
{
624+
capabilities: {
625+
elicitation: {
626+
applyDefaults: true
627+
}
628+
}
629+
}
630+
);
631+
632+
// Client returns no values; SDK should apply defaults automatically (and validate)
633+
client.setRequestHandler(ElicitRequestSchema, request => {
634+
expect(request.params.requestedSchema).toEqual({
635+
type: 'object',
636+
properties: {
637+
subscribe: { type: 'boolean', default: true },
638+
nickname: { type: 'string', default: 'Guest' },
639+
age: { type: 'integer', minimum: 0, maximum: 150, default: 18 },
640+
color: { type: 'string', enum: ['red', 'green'], default: 'green' }
641+
},
642+
required: ['subscribe', 'nickname', 'age', 'color']
643+
});
644+
return {
645+
action: 'accept',
646+
content: {}
647+
};
648+
});
649+
650+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
651+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
652+
653+
const result = await server.elicitInput({
654+
message: 'Provide your preferences',
655+
requestedSchema: {
656+
type: 'object',
657+
properties: {
658+
subscribe: { type: 'boolean', default: true },
659+
nickname: { type: 'string', default: 'Guest' },
660+
age: { type: 'integer', minimum: 0, maximum: 150, default: 18 },
661+
color: { type: 'string', enum: ['red', 'green'], default: 'green' }
662+
},
663+
required: ['subscribe', 'nickname', 'age', 'color']
664+
}
665+
});
666+
667+
expect(result).toEqual({
668+
action: 'accept',
669+
content: {
670+
subscribe: true,
671+
nickname: 'Guest',
672+
age: 18,
673+
color: 'green'
674+
}
675+
});
676+
});
677+
612678
test(`${validatorName}: should reject invalid email format`, async () => {
613679
const server = new Server(
614680
{ name: 'test-server', version: '1.0.0' },

src/spec.types.test.ts

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@ type MakeUnknownsNotOptional<T> =
6262
}
6363
: T;
6464

65+
// Targeted fix: in spec, treat ClientCapabilities.elicitation?: object as Record<string, unknown>
66+
type FixSpecClientCapabilities<T> = T extends { elicitation?: object }
67+
? Omit<T, 'elicitation'> & { elicitation?: Record<string, unknown> }
68+
: T;
69+
70+
type FixSpecInitializeRequestParams<T> = T extends { capabilities: infer C }
71+
? Omit<T, 'capabilities'> & { capabilities: FixSpecClientCapabilities<C> }
72+
: T;
73+
74+
type FixSpecInitializeRequest<T> = T extends { params: infer P } ? Omit<T, 'params'> & { params: FixSpecInitializeRequestParams<P> } : T;
75+
76+
type FixSpecClientRequest<T> = T extends { params: infer P } ? Omit<T, 'params'> & { params: FixSpecInitializeRequestParams<P> } : T;
77+
6578
const sdkTypeChecks = {
6679
RequestParams: (sdk: SDKTypes.RequestParams, spec: SpecTypes.RequestParams) => {
6780
sdk = spec;
@@ -75,7 +88,10 @@ const sdkTypeChecks = {
7588
sdk = spec;
7689
spec = sdk;
7790
},
78-
InitializeRequestParams: (sdk: SDKTypes.InitializeRequestParams, spec: SpecTypes.InitializeRequestParams) => {
91+
InitializeRequestParams: (
92+
sdk: SDKTypes.InitializeRequestParams,
93+
spec: FixSpecInitializeRequestParams<SpecTypes.InitializeRequestParams>
94+
) => {
7995
sdk = spec;
8096
spec = sdk;
8197
},
@@ -480,23 +496,29 @@ const sdkTypeChecks = {
480496
sdk = spec;
481497
spec = sdk;
482498
},
483-
InitializeRequest: (sdk: WithJSONRPCRequest<SDKTypes.InitializeRequest>, spec: SpecTypes.InitializeRequest) => {
499+
InitializeRequest: (
500+
sdk: WithJSONRPCRequest<SDKTypes.InitializeRequest>,
501+
spec: FixSpecInitializeRequest<SpecTypes.InitializeRequest>
502+
) => {
484503
sdk = spec;
485504
spec = sdk;
486505
},
487506
InitializeResult: (sdk: SDKTypes.InitializeResult, spec: SpecTypes.InitializeResult) => {
488507
sdk = spec;
489508
spec = sdk;
490509
},
491-
ClientCapabilities: (sdk: SDKTypes.ClientCapabilities, spec: SpecTypes.ClientCapabilities) => {
510+
ClientCapabilities: (sdk: SDKTypes.ClientCapabilities, spec: FixSpecClientCapabilities<SpecTypes.ClientCapabilities>) => {
492511
sdk = spec;
493512
spec = sdk;
494513
},
495514
ServerCapabilities: (sdk: SDKTypes.ServerCapabilities, spec: SpecTypes.ServerCapabilities) => {
496515
sdk = spec;
497516
spec = sdk;
498517
},
499-
ClientRequest: (sdk: RemovePassthrough<WithJSONRPCRequest<SDKTypes.ClientRequest>>, spec: SpecTypes.ClientRequest) => {
518+
ClientRequest: (
519+
sdk: RemovePassthrough<WithJSONRPCRequest<SDKTypes.ClientRequest>>,
520+
spec: FixSpecClientRequest<SpecTypes.ClientRequest>
521+
) => {
500522
sdk = spec;
501523
spec = sdk;
502524
},

src/types.ts

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,17 @@ export const ClientCapabilitiesSchema = z.object({
286286
/**
287287
* Present if the client supports eliciting user input.
288288
*/
289-
elicitation: AssertObjectSchema.optional(),
289+
elicitation: z.intersection(
290+
z
291+
.object({
292+
/**
293+
* Whether the client should apply defaults to the user input.
294+
*/
295+
applyDefaults: z.boolean().optional()
296+
})
297+
.optional(),
298+
z.record(z.string(), z.unknown()).optional()
299+
),
290300
/**
291301
* Present if the client supports listing roots.
292302
*/
@@ -1198,49 +1208,52 @@ export const CreateMessageResultSchema = ResultSchema.extend({
11981208
*/
11991209
export const BooleanSchemaSchema = z.object({
12001210
type: z.literal('boolean'),
1201-
title: z.optional(z.string()),
1202-
description: z.optional(z.string()),
1203-
default: z.optional(z.boolean())
1211+
title: z.string().optional(),
1212+
description: z.string().optional(),
1213+
default: z.boolean().optional()
12041214
});
12051215

12061216
/**
12071217
* Primitive schema definition for string fields.
12081218
*/
12091219
export const StringSchemaSchema = z.object({
12101220
type: z.literal('string'),
1211-
title: z.optional(z.string()),
1212-
description: z.optional(z.string()),
1213-
minLength: z.optional(z.number()),
1214-
maxLength: z.optional(z.number()),
1215-
format: z.optional(z.enum(['email', 'uri', 'date', 'date-time']))
1221+
title: z.string().optional(),
1222+
description: z.string().optional(),
1223+
minLength: z.number().optional(),
1224+
maxLength: z.number().optional(),
1225+
format: z.enum(['email', 'uri', 'date', 'date-time']).optional(),
1226+
default: z.string().optional()
12161227
});
12171228

12181229
/**
12191230
* Primitive schema definition for number fields.
12201231
*/
12211232
export const NumberSchemaSchema = z.object({
12221233
type: z.enum(['number', 'integer']),
1223-
title: z.optional(z.string()),
1224-
description: z.optional(z.string()),
1225-
minimum: z.optional(z.number()),
1226-
maximum: z.optional(z.number())
1234+
title: z.string().optional(),
1235+
description: z.string().optional(),
1236+
minimum: z.number().optional(),
1237+
maximum: z.number().optional(),
1238+
default: z.number().optional()
12271239
});
12281240

12291241
/**
12301242
* Primitive schema definition for enum fields.
12311243
*/
12321244
export const EnumSchemaSchema = z.object({
12331245
type: z.literal('string'),
1234-
title: z.optional(z.string()),
1235-
description: z.optional(z.string()),
1246+
title: z.string().optional(),
1247+
description: z.string().optional(),
12361248
enum: z.array(z.string()),
1237-
enumNames: z.optional(z.array(z.string()))
1249+
enumNames: z.array(z.string()).optional(),
1250+
default: z.string().optional()
12381251
});
12391252

12401253
/**
12411254
* Union of all primitive schema definitions.
12421255
*/
1243-
export const PrimitiveSchemaDefinitionSchema = z.union([BooleanSchemaSchema, StringSchemaSchema, NumberSchemaSchema, EnumSchemaSchema]);
1256+
export const PrimitiveSchemaDefinitionSchema = z.union([EnumSchemaSchema, BooleanSchemaSchema, StringSchemaSchema, NumberSchemaSchema]);
12441257

12451258
/**
12461259
* Parameters for an `elicitation/create` request.

0 commit comments

Comments
 (0)