Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions server/http_transport_options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package server

import (
"context"
"net/http"
"net/url"
"strings"
"time"
)

// HTTPContextFunc is a function that takes an existing context and the current
// request and returns a potentially modified context based on the request
// content. This can be used to inject context values from headers, for example.
type HTTPContextFunc func(ctx context.Context, r *http.Request) context.Context

// httpTransportConfigurable is an internal interface for shared HTTP transport configuration.
type httpTransportConfigurable interface {
setBasePath(string)
setDynamicBasePath(DynamicBasePathFunc)
setKeepAliveInterval(time.Duration)
setKeepAlive(bool)
setContextFunc(HTTPContextFunc)
setHTTPServer(*http.Server)
setBaseURL(string)
}

// HTTPTransportOption is a function that configures an httpTransportConfigurable.
type HTTPTransportOption func(httpTransportConfigurable)

// Option interfaces and wrappers for server configuration
// Base option interface
type HTTPServerOption interface {
isHTTPServerOption()
}

// SSE-specific option interface
type SSEOption interface {
HTTPServerOption
applyToSSE(*SSEServer)
}

// StreamableHTTP-specific option interface
type StreamableHTTPOption interface {
HTTPServerOption
applyToStreamableHTTP(*StreamableHTTPServer)
}

// Common options that work with both server types
type CommonHTTPServerOption interface {
SSEOption
StreamableHTTPOption
}

// Wrapper for SSE-specific functional options
type sseOption func(*SSEServer)

func (o sseOption) isHTTPServerOption() {}
func (o sseOption) applyToSSE(s *SSEServer) { o(s) }

// Wrapper for StreamableHTTP-specific functional options
type streamableHTTPOption func(*StreamableHTTPServer)

func (o streamableHTTPOption) isHTTPServerOption() {}
func (o streamableHTTPOption) applyToStreamableHTTP(s *StreamableHTTPServer) { o(s) }

// Refactor commonOption to use a single apply func(httpTransportConfigurable)
type commonOption struct {
apply func(httpTransportConfigurable)
}

func (o commonOption) isHTTPServerOption() {}
func (o commonOption) applyToSSE(s *SSEServer) { o.apply(s) }
func (o commonOption) applyToStreamableHTTP(s *StreamableHTTPServer) { o.apply(s) }

// TODO: This is a stub implementation of StreamableHTTPServer just to show how
// to use it with the new options interfaces.
type StreamableHTTPServer struct{}

// Add stub methods to satisfy httpTransportConfigurable

func (s *StreamableHTTPServer) setBasePath(string) {}
func (s *StreamableHTTPServer) setDynamicBasePath(DynamicBasePathFunc) {}
func (s *StreamableHTTPServer) setKeepAliveInterval(time.Duration) {}
func (s *StreamableHTTPServer) setKeepAlive(bool) {}
func (s *StreamableHTTPServer) setContextFunc(HTTPContextFunc) {}
func (s *StreamableHTTPServer) setHTTPServer(srv *http.Server) {}
func (s *StreamableHTTPServer) setBaseURL(baseURL string) {}

// Ensure the option types implement the correct interfaces
var (
_ httpTransportConfigurable = (*StreamableHTTPServer)(nil)
_ SSEOption = sseOption(nil)
_ StreamableHTTPOption = streamableHTTPOption(nil)
_ CommonHTTPServerOption = commonOption{}
)

// WithStaticBasePath adds a new option for setting a static base path.
// This is useful for mounting the server at a known, fixed path.
func WithStaticBasePath(basePath string) CommonHTTPServerOption {
return commonOption{
apply: func(c httpTransportConfigurable) {
c.setBasePath(basePath)
},
}
}

// DynamicBasePathFunc allows the user to provide a function to generate the
// base path for a given request and sessionID. This is useful for cases where
// the base path is not known at the time of SSE server creation, such as when
// using a reverse proxy or when the base path is dynamically generated. The
// function should return the base path (e.g., "/mcp/tenant123").
type DynamicBasePathFunc func(r *http.Request, sessionID string) string

