5
5
"bytes"
6
6
"context"
7
7
"encoding/json"
8
+ "errors"
8
9
"fmt"
9
10
"io"
10
11
"net/http"
@@ -36,6 +37,9 @@ type SSE struct {
36
37
started atomic.Bool
37
38
closed atomic.Bool
38
39
cancelSSEStream context.CancelFunc
40
+
41
+ // OAuth support
42
+ oauthHandler * OAuthHandler
39
43
}
40
44
41
45
type ClientOption func (* SSE )
@@ -58,6 +62,12 @@ func WithHTTPClient(httpClient *http.Client) ClientOption {
58
62
}
59
63
}
60
64
65
+ func WithOAuth (config OAuthConfig ) ClientOption {
66
+ return func (sc * SSE ) {
67
+ sc .oauthHandler = NewOAuthHandler (config )
68
+ }
69
+ }
70
+
61
71
// NewSSE creates a new SSE-based MCP client with the given base URL.
62
72
// Returns an error if the URL is invalid.
63
73
func NewSSE (baseURL string , options ... ClientOption ) (* SSE , error ) {
@@ -78,6 +88,13 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
78
88
opt (smc )
79
89
}
80
90
91
+ // If OAuth is configured, set the base URL for metadata discovery
92
+ if smc .oauthHandler != nil {
93
+ // Extract base URL from server URL for metadata discovery
94
+ baseURL := fmt .Sprintf ("%s://%s" , parsedURL .Scheme , parsedURL .Host )
95
+ smc .oauthHandler .SetBaseURL (baseURL )
96
+ }
97
+
81
98
return smc , nil
82
99
}
83
100
@@ -112,13 +129,34 @@ func (c *SSE) Start(ctx context.Context) error {
112
129
}
113
130
}
114
131
132
+ // Add OAuth authorization if configured
133
+ if c .oauthHandler != nil {
134
+ authHeader , err := c .oauthHandler .GetAuthorizationHeader (ctx )
135
+ if err != nil {
136
+ // If we get an authorization error, return a specific error that can be handled by the client
137
+ if err .Error () == "no valid token available, authorization required" {
138
+ return & OAuthAuthorizationRequiredError {
139
+ Handler : c .oauthHandler ,
140
+ }
141
+ }
142
+ return fmt .Errorf ("failed to get authorization header: %w" , err )
143
+ }
144
+ req .Header .Set ("Authorization" , authHeader )
145
+ }
146
+
115
147
resp , err := c .httpClient .Do (req )
116
148
if err != nil {
117
149
return fmt .Errorf ("failed to connect to SSE stream: %w" , err )
118
150
}
119
151
120
152
if resp .StatusCode != http .StatusOK {
121
153
resp .Body .Close ()
154
+ // Handle OAuth unauthorized error
155
+ if resp .StatusCode == http .StatusUnauthorized && c .oauthHandler != nil {
156
+ return & OAuthAuthorizationRequiredError {
157
+ Handler : c .oauthHandler ,
158
+ }
159
+ }
122
160
return fmt .Errorf ("unexpected status code: %d" , resp .StatusCode )
123
161
}
124
162
@@ -281,6 +319,22 @@ func (c *SSE) SendRequest(
281
319
for k , v := range c .headers {
282
320
req .Header .Set (k , v )
283
321
}
322
+
323
+ // Add OAuth authorization if configured
324
+ if c .oauthHandler != nil {
325
+ authHeader , err := c .oauthHandler .GetAuthorizationHeader (ctx )
326
+ if err != nil {
327
+ // If we get an authorization error, return a specific error that can be handled by the client
328
+ if err .Error () == "no valid token available, authorization required" {
329
+ return nil , & OAuthAuthorizationRequiredError {
330
+ Handler : c .oauthHandler ,
331
+ }
332
+ }
333
+ return nil , fmt .Errorf ("failed to get authorization header: %w" , err )
334
+ }
335
+ req .Header .Set ("Authorization" , authHeader )
336
+ }
337
+
284
338
if c .headerFunc != nil {
285
339
for k , v := range c .headerFunc (ctx ) {
286
340
req .Header .Set (k , v )
@@ -320,6 +374,13 @@ func (c *SSE) SendRequest(
320
374
if resp .StatusCode != http .StatusOK && resp .StatusCode != http .StatusAccepted {
321
375
deleteResponseChan ()
322
376
377
+ // Handle OAuth unauthorized error
378
+ if resp .StatusCode == http .StatusUnauthorized && c .oauthHandler != nil {
379
+ return nil , & OAuthAuthorizationRequiredError {
380
+ Handler : c .oauthHandler ,
381
+ }
382
+ }
383
+
323
384
return nil , fmt .Errorf ("request failed with status %d: %s" , resp .StatusCode , body )
324
385
}
325
386
@@ -385,6 +446,22 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
385
446
for k , v := range c .headers {
386
447
req .Header .Set (k , v )
387
448
}
449
+
450
+ // Add OAuth authorization if configured
451
+ if c .oauthHandler != nil {
452
+ authHeader , err := c .oauthHandler .GetAuthorizationHeader (ctx )
453
+ if err != nil {
454
+ // If we get an authorization error, return a specific error that can be handled by the client
455
+ if errors .Is (err , ErrOAuthAuthorizationRequired ) {
456
+ return & OAuthAuthorizationRequiredError {
457
+ Handler : c .oauthHandler ,
458
+ }
459
+ }
460
+ return fmt .Errorf ("failed to get authorization header: %w" , err )
461
+ }
462
+ req .Header .Set ("Authorization" , authHeader )
463
+ }
464
+
388
465
if c .headerFunc != nil {
389
466
for k , v := range c .headerFunc (ctx ) {
390
467
req .Header .Set (k , v )
@@ -398,6 +475,13 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
398
475
defer resp .Body .Close ()
399
476
400
477
if resp .StatusCode != http .StatusOK && resp .StatusCode != http .StatusAccepted {
478
+ // Handle OAuth unauthorized error
479
+ if resp .StatusCode == http .StatusUnauthorized && c .oauthHandler != nil {
480
+ return & OAuthAuthorizationRequiredError {
481
+ Handler : c .oauthHandler ,
482
+ }
483
+ }
484
+
401
485
body , _ := io .ReadAll (resp .Body )
402
486
return fmt .Errorf (
403
487
"notification failed with status %d: %s" ,
@@ -418,3 +502,13 @@ func (c *SSE) GetEndpoint() *url.URL {
418
502
func (c * SSE ) GetBaseURL () * url.URL {
419
503
return c .baseURL
420
504
}
505
+
506
+ // GetOAuthHandler returns the OAuth handler if configured
507
+ func (c * SSE ) GetOAuthHandler () * OAuthHandler {
508
+ return c .oauthHandler
509
+ }
510
+
511
+ // IsOAuthEnabled returns true if OAuth is enabled
512
+ func (c * SSE ) IsOAuthEnabled () bool {
513
+ return c .oauthHandler != nil
514
+ }
0 commit comments