5
5
"encoding/json"
6
6
"errors"
7
7
"fmt"
8
+ "slices"
8
9
"sync"
9
10
"sync/atomic"
10
11
@@ -22,6 +23,8 @@ type Client struct {
22
23
requestID atomic.Int64
23
24
clientCapabilities mcp.ClientCapabilities
24
25
serverCapabilities mcp.ServerCapabilities
26
+ protocolVersion string
27
+ samplingHandler SamplingHandler
25
28
}
26
29
27
30
type ClientOption func (* Client )
@@ -33,6 +36,21 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
33
36
}
34
37
}
35
38
39
+ // WithSamplingHandler sets the sampling handler for the client.
40
+ // When set, the client will declare sampling capability during initialization.
41
+ func WithSamplingHandler (handler SamplingHandler ) ClientOption {
42
+ return func (c * Client ) {
43
+ c .samplingHandler = handler
44
+ }
45
+ }
46
+
47
+ // WithSession assumes a MCP Session has already been initialized
48
+ func WithSession () ClientOption {
49
+ return func (c * Client ) {
50
+ c .initialized = true
51
+ }
52
+ }
53
+
36
54
// NewClient creates a new MCP client with the given transport.
37
55
// Usage:
38
56
//
@@ -71,6 +89,12 @@ func (c *Client) Start(ctx context.Context) error {
71
89
handler (notification )
72
90
}
73
91
})
92
+
93
+ // Set up request handler for bidirectional communication (e.g., sampling)
94
+ if bidirectional , ok := c .transport .(transport.BidirectionalInterface ); ok {
95
+ bidirectional .SetRequestHandler (c .handleIncomingRequest )
96
+ }
97
+
74
98
return nil
75
99
}
76
100
@@ -111,7 +135,7 @@ func (c *Client) sendRequest(
111
135
112
136
response , err := c .transport .SendRequest (ctx , request )
113
137
if err != nil {
114
- return nil , fmt . Errorf ( "transport error: %w" , err )
138
+ return nil , transport . NewError ( err )
115
139
}
116
140
117
141
if response .Error != nil {
@@ -127,6 +151,12 @@ func (c *Client) Initialize(
127
151
ctx context.Context ,
128
152
request mcp.InitializeRequest ,
129
153
) (* mcp.InitializeResult , error ) {
154
+ // Merge client capabilities with sampling capability if handler is configured
155
+ capabilities := request .Params .Capabilities
156
+ if c .samplingHandler != nil {
157
+ capabilities .Sampling = & struct {}{}
158
+ }
159
+
130
160
// Ensure we send a params object with all required fields
131
161
params := struct {
132
162
ProtocolVersion string `json:"protocolVersion"`
@@ -135,7 +165,7 @@ func (c *Client) Initialize(
135
165
}{
136
166
ProtocolVersion : request .Params .ProtocolVersion ,
137
167
ClientInfo : request .Params .ClientInfo ,
138
- Capabilities : request . Params . Capabilities , // Will be empty struct if not set
168
+ Capabilities : capabilities ,
139
169
}
140
170
141
171
response , err := c .sendRequest (ctx , "initialize" , params )
@@ -148,8 +178,19 @@ func (c *Client) Initialize(
148
178
return nil , fmt .Errorf ("failed to unmarshal response: %w" , err )
149
179
}
150
180
151
- // Store serverCapabilities
181
+ // Validate protocol version
182
+ if ! slices .Contains (mcp .ValidProtocolVersions , result .ProtocolVersion ) {
183
+ return nil , mcp.UnsupportedProtocolVersionError {Version : result .ProtocolVersion }
184
+ }
185
+
186
+ // Store serverCapabilities and protocol version
152
187
c .serverCapabilities = result .Capabilities
188
+ c .protocolVersion = result .ProtocolVersion
189
+
190
+ // Set protocol version on HTTP transports
191
+ if httpConn , ok := c .transport .(transport.HTTPConnection ); ok {
192
+ httpConn .SetProtocolVersion (result .ProtocolVersion )
193
+ }
153
194
154
195
// Send initialized notification
155
196
notification := mcp.JSONRPCNotification {
@@ -398,6 +439,64 @@ func (c *Client) Complete(
398
439
return & result , nil
399
440
}
400
441
442
+ // handleIncomingRequest processes incoming requests from the server.
443
+ // This is the main entry point for server-to-client requests like sampling.
444
+ func (c * Client ) handleIncomingRequest (ctx context.Context , request transport.JSONRPCRequest ) (* transport.JSONRPCResponse , error ) {
445
+ switch request .Method {
446
+ case string (mcp .MethodSamplingCreateMessage ):
447
+ return c .handleSamplingRequestTransport (ctx , request )
448
+ default :
449
+ return nil , fmt .Errorf ("unsupported request method: %s" , request .Method )
450
+ }
451
+ }
452
+
453
+ // handleSamplingRequestTransport handles sampling requests at the transport level.
454
+ func (c * Client ) handleSamplingRequestTransport (ctx context.Context , request transport.JSONRPCRequest ) (* transport.JSONRPCResponse , error ) {
455
+ if c .samplingHandler == nil {
456
+ return nil , fmt .Errorf ("no sampling handler configured" )
457
+ }
458
+
459
+ // Parse the request parameters
460
+ var params mcp.CreateMessageParams
461
+ if request .Params != nil {
462
+ paramsBytes , err := json .Marshal (request .Params )
463
+ if err != nil {
464
+ return nil , fmt .Errorf ("failed to marshal params: %w" , err )
465
+ }
466
+ if err := json .Unmarshal (paramsBytes , & params ); err != nil {
467
+ return nil , fmt .Errorf ("failed to unmarshal params: %w" , err )
468
+ }
469
+ }
470
+
471
+ // Create the MCP request
472
+ mcpRequest := mcp.CreateMessageRequest {
473
+ Request : mcp.Request {
474
+ Method : string (mcp .MethodSamplingCreateMessage ),
475
+ },
476
+ CreateMessageParams : params ,
477
+ }
478
+
479
+ // Call the sampling handler
480
+ result , err := c .samplingHandler .CreateMessage (ctx , mcpRequest )
481
+ if err != nil {
482
+ return nil , err
483
+ }
484
+
485
+ // Marshal the result
486
+ resultBytes , err := json .Marshal (result )
487
+ if err != nil {
488
+ return nil , fmt .Errorf ("failed to marshal result: %w" , err )
489
+ }
490
+
491
+ // Create the transport response
492
+ response := & transport.JSONRPCResponse {
493
+ JSONRPC : mcp .JSONRPC_VERSION ,
494
+ ID : request .ID ,
495
+ Result : json .RawMessage (resultBytes ),
496
+ }
497
+
498
+ return response , nil
499
+ }
401
500
func listByPage [T any ](
402
501
ctx context.Context ,
403
502
client * Client ,
@@ -432,3 +531,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
432
531
func (c * Client ) GetClientCapabilities () mcp.ClientCapabilities {
433
532
return c .clientCapabilities
434
533
}
534
+
535
+ // GetSessionId returns the session ID of the transport.
536
+ // If the transport does not support sessions, it returns an empty string.
537
+ func (c * Client ) GetSessionId () string {
538
+ if c .transport == nil {
539
+ return ""
540
+ }
541
+ return c .transport .GetSessionId ()
542
+ }
543
+
544
+ // IsInitialized returns true if the client has been initialized.
545
+ func (c * Client ) IsInitialized () bool {
546
+ return c .initialized
547
+ }
0 commit comments