Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions examples/custom_sse_pattern/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package main

import (
"context"
"fmt"
"log"
"net/http"

"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)

// Custom context function for SSE connections
func customContextFunc(ctx context.Context, r *http.Request) context.Context {
params := server.GetRouteParams(ctx)
log.Printf("SSE Connection Established - Route Parameters: %+v", params)
log.Printf("Request Path: %s", r.URL.Path)
return ctx
}

// Message handler for simulating message sending
func messageHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Get channel parameter from context
channel := server.GetRouteParam(ctx, "channel")
log.Printf("Processing Message - Channel Parameter: %s", channel)

if channel == "" {
return mcp.NewToolResultText("Failed to get channel parameter"), nil
}

message := fmt.Sprintf("Message sent to channel: %s", channel)
return mcp.NewToolResultText(message), nil
}

func main() {
// Create MCP Server
mcpServer := server.NewMCPServer("test-server", "1.0.0")

// Register test tool
mcpServer.AddTool(mcp.NewTool("send_message"), messageHandler)

// Create SSE Server with custom route pattern
sseServer := server.NewSSEServer(mcpServer,
server.WithBaseURL("http://localhost:8080"),
server.WithSSEPattern("/:channel/sse"),
server.WithSSEContextFunc(customContextFunc),
)

// Start server
log.Printf("Server started on port :8080")
log.Printf("Test URL: http://localhost:8080/test/sse")
log.Printf("Test URL: http://localhost:8080/news/sse")

if err := sseServer.Start(":8080"); err != nil {
log.Fatalf("Server error: %v", err)
}
}
157 changes: 124 additions & 33 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,36 @@ type sseSession struct {
sessionID string
notificationChannel chan mcp.JSONRPCNotification
initialized atomic.Bool
routeParams RouteParams // Store route parameters in session
}

// SSEContextFunc is a function that takes an existing context and the current
// request and returns a potentially modified context based on the request
// content. This can be used to inject context values from headers, for example.
type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context

// RouteParamsKey is the key type for storing route parameters in context
type RouteParamsKey struct{}

// RouteParams stores path parameters
type RouteParams map[string]string

Comment on lines +36 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid exporting RouteParamsKey and RouteParams unless they are part of the public API

Both symbols are exported, yet they are only consumed inside server package (and the example).
Un‑exporting prevents accidental key collisions across packages and keeps your API surface minimal.

-type RouteParamsKey struct{}
-type RouteParams map[string]string
+type routeParamsKey struct{}
+type RouteParams map[string]string // keep this exported only if other pkgs really need it

If RouteParams itself is also private to server, rename it similarly (type routeParams map[string]string).
Update the occurrences accordingly.

// GetRouteParam retrieves a route parameter from context
func GetRouteParam(ctx context.Context, key string) string {
if params, ok := ctx.Value(RouteParamsKey{}).(RouteParams); ok {
return params[key]
}
return ""
}

// GetRouteParams retrieves all route parameters from context
func GetRouteParams(ctx context.Context) RouteParams {
if params, ok := ctx.Value(RouteParamsKey{}).(RouteParams); ok {
return params
}
return RouteParams{}
}

func (s *sseSession) SessionID() string {
return s.sessionID
}
Expand All @@ -53,18 +76,18 @@ var _ ClientSession = (*sseSession)(nil)
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
// It provides real-time communication capabilities over HTTP using the SSE protocol.
type SSEServer struct {
server *MCPServer
baseURL string
basePath string
useFullURLForMessageEndpoint bool
messageEndpoint string
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc

keepAlive bool
keepAliveInterval time.Duration
server *MCPServer
baseURL string
basePath string
useFullURLForMessageEndpoint bool
messageEndpoint string
sseEndpoint string
ssePattern string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc
keepAlive bool
keepAliveInterval time.Duration
}

// SSEOption defines a function type for configuring SSEServer
Expand Down Expand Up @@ -127,6 +150,13 @@ func WithSSEEndpoint(endpoint string) SSEOption {
}
}

// WithSSEPattern sets the SSE endpoint pattern with route parameters
func WithSSEPattern(pattern string) SSEOption {
return func(s *SSEServer) {
s.ssePattern = pattern
}
}

