Skip to content

Commit 557a80c

Browse files
ezynda3opencode
andcommitted
feat: implement protocol version negotiation
Implement protocol version negotiation following the TypeScript SDK approach: - Update LATEST_PROTOCOL_VERSION to 2025-06-18 - Add client-side validation of server protocol version - Return UnsupportedProtocolVersionError for incompatible versions - Add Mcp-Protocol-Version header support for HTTP transports - Implement SetProtocolVersion method on HTTP connections - Add comprehensive tests for protocol negotiation This ensures both client and server agree on a mutually supported protocol version, preventing compatibility issues. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode <[email protected]>
1 parent 7c38b56 commit 557a80c

File tree

11 files changed

+320
-68
lines changed

11 files changed

+320
-68
lines changed

client/client.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"slices"
89
"sync"
910
"sync/atomic"
1011

@@ -22,6 +23,7 @@ type Client struct {
2223
requestID atomic.Int64
2324
clientCapabilities mcp.ClientCapabilities
2425
serverCapabilities mcp.ServerCapabilities
26+
protocolVersion string
2527
samplingHandler SamplingHandler
2628
}
2729

@@ -176,8 +178,19 @@ func (c *Client) Initialize(
176178
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
177179
}
178180

179-
// Store serverCapabilities
181+
// Validate protocol version
182+
if !slices.Contains(mcp.ValidProtocolVersions, result.ProtocolVersion) {
183+
return nil, mcp.UnsupportedProtocolVersionError{Version: result.ProtocolVersion}
184+
}
185+
186+
// Store serverCapabilities and protocol version
180187
c.serverCapabilities = result.Capabilities
188+
c.protocolVersion = result.ProtocolVersion
189+
190+
// Set protocol version on HTTP transports
191+
if httpConn, ok := c.transport.(transport.HTTPConnection); ok {
192+
httpConn.SetProtocolVersion(result.ProtocolVersion)
193+
}
181194

