Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions server/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"mime"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -93,6 +94,15 @@ func WithLogger(logger util.Logger) StreamableHTTPOption {
}
}

// WithTLSCert sets the TLS certificate and key files for HTTPS support.
// Both certFile and keyFile must be provided to enable TLS.
func WithTLSCert(certFile, keyFile string) StreamableHTTPOption {
return func(s *StreamableHTTPServer) {
s.tlsCertFile = certFile
s.tlsKeyFile = keyFile
}
}

// StreamableHTTPServer implements a Streamable-http based MCP server.
// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams.
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
Expand Down Expand Up @@ -131,6 +141,9 @@ type StreamableHTTPServer struct {
listenHeartbeatInterval time.Duration
logger util.Logger
sessionLogLevels *sessionLogLevelsStore

tlsCertFile string
tlsKeyFile string
}

// NewStreamableHTTPServer creates a new streamable-http server instance
Expand Down Expand Up @@ -188,6 +201,19 @@ func (s *StreamableHTTPServer) Start(addr string) error {
srv := s.httpServer
s.mu.Unlock()

if s.tlsCertFile != "" || s.tlsKeyFile != "" {
if s.tlsCertFile == "" || s.tlsKeyFile == "" {
return fmt.Errorf("both TLS cert and key must be provided")
}
if _, err := os.Stat(s.tlsCertFile); err != nil {
return fmt.Errorf("failed to find TLS certificate file: %w", err)
}
if _, err := os.Stat(s.tlsKeyFile); err != nil {
return fmt.Errorf("failed to find TLS key file: %w", err)
}
return srv.ListenAndServeTLS(s.tlsCertFile, s.tlsKeyFile)
}

return srv.ListenAndServe()
}

Expand Down Expand Up @@ -237,9 +263,9 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
}

// Check if this is a sampling response (has result/error but no method)
isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil &&
isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil &&
(jsonMessage.Result != nil || jsonMessage.Error != nil)

isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize

// Handle sampling responses separately
Expand Down Expand Up @@ -390,7 +416,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
return
}
defer s.server.UnregisterSession(r.Context(), sessionID)

// Register session for sampling response delivery
s.activeSessions.Store(sessionID, session)
defer s.activeSessions.Delete(sessionID)
Expand Down Expand Up @@ -743,18 +769,18 @@ type streamableHttpSession struct {
logLevels *sessionLogLevelsStore

// Sampling support for bidirectional communication
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
samplingRequests sync.Map // requestID -> pending sampling request context
requestIDCounter atomic.Int64 // for generating unique request IDs
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
samplingRequests sync.Map // requestID -> pending sampling request context
requestIDCounter atomic.Int64 // for generating unique request IDs
}

func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession {
s := &streamableHttpSession{
sessionID: sessionID,
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
tools: toolStore,
logLevels: levels,
samplingRequestChan: make(chan samplingRequestItem, 10),
sessionID: sessionID,
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
tools: toolStore,
logLevels: levels,
samplingRequestChan: make(chan samplingRequestItem, 10),
}
return s
}
Expand Down Expand Up @@ -810,21 +836,21 @@ var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
// Generate unique request ID
requestID := s.requestIDCounter.Add(1)

// Create response channel for this specific request
responseChan := make(chan samplingResponseItem, 1)

// Create the sampling request item
samplingRequest := samplingRequestItem{
requestID: requestID,
request: request,
response: responseChan,
}

// Store the pending request
s.samplingRequests.Store(requestID, responseChan)
defer s.samplingRequests.Delete(requestID)

// Send the sampling request via the channel (non-blocking)
select {
case s.samplingRequestChan <- samplingRequest:
Expand All @@ -834,7 +860,7 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp
default:
return nil, fmt.Errorf("sampling request queue is full - server overloaded")
}

// Wait for response or context cancellation
select {
case response := <-responseChan:
Expand Down
20 changes: 20 additions & 0 deletions server/streamable_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,26 @@ func TestStreamableHTTP_HeaderPassthrough(t *testing.T) {
}
}

func TestStreamableHTTPServer_TLS(t *testing.T) {
t.Run("TLS options are set correctly", func(t *testing.T) {
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
certFile := "/path/to/cert.pem"
keyFile := "/path/to/key.pem"

server := NewStreamableHTTPServer(
mcpServer,
WithTLSCert(certFile, keyFile),
)

if server.tlsCertFile != certFile {
t.Errorf("Expected tlsCertFile to be %s, got %s", certFile, server.tlsCertFile)
}
if server.tlsKeyFile != keyFile {
t.Errorf("Expected tlsKeyFile to be %s, got %s", keyFile, server.tlsKeyFile)
}
})
}

func postJSON(url string, bodyObject any) (*http.Response, error) {
jsonBody, _ := json.Marshal(bodyObject)
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))
Expand Down
1 change: 1 addition & 0 deletions www/docs/pages/servers/basics.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ Configure transport-specific options:
httpServer := server.NewStreamableHTTPServer(s,
server.WithEndpointPath("/mcp"),
server.WithStateless(true),
server.WithTLSCert("/path/to/cert.pem", "/path/to/key.pem"),
)

if err := httpServer.Start(":8080"); err != nil {
Expand Down