Skip to content

Commit 7162a07

Browse files
committed
feat: add hooks for sse client
1 parent 33c98f1 commit 7162a07

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

client/transport/sse.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@ import (
1717
"github.com/mark3labs/mcp-go/mcp"
1818
)
1919

20+
// OnBeforeRequestFunc is called before sending the request, with context.
21+
type OnBeforeRequestFunc func(ctx context.Context, req *http.Request)
22+
23+
// OnAfterResponseFunc is called after receiving the response, with context. (Regardless of error, when err is not nil resp may be nil.) The req parameter is included.
24+
type OnAfterResponseFunc func(ctx context.Context, req *http.Request, resp *http.Response, err error)
25+
26+
// SSEHooks supports multiple before and after processing functions.
27+
type SSEHooks struct {
28+
OnBeforeRequest []OnBeforeRequestFunc
29+
OnAfterResponse []OnAfterResponseFunc
30+
}
31+
2032
// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE).
2133
// It maintains a persistent HTTP connection to receive server-pushed events
2234
// while sending requests over regular HTTP POST calls. The client handles
@@ -32,6 +44,8 @@ type SSE struct {
3244
endpointChan chan struct{}
3345
headers map[string]string
3446

47+
hooks SSEHooks
48+
3549
started atomic.Bool
3650
closed atomic.Bool
3751
cancelSSEStream context.CancelFunc
@@ -45,6 +59,13 @@ func WithHeaders(headers map[string]string) ClientOption {
4559
}
4660
}
4761

62+
// Register a set of hooks (overwrites existing hooks)
63+
func WithSSEHooks(hooks SSEHooks) ClientOption {
64+
return func(sc *SSE) {
65+
sc.hooks = hooks
66+
}
67+
}
68+
4869
// NewSSE creates a new SSE-based MCP client with the given base URL.
4970
// Returns an error if the URL is invalid.
5071
func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
@@ -261,6 +282,13 @@ func (c *SSE) SendRequest(
261282
req.Header.Set(k, v)
262283
}
263284

285+
// hooks: before request
286+
for _, hook := range c.hooks.OnBeforeRequest {
287+
if hook != nil {
288+
hook(ctx, req)
289+
}
290+
}
291+
264292
// Register response channel
265293
responseChan := make(chan *JSONRPCResponse, 1)
266294
c.mu.Lock()
@@ -274,6 +302,14 @@ func (c *SSE) SendRequest(
274302

275303
// Send request
276304
resp, err := c.httpClient.Do(req)
305+
306+
// hooks: after response
307+
for _, hook := range c.hooks.OnAfterResponse {
308+
if hook != nil {
309+
hook(ctx, req, resp, err)
310+
}
311+
}
312+
277313
if err != nil {
278314
deleteResponseChan()
279315
return nil, fmt.Errorf("failed to send request: %w", err)
@@ -348,7 +384,22 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
348384
req.Header.Set(k, v)
349385
}
350386

387+
// hooks: before request
388+
for _, hook := range c.hooks.OnBeforeRequest {
389+
if hook != nil {
390+
hook(ctx, req)
391+
}
392+
}
393+
351394
resp, err := c.httpClient.Do(req)
395+
396+
// hooks: after response
397+
for _, hook := range c.hooks.OnAfterResponse {
398+
if hook != nil {
399+
hook(ctx, req, resp, err)
400+
}
401+
}
402+
352403
if err != nil {
353404
return fmt.Errorf("failed to send notification: %w", err)
354405
}

0 commit comments

Comments
 (0)