Skip to content

Commit 3297854

Browse files
authored
Merge branch 'main' into dev
2 parents 807d5b7 + c509bf9 commit 3297854

File tree

9 files changed

+557
-157
lines changed

9 files changed

+557
-157
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212

1313
[![Tutorial](http://img.youtube.com/vi/qoaeYMrXJH0/0.jpg)](http://www.youtube.com/watch?v=qoaeYMrXJH0 "Tutorial")
1414

15+
<br>
16+
17+
Discuss the SDK on [Discord](https://discord.gg/RqSS2NQVsY)
18+
1519
</div>
1620

1721
```go
@@ -122,6 +126,7 @@ func main() {
122126
"1.0.0",
123127
server.WithResourceCapabilities(true, true),
124128
server.WithLogging(),
129+
server.WithRecovery(),
125130
)
126131

127132
// Add a calculator tool
@@ -522,6 +527,12 @@ initialization.
522527
Add the `Hooks` to the server at the time of creation using the
523528
`server.WithHooks` option.
524529

530+
### Tool Handler Middleware
531+
532+
Add middleware to tool call handlers using the `server.WithToolHandlerMiddleware` option. Middlewares can be registered on server creation and are applied on every tool call.
533+
534+
A recovery middleware option is available to recover from panics in a tool call and can be added to the server with the `server.WithRecovery` option.
535+
525536
## Contributing
526537

527538
<details>

client/client.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package client
33

44
import (
55
"context"
6+
"encoding/json"
7+
"fmt"
68

79
"github.com/mark3labs/mcp-go/mcp"
810
)
@@ -18,12 +20,25 @@ type MCPClient interface {
1820
// Ping checks if the server is alive
1921
Ping(ctx context.Context) error
2022

23+
// ListResourcesByPage manually list resources by page.
24+
ListResourcesByPage(
25+
ctx context.Context,
26+
request mcp.ListResourcesRequest,
27+
) (*mcp.ListResourcesResult, error)
28+
2129
// ListResources requests a list of available resources from the server
2230
ListResources(
2331
ctx context.Context,
2432
request mcp.ListResourcesRequest,
2533
) (*mcp.ListResourcesResult, error)
2634

35+
// ListResourceTemplatesByPage manually list resource templates by page.
36+
ListResourceTemplatesByPage(
37+
ctx context.Context,
38+
request mcp.ListResourceTemplatesRequest,
39+
) (*mcp.ListResourceTemplatesResult,
40+
error)
41+
2742
// ListResourceTemplates requests a list of available resource templates from the server
2843
ListResourceTemplates(
2944
ctx context.Context,
@@ -43,6 +58,12 @@ type MCPClient interface {
4358
// Unsubscribe cancels notifications for a specific resource
4459
Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error
4560

61+
// ListPromptsByPage manually list prompts by page.
62+
ListPromptsByPage(
63+
ctx context.Context,
64+
request mcp.ListPromptsRequest,
65+
) (*mcp.ListPromptsResult, error)
66+
4667
// ListPrompts requests a list of available prompts from the server
4768
ListPrompts(
4869
ctx context.Context,
@@ -55,6 +76,12 @@ type MCPClient interface {
5576
request mcp.GetPromptRequest,
5677
) (*mcp.GetPromptResult, error)
5778

79+
// ListToolsByPage manually list tools by page.
80+
ListToolsByPage(
81+
ctx context.Context,
82+
request mcp.ListToolsRequest,
83+
) (*mcp.ListToolsResult, error)
84+
5885
// ListTools requests a list of available tools from the server
5986
ListTools(
6087
ctx context.Context,
@@ -82,3 +109,26 @@ type MCPClient interface {
82109
// OnNotification registers a handler for notifications
83110
OnNotification(handler func(notification mcp.JSONRPCNotification))
84111
}
112+
113+
type mcpClient interface {
114+
MCPClient
115+
116+
sendRequest(ctx context.Context, method string, params interface{}) (*json.RawMessage, error)
117+
}
118+
119+
func listByPage[T any](
120+
ctx context.Context,
121+
client mcpClient,
122+
request mcp.PaginatedRequest,
123+
method string,
124+
) (*T, error) {
125+
response, err := client.sendRequest(ctx, method, request.Params)
126+
if err != nil {
127+
return nil, err
128+
}
129+
var result T
130+
if err := json.Unmarshal(*response, &result); err != nil {
131+
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
132+
}
133+
return &result, nil
134+
}

client/sse.go

Lines changed: 128 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,19 @@ import (
2323
// while sending requests over regular HTTP POST calls. The client handles
2424
// automatic reconnection and message routing between requests and responses.
2525
type SSEMCPClient struct {
26-
baseURL *url.URL
27-
endpoint *url.URL
28-
httpClient *http.Client
29-
requestID atomic.Int64
30-
responses map[int64]chan RPCResponse
31-
mu sync.RWMutex
32-
done chan struct{}
33-
initialized bool
34-
notifications []func(mcp.JSONRPCNotification)
35-
notifyMu sync.RWMutex
36-
endpointChan chan struct{}
37-
capabilities mcp.ServerCapabilities
38-
headers map[string]string
39-
sseReadTimeout time.Duration
26+
baseURL *url.URL
27+
endpoint *url.URL
28+
httpClient *http.Client
29+
requestID atomic.Int64
30+
responses map[int64]chan RPCResponse
31+
mu sync.RWMutex
32+
done chan struct{}
33+
initialized bool
34+
notifications []func(mcp.JSONRPCNotification)
35+
notifyMu sync.RWMutex
36+
endpointChan chan struct{}
37+
capabilities mcp.ServerCapabilities
38+
headers map[string]string
4039
}
4140

4241
type ClientOption func(*SSEMCPClient)
@@ -68,13 +67,12 @@ func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, er
6867
}
6968

7069
smc := &SSEMCPClient{
71-
baseURL: parsedURL,
72-
httpClient: &http.Client{},
73-
responses: make(map[int64]chan RPCResponse),
74-
done: make(chan struct{}),
75-
endpointChan: make(chan struct{}),
76-
sseReadTimeout: 30 * time.Second,
77-
headers: make(map[string]string),
70+
baseURL: parsedURL,
71+
httpClient: &http.Client{},
72+
responses: make(map[int64]chan RPCResponse),
73+
done: make(chan struct{}),
74+
endpointChan: make(chan struct{}),
75+
headers: make(map[string]string),
7876
}
7977

8078
for _, opt := range options {
@@ -99,6 +97,9 @@ func (c *SSEMCPClient) Start(ctx context.Context) error {
9997
req.Header.Set("Accept", "text/event-stream")
10098
req.Header.Set("Cache-Control", "no-cache")
10199
req.Header.Set("Connection", "keep-alive")
100+
for k, v := range c.headers {
101+
req.Header.Set(k, v)
102+
}
102103

103104
resp, err := c.httpClient.Do(req)
104105
if err != nil {
@@ -134,12 +135,9 @@ func (c *SSEMCPClient) readSSE(reader io.ReadCloser) {
134135
br := bufio.NewReader(reader)
135136
var event, data string
136137

137-
ctx, cancel := context.WithTimeout(context.Background(), c.sseReadTimeout)
138-
defer cancel()
139-
140138
for {
141139
select {
142-
case <-ctx.Done():
140+
case <-c.done:
143141
return
144142
default:
145143
line, err := br.ReadString('\n')
@@ -405,7 +403,7 @@ func (c *SSEMCPClient) Initialize(
405403
err,
406404
)
407405
}
408-
resp.Body.Close()
406+
defer resp.Body.Close()
409407

410408
c.initialized = true
411409
return &result, nil
@@ -416,42 +414,77 @@ func (c *SSEMCPClient) Ping(ctx context.Context) error {
416414
return err
417415
}
418416

419-
func (c *SSEMCPClient) ListResources(
417+
// ListResourcesByPage manually list resources by page.
418+
func (c *SSEMCPClient) ListResourcesByPage(
420419
ctx context.Context,
421420
request mcp.ListResourcesRequest,
422421
) (*mcp.ListResourcesResult, error) {
423-
response, err := c.sendRequest(ctx, "resources/list", request.Params)
422+
result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list")
424423
if err != nil {
425424
return nil, err
426425
}
426+
return result, nil
427+
}
427428

428-
var result mcp.ListResourcesResult
429-
if err := json.Unmarshal(*response, &result); err != nil {
430-
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
429+
func (c *SSEMCPClient) ListResources(
430+
ctx context.Context,
431+
request mcp.ListResourcesRequest,
432+
) (*mcp.ListResourcesResult, error) {
433+
result, err := c.ListResourcesByPage(ctx, request)
434+
if err != nil {
435+
return nil, err
431436
}
437+
for result.NextCursor != "" {
438+
select {
439+
case <-ctx.Done():
440+
return nil, ctx.Err()
441+
default:
442+
request.Params.Cursor = result.NextCursor
443+
newPageRes, err := c.ListResourcesByPage(ctx, request)
444+
if err != nil {
445+
return nil, err
446+
}
447+
result.Resources = append(result.Resources, newPageRes.Resources...)
448+
result.NextCursor = newPageRes.NextCursor
449+
}
450+
}
451+
return result, nil
452+
}
432453

433-
return &result, nil
454+
func (c *SSEMCPClient) ListResourceTemplatesByPage(
455+
ctx context.Context,
456+
request mcp.ListResourceTemplatesRequest,
457+
) (*mcp.ListResourceTemplatesResult, error) {
458+
result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list")
459+
if err != nil {
460+
return nil, err
461+
}
462+
return result, nil
434463
}
435464

436465
func (c *SSEMCPClient) ListResourceTemplates(
437466
ctx context.Context,
438467
request mcp.ListResourceTemplatesRequest,
439468
) (*mcp.ListResourceTemplatesResult, error) {
440-
response, err := c.sendRequest(
441-
ctx,
442-
"resources/templates/list",
443-
request.Params,
444-
)
469+
result, err := c.ListResourceTemplatesByPage(ctx, request)
445470
if err != nil {
446471
return nil, err
447472
}
448-
449-
var result mcp.ListResourceTemplatesResult
450-
if err := json.Unmarshal(*response, &result); err != nil {
451-
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
473+
for result.NextCursor != "" {
474+
select {
475+
case <-ctx.Done():
476+
return nil, ctx.Err()
477+
default:
478+
request.Params.Cursor = result.NextCursor
479+
newPageRes, err := c.ListResourceTemplatesByPage(ctx, request)
480+
if err != nil {
481+
return nil, err
482+
}
483+
result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...)
484+
result.NextCursor = newPageRes.NextCursor
485+
}
452486
}
453-
454-
return &result, nil
487+
return result, nil
455488
}
456489

457490
func (c *SSEMCPClient) ReadResource(
@@ -482,21 +515,40 @@ func (c *SSEMCPClient) Unsubscribe(
482515
return err
483516
}
484517

485-
func (c *SSEMCPClient) ListPrompts(
518+
func (c *SSEMCPClient) ListPromptsByPage(
486519
ctx context.Context,
487520
request mcp.ListPromptsRequest,
488521
) (*mcp.ListPromptsResult, error) {
489-
response, err := c.sendRequest(ctx, "prompts/list", request.Params)
522+
result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list")
490523
if err != nil {
491524
return nil, err
492525
}
526+
return result, nil
527+
}
493528

494-
var result mcp.ListPromptsResult
495-
if err := json.Unmarshal(*response, &result); err != nil {
496-
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
529+
func (c *SSEMCPClient) ListPrompts(
530+
ctx context.Context,
531+
request mcp.ListPromptsRequest,
532+
) (*mcp.ListPromptsResult, error) {
533+
result, err := c.ListPromptsByPage(ctx, request)
534+
if err != nil {
535+
return nil, err
497536
}
498-
499-
return &result, nil
537+
for result.NextCursor != "" {
538+
select {
539+
case <-ctx.Done():
540+
return nil, ctx.Err()
541+
default:
542+
request.Params.Cursor = result.NextCursor
543+
newPageRes, err := c.ListPromptsByPage(ctx, request)
544+
if err != nil {
545+
return nil, err
546+
}
547+
result.Prompts = append(result.Prompts, newPageRes.Prompts...)
548+
result.NextCursor = newPageRes.NextCursor
549+
}
550+
}
551+
return result, nil
500552
}
501553

502554
func (c *SSEMCPClient) GetPrompt(
@@ -511,21 +563,40 @@ func (c *SSEMCPClient) GetPrompt(
511563
return mcp.ParseGetPromptResult(response)
512564
}
513565

514-
func (c *SSEMCPClient) ListTools(
566+
func (c *SSEMCPClient) ListToolsByPage(
515567
ctx context.Context,
516568
request mcp.ListToolsRequest,
517569
) (*mcp.ListToolsResult, error) {
518-
response, err := c.sendRequest(ctx, "tools/list", request.Params)
570+
result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list")
519571
if err != nil {
520572
return nil, err
521573
}
574+
return result, nil
575+
}
522576

523-
var result mcp.ListToolsResult
524-
if err := json.Unmarshal(*response, &result); err != nil {
525-
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
577+
func (c *SSEMCPClient) ListTools(
578+
ctx context.Context,
579+
request mcp.ListToolsRequest,
580+
) (*mcp.ListToolsResult, error) {
581+
result, err := c.ListToolsByPage(ctx, request)
582+
if err != nil {
583+
return nil, err
526584
}
527-
528-
return &result, nil
585+
for result.NextCursor != "" {
586+
select {
587+
case <-ctx.Done():
588+
return nil, ctx.Err()
589+
default:
590+
request.Params.Cursor = result.NextCursor
591+
newPageRes, err := c.ListToolsByPage(ctx, request)
592+
if err != nil {
593+
return nil, err
594+
}
595+
result.Tools = append(result.Tools, newPageRes.Tools...)
596+
result.NextCursor = newPageRes.NextCursor
597+
}
598+
}
599+
return result, nil
529600
}
530601

531602
func (c *SSEMCPClient) CallTool(

0 commit comments

Comments
 (0)