182195
// Send initialized notification
183196
notification := mcp.JSONRPCNotification{

client/protocol_negotiation_test.go

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"strings"
8+
"testing"
9+
10+
"github.com/mark3labs/mcp-go/client/transport"
11+
"github.com/mark3labs/mcp-go/mcp"
12+
)
13+
14+
// mockProtocolTransport implements transport.Interface for testing protocol negotiation
15+
type mockProtocolTransport struct {
16+
responses map[string]string
17+
notificationHandler func(mcp.JSONRPCNotification)
18+
started bool
19+
closed bool
20+
}
21+
22+
func (m *mockProtocolTransport) Start(ctx context.Context) error {
23+
m.started = true
24+
return nil
25+
}
26+
27+
func (m *mockProtocolTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
28+
responseStr, ok := m.responses[request.Method]
29+
if !ok {
30+
return nil, fmt.Errorf("no mock response for method %s", request.Method)
31+
}
32+
33+
return &transport.JSONRPCResponse{
34+
JSONRPC: "2.0",
35+
ID: request.ID,
36+
Result: json.RawMessage(responseStr),
37+
}, nil
38+
}
39+
40+
func (m *mockProtocolTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
41+
return nil
42+
}
43+
44+
func (m *mockProtocolTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) {
45+
m.notificationHandler = handler
46+
}
47+
48+
func (m *mockProtocolTransport) Close() error {
49+
m.closed = true
50+
return nil
51+
}
52+
53+
func (m *mockProtocolTransport) GetSessionId() string {
54+
return "mock-session"
55+
}
56+
57+
func TestProtocolVersionNegotiation(t *testing.T) {
58+
tests := []struct {
59+
name string
60+
serverVersion string
61+
expectError bool
62+
errorContains string
63+
}{
64+
{
65+
name: "supported latest version",
66+
serverVersion: mcp.LATEST_PROTOCOL_VERSION,
67+
expectError: false,
68+
},
69+
{
70+
name: "supported older version 2025-03-26",
71+
serverVersion: "2025-03-26",
72+
expectError: false,
73+
},
74+
{
75+
name: "supported older version 2024-11-05",
76+
serverVersion: "2024-11-05",
77+
expectError: false,
78+
},
79+
{
80+
name: "unsupported version",
81+
serverVersion: "2023-01-01",
82+
expectError: true,
83+
errorContains: "unsupported protocol version",
84+
},
85+
{
86+
name: "unsupported future version",
87+
serverVersion: "2030-01-01",
88+
expectError: true,
89+
errorContains: "unsupported protocol version",
90+
},
91+
}
92+
93+
for _, tt := range tests {
94+
t.Run(tt.name, func(t *testing.T) {
95+
// Create mock transport that returns specific version
96+
mockTransport := &mockProtocolTransport{
97+
responses: map[string]string{
98+
"initialize": fmt.Sprintf(`{
99+
"protocolVersion": "%s",
100+
"capabilities": {},
101+
"serverInfo": {"name": "test", "version": "1.0"}
102+
}`, tt.serverVersion),
103+
},
104+
}
105+
106+
client := NewClient(mockTransport)
107+
108+
_, err := client.Initialize(context.Background(), mcp.InitializeRequest{
109+
Params: mcp.InitializeParams{
110+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
111+
ClientInfo: mcp.Implementation{Name: "test-client", Version: "1.0"},
112+
Capabilities: mcp.ClientCapabilities{},
113+
},
114+
})
115+
116+
if tt.expectError {
117+
if err == nil {
118+
t.Errorf("expected error but got none")
119+
} else if !strings.Contains(err.Error(), tt.errorContains) {
120+
t.Errorf("expected error containing %q, got %q", tt.errorContains, err.Error())
121+
}
122+
// Verify it's the correct error type
123+
if !mcp.IsUnsupportedProtocolVersion(err) {
124+
t.Errorf("expected UnsupportedProtocolVersionError, got %T", err)
125+
}
126+
} else {
127+
if err != nil {
128+
t.Errorf("unexpected error: %v", err)
129+
}
130+
// Verify the protocol version was stored
131+
if client.protocolVersion != tt.serverVersion {
132+
t.Errorf("expected protocol version %q, got %q", tt.serverVersion, client.protocolVersion)
133+
}
134+
}
135+
})
136+
}
137+
}
138+
139+
// mockHTTPTransport implements both transport.Interface and transport.HTTPConnection
140+
type mockHTTPTransport struct {
141+
mockProtocolTransport
142+
protocolVersion string
143+
}
144+
145+
func (m *mockHTTPTransport) SetProtocolVersion(version string) {
146+
m.protocolVersion = version
147+
}
148+
149+
func TestProtocolVersionHeaderSetting(t *testing.T) {
150+
// Create mock HTTP transport
151+
mockTransport := &mockHTTPTransport{
152+
mockProtocolTransport: mockProtocolTransport{
153+
responses: map[string]string{
154+
"initialize": fmt.Sprintf(`{
155+
"protocolVersion": "%s",
156+
"capabilities": {},
157+
"serverInfo": {"name": "test", "version": "1.0"}
158+
}`, mcp.LATEST_PROTOCOL_VERSION),
159+
},
160+
},
161+
}
162+
163+
client := NewClient(mockTransport)
164+
165+
_, err := client.Initialize(context.Background(), mcp.InitializeRequest{
166+
Params: mcp.InitializeParams{
167+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
168+
ClientInfo: mcp.Implementation{Name: "test-client", Version: "1.0"},
169+
Capabilities: mcp.ClientCapabilities{},
170+
},
171+
})
172+
173+
if err != nil {
174+
t.Fatalf("unexpected error: %v", err)
175+
}
176+
177+
// Verify SetProtocolVersion was called on HTTP transport
178+
if mockTransport.protocolVersion != mcp.LATEST_PROTOCOL_VERSION {
179+
t.Errorf("expected SetProtocolVersion to be called with %q, got %q",
180+
mcp.LATEST_PROTOCOL_VERSION, mockTransport.protocolVersion)
181+
}
182+
}

