Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 52 additions & 13 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"sync"
"sync/atomic"
"time"

"github.com/google/uuid"
"github.com/mark3labs/mcp-go/mcp"
Expand Down Expand Up @@ -52,15 +53,18 @@ var _ ClientSession = (*sseSession)(nil)
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
// It provides real-time communication capabilities over HTTP using the SSE protocol.
type SSEServer struct {
server *MCPServer
baseURL string
basePath string
messageEndpoint string
useFullURLForMessageEndpoint bool
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc
server *MCPServer
baseURL string
basePath string
useFullURLForMessageEndpoint bool
messageEndpoint string
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc

keepAlive bool
keepAliveInterval time.Duration
}

// SSEOption defines a function type for configuring SSEServer
Expand Down Expand Up @@ -130,6 +134,19 @@ func WithHTTPServer(srv *http.Server) SSEOption {
}
}

func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption {
return func(s *SSEServer) {
s.keepAlive = true
s.keepAliveInterval = keepAliveInterval
}
}

func WithKeepAlive(keepAlive bool) SSEOption {
return func(s *SSEServer) {
s.keepAlive = keepAlive
}
}

// WithContextFunc sets a function that will be called to customise the context
// to the server using the incoming request.
func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
Expand All @@ -141,10 +158,12 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
s := &SSEServer{
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/messages",
useFullURLForMessageEndpoint: true,
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/message",
useFullURLForMessageEndpoint: true,
keepAlive: false,
keepAliveInterval: 10 * time.Second,
}

// Apply all options
Expand Down Expand Up @@ -255,6 +274,26 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
}
}()

// Start keep alive : ping
if s.keepAlive {
go func() {
ticker := time.NewTicker(s.keepAliveInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
//: ping - 2025-03-27 07:44:38.682659+00:00
session.eventQueue <- fmt.Sprintf(":ping - %s\n\n", time.Now().Format(time.RFC3339))
case <-session.done:
return
case <-r.Context().Done():
return
}
}
}()
}


// Send the initial endpoint event
fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID))
flusher.Flush()
Expand Down