diff --git a/mcp/client_example_test.go b/mcp/client_example_test.go index 225dabac..cc7146c2 100644 --- a/mcp/client_example_test.go +++ b/mcp/client_example_test.go @@ -42,9 +42,11 @@ func Example_roots() { // Connect the server and client... t1, t2 := mcp.NewInMemoryTransports() - if _, err := s.Connect(ctx, t1, nil); err != nil { + serverSession, err := s.Connect(ctx, t1, nil) + if err != nil { log.Fatal(err) } + defer serverSession.Close() clientSession, err := c.Connect(ctx, t2, nil) if err != nil { diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index cbaadcb0..ebf6b592 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -97,6 +97,8 @@ func TestServerRunContextCancel(t *testing.T) { if err != nil { t.Fatal(err) } + t.Cleanup(func() { session.Close() }) + if err := session.Ping(context.Background(), nil); err != nil { t.Fatal(err) } @@ -122,35 +124,30 @@ func TestServerInterrupt(t *testing.T) { } requireExec(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - + t.Log("Starting server command") cmd := createServerCommand(t, "default") client := mcp.NewClient(testImpl, nil) - _, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil) + t.Log("Connecting to server") + + ctx := context.Background() + session, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil) if err != nil { t.Fatal(err) } - // get a signal when the server process exits - onExit := make(chan struct{}) - go func() { - cmd.Process.Wait() - close(onExit) - }() - - // send a signal to the server process to terminate it + t.Log("Send a signal to the server process to terminate it") if err := cmd.Process.Signal(os.Interrupt); err != nil { t.Fatal(err) } - // wait for the server to exit - // TODO: use synctest when available - select { - case <-time.After(5 * time.Second): - t.Fatal("server did not exit after SIGINT") - case <-onExit: + t.Log("Closing client session so server can exit immediately") + session.Close() + + t.Log("Wait for process to terminate after interrupt signal") + _, err = cmd.Process.Wait() + if err == nil { + t.Errorf("unexpected error: %v", err) } } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 0e1ee1f2..a78c1525 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -118,16 +118,6 @@ func TestEndToEnd(t *testing.T) { t.Errorf("after connection, Clients() has length %d, want 1", len(got)) } - // Wait for the server to exit after the client closes its connection. - var clientWG sync.WaitGroup - clientWG.Add(1) - go func() { - if err := ss.Wait(); err != nil { - t.Errorf("server failed: %v", err) - } - clientWG.Done() - }() - loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging opts := &ClientOptions{ CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { @@ -518,7 +508,9 @@ func TestEndToEnd(t *testing.T) { // Disconnect. cs.Close() - clientWG.Wait() + if err := ss.Wait(); err != nil { + t.Errorf("server failed: %v", err) + } // After disconnecting, neither client nor server should have any // connections. @@ -626,6 +618,7 @@ func basicClientServerConnection(t *testing.T, client *Client, server *Server, c if err != nil { t.Fatal(err) } + t.Cleanup(func() { _ = ss.Close() }) if client == nil { client = NewClient(testImpl, nil) @@ -634,6 +627,8 @@ func basicClientServerConnection(t *testing.T, client *Client, server *Server, c if err != nil { t.Fatal(err) } + t.Cleanup(func() { _ = cs.Close() }) + return cs, ss, func() { cs.Close() ss.Wait() @@ -750,11 +745,7 @@ func TestMiddleware(t *testing.T) { t.Fatal(err) } // Wait for the server to exit after the client closes its connection. - defer func() { - if err := ss.Wait(); err != nil { - t.Errorf("server failed: %v", err) - } - }() + t.Cleanup(func() { _ = ss.Close() }) var sbuf, cbuf bytes.Buffer sbuf.WriteByte('\n') @@ -773,7 +764,7 @@ func TestMiddleware(t *testing.T) { if err != nil { t.Fatal(err) } - defer cs.Close() + t.Cleanup(func() { _ = cs.Close() }) if _, err := cs.ListTools(ctx, nil); err != nil { t.Fatal(err) diff --git a/mcp/server.go b/mcp/server.go index f5b99cf7..1ed170b4 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -776,6 +776,7 @@ func (s *Server) Run(ctx context.Context, t Transport) error { select { case <-ctx.Done(): ss.Close() + <-ssClosed // wait until waiting go routine above actually completes s.opts.Logger.Error("server run cancelled", "error", ctx.Err()) return ctx.Err() case err := <-ssClosed: diff --git a/mcp/streamable_example_test.go b/mcp/streamable_example_test.go index 430f2745..f1cdf90a 100644 --- a/mcp/streamable_example_test.go +++ b/mcp/streamable_example_test.go @@ -18,6 +18,7 @@ import ( // !+streamablehandler +// TODO: Until we have a way to clean up abandoned sessions, this test will leak goroutines (see #499) func ExampleStreamableHTTPHandler() { // Create a new streamable handler, using the same MCP server for every request. // @@ -45,7 +46,7 @@ func ExampleStreamableHTTPHandler_middleware() { server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil) handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { return server - }, nil) + }, &mcp.StreamableHTTPOptions{Stateless: true}) loggingHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // Example debugging; you could also capture the response. body, err := io.ReadAll(req.Body) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 6ccaebf7..7cdee1ed 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -271,6 +271,8 @@ func TestStreamableServerShutdown(t *testing.T) { // network failure and receive replayed messages (if replay is configured). It // uses a proxy that is killed and restarted to simulate a recoverable network // outage. +// +// TODO: Until we have a way to clean up abandoned sessions, this test will leak goroutines (see #499) func TestClientReplay(t *testing.T) { for _, test := range []clientReplayTest{ {"default", 0, true}, @@ -318,7 +320,10 @@ func testClientReplay(t *testing.T, test clientReplayTest) { }) realServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))) - defer realServer.Close() + t.Cleanup(func() { + t.Log("Closing real HTTP server") + realServer.Close() + }) realServerURL, err := url.Parse(realServer.URL) if err != nil { t.Fatalf("Failed to parse real server URL: %v", err) @@ -345,21 +350,20 @@ func testClientReplay(t *testing.T, test clientReplayTest) { if err != nil { t.Fatalf("client.Connect() failed: %v", err) } - defer clientSession.Close() + t.Cleanup(func() { + t.Log("Closing clientSession") + clientSession.Close() + }) - var ( - wg sync.WaitGroup - callErr error - ) - wg.Add(1) + toolCallResult := make(chan error, 1) go func() { - defer wg.Done() - _, callErr = clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) + _, callErr := clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) + toolCallResult <- callErr }() select { case <-serverReadyToKillProxy: - // Server has sent the first two messages and is paused. + t.Log("Server has sent the first two messages and is paused.") case <-ctx.Done(): t.Fatalf("Context timed out before server was ready to kill proxy") } @@ -387,9 +391,9 @@ func testClientReplay(t *testing.T, test clientReplayTest) { restartedProxy := &http.Server{Handler: proxyHandler} go restartedProxy.Serve(listener) - defer restartedProxy.Close() + t.Cleanup(func() { restartedProxy.Close() }) - wg.Wait() + callErr := <-toolCallResult if test.wantRecovered { // If we've recovered, we should get all 4 notifications and the tool call @@ -463,14 +467,15 @@ func TestServerTransportCleanup(t *testing.T) { if err != nil { t.Fatalf("client.Connect() failed: %v", err) } - defer clientSession.Close() + t.Cleanup(func() { _ = clientSession.Close() }) } for _, ch := range chans { select { case <-ctx.Done(): t.Errorf("did not capture transport deletion event from all session in 10 seconds") - case <-ch: // Received transport deletion signal of this session + case <-ch: + t.Log("Received session transport deletion signal") } } @@ -1256,6 +1261,7 @@ func TestStreamableStateless(t *testing.T) { if err != nil { t.Fatal(err) } + t.Cleanup(func() { cs.Close() }) res, err := cs.CallTool(ctx, &CallToolParams{Name: "greet", Arguments: hiParams{Name: "bar"}}) if err != nil { t.Fatal(err) @@ -1434,6 +1440,18 @@ func TestStreamableGET(t *testing.T) { if got, want := resp.StatusCode, http.StatusOK; got != want { t.Errorf("GET with session ID: got status %d, want %d", got, want) } + + t.Log("Sending final DELETE request to close session and release resources") + del := newReq("DELETE", nil) + del.Header.Set(sessionIDHeader, sessionID) + resp, err = http.DefaultClient.Do(del) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusNoContent; got != want { + t.Errorf("DELETE with session ID: got status %d, want %d", got, want) + } } func TestStreamableClientContextPropagation(t *testing.T) { diff --git a/mcp/transport.go b/mcp/transport.go index f2f93f4b..1beab470 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -310,7 +310,14 @@ func (r rwc) Write(p []byte) (n int, err error) { } func (r rwc) Close() error { - return errors.Join(r.rc.Close(), r.wc.Close()) + rcErr := r.rc.Close() + + var wcErr error + if r.wc != nil { // we only allow a nil writer in unit tests + wcErr = r.wc.Close() + } + + return errors.Join(rcErr, wcErr) } // An ioConn is a transport that delimits messages with newlines across diff --git a/mcp/transport_example_test.go b/mcp/transport_example_test.go index ab54a422..7390ea4e 100644 --- a/mcp/transport_example_test.go +++ b/mcp/transport_example_test.go @@ -28,7 +28,7 @@ func ExampleLoggingTransport() { if err != nil { log.Fatal(err) } - defer serverSession.Wait() + defer serverSession.Close() client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) var b bytes.Buffer diff --git a/mcp/transport_test.go b/mcp/transport_test.go index 10804a87..aeff3663 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -25,7 +25,7 @@ func TestBatchFraming(t *testing.T) { r, w := io.Pipe() tport := newIOConn(rwc{r, w}) tport.outgoingBatch = make([]jsonrpc.Message, 0, 2) - defer tport.Close() + t.Cleanup(func() { tport.Close() }) // Read the two messages into a channel, for easy testing later. read := make(chan jsonrpc.Message) @@ -101,6 +101,7 @@ func TestIOConnRead(t *testing.T) { tr := newIOConn(rwc{ rc: io.NopCloser(strings.NewReader(tt.input)), }) + t.Cleanup(func() { tr.Close() }) if tt.protocolVersion != "" { tr.sessionUpdated(ServerSessionState{ InitializeParams: &InitializeParams{