// WithDynamicBasePath accepts a function for generating the base path.
// This is useful for cases where the base path is not known at the time of server creation,
// such as when using a reverse proxy or when the server is mounted at a dynamic path.
func WithDynamicBasePath(fn DynamicBasePathFunc) CommonHTTPServerOption {
return commonOption{
apply: func(c httpTransportConfigurable) {
c.setDynamicBasePath(fn)
},
}
}

// WithKeepAliveInterval sets the keep-alive interval for the transport.
// When enabled, the server will periodically send ping events to keep the connection alive.
func WithKeepAliveInterval(interval time.Duration) CommonHTTPServerOption {
return commonOption{
apply: func(c httpTransportConfigurable) {
c.setKeepAliveInterval(interval)
},
}
}

// WithKeepAlive enables or disables keep-alive for the transport.
// When enabled, the server will send periodic keep-alive events to clients.
func WithKeepAlive(keepAlive bool) CommonHTTPServerOption {
return commonOption{
apply: func(c httpTransportConfigurable) {
c.setKeepAlive(keepAlive)
},
}
}

// WithHTTPContextFunc sets a function that will be called to customize the context
// for the server using the incoming request. This is useful for injecting
// context values from headers or other request properties.
func WithHTTPContextFunc(fn HTTPContextFunc) CommonHTTPServerOption {
return commonOption{
apply: func(c httpTransportConfigurable) {
c.setContextFunc(fn)
},
}
}

// WithBaseURL sets the base URL for the HTTP transport server.
// This is useful for configuring the externally visible base URL for clients.
func WithBaseURL(baseURL string) CommonHTTPServerOption {
return commonOption{
apply: func(c httpTransportConfigurable) {
if baseURL != "" {
u, err := url.Parse(baseURL)
if err != nil {
return
}
if u.Scheme != "http" && u.Scheme != "https" {
return
}
if u.Host == "" || strings.HasPrefix(u.Host, ":") {
return
}
if len(u.Query()) > 0 {
return
}
}
c.setBaseURL(strings.TrimSuffix(baseURL, "/"))
},
}
}

