diff --git a/mcp/client.go b/mcp/client.go index fb77b3f3..d7e3ae5a 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -11,6 +11,7 @@ import ( "iter" "slices" "sync" + "sync/atomic" "time" "github.com/google/jsonschema-go/jsonschema" @@ -177,7 +178,11 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio // Call [ClientSession.Close] to close the connection, or await server // termination with [ClientSession.Wait]. type ClientSession struct { - onClose func() + // Ensure that onClose is called at most once. + // We defensively use an atomic CompareAndSwap rather than a sync.Once, in case the + // onClose callback triggers a re-entrant call to Close. + calledOnClose atomic.Bool + onClose func() conn *jsonrpc2.Connection client *Client @@ -205,6 +210,8 @@ func (cs *ClientSession) ID() string { // Close performs a graceful close of the connection, preventing new requests // from being handled, and waiting for ongoing requests to return. Close then // terminates the connection. +// +// Close is idempotent and concurrency safe. func (cs *ClientSession) Close() error { // Note: keepaliveCancel access is safe without a mutex because: // 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls) @@ -216,7 +223,7 @@ func (cs *ClientSession) Close() error { } err := cs.conn.Close() - if cs.onClose != nil { + if cs.onClose != nil && cs.calledOnClose.CompareAndSwap(false, true) { cs.onClose() } diff --git a/mcp/server.go b/mcp/server.go index 17b68c44..29be8ff1 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -19,6 +19,7 @@ import ( "reflect" "slices" "sync" + "sync/atomic" "time" "github.com/google/jsonschema-go/jsonschema" @@ -825,7 +826,7 @@ func (s *Server) disconnect(cc *ServerSession) { type ServerSessionOptions struct { State *ServerSessionState - onClose func() + onClose func() // used to clean up associated resources } // Connect connects the MCP server over the given transport and starts handling @@ -920,7 +921,11 @@ func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { // Call [ServerSession.Close] to close the connection, or await client // termination with [ServerSession.Wait]. type ServerSession struct { - onClose func() + // Ensure that onClose is called at most once. + // We defensively use an atomic CompareAndSwap rather than a sync.Once, in case the + // onClose callback triggers a re-entrant call to Close. + calledOnClose atomic.Bool + onClose func() server *Server conn *jsonrpc2.Connection @@ -1185,6 +1190,8 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLoggingLevelPara // Close performs a graceful shutdown of the connection, preventing new // requests from being handled, and waiting for ongoing requests to return. // Close then terminates the connection. +// +// Close is idempotent and concurrency safe. func (ss *ServerSession) Close() error { if ss.keepaliveCancel != nil { // Note: keepaliveCancel access is safe without a mutex because: @@ -1196,7 +1203,7 @@ func (ss *ServerSession) Close() error { } err := ss.conn.Close() - if ss.onClose != nil { + if ss.onClose != nil && ss.calledOnClose.CompareAndSwap(false, true) { ss.onClose() } diff --git a/mcp/streamable.go b/mcp/streamable.go index 8773ea70..d8ce45e7 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -12,6 +12,7 @@ import ( "fmt" "io" "log/slog" + "maps" "math" "math/rand/v2" "net/http" @@ -40,12 +41,76 @@ type StreamableHTTPHandler struct { getServer func(*http.Request) *Server opts StreamableHTTPOptions - onTransportDeletion func(sessionID string) // for testing only + onTransportDeletion func(sessionID string) // for testing - mu sync.Mutex - // TODO: we should store the ServerSession along with the transport, because - // we need to cancel keepalive requests when closing the transport. - transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) + mu sync.Mutex + sessions map[string]*sessionInfo // keyed by session ID +} + +type sessionInfo struct { + session *ServerSession + transport *StreamableServerTransport + + // If timeout is set, automatically close the session after an idle period. + timeout time.Duration + timerMu sync.Mutex + refs int // reference count + timer *time.Timer +} + +// startPOST signals that a POST request for this session is starting (which +// carries a client->server message), pausing the session timeout if it was +// running. +// +// TODO: we may want to also pause the timer when resuming non-standalone SSE +// streams, but that is tricy to implement. Clients should generally make +// keepalive pings if they want to keep the session live. +func (i *sessionInfo) startPOST() { + if i.timeout <= 0 { + return + } + + i.timerMu.Lock() + defer i.timerMu.Unlock() + + if i.timer == nil { + return // timer stopped permanently + } + if i.refs == 0 { + i.timer.Stop() + } + i.refs++ +} + +// endPOST sigals that a request for this session is ending, starting the +// timeout if there are no other requests running. +func (i *sessionInfo) endPOST() { + if i.timeout <= 0 { + return + } + + i.timerMu.Lock() + defer i.timerMu.Unlock() + + if i.timer == nil { + return // timer stopped permanently + } + + i.refs-- + assert(i.refs >= 0, "negative ref count") + if i.refs == 0 { + i.timer.Reset(i.timeout) + } +} + +// stopTimer stops the inactivity timer permanently. +func (i *sessionInfo) stopTimer() { + i.timerMu.Lock() + defer i.timerMu.Unlock() + if i.timer != nil { + i.timer.Stop() + i.timer = nil + } } // StreamableHTTPOptions configures the StreamableHTTPHandler. @@ -77,6 +142,14 @@ type StreamableHTTPOptions struct { // If set, EventStore will be used to persist stream events and replay them // upon stream resumption. EventStore EventStore + + // SessionTimeout configures a timeout for idle sessions. + // + // When sessions receive no new HTTP requests from the client for this + // duration, they are automatically closed. + // + // If SessionTimeout is the zero value, idle sessions are never closed. + SessionTimeout time.Duration } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -86,8 +159,8 @@ type StreamableHTTPOptions struct { // If getServer returns nil, a 400 Bad Request will be served. func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler { h := &StreamableHTTPHandler{ - getServer: getServer, - transports: make(map[string]*StreamableServerTransport), + getServer: getServer, + sessions: make(map[string]*sessionInfo), } if opts != nil { h.opts = *opts @@ -100,7 +173,7 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea return h } -// closeAll closes all ongoing sessions. +// closeAll closes all ongoing sessions, for tests. // // TODO(rfindley): investigate the best API for callers to configure their // session lifecycle. (?) @@ -108,12 +181,19 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea // Should we allow passing in a session store? That would allow the handler to // be stateless. func (h *StreamableHTTPHandler) closeAll() { + // TODO: if we ever expose this outside of tests, we'll need to do better + // than simply collecting sessions while holding the lock: we need to prevent + // new sessions from being added. + // + // Currently, sessions remove themselves from h.sessions when closed, so we + // can't call Close while holding the lock. h.mu.Lock() - defer h.mu.Unlock() - for _, s := range h.transports { - s.connection.Close() + sessionInfos := slices.Collect(maps.Values(h.sessions)) + h.sessions = nil + h.mu.Unlock() + for _, s := range sessionInfos { + s.session.Close() } - h.transports = nil } func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -144,12 +224,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } sessionID := req.Header.Get(sessionIDHeader) - var transport *StreamableServerTransport + var sessInfo *sessionInfo if sessionID != "" { h.mu.Lock() - transport = h.transports[sessionID] + sessInfo = h.sessions[sessionID] h.mu.Unlock() - if transport == nil && !h.opts.Stateless { + if sessInfo == nil && !h.opts.Stateless { // Unless we're in 'stateless' mode, which doesn't perform any Session-ID // validation, we require that the session ID matches a known session. // @@ -164,11 +244,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "Bad Request: DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } - if transport != nil { // transport may be nil in stateless mode - h.mu.Lock() - delete(h.transports, transport.SessionID) - h.mu.Unlock() - transport.connection.Close() + if sessInfo != nil { // sessInfo may be nil in stateless mode + // Closing the session also removes it from h.sessions, due to the + // onClose callback. + sessInfo.session.Close() } w.WriteHeader(http.StatusNoContent) return @@ -225,7 +304,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } - if transport == nil { + if sessInfo == nil { server := h.getServer(req) if server == nil { // The getServer argument to NewStreamableHTTPHandler returned nil. @@ -237,7 +316,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // existing transport. sessionID = server.opts.GetSessionID() } - transport = &StreamableServerTransport{ + transport := &StreamableServerTransport{ SessionID: sessionID, Stateless: h.opts.Stateless, EventStore: h.opts.EventStore, @@ -301,10 +380,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque connectOpts = &ServerSessionOptions{ onClose: func() { h.mu.Lock() - delete(h.transports, transport.SessionID) - h.mu.Unlock() - if h.onTransportDeletion != nil { - h.onTransportDeletion(transport.SessionID) + defer h.mu.Unlock() + if info, ok := h.sessions[transport.SessionID]; ok { + info.stopTimer() + delete(h.sessions, transport.SessionID) + if h.onTransportDeletion != nil { + h.onTransportDeletion(transport.SessionID) + } } }, } @@ -313,23 +395,44 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the // long-running stream. - ss, err := server.Connect(req.Context(), transport, connectOpts) + session, err := server.Connect(req.Context(), transport, connectOpts) if err != nil { http.Error(w, "failed connection", http.StatusInternalServerError) return } + sessInfo = &sessionInfo{ + session: session, + transport: transport, + } + if h.opts.Stateless { // Stateless mode: close the session when the request exits. - defer ss.Close() // close the fake session after handling the request + defer session.Close() // close the fake session after handling the request } else { // Otherwise, save the transport so that it can be reused + + // Clean up the session when it times out. + // + // Note that the timer here may fire multiple times, but + // sessInfo.session.Close is idempotent. + if h.opts.SessionTimeout > 0 { + sessInfo.timeout = h.opts.SessionTimeout + sessInfo.timer = time.AfterFunc(sessInfo.timeout, func() { + sessInfo.session.Close() + }) + } h.mu.Lock() - h.transports[transport.SessionID] = transport + h.sessions[transport.SessionID] = sessInfo h.mu.Unlock() } } - transport.ServeHTTP(w, req) + if req.Method == http.MethodPost { + sessInfo.startPOST() + defer sessInfo.endPOST() + } + + sessInfo.transport.ServeHTTP(w, req) } // A StreamableServerTransport implements the server side of the MCP streamable @@ -1383,9 +1486,12 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e go c.handleJSON(requestSummary, resp) case "text/event-stream": - jsonReq, _ := msg.(*jsonrpc.Request) + var forCall *jsonrpc.Request + if jsonReq, ok := msg.(*jsonrpc.Request); ok && jsonReq.IsCall() { + forCall = jsonReq + } // TODO: should we cancel this logical SSE request if/when jsonReq is canceled? - go c.handleSSE(requestSummary, resp, false, jsonReq) + go c.handleSSE(requestSummary, resp, false, forCall) default: resp.Body.Close() @@ -1435,9 +1541,9 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp // handleSSE manages the lifecycle of an SSE connection. It can be either // persistent (for the main GET listener) or temporary (for a POST response). // -// If forReq is set, it is the request that initiated the stream, and the +// If forCall is set, it is the call that initiated the stream, and the // stream is complete when we receive its response. -func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) { +func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forCall *jsonrpc2.Request) { resp := initialResp var lastEventID string for { @@ -1447,7 +1553,7 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt // Eventually, if we don't get the response, we should stop trying and // fail the request. if resp != nil { - eventID, clientClosed := c.processStream(requestSummary, resp, forReq) + eventID, clientClosed := c.processStream(requestSummary, resp, forCall) lastEventID = eventID // If the connection was closed by the client, we're done. @@ -1510,11 +1616,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt // incoming channel. It returns the ID of the last processed event and a flag // indicating if the connection was closed by the client. If resp is nil, it // returns "", false. -func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) { +func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, clientClosed bool) { defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { - return lastEventID, false + break } if evt.ID != "" { @@ -1529,10 +1635,10 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R select { case c.incoming <- msg: - if jsonResp, ok := msg.(*jsonrpc.Response); ok && forReq != nil { + if jsonResp, ok := msg.(*jsonrpc.Response); ok && forCall != nil { // TODO: we should never get a response when forReq is nil (the standalone SSE request). // We should detect this case. - if jsonResp.ID == forReq.ID { + if jsonResp.ID == forCall.ID { return "", true } } @@ -1542,7 +1648,20 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R } } // The loop finished without an error, indicating the server closed the stream. - return "", false + // + // If the lastEventID is "", the stream is not retryable and we should + // report a synthetic error for the call. + if lastEventID == "" && forCall != nil { + errmsg := &jsonrpc2.Response{ + ID: forCall.ID, + Error: fmt.Errorf("request terminated without response"), + } + select { + case c.incoming <- errmsg: + case <-c.done: + } + } + return lastEventID, false } // reconnect handles the logic of retrying a connection with an exponential diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index faca04c6..4a4f5c65 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -34,7 +34,6 @@ type streamableResponse struct { body string // or "" optional bool // if set, request need not be sent wantProtocolVersion string // if "", unchecked - callback func() // if set, called after the request is handled } type fakeResponses map[streamableRequestKey]*streamableResponse @@ -96,9 +95,6 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques http.Error(w, "no response", http.StatusInternalServerError) return } - if resp.callback != nil { - defer resp.callback() - } for k, v := range resp.header { w.Header().Set(k, v) } @@ -411,3 +407,38 @@ func TestStreamableClientStrictness(t *testing.T) { }) } } + +func TestStreamableClientUnresumableRequest(t *testing.T) { + // This test verifies that the client fails fast when making a request that + // is unresumable, because it does not contain any events. + ctx := context.Background() + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "text/event-stream", + sessionIDHeader: "123", + }, + body: "", + }, + {"DELETE", "123", ""}: {optional: true}, + }, + } + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + cs, err := client.Connect(ctx, transport, nil) + if err == nil { + cs.Close() + t.Fatalf("Connect succeeded unexpectedly") + } + // This may be a bit of a change detector, but for now check that we're + // actually exercising the early failure codepath. + msg := "terminated without response" + if !strings.Contains(err.Error(), msg) { + t.Errorf("Connect: got error %v, want containing %q", err, msg) + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index a2e5c73c..0f38a0f4 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -20,6 +20,7 @@ import ( "net/url" "os" "runtime" + "slices" "sort" "strings" "sync" @@ -581,8 +582,8 @@ func TestServerTransportCleanup(t *testing.T) { } handler.mu.Lock() - if len(handler.transports) != 0 { - t.Errorf("want empty transports map, find %v entries from handler's transports map", len(handler.transports)) + if len(handler.sessions) != 0 { + t.Errorf("want empty transports map, find %v entries from handler's transports map", len(handler.sessions)) } handler.mu.Unlock() } @@ -1623,7 +1624,93 @@ func TestStreamableClientContextPropagation(t *testing.T) { case <-time.After(100 * time.Millisecond): t.Error("Connection context was not cancelled when parent was cancelled") } +} + +func TestStreamableSessionTimeout(t *testing.T) { + // TODO: this test relies on timing and may be flaky. + // Fixing with testing/synctest is challenging because it uses real I/O (via + // httptest.NewServer). + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + server := NewServer(testImpl, nil) + + deleted := make(chan string, 1) + handler := NewStreamableHTTPHandler( + func(req *http.Request) *Server { return server }, + &StreamableHTTPOptions{ + SessionTimeout: 50 * time.Millisecond, + }, + ) + handler.onTransportDeletion = func(sessionID string) { + deleted <- sessionID + } + + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + // Connect a client to create a session. + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + + sessionID := session.ID() + if sessionID == "" { + t.Fatal("client session has empty ID") + } + + // Verify the session exists on the server. + serverSessions := slices.Collect(server.Sessions()) + if len(serverSessions) != 1 { + t.Fatalf("got %d sessions, want 1", len(serverSessions)) + } + if got := serverSessions[0].ID(); got != sessionID { + t.Fatalf("server session is %q, want %q", got, sessionID) + } + + // Test that (possibly concurrent) requests keep the session alive. + // + // Spin up two goroutines, each making a request every 10ms. These requests + // should keep the server from timing out. + var wg sync.WaitGroup + wg.Add(2) + for range 2 { + go func() { + defer wg.Done() + + for range 20 { + if _, err := session.ListTools(ctx, nil); err != nil { + t.Errorf("ListTools failed: %v", err) + } + time.Sleep(10 * time.Millisecond) + } + }() + } + + wg.Wait() + + // Wait for the session to be cleaned up. + select { + case deletedID := <-deleted: + if deletedID != sessionID { + t.Errorf("deleted session ID = %q, want %q", deletedID, sessionID) + } + case <-ctx.Done(): + t.Fatal("timed out waiting for session cleanup") + } + + // Verify the session is gone from both handler and server. + handler.mu.Lock() + if len(handler.sessions) != 0 { + t.Errorf("handler.sessions is not empty; length %d", len(handler.sessions)) + } + if ss := slices.Collect(server.Sessions()); len(ss) != 0 { + t.Errorf("server.Sessions() is not empty; length %d", len(ss)) + } + handler.mu.Unlock() } // mustNotPanic is a helper to enforce that test handlers do not panic (see