client/stdio_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func TestStdioMCPClient(t *testing.T) {
9393
defer cancel()
9494

9595
request := mcp.InitializeRequest{}
96-
request.Params.ProtocolVersion = "1.0"
96+
request.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
9797
request.Params.ClientInfo = mcp.Implementation{
9898
Name: "test-client",
9999
Version: "1.0.0",

client/transport/interface.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ type BidirectionalInterface interface {
4747
SetRequestHandler(handler RequestHandler)
4848
}
4949

50+
// HTTPConnection is a Transport that runs over HTTP and supports
51+
// protocol version headers.
52+
type HTTPConnection interface {
53+
Interface
54+
SetProtocolVersion(version string)
55+
}
56+
5057
type JSONRPCRequest struct {
5158
JSONRPC string `json:"jsonrpc"`
5259
ID mcp.RequestId `json:"id"`

client/transport/sse.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type SSE struct {
3737
started atomic.Bool
3838
closed atomic.Bool
3939
cancelSSEStream context.CancelFunc
40+
protocolVersion atomic.Value // string
4041

4142
// OAuth support
4243
oauthHandler *OAuthHandler
@@ -324,6 +325,12 @@ func (c *SSE) SendRequest(
324325

325326
// Set headers
326327
req.Header.Set("Content-Type", "application/json")
328+
// Set protocol version header if negotiated
329+
if v := c.protocolVersion.Load(); v != nil {
330+
if version, ok := v.(string); ok && version != "" {
331+
req.Header.Set(headerKeyProtocolVersion, version)
332+
}
333+
}
327334
for k, v := range c.headers {
328335
req.Header.Set(k, v)
329336
}
@@ -434,6 +441,11 @@ func (c *SSE) GetSessionId() string {
434441
return ""
435442
}
436443

444+
// SetProtocolVersion sets the negotiated protocol version for this connection.
445+
func (c *SSE) SetProtocolVersion(version string) {
446+
c.protocolVersion.Store(version)
447+
}
448+
437449
// SendNotification sends a JSON-RPC notification to the server without expecting a response.
438450
func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
439451
if c.endpoint == nil {
@@ -456,6 +468,12 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
456468
}
457469

458470
req.Header.Set("Content-Type", "application/json")
471+
// Set protocol version header if negotiated
472+
if v := c.protocolVersion.Load(); v != nil {
473+
if version, ok := v.(string); ok && version != "" {
474+
req.Header.Set(headerKeyProtocolVersion, version)
475+
}
476+
}
459477
// Set custom HTTP headers
460478
for k, v := range c.headers {
461479
req.Header.Set(k, v)

client/transport/streamable_http.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ type StreamableHTTP struct {
102102
logger util.Logger
103103
getListeningEnabled bool
104104

105-
sessionID atomic.Value // string
105+
sessionID atomic.Value // string
106+
protocolVersion atomic.Value // string
106107

107108
initialized chan struct{}
108109
initializedOnce sync.Once
@@ -207,8 +208,14 @@ func (c *StreamableHTTP) Close() error {
207208
return nil
208209
}
209210

211+
// SetProtocolVersion sets the negotiated protocol version for this connection.
212+
func (c *StreamableHTTP) SetProtocolVersion(version string) {
213+
c.protocolVersion.Store(version)
214+
}
215+
210216
const (
211-
headerKeySessionID = "Mcp-Session-Id"
217+
headerKeySessionID = "Mcp-Session-Id"
218+
headerKeyProtocolVersion = "Mcp-Protocol-Version"
212219
)
213220

214221
// ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required
@@ -337,6 +344,12 @@ func (c *StreamableHTTP) sendHTTP(
337344
if sessionID != "" {
338345
req.Header.Set(headerKeySessionID, sessionID)
339346
}
347+
// Set protocol version header if negotiated
348+
if v := c.protocolVersion.Load(); v != nil {
349+
if version, ok := v.(string); ok && version != "" {
350+
req.Header.Set(headerKeyProtocolVersion, version)
351+
}
352+
}
340353
for k, v := range c.headers {
341354
req.Header.Set(k, v)
342355
}

mcp/errors.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package mcp
2+
3+
import "fmt"
4+
5+
// UnsupportedProtocolVersionError is returned when the server responds with
6+
// a protocol version that the client doesn't support.
7+
type UnsupportedProtocolVersionError struct {
8+
Version string
9+
}
10+
11+
func (e UnsupportedProtocolVersionError) Error() string {
12+
return fmt.Sprintf("unsupported protocol version: %q", e.Version)
13+
}
14+
15+
// IsUnsupportedProtocolVersion checks if an error is an UnsupportedProtocolVersionError
16+
func IsUnsupportedProtocolVersion(err error) bool {
17+
_, ok := err.(UnsupportedProtocolVersionError)
18+
return ok
19+
}

mcp/types.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,13 @@ func (t *URITemplate) UnmarshalJSON(data []byte) error {
9797
type JSONRPCMessage any
9898

9999
// LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol.
100-
const LATEST_PROTOCOL_VERSION = "2025-03-26"
100+
const LATEST_PROTOCOL_VERSION = "2025-06-18"
101101

102102
// ValidProtocolVersions lists all known valid MCP protocol versions.
103103
var ValidProtocolVersions = []string{
104-
"2024-11-05",
105104
LATEST_PROTOCOL_VERSION,
105+
"2025-03-26",
106+
"2024-11-05",
106107
}
107108

108109
// JSONRPC_VERSION is the version of JSON-RPC used by MCP.

server/streamable_http.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
208208
// --- internal methods ---
209209

210210
const (
211-
headerKeySessionID = "Mcp-Session-Id"
211+
headerKeySessionID = "Mcp-Session-Id"
212+
headerKeyProtocolVersion = "Mcp-Protocol-Version"
212213
)
213214

214215
func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) {

0 commit comments

Comments
 (0)