Skip to content

Commit 774b17b

Browse files
[SSE][OAuth] Add OAuth support to SSE client (#340)
1 parent 9557d0a commit 774b17b

File tree

7 files changed

+409
-44
lines changed

7 files changed

+409
-44
lines changed

client/oauth.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ var NewMemoryTokenStore = transport.NewMemoryTokenStore
2626
// Returns an error if the URL is invalid.
2727
func NewOAuthStreamableHttpClient(baseURL string, oauthConfig OAuthConfig, options ...transport.StreamableHTTPCOption) (*Client, error) {
2828
// Add OAuth option to the list of options
29-
options = append(options, transport.WithOAuth(oauthConfig))
29+
options = append(options, transport.WithHTTPOAuth(oauthConfig))
3030

3131
trans, err := transport.NewStreamableHTTP(baseURL, options...)
3232
if err != nil {
@@ -35,6 +35,19 @@ func NewOAuthStreamableHttpClient(baseURL string, oauthConfig OAuthConfig, optio
3535
return NewClient(trans), nil
3636
}
3737

38+
// NewOAuthStreamableHttpClient creates a new streamable-http-based MCP client with OAuth support.
39+
// Returns an error if the URL is invalid.
40+
func NewOAuthSSEClient(baseURL string, oauthConfig OAuthConfig, options ...transport.ClientOption) (*Client, error) {
41+
// Add OAuth option to the list of options
42+
options = append(options, transport.WithOAuth(oauthConfig))
43+
44+
trans, err := transport.NewSSE(baseURL, options...)
45+
if err != nil {
46+
return nil, fmt.Errorf("failed to create SSE transport: %w", err)
47+
}
48+
return NewClient(trans), nil
49+
}
50+
3851
// GenerateCodeVerifier generates a code verifier for PKCE
3952
var GenerateCodeVerifier = transport.GenerateCodeVerifier
4053

client/transport/oauth.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,14 @@ func (h *OAuthHandler) GetAuthorizationHeader(ctx context.Context) (string, erro
136136
if err != nil {
137137
return "", err
138138
}
139-
return fmt.Sprintf("%s %s", token.TokenType, token.AccessToken), nil
139+
140+
// Some auth implementations are strict about token type
141+
tokenType := token.TokenType
142+
if tokenType == "bearer" {
143+
tokenType = "Bearer"
144+
}
145+
146+
return fmt.Sprintf("%s %s", tokenType, token.AccessToken), nil
140147
}
141148

142149
// getValidToken returns a valid token, refreshing if necessary

client/transport/sse.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"bytes"
66
"context"
77
"encoding/json"
8+
"errors"
89
"fmt"
910
"io"
1011
"net/http"
@@ -36,6 +37,9 @@ type SSE struct {
3637
started atomic.Bool
3738
closed atomic.Bool
3839
cancelSSEStream context.CancelFunc
40+
41+
// OAuth support
42+
oauthHandler *OAuthHandler
3943
}
4044

4145
type ClientOption func(*SSE)
@@ -58,6 +62,12 @@ func WithHTTPClient(httpClient *http.Client) ClientOption {
5862
}
5963
}
6064

65+
func WithOAuth(config OAuthConfig) ClientOption {
66+
return func(sc *SSE) {
67+
sc.oauthHandler = NewOAuthHandler(config)
68+
}
69+
}
70+
6171
// NewSSE creates a new SSE-based MCP client with the given base URL.
6272
// Returns an error if the URL is invalid.
6373
func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
@@ -78,6 +88,13 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
7888
opt(smc)
7989
}
8090

91+
// If OAuth is configured, set the base URL for metadata discovery
92+
if smc.oauthHandler != nil {
93+
// Extract base URL from server URL for metadata discovery
94+
baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
95+
smc.oauthHandler.SetBaseURL(baseURL)
96+
}
97+
8198
return smc, nil
8299
}
83100

@@ -112,13 +129,34 @@ func (c *SSE) Start(ctx context.Context) error {
112129
}
113130
}
114131

