Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type Client struct {
requestID atomic.Int64
clientCapabilities mcp.ClientCapabilities
serverCapabilities mcp.ServerCapabilities
samplingHandler SamplingHandler
}

type ClientOption func(*Client)
Expand All @@ -33,6 +34,14 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
}
}

// WithSamplingHandler sets the sampling handler for the client.
// When set, the client will declare sampling capability during initialization.
func WithSamplingHandler(handler SamplingHandler) ClientOption {
return func(c *Client) {
c.samplingHandler = handler
}
}

// WithSession assumes a MCP Session has already been initialized
func WithSession() ClientOption {
return func(c *Client) {
Expand Down Expand Up @@ -78,6 +87,12 @@ func (c *Client) Start(ctx context.Context) error {
handler(notification)
}
})

// Set up request handler for bidirectional communication (e.g., sampling)
if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok {
bidirectional.SetRequestHandler(c.handleIncomingRequest)
}

return nil
}

Expand Down Expand Up @@ -134,6 +149,12 @@ func (c *Client) Initialize(
ctx context.Context,
request mcp.InitializeRequest,
) (*mcp.InitializeResult, error) {
// Merge client capabilities with sampling capability if handler is configured
capabilities := request.Params.Capabilities
if c.samplingHandler != nil {
capabilities.Sampling = &struct{}{}
}

// Ensure we send a params object with all required fields
params := struct {
ProtocolVersion string `json:"protocolVersion"`
Expand All @@ -142,7 +163,7 @@ func (c *Client) Initialize(
}{
ProtocolVersion: request.Params.ProtocolVersion,
ClientInfo: request.Params.ClientInfo,
Capabilities: request.Params.Capabilities, // Will be empty struct if not set
Capabilities: capabilities,
}

response, err := c.sendRequest(ctx, "initialize", params)
Expand Down Expand Up @@ -405,6 +426,64 @@ func (c *Client) Complete(
return &result, nil
}

// handleIncomingRequest processes incoming requests from the server.
// This is the main entry point for server-to-client requests like sampling.
func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
switch request.Method {
case string(mcp.MethodSamplingCreateMessage):
return c.handleSamplingRequestTransport(ctx, request)
default:
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
}
}

// handleSamplingRequestTransport handles sampling requests at the transport level.
func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
if c.samplingHandler == nil {
return nil, fmt.Errorf("no sampling handler configured")
}

// Parse the request parameters
var params mcp.CreateMessageParams
if request.Params != nil {
paramsBytes, err := json.Marshal(request.Params)
if err != nil {
return nil, fmt.Errorf("failed to marshal params: %w", err)
}
if err := json.Unmarshal(paramsBytes, &params); err != nil {
return nil, fmt.Errorf("failed to unmarshal params: %w", err)
}
}

// Create the MCP request
mcpRequest := mcp.CreateMessageRequest{
Request: mcp.Request{
Method: string(mcp.MethodSamplingCreateMessage),
},
CreateMessageParams: params,
}

// Call the sampling handler
result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest)
if err != nil {
return nil, err
}

// Marshal the result
resultBytes, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal result: %w", err)
}

// Create the transport response
response := &transport.JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: request.ID,
Result: json.RawMessage(resultBytes),
}

return response, nil
}
func listByPage[T any](
ctx context.Context,
client *Client,
Expand Down
20 changes: 20 additions & 0 deletions client/sampling.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package client

import (
"context"

"github.com/mark3labs/mcp-go/mcp"
)

// SamplingHandler defines the interface for handling sampling requests from servers.
// Clients can implement this interface to provide LLM sampling capabilities to servers.
type SamplingHandler interface {
// CreateMessage handles a sampling request from the server and returns the generated message.
// The implementation should:
// 1. Validate the request parameters
// 2. Optionally prompt the user for approval (human-in-the-loop)
// 3. Select an appropriate model based on preferences
// 4. Generate the response using the selected model
// 5. Return the result with model information and stop reason
CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
}
Loading