Skip to content

Commit 010bdbc

Browse files
authored
mcp: fix goroutine leaks in unit tests (#496)
This PR fixes goroutine leaks in all unit tests. To find the leaks I integrated https://github.com/uber-go/goleak and I suggest to keep using it to catch any future regressions. I used the folowing script to more easily find individual tests which were leaking goroutines: ``` for t in $(go test -run=Nothing -list=. ./mcp | grep -v ok); do go test ./mcp -run="$t" > /dev/null && echo -n . || echo -e "\nDetected leak: $t"; done ```
1 parent 14e776d commit 010bdbc

File tree

9 files changed

+72
-54
lines changed

9 files changed

+72
-54
lines changed

mcp/client_example_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@ func Example_roots() {
4242

4343
// Connect the server and client...
4444
t1, t2 := mcp.NewInMemoryTransports()
45-
if _, err := s.Connect(ctx, t1, nil); err != nil {
45+
serverSession, err := s.Connect(ctx, t1, nil)
46+
if err != nil {
4647
log.Fatal(err)
4748
}
49+
defer serverSession.Close()
4850

4951
clientSession, err := c.Connect(ctx, t2, nil)
5052
if err != nil {

mcp/cmd_test.go

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ func TestServerRunContextCancel(t *testing.T) {
9797
if err != nil {
9898
t.Fatal(err)
9999
}
100+
t.Cleanup(func() { session.Close() })
101+
100102
if err := session.Ping(context.Background(), nil); err != nil {
101103
t.Fatal(err)
102104
}
@@ -122,35 +124,30 @@ func TestServerInterrupt(t *testing.T) {
122124
}
123125
requireExec(t)
124126

125-
ctx, cancel := context.WithCancel(context.Background())
126-
defer cancel()
127-
127+
t.Log("Starting server command")
128128
cmd := createServerCommand(t, "default")
129129

130130
client := mcp.NewClient(testImpl, nil)
131-
_, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil)
131+
t.Log("Connecting to server")
132+
133+
ctx := context.Background()
134+
session, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil)
132135
if err != nil {
133136
t.Fatal(err)
134137
}
135138

136-
// get a signal when the server process exits
137-
onExit := make(chan struct{})
138-
go func() {
139-
cmd.Process.Wait()
140-
close(onExit)
141-
}()
142-
143-
// send a signal to the server process to terminate it
139+
t.Log("Send a signal to the server process to terminate it")
144140
if err := cmd.Process.Signal(os.Interrupt); err != nil {
145141
t.Fatal(err)
146142
}
147143

148-
// wait for the server to exit
149-
// TODO: use synctest when available
150-
select {
151-
case <-time.After(5 * time.Second):
152-
t.Fatal("server did not exit after SIGINT")
153-
case <-onExit:
144+
t.Log("Closing client session so server can exit immediately")
145+
session.Close()
146+
147+
t.Log("Wait for process to terminate after interrupt signal")
148+
_, err = cmd.Process.Wait()
149+
if err == nil {
150+
t.Errorf("unexpected error: %v", err)
154151
}
155152
}
156153

mcp/mcp_test.go

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,6 @@ func TestEndToEnd(t *testing.T) {
118118
t.Errorf("after connection, Clients() has length %d, want 1", len(got))
119119
}
120120

121-
// Wait for the server to exit after the client closes its connection.
122-
var clientWG sync.WaitGroup
123-
clientWG.Add(1)
124-
go func() {
125-
if err := ss.Wait(); err != nil {
126-
t.Errorf("server failed: %v", err)
127-
}
128-
clientWG.Done()
129-
}()
130-
131121
loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging
132122
opts := &ClientOptions{
133123
CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) {
@@ -518,7 +508,9 @@ func TestEndToEnd(t *testing.T) {
518508

519509
// Disconnect.
520510
cs.Close()
521-
clientWG.Wait()
511+
if err := ss.Wait(); err != nil {
512+
t.Errorf("server failed: %v", err)
513+
}
522514

523515
// After disconnecting, neither client nor server should have any
524516
// connections.
@@ -626,6 +618,7 @@ func basicClientServerConnection(t *testing.T, client *Client, server *Server, c
626618
if err != nil {
627619
t.Fatal(err)
628620
}
621+
t.Cleanup(func() { _ = ss.Close() })
629622

630623
if client == nil {
631624
client = NewClient(testImpl, nil)
@@ -634,6 +627,8 @@ func basicClientServerConnection(t *testing.T, client *Client, server *Server, c
634627
if err != nil {
635628
t.Fatal(err)
636629
}
630+
t.Cleanup(func() { _ = cs.Close() })
631+
637632
return cs, ss, func() {
638633
cs.Close()
639634
ss.Wait()
@@ -750,11 +745,7 @@ func TestMiddleware(t *testing.T) {
750745
t.Fatal(err)
751746
}
752747
// Wait for the server to exit after the client closes its connection.
753-
defer func() {
754-
if err := ss.Wait(); err != nil {
755-
t.Errorf("server failed: %v", err)
756-
}
757-
}()
748+
t.Cleanup(func() { _ = ss.Close() })
758749

759750
var sbuf, cbuf bytes.Buffer
760751
sbuf.WriteByte('\n')
@@ -773,7 +764,7 @@ func TestMiddleware(t *testing.T) {
773764
if err != nil {
774765
t.Fatal(err)
775766
}
776-
defer cs.Close()
767+
t.Cleanup(func() { _ = cs.Close() })
777768

778769
if _, err := cs.ListTools(ctx, nil); err != nil {
779770
t.Fatal(err)

mcp/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,7 @@ func (s *Server) Run(ctx context.Context, t Transport) error {
778778
select {
779779
case <-ctx.Done():
780780
ss.Close()
781+
<-ssClosed // wait until waiting go routine above actually completes
781782
s.opts.Logger.Error("server run cancelled", "error", ctx.Err())
782783
return ctx.Err()
783784
case err := <-ssClosed:

mcp/streamable_example_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818

1919
// !+streamablehandler
2020

21+
// TODO: Until we have a way to clean up abandoned sessions, this test will leak goroutines (see #499)
2122
func ExampleStreamableHTTPHandler() {
2223
// Create a new streamable handler, using the same MCP server for every request.
2324
//
@@ -45,7 +46,7 @@ func ExampleStreamableHTTPHandler_middleware() {
4546
server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil)
4647
handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server {
4748
return server
48-
}, nil)
49+
}, &mcp.StreamableHTTPOptions{Stateless: true})
4950
loggingHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
5051
// Example debugging; you could also capture the response.
5152
body, err := io.ReadAll(req.Body)

mcp/streamable_test.go

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,8 @@ func TestStreamableServerShutdown(t *testing.T) {
322322
// network failure and receive replayed messages (if replay is configured). It
323323
// uses a proxy that is killed and restarted to simulate a recoverable network
324324
// outage.
325+
//
326+
// TODO: Until we have a way to clean up abandoned sessions, this test will leak goroutines (see #499)
325327
func TestClientReplay(t *testing.T) {
326328
for _, test := range []clientReplayTest{
327329
{"default", 0, true},
@@ -369,7 +371,10 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
369371
})
370372

371373
realServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)))
372-
defer realServer.Close()
374+
t.Cleanup(func() {
375+
t.Log("Closing real HTTP server")
376+
realServer.Close()
377+
})
373378
realServerURL, err := url.Parse(realServer.URL)
374379
if err != nil {
375380
t.Fatalf("Failed to parse real server URL: %v", err)
@@ -396,21 +401,20 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
396401
if err != nil {
397402
t.Fatalf("client.Connect() failed: %v", err)
398403
}
399-
defer clientSession.Close()
404+
t.Cleanup(func() {
405+
t.Log("Closing clientSession")
406+
clientSession.Close()
407+
})
400408

401-
var (
402-
wg sync.WaitGroup
403-
callErr error
404-
)
405-
wg.Add(1)
409+
toolCallResult := make(chan error, 1)
406410
go func() {
407-
defer wg.Done()
408-
_, callErr = clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"})
411+
_, callErr := clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"})
412+
toolCallResult <- callErr
409413
}()
410414

411415
select {
412416
case <-serverReadyToKillProxy:
413-
// Server has sent the first two messages and is paused.
417+
t.Log("Server has sent the first two messages and is paused.")
414418
case <-ctx.Done():
415419
t.Fatalf("Context timed out before server was ready to kill proxy")
416420
}
@@ -438,9 +442,9 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
438442

439443
restartedProxy := &http.Server{Handler: proxyHandler}
440444
go restartedProxy.Serve(listener)
441-
defer restartedProxy.Close()
445+
t.Cleanup(func() { restartedProxy.Close() })
442446

443-
wg.Wait()
447+
callErr := <-toolCallResult
444448

445449
if test.wantRecovered {
446450
// If we've recovered, we should get all 4 notifications and the tool call
@@ -514,14 +518,15 @@ func TestServerTransportCleanup(t *testing.T) {
514518
if err != nil {
515519
t.Fatalf("client.Connect() failed: %v", err)
516520
}
517-
defer clientSession.Close()
521+
t.Cleanup(func() { _ = clientSession.Close() })
518522
}
519523

520524
for _, ch := range chans {
521525
select {
522526
case <-ctx.Done():
523527
t.Errorf("did not capture transport deletion event from all session in 10 seconds")
524-
case <-ch: // Received transport deletion signal of this session
528+
case <-ch:
529+
t.Log("Received session transport deletion signal")
525530
}
526531
}
527532

@@ -1307,6 +1312,7 @@ func TestStreamableStateless(t *testing.T) {
13071312
if err != nil {
13081313
t.Fatal(err)
13091314
}
1315+
t.Cleanup(func() { cs.Close() })
13101316
res, err := cs.CallTool(ctx, &CallToolParams{Name: "greet", Arguments: hiParams{Name: "bar"}})
13111317
if err != nil {
13121318
t.Fatal(err)
@@ -1485,6 +1491,18 @@ func TestStreamableGET(t *testing.T) {
14851491
if got, want := resp.StatusCode, http.StatusOK; got != want {
14861492
t.Errorf("GET with session ID: got status %d, want %d", got, want)
14871493
}
1494+
1495+
t.Log("Sending final DELETE request to close session and release resources")
1496+
del := newReq("DELETE", nil)
1497+
del.Header.Set(sessionIDHeader, sessionID)
1498+
resp, err = http.DefaultClient.Do(del)
1499+
if err != nil {
1500+
t.Fatal(err)
1501+
}
1502+
defer resp.Body.Close()
1503+
if got, want := resp.StatusCode, http.StatusNoContent; got != want {
1504+
t.Errorf("DELETE with session ID: got status %d, want %d", got, want)
1505+
}
14881506
}
14891507

14901508
func TestStreamableClientContextPropagation(t *testing.T) {

mcp/transport.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,14 @@ func (r rwc) Write(p []byte) (n int, err error) {
310310
}
311311

312312
func (r rwc) Close() error {
313-
return errors.Join(r.rc.Close(), r.wc.Close())
313+
rcErr := r.rc.Close()
314+
315+
var wcErr error
316+
if r.wc != nil { // we only allow a nil writer in unit tests
317+
wcErr = r.wc.Close()
318+
}
319+
320+
return errors.Join(rcErr, wcErr)
314321
}
315322

316323
// An ioConn is a transport that delimits messages with newlines across

mcp/transport_example_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func ExampleLoggingTransport() {
2828
if err != nil {
2929
log.Fatal(err)
3030
}
31-
defer serverSession.Wait()
31+
defer serverSession.Close()
3232

3333
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil)
3434
var b bytes.Buffer

mcp/transport_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func TestBatchFraming(t *testing.T) {
2525
r, w := io.Pipe()
2626
tport := newIOConn(rwc{r, w})
2727
tport.outgoingBatch = make([]jsonrpc.Message, 0, 2)
28-
defer tport.Close()
28+
t.Cleanup(func() { tport.Close() })
2929

3030
// Read the two messages into a channel, for easy testing later.
3131
read := make(chan jsonrpc.Message)
@@ -101,6 +101,7 @@ func TestIOConnRead(t *testing.T) {
101101
tr := newIOConn(rwc{
102102
rc: io.NopCloser(strings.NewReader(tt.input)),
103103
})
104+
t.Cleanup(func() { tr.Close() })
104105
if tt.protocolVersion != "" {
105106
tr.sessionUpdated(ServerSessionState{
106107
InitializeParams: &InitializeParams{

0 commit comments

Comments
 (0)