Skip to content

Commit b7a18c5

Browse files
committed
feat: implement MCP elicitation support (#413)
* Add ElicitationRequest, ElicitationResult, and related types to mcp/types.go * Implement server-side RequestElicitation method with session support * Add client-side ElicitationHandler interface and request handling * Implement elicitation in stdio and in-process transports * Add comprehensive tests following sampling patterns * Create elicitation example demonstrating usage patterns * Use 'Elicitation' prefix for type names to maintain clarity
1 parent 35ebaa5 commit b7a18c5

File tree

13 files changed

+1291
-33
lines changed

13 files changed

+1291
-33
lines changed

client/client.go

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type Client struct {
2525
serverCapabilities mcp.ServerCapabilities
2626
protocolVersion string
2727
samplingHandler SamplingHandler
28+
elicitationHandler ElicitationHandler
2829
}
2930

3031
type ClientOption func(*Client)
@@ -44,6 +45,14 @@ func WithSamplingHandler(handler SamplingHandler) ClientOption {
4445
}
4546
}
4647

48+
// WithElicitationHandler sets the elicitation handler for the client.
49+
// When set, the client will declare elicitation capability during initialization.
50+
func WithElicitationHandler(handler ElicitationHandler) ClientOption {
51+
return func(c *Client) {
52+
c.elicitationHandler = handler
53+
}
54+
}
55+
4756
// WithSession assumes a MCP Session has already been initialized
4857
func WithSession() ClientOption {
4958
return func(c *Client) {
@@ -167,6 +176,10 @@ func (c *Client) Initialize(
167176
if c.samplingHandler != nil {
168177
capabilities.Sampling = &struct{}{}
169178
}
179+
// Add elicitation capability if handler is configured
180+
if c.elicitationHandler != nil {
181+
capabilities.Elicitation = &struct{}{}
182+
}
170183

171184
// Ensure we send a params object with all required fields
172185
params := struct {
@@ -451,11 +464,13 @@ func (c *Client) Complete(
451464
}
452465

453466
// handleIncomingRequest processes incoming requests from the server.
454-
// This is the main entry point for server-to-client requests like sampling.
467+
// This is the main entry point for server-to-client requests like sampling and elicitation.
455468
func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
456469
switch request.Method {
457470
case string(mcp.MethodSamplingCreateMessage):
458471
return c.handleSamplingRequestTransport(ctx, request)
472+
case string(mcp.MethodElicitationCreate):
473+
return c.handleElicitationRequestTransport(ctx, request)
459474
default:
460475
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
461476
}
@@ -508,6 +523,55 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra
508523

509524
return response, nil
510525
}
526+
527+
// handleElicitationRequestTransport handles elicitation requests at the transport level.
528+
func (c *Client) handleElicitationRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
529+
if c.elicitationHandler == nil {
530+
return nil, fmt.Errorf("no elicitation handler configured")
531+
}
532+
533+
// Parse the request parameters
534+
var params mcp.ElicitationParams
535+
if request.Params != nil {
536+
paramsBytes, err := json.Marshal(request.Params)
537+
if err != nil {
538+
return nil, fmt.Errorf("failed to marshal params: %w", err)
539+
}
540+
if err := json.Unmarshal(paramsBytes, &params); err != nil {
541+
return nil, fmt.Errorf("failed to unmarshal params: %w", err)
542+
}
543+
}
544+
545+
// Create the MCP request
546+
mcpRequest := mcp.ElicitationRequest{
547+
Request: mcp.Request{
548+
Method: string(mcp.MethodElicitationCreate),
549+
},
550+
Params: params,
551+
}
552+
553+
// Call the elicitation handler
554+
result, err := c.elicitationHandler.Elicit(ctx, mcpRequest)
555+
if err != nil {
556+
return nil, err
557+
}
558+
559+
// Marshal the result
560+
resultBytes, err := json.Marshal(result)
561+
if err != nil {
562+
return nil, fmt.Errorf("failed to marshal result: %w", err)
563+
}
564+
565+
// Create the transport response
566+
response := &transport.JSONRPCResponse{
567+
JSONRPC: mcp.JSONRPC_VERSION,
568+
ID: request.ID,
569+
Result: json.RawMessage(resultBytes),
570+
}
571+
572+
return response, nil
573+
}
574+
511575
func listByPage[T any](
512576
ctx context.Context,
513577
client *Client,

client/elicitation.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package client
2+
3+
import (
4+
"context"
5+
6+
"github.com/mark3labs/mcp-go/mcp"
7+
)
8+
9+
// ElicitationHandler defines the interface for handling elicitation requests from servers.
10+
// Clients can implement this interface to request additional information from users.
11+
type ElicitationHandler interface {
12+
// Elicit handles an elicitation request from the server and returns the user's response.
13+
// The implementation should:
14+
// 1. Present the request message to the user
15+
// 2. Validate input against the requested schema
16+
// 3. Allow the user to accept, decline, or cancel
17+
// 4. Return the appropriate response
18+
Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error)
19+
}

client/elicitation_test.go

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"testing"
8+
9+
"github.com/mark3labs/mcp-go/client/transport"
10+
"github.com/mark3labs/mcp-go/mcp"
11+
)
12+
13+
// mockElicitationHandler implements ElicitationHandler for testing
14+
type mockElicitationHandler struct {
15+
result *mcp.ElicitationResult
16+
err error
17+
}
18+
19+
func (m *mockElicitationHandler) Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) {
20+
if m.err != nil {
21+
return nil, m.err
22+
}
23+
return m.result, nil
24+
}
25+
26+
func TestClient_HandleElicitationRequest(t *testing.T) {
27+
tests := []struct {
28+
name string
29+
handler ElicitationHandler
30+
expectedError string
31+
}{
32+
{
33+
name: "no handler configured",
34+
handler: nil,
35+
expectedError: "no elicitation handler configured",
36+
},
37+
{
38+
name: "successful elicitation - accept",
39+
handler: &mockElicitationHandler{
40+
result: &mcp.ElicitationResult{
41+
Response: mcp.ElicitationResponse{
42+
Type: mcp.ElicitationResponseTypeAccept,
43+
Value: map[string]interface{}{
44+
"name": "test-project",
45+
"framework": "react",
46+
},
47+
},
48+
},
49+
},
50+
},
51+
{
52+
name: "successful elicitation - decline",
53+
handler: &mockElicitationHandler{
54+
result: &mcp.ElicitationResult{
55+
Response: mcp.ElicitationResponse{
56+
Type: mcp.ElicitationResponseTypeDecline,
57+
},
58+
},
59+
},
60+
},
61+
{
62+
name: "successful elicitation - cancel",
63+
handler: &mockElicitationHandler{
64+
result: &mcp.ElicitationResult{
65+
Response: mcp.ElicitationResponse{
66+
Type: mcp.ElicitationResponseTypeCancel,
67+
},
68+
},
69+
},
70+
},
71+
{
72+
name: "handler returns error",
73+
handler: &mockElicitationHandler{
74+
err: fmt.Errorf("user interaction failed"),
75+
},
76+
expectedError: "user interaction failed",
77+
},
78+
}
79+
80+
for _, tt := range tests {
81+
t.Run(tt.name, func(t *testing.T) {
82+
client := &Client{elicitationHandler: tt.handler}
83+
84+
request := transport.JSONRPCRequest{
85+
ID: mcp.NewRequestId(1),
86+
Method: string(mcp.MethodElicitationCreate),
87+
Params: map[string]interface{}{
88+
"message": "Please provide project details",
89+
"requestedSchema": map[string]interface{}{
90+
"type": "object",
91+
"properties": map[string]interface{}{
92+
"name": map[string]interface{}{"type": "string"},
93+
"framework": map[string]interface{}{"type": "string"},
94+
},
95+
},
96+
},
97+
}
98+
99+
result, err := client.handleElicitationRequestTransport(context.Background(), request)
100+
101+
if tt.expectedError != "" {
102+
if err == nil {
103+
t.Errorf("expected error %q, got nil", tt.expectedError)
104+
} else if err.Error() != tt.expectedError {
105+
t.Errorf("expected error %q, got %q", tt.expectedError, err.Error())
106+
}
107+
} else {
108+
if err != nil {
109+
t.Errorf("unexpected error: %v", err)
110+
}
111+
if result == nil {
112+
t.Error("expected result, got nil")
113+
} else {
114+
// Verify the response is properly formatted
115+
var elicitationResult mcp.ElicitationResult
116+
if err := json.Unmarshal(result.Result, &elicitationResult); err != nil {
117+
t.Errorf("failed to unmarshal result: %v", err)
118+
}
119+
}
120+
}
121+
})
122+
}
123+
}
124+
125+
func TestWithElicitationHandler(t *testing.T) {
126+
handler := &mockElicitationHandler{}
127+
client := &Client{}
128+
129+
option := WithElicitationHandler(handler)
130+
option(client)
131+
132+
if client.elicitationHandler != handler {
133+
t.Error("elicitation handler not set correctly")
134+
}
135+
}
136+
137+
func TestClient_Initialize_WithElicitationHandler(t *testing.T) {
138+
mockTransport := &mockElicitationTransport{
139+
sendRequestFunc: func(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
140+
// Verify that elicitation capability is included
141+
// The client internally converts the typed params to a map for transport
142+
// So we check if we're getting the initialize request
143+
if request.Method != "initialize" {
144+
t.Fatalf("expected initialize method, got %s", request.Method)
145+
}
146+
147+
// Return successful initialization response
148+
result := mcp.InitializeResult{
149+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
150+
ServerInfo: mcp.Implementation{
151+
Name: "test-server",
152+
Version: "1.0.0",
153+
},
154+
Capabilities: mcp.ServerCapabilities{},
155+
}
156+
157+
resultBytes, _ := json.Marshal(result)
158+
return &transport.JSONRPCResponse{
159+
ID: request.ID,
160+
Result: json.RawMessage(resultBytes),
161+
}, nil
162+
},
163+
sendNotificationFunc: func(ctx context.Context, notification mcp.JSONRPCNotification) error {
164+
return nil
165+
},
166+
}
167+
168+
handler := &mockElicitationHandler{}
169+
client := NewClient(mockTransport, WithElicitationHandler(handler))
170+
171+
err := client.Start(context.Background())
172+
if err != nil {
173+
t.Fatalf("failed to start client: %v", err)
174+
}
175+
176+
_, err = client.Initialize(context.Background(), mcp.InitializeRequest{
177+
Params: mcp.InitializeParams{
178+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
179+
ClientInfo: mcp.Implementation{
180+
Name: "test-client",
181+
Version: "1.0.0",
182+
},
183+
Capabilities: mcp.ClientCapabilities{},
184+
},
185+
})
186+
187+
if err != nil {
188+
t.Fatalf("failed to initialize: %v", err)
189+
}
190+
}
191+
192+
// mockElicitationTransport implements transport.Interface for testing
193+
type mockElicitationTransport struct {
194+
sendRequestFunc func(context.Context, transport.JSONRPCRequest) (*transport.JSONRPCResponse, error)
195+
sendNotificationFunc func(context.Context, mcp.JSONRPCNotification) error
196+
}
197+
198+
func (m *mockElicitationTransport) Start(ctx context.Context) error {
199+
return nil
200+
}
201+
202+
func (m *mockElicitationTransport) Close() error {
203+
return nil
204+
}
205+
206+
func (m *mockElicitationTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
207+
if m.sendRequestFunc != nil {
208+
return m.sendRequestFunc(ctx, request)
209+
}
210+
return nil, nil
211+
}
212+
213+
func (m *mockElicitationTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
214+
if m.sendNotificationFunc != nil {
215+
return m.sendNotificationFunc(ctx, notification)
216+
}
217+
return nil
218+
}
219+
220+
func (m *mockElicitationTransport) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) {
221+
}
222+
223+
func (m *mockElicitationTransport) GetSessionId() string {
224+
return "mock-session"
225+
}

0 commit comments

Comments
 (0)