Skip to content

Commit decdee8

Browse files
authored
Merge branch 'main' into pottekkat/structured-schema
2 parents 085c67d + a43b104 commit decdee8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+6473
-394
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ package main
2424

2525
import (
2626
"context"
27-
"errors"
2827
"fmt"
2928

3029
"github.com/mark3labs/mcp-go/mcp"

client/client.go

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"slices"
89
"sync"
910
"sync/atomic"
1011

@@ -22,6 +23,8 @@ type Client struct {
2223
requestID atomic.Int64
2324
clientCapabilities mcp.ClientCapabilities
2425
serverCapabilities mcp.ServerCapabilities
26+
protocolVersion string
27+
samplingHandler SamplingHandler
2528
}
2629

2730
type ClientOption func(*Client)
@@ -33,6 +36,21 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
3336
}
3437
}
3538

39+
// WithSamplingHandler sets the sampling handler for the client.
40+
// When set, the client will declare sampling capability during initialization.
41+
func WithSamplingHandler(handler SamplingHandler) ClientOption {
42+
return func(c *Client) {
43+
c.samplingHandler = handler
44+
}
45+
}
46+
47+
// WithSession assumes a MCP Session has already been initialized
48+
func WithSession() ClientOption {
49+
return func(c *Client) {
50+
c.initialized = true
51+
}
52+
}
53+
3654
// NewClient creates a new MCP client with the given transport.
3755
// Usage:
3856
//
@@ -71,6 +89,12 @@ func (c *Client) Start(ctx context.Context) error {
7189
handler(notification)
7290
}
7391
})
92+
93+
// Set up request handler for bidirectional communication (e.g., sampling)
94+
if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok {
95+
bidirectional.SetRequestHandler(c.handleIncomingRequest)
96+
}
97+
7498
return nil
7599
}
76100

@@ -111,7 +135,7 @@ func (c *Client) sendRequest(
111135

112136
response, err := c.transport.SendRequest(ctx, request)
113137
if err != nil {
114-
return nil, fmt.Errorf("transport error: %w", err)
138+
return nil, transport.NewError(err)
115139
}
116140

117141
if response.Error != nil {
@@ -127,6 +151,12 @@ func (c *Client) Initialize(
127151
ctx context.Context,
128152
request mcp.InitializeRequest,
129153
) (*mcp.InitializeResult, error) {
154+
// Merge client capabilities with sampling capability if handler is configured
155+
capabilities := request.Params.Capabilities
156+
if c.samplingHandler != nil {
157+
capabilities.Sampling = &struct{}{}
158+
}
159+
130160
// Ensure we send a params object with all required fields
131161
params := struct {
132162
ProtocolVersion string `json:"protocolVersion"`
@@ -135,7 +165,7 @@ func (c *Client) Initialize(
135165
}{
136166
ProtocolVersion: request.Params.ProtocolVersion,
137167
ClientInfo: request.Params.ClientInfo,
138-
Capabilities: request.Params.Capabilities, // Will be empty struct if not set
168+
Capabilities: capabilities,
139169
}
140170

141171
response, err := c.sendRequest(ctx, "initialize", params)
@@ -148,8 +178,19 @@ func (c *Client) Initialize(
148178
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
149179
}
150180

151-
// Store serverCapabilities
181+
// Validate protocol version
182+
if !slices.Contains(mcp.ValidProtocolVersions, result.ProtocolVersion) {
183+
return nil, mcp.UnsupportedProtocolVersionError{Version: result.ProtocolVersion}
184+
}
185+
186+
// Store serverCapabilities and protocol version
152187
c.serverCapabilities = result.Capabilities
188+
c.protocolVersion = result.ProtocolVersion
189+
190+
// Set protocol version on HTTP transports
191+
if httpConn, ok := c.transport.(transport.HTTPConnection); ok {
192+
httpConn.SetProtocolVersion(result.ProtocolVersion)
193+
}
153194

154195
// Send initialized notification
155196
notification := mcp.JSONRPCNotification{
@@ -398,6 +439,64 @@ func (c *Client) Complete(
398439
return &result, nil
399440
}
400441

442+
// handleIncomingRequest processes incoming requests from the server.
443+
// This is the main entry point for server-to-client requests like sampling.
444+
func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
445+
switch request.Method {
446+
case string(mcp.MethodSamplingCreateMessage):
447+
return c.handleSamplingRequestTransport(ctx, request)
448+
default:
449+
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
450+
}
451+
}
452+
453+
// handleSamplingRequestTransport handles sampling requests at the transport level.
454+
func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
455+
if c.samplingHandler == nil {
456+
return nil, fmt.Errorf("no sampling handler configured")
457+
}
458+
459+
// Parse the request parameters
460+
var params mcp.CreateMessageParams
461+
if request.Params != nil {
462+
paramsBytes, err := json.Marshal(request.Params)
463+
if err != nil {
464+
return nil, fmt.Errorf("failed to marshal params: %w", err)
465+
}
466+
if err := json.Unmarshal(paramsBytes, &params); err != nil {
467+
return nil, fmt.Errorf("failed to unmarshal params: %w", err)
468+
}
469+
}
470+
471+
// Create the MCP request
472+
mcpRequest := mcp.CreateMessageRequest{
473+
Request: mcp.Request{
474+
Method: string(mcp.MethodSamplingCreateMessage),
475+
},
476+
CreateMessageParams: params,
477+
}
478+
479+
// Call the sampling handler
480+
result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest)
481+
if err != nil {
482+
return nil, err
483+
}
484+
485+
// Marshal the result
486+
resultBytes, err := json.Marshal(result)
487+
if err != nil {
488+
return nil, fmt.Errorf("failed to marshal result: %w", err)
489+
}
490+
491+
// Create the transport response
492+
response := &transport.JSONRPCResponse{
493+
JSONRPC: mcp.JSONRPC_VERSION,
494+
ID: request.ID,
495+
Result: json.RawMessage(resultBytes),
496+
}
497+
498+
return response, nil
499+
}
401500
func listByPage[T any](
402501
ctx context.Context,
403502
client *Client,
@@ -432,3 +531,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
432531
func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
433532
return c.clientCapabilities
434533
}
534+
535+
// GetSessionId returns the session ID of the transport.
536+
// If the transport does not support sessions, it returns an empty string.
537+
func (c *Client) GetSessionId() string {
538+
if c.transport == nil {
539+
return ""
540+
}
541+
return c.transport.GetSessionId()
542+
}
543+
544+
// IsInitialized returns true if the client has been initialized.
545+
func (c *Client) IsInitialized() bool {
546+
return c.initialized
547+
}

client/http.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP
1313
if err != nil {
1414
return nil, fmt.Errorf("failed to create SSE transport: %w", err)
1515
}
16-
return NewClient(trans), nil
16+
clientOptions := make([]ClientOption, 0)
17+
sessionID := trans.GetSessionId()
18+
if sessionID != "" {
19+
clientOptions = append(clientOptions, WithSession())
20+
}
21+
return NewClient(trans, clientOptions...), nil
1722
}

0 commit comments

Comments
 (0)