@@ -17,6 +17,18 @@ import (
17
17
"github.com/mark3labs/mcp-go/mcp"
18
18
)
19
19
20
+ // OnBeforeRequestFunc is called before sending the request, with context.
21
+ type OnBeforeRequestFunc func (ctx context.Context , req * http.Request )
22
+
23
+ // OnAfterResponseFunc is called after receiving the response, with context. (Regardless of error, when err is not nil resp may be nil.) The req parameter is included.
24
+ type OnAfterResponseFunc func (ctx context.Context , req * http.Request , resp * http.Response , err error )
25
+
26
+ // SSEHooks supports multiple before and after processing functions.
27
+ type SSEHooks struct {
28
+ OnBeforeRequest []OnBeforeRequestFunc
29
+ OnAfterResponse []OnAfterResponseFunc
30
+ }
31
+
20
32
// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE).
21
33
// It maintains a persistent HTTP connection to receive server-pushed events
22
34
// while sending requests over regular HTTP POST calls. The client handles
@@ -32,6 +44,8 @@ type SSE struct {
32
44
endpointChan chan struct {}
33
45
headers map [string ]string
34
46
47
+ hooks SSEHooks
48
+
35
49
started atomic.Bool
36
50
closed atomic.Bool
37
51
cancelSSEStream context.CancelFunc
@@ -45,6 +59,13 @@ func WithHeaders(headers map[string]string) ClientOption {
45
59
}
46
60
}
47
61
62
+ // Register a set of hooks (overwrites existing hooks)
63
+ func WithSSEHooks (hooks SSEHooks ) ClientOption {
64
+ return func (sc * SSE ) {
65
+ sc .hooks = hooks
66
+ }
67
+ }
68
+
48
69
// NewSSE creates a new SSE-based MCP client with the given base URL.
49
70
// Returns an error if the URL is invalid.
50
71
func NewSSE (baseURL string , options ... ClientOption ) (* SSE , error ) {
@@ -261,6 +282,13 @@ func (c *SSE) SendRequest(
261
282
req .Header .Set (k , v )
262
283
}
263
284
285
+ // hooks: before request
286
+ for _ , hook := range c .hooks .OnBeforeRequest {
287
+ if hook != nil {
288
+ hook (ctx , req )
289
+ }
290
+ }
291
+
264
292
// Register response channel
265
293
responseChan := make (chan * JSONRPCResponse , 1 )
266
294
c .mu .Lock ()
@@ -274,6 +302,14 @@ func (c *SSE) SendRequest(
274
302
275
303
// Send request
276
304
resp , err := c .httpClient .Do (req )
305
+
306
+ // hooks: after response
307
+ for _ , hook := range c .hooks .OnAfterResponse {
308
+ if hook != nil {
309
+ hook (ctx , req , resp , err )
310
+ }
311
+ }
312
+
277
313
if err != nil {
278
314
deleteResponseChan ()
279
315
return nil , fmt .Errorf ("failed to send request: %w" , err )
@@ -348,7 +384,22 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
348
384
req .Header .Set (k , v )
349
385
}
350
386
387
+ // hooks: before request
388
+ for _ , hook := range c .hooks .OnBeforeRequest {
389
+ if hook != nil {
390
+ hook (ctx , req )
391
+ }
392
+ }
393
+
351
394
resp , err := c .httpClient .Do (req )
395
+
396
+ // hooks: after response
397
+ for _ , hook := range c .hooks .OnAfterResponse {
398
+ if hook != nil {
399
+ hook (ctx , req , resp , err )
400
+ }
401
+ }
402
+
352
403
if err != nil {
353
404
return fmt .Errorf ("failed to send notification: %w" , err )
354
405
}
0 commit comments