Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ type StreamableHTTPOptions struct {
// If nil, do not log.
Logger *slog.Logger

// TODO(rfindley): file a proposal to export this option, or something equivalent.
configureTransport func(req *http.Request, transport *StreamableServerTransport)
// EventStore enables stream resumption.
//
// If set, EventStore will be used to persist stream events and replay them
// upon stream resumption.
EventStore EventStore
}

// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler].
Expand Down Expand Up @@ -237,12 +240,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
transport = &StreamableServerTransport{
SessionID: sessionID,
Stateless: h.opts.Stateless,
EventStore: h.opts.EventStore,
jsonResponse: h.opts.JSONResponse,
logger: h.opts.Logger,
}
if h.opts.configureTransport != nil {
h.opts.configureTransport(req, transport)
}

// To support stateless mode, we initialize the session with a default
// state, so that it doesn't reject subsequent requests.
Expand Down
25 changes: 9 additions & 16 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,13 @@ func TestStreamableTransports(t *testing.T) {

// Start an httptest.Server with the StreamableHTTPHandler, wrapped in a
// cookie-checking middleware.
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{
opts := &StreamableHTTPOptions{
JSONResponse: test.useJSON,
configureTransport: func(_ *http.Request, transport *StreamableServerTransport) {
if test.replay {
transport.EventStore = NewMemoryEventStore(nil)
}
},
})
}
if test.replay {
opts.EventStore = NewMemoryEventStore(nil)
}
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, opts)

var (
headerMu sync.Mutex
Expand Down Expand Up @@ -386,9 +385,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
})

realServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
configureTransport: func(_ *http.Request, t *StreamableServerTransport) {
t.EventStore = NewMemoryEventStore(nil) // necessary for replay
},
EventStore: NewMemoryEventStore(nil), // necessary for replay
})))
t.Cleanup(func() {
t.Log("Closing real HTTP server")
Expand Down Expand Up @@ -567,9 +564,7 @@ func TestServerInitiatedSSE(t *testing.T) {
// However, it shouldn't be necessary to use replay here, as we should be
// guaranteed that the standalone SSE stream is started by the time the
// client is connected.
configureTransport: func(_ *http.Request, transport *StreamableServerTransport) {
transport.EventStore = NewMemoryEventStore(nil)
},
EventStore: NewMemoryEventStore(nil),
}
httpServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, opts)))
defer httpServer.Close()
Expand Down Expand Up @@ -942,9 +937,7 @@ func TestStreamableServerTransport(t *testing.T) {

opts := &StreamableHTTPOptions{}
if test.replay {
opts.configureTransport = func(_ *http.Request, t *StreamableServerTransport) {
t.EventStore = NewMemoryEventStore(nil)
}
opts.EventStore = NewMemoryEventStore(nil)
}
// Start the streamable handler.
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, opts)
Expand Down