// WithHTTPServer sets the HTTP server instance for the transport.
// This is useful for advanced scenarios where you want to provide your own http.Server.
func WithHTTPServer(srv *http.Server) CommonHTTPServerOption {
return commonOption{
apply: func(c httpTransportConfigurable) {
c.setHTTPServer(srv)
},
}
}
131 changes: 48 additions & 83 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,6 @@ type sseSession struct {
// content. This can be used to inject context values from headers, for example.
type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context

// DynamicBasePathFunc allows the user to provide a function to generate the
// base path for a given request and sessionID. This is useful for cases where
// the base path is not known at the time of SSE server creation, such as when
// using a reverse proxy or when the base path is dynamically generated. The
// function should return the base path (e.g., "/mcp/tenant123").
type DynamicBasePathFunc func(r *http.Request, sessionID string) string

func (s *sseSession) SessionID() string {
return s.sessionID
}
Expand Down Expand Up @@ -100,7 +93,7 @@ type SSEServer struct {
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc
contextFunc HTTPContextFunc
dynamicBasePathFunc DynamicBasePathFunc

keepAlive bool
Expand All @@ -109,37 +102,41 @@ type SSEServer struct {
mu sync.RWMutex
}

// SSEOption defines a function type for configuring SSEServer
type SSEOption func(*SSEServer)
// Ensure SSEServer implements httpTransportConfigurable
var _ httpTransportConfigurable = (*SSEServer)(nil)

// WithBaseURL sets the base URL for the SSE server
func WithBaseURL(baseURL string) SSEOption {
return func(s *SSEServer) {
if baseURL != "" {
u, err := url.Parse(baseURL)
if err != nil {
return
}
if u.Scheme != "http" && u.Scheme != "https" {
return
}
// Check if the host is empty or only contains a port
if u.Host == "" || strings.HasPrefix(u.Host, ":") {
return
}
if len(u.Query()) > 0 {
return
}
func (s *SSEServer) setBasePath(basePath string) {
s.basePath = normalizeURLPath(basePath)
}

func (s *SSEServer) setDynamicBasePath(fn DynamicBasePathFunc) {
if fn != nil {
s.dynamicBasePathFunc = func(r *http.Request, sid string) string {
bp := fn(r, sid)
return normalizeURLPath(bp)
}
s.baseURL = strings.TrimSuffix(baseURL, "/")
}
}

// WithStaticBasePath adds a new option for setting a static base path
func WithStaticBasePath(basePath string) SSEOption {
return func(s *SSEServer) {
s.basePath = normalizeURLPath(basePath)
}
func (s *SSEServer) setKeepAliveInterval(interval time.Duration) {
s.keepAlive = true
s.keepAliveInterval = interval
}

func (s *SSEServer) setKeepAlive(keepAlive bool) {
s.keepAlive = keepAlive
}

func (s *SSEServer) setContextFunc(fn HTTPContextFunc) {
s.contextFunc = fn
}

func (s *SSEServer) setHTTPServer(srv *http.Server) {
s.srv = srv
}

func (s *SSEServer) setBaseURL(baseURL string) {
s.baseURL = baseURL
}

// WithBasePath adds a new option for setting a static base path.
Expand All @@ -151,26 +148,11 @@ func WithBasePath(basePath string) SSEOption {
return WithStaticBasePath(basePath)
}

// WithDynamicBasePath accepts a function for generating the base path. This is
// useful for cases where the base path is not known at the time of SSE server
// creation, such as when using a reverse proxy or when the server is mounted
// at a dynamic path.
func WithDynamicBasePath(fn DynamicBasePathFunc) SSEOption {
return func(s *SSEServer) {
if fn != nil {
s.dynamicBasePathFunc = func(r *http.Request, sid string) string {
bp := fn(r, sid)
return normalizeURLPath(bp)
}
}
}
}

// WithMessageEndpoint sets the message endpoint path
func WithMessageEndpoint(endpoint string) SSEOption {
return func(s *SSEServer) {
return sseOption(func(s *SSEServer) {
s.messageEndpoint = endpoint
}
})
}

// WithAppendQueryToMessageEndpoint configures the SSE server to append the original request's
Expand All @@ -179,53 +161,37 @@ func WithMessageEndpoint(endpoint string) SSEOption {
// SSE connection request and carry them over to subsequent message requests, maintaining
// context or authentication details across the communication channel.
func WithAppendQueryToMessageEndpoint() SSEOption {
return func(s *SSEServer) {
return sseOption(func(s *SSEServer) {
s.appendQueryToMessageEndpoint = true
}
})
}

// WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL)
// or just the path portion for the message endpoint. Set to false when clients will concatenate
// the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message".
func WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint bool) SSEOption {
return func(s *SSEServer) {
return sseOption(func(s *SSEServer) {
s.useFullURLForMessageEndpoint = useFullURLForMessageEndpoint
}
})
}

// WithSSEEndpoint sets the SSE endpoint path
func WithSSEEndpoint(endpoint string) SSEOption {
return func(s *SSEServer) {
return sseOption(func(s *SSEServer) {
s.sseEndpoint = endpoint
}
}

// WithHTTPServer sets the HTTP server instance
func WithHTTPServer(srv *http.Server) SSEOption {
return func(s *SSEServer) {
s.srv = srv
}
}

func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption {
return func(s *SSEServer) {
s.keepAlive = true
s.keepAliveInterval = keepAliveInterval
}
}

func WithKeepAlive(keepAlive bool) SSEOption {
return func(s *SSEServer) {
s.keepAlive = keepAlive
}
})
}

// WithSSEContextFunc sets a function that will be called to customise the context
// to the server using the incoming request.
//
// Deprecated: Use WithContextFunc instead. This will be removed in a future version.
//
//go:deprecated
func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
return func(s *SSEServer) {
s.contextFunc = fn
}
return sseOption(func(s *SSEServer) {
WithHTTPContextFunc(HTTPContextFunc(fn)).applyToSSE(s)
})
}

// NewSSEServer creates a new SSE server instance with the given MCP server and options.
Expand All @@ -241,16 +207,15 @@ func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {

// Apply all options
for _, opt := range opts {
opt(s)
opt.applyToSSE(s)
}

return s
}

// NewTestServer creates a test server for testing purposes
// NewTestServer creates a test server for testing purposes.
func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server {
sseServer := NewSSEServer(server, opts...)

testServer := httptest.NewServer(sseServer)
sseServer.baseURL = testServer.URL
return testServer
Expand Down