// WithHTTPServer sets the HTTP server instance
func WithHTTPServer(srv *http.Server) SSEOption {
return func(s *SSEServer) {
Expand All @@ -147,8 +177,7 @@ func WithKeepAlive(keepAlive bool) SSEOption {
}
}

// WithContextFunc sets a function that will be called to customise the context
// to the server using the incoming request.
// WithSSEContextFunc sets a function that will be called to customise the context
func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
return func(s *SSEServer) {
s.contextFunc = fn
Expand All @@ -158,12 +187,12 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
s := &SSEServer{
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/message",
useFullURLForMessageEndpoint: true,
keepAlive: false,
keepAliveInterval: 10 * time.Second,
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/message",
useFullURLForMessageEndpoint: true,
keepAlive: false,
keepAliveInterval: 10 * time.Second,
}

// Apply all options
Expand Down Expand Up @@ -241,12 +270,21 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
eventQueue: make(chan string, 100), // Buffer for events
sessionID: sessionID,
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
routeParams: GetRouteParams(r.Context()), // Store route parameters from context
}

s.sessions.Store(sessionID, session)
defer s.sessions.Delete(sessionID)

if err := s.server.RegisterSession(r.Context(), session); err != nil {
// Create base context with session
ctx := s.server.WithContext(r.Context(), session)

// Apply custom context function if set
if s.contextFunc != nil {
ctx = s.contextFunc(ctx, r)
}

if err := s.server.RegisterSession(ctx, session); err != nil {
http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError)
return
}
Expand All @@ -268,7 +306,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
}
case <-session.done:
return
case <-r.Context().Done():
case <-ctx.Done():
return
}
}
Expand All @@ -286,14 +324,13 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
session.eventQueue <- fmt.Sprintf(":ping - %s\n\n", time.Now().Format(time.RFC3339))
case <-session.done:
return
case <-r.Context().Done():
case <-ctx.Done():
return
}
}
}()
}


// Send the initial endpoint event
fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID))
flusher.Flush()
Expand All @@ -305,7 +342,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
// Write the event to the response
fmt.Fprint(w, event)
flusher.Flush()
case <-r.Context().Done():
case <-ctx.Done():
close(session.done)
return
}
Expand Down Expand Up @@ -343,8 +380,15 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
}
session := sessionI.(*sseSession)

// Set the client context before handling the message
// Create base context with session
ctx := s.server.WithContext(r.Context(), session)

// Add stored route parameters to context
if len(session.routeParams) > 0 {
ctx = context.WithValue(ctx, RouteParamsKey{}, session.routeParams)
}

// Apply custom context function if set
if s.contextFunc != nil {
ctx = s.contextFunc(ctx, r)
}
Expand All @@ -356,7 +400,7 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
return
}

// Process message through MCPServer
// Process message through MCPServer with the context containing route parameters
response := s.server.HandleMessage(ctx, rawMessage)

// Only send response if there is one (not for notifications)
Expand Down Expand Up @@ -423,6 +467,7 @@ func (s *SSEServer) SendEventToSession(
return fmt.Errorf("event queue full")
}
}

func (s *SSEServer) GetUrlPath(input string) (string, error) {
parse, err := url.Parse(input)
if err != nil {
Expand All @@ -434,6 +479,7 @@ func (s *SSEServer) GetUrlPath(input string) (string, error) {
func (s *SSEServer) CompleteSseEndpoint() string {
return s.baseURL + s.basePath + s.sseEndpoint
}

func (s *SSEServer) CompleteSsePath() string {
path, err := s.GetUrlPath(s.CompleteSseEndpoint())
if err != nil {
Expand All @@ -445,6 +491,7 @@ func (s *SSEServer) CompleteSsePath() string {
func (s *SSEServer) CompleteMessageEndpoint() string {
return s.baseURL + s.basePath + s.messageEndpoint
}

func (s *SSEServer) CompleteMessagePath() string {
path, err := s.GetUrlPath(s.CompleteMessageEndpoint())
if err != nil {
Expand All @@ -456,17 +503,61 @@ func (s *SSEServer) CompleteMessagePath() string {
// ServeHTTP implements the http.Handler interface.
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// Use exact path matching rather than Contains
ssePath := s.CompleteSsePath()
if ssePath != "" && path == ssePath {
s.handleSSE(w, r)
return
}
messagePath := s.CompleteMessagePath()

// Handle message endpoint
if messagePath != "" && path == messagePath {
s.handleMessage(w, r)
return
}

// Handle SSE endpoint with route parameters
if s.ssePattern != "" {
// Try pattern matching if pattern is set
fullPattern := s.basePath + s.ssePattern
matches, params := matchPath(fullPattern, path)
if matches {
// Create new context with route parameters
ctx := context.WithValue(r.Context(), RouteParamsKey{}, params)
s.handleSSE(w, r.WithContext(ctx))
return
}
// If pattern is set but doesn't match, return 404
http.NotFound(w, r)
return
}

// If no pattern is set, use the default SSE endpoint
ssePath := s.CompleteSsePath()
if ssePath != "" && path == ssePath {
s.handleSSE(w, r)
return
}

http.NotFound(w, r)
}

// matchPath checks if the given path matches the pattern and extracts parameters
// pattern format: /user/:id/profile/:type
func matchPath(pattern, path string) (bool, RouteParams) {
patternParts := strings.Split(strings.Trim(pattern, "/"), "/")
pathParts := strings.Split(strings.Trim(path, "/"), "/")

if len(patternParts) != len(pathParts) {
return false, nil
}

params := make(RouteParams)
for i, part := range patternParts {
if strings.HasPrefix(part, ":") {
// This is a parameter
paramName := strings.TrimPrefix(part, ":")
params[paramName] = pathParts[i]
} else if part != pathParts[i] {
// Static part doesn't match
return false, nil
}
}

return true, params
}