132+
// Add OAuth authorization if configured
133+
if c.oauthHandler != nil {
134+
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
135+
if err != nil {
136+
// If we get an authorization error, return a specific error that can be handled by the client
137+
if err.Error() == "no valid token available, authorization required" {
138+
return &OAuthAuthorizationRequiredError{
139+
Handler: c.oauthHandler,
140+
}
141+
}
142+
return fmt.Errorf("failed to get authorization header: %w", err)
143+
}
144+
req.Header.Set("Authorization", authHeader)
145+
}
146+
115147
resp, err := c.httpClient.Do(req)
116148
if err != nil {
117149
return fmt.Errorf("failed to connect to SSE stream: %w", err)
118150
}
119151

120152
if resp.StatusCode != http.StatusOK {
121153
resp.Body.Close()
154+
// Handle OAuth unauthorized error
155+
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
156+
return &OAuthAuthorizationRequiredError{
157+
Handler: c.oauthHandler,
158+
}
159+
}
122160
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
123161
}
124162

@@ -281,6 +319,22 @@ func (c *SSE) SendRequest(
281319
for k, v := range c.headers {
282320
req.Header.Set(k, v)
283321
}
322+
323+
// Add OAuth authorization if configured
324+
if c.oauthHandler != nil {
325+
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
326+
if err != nil {
327+
// If we get an authorization error, return a specific error that can be handled by the client
328+
if err.Error() == "no valid token available, authorization required" {
329+
return nil, &OAuthAuthorizationRequiredError{
330+
Handler: c.oauthHandler,
331+
}
332+
}
333+
return nil, fmt.Errorf("failed to get authorization header: %w", err)
334+
}
335+
req.Header.Set("Authorization", authHeader)
336+
}
337+
284338
if c.headerFunc != nil {
285339
for k, v := range c.headerFunc(ctx) {
286340
req.Header.Set(k, v)
@@ -320,6 +374,13 @@ func (c *SSE) SendRequest(
320374
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
321375
deleteResponseChan()
322376

377+
// Handle OAuth unauthorized error
378+
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
379+
return nil, &OAuthAuthorizationRequiredError{
380+
Handler: c.oauthHandler,
381+
}
382+
}
383+
323384
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
324385
}
325386

@@ -385,6 +446,22 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
385446
for k, v := range c.headers {
386447
req.Header.Set(k, v)
387448
}
449+
450+
// Add OAuth authorization if configured
451+
if c.oauthHandler != nil {
452+
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
453+
if err != nil {
454+
// If we get an authorization error, return a specific error that can be handled by the client
455+
if errors.Is(err, ErrOAuthAuthorizationRequired) {
456+
return &OAuthAuthorizationRequiredError{
457+
Handler: c.oauthHandler,
458+
}
459+
}
460+
return fmt.Errorf("failed to get authorization header: %w", err)
461+
}
462+
req.Header.Set("Authorization", authHeader)
463+
}
464+
388465
if c.headerFunc != nil {
389466
for k, v := range c.headerFunc(ctx) {
390467
req.Header.Set(k, v)
@@ -398,6 +475,13 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
398475
defer resp.Body.Close()
399476

400477
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
478+
// Handle OAuth unauthorized error
479+
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
480+
return &OAuthAuthorizationRequiredError{
481+
Handler: c.oauthHandler,
482+
}
483+
}
484+
401485
body, _ := io.ReadAll(resp.Body)
402486
return fmt.Errorf(
403487
"notification failed with status %d: %s",
@@ -418,3 +502,13 @@ func (c *SSE) GetEndpoint() *url.URL {
418502
func (c *SSE) GetBaseURL() *url.URL {
419503
return c.baseURL
420504
}
505+
506+
// GetOAuthHandler returns the OAuth handler if configured
507+
func (c *SSE) GetOAuthHandler() *OAuthHandler {
508+
return c.oauthHandler
509+
}
510+
511+
// IsOAuthEnabled returns true if OAuth is enabled
512+
func (c *SSE) IsOAuthEnabled() bool {
513+
return c.oauthHandler != nil
514+
}

0 commit comments

Comments
 (0)