Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mcp/client_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ func Example_roots() {

// Connect the server and client...
t1, t2 := mcp.NewInMemoryTransports()
if _, err := s.Connect(ctx, t1, nil); err != nil {
serverSession, err := s.Connect(ctx, t1, nil)
if err != nil {
log.Fatal(err)
}
defer serverSession.Close()

clientSession, err := c.Connect(ctx, t2, nil)
if err != nil {
Expand Down
33 changes: 15 additions & 18 deletions mcp/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ func TestServerRunContextCancel(t *testing.T) {
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { session.Close() })

if err := session.Ping(context.Background(), nil); err != nil {
t.Fatal(err)
}
Expand All @@ -122,35 +124,30 @@ func TestServerInterrupt(t *testing.T) {
}
requireExec(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

t.Log("Starting server command")
cmd := createServerCommand(t, "default")

client := mcp.NewClient(testImpl, nil)
_, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil)
t.Log("Connecting to server")

ctx := context.Background()
session, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil)
if err != nil {
t.Fatal(err)
}

// get a signal when the server process exits
onExit := make(chan struct{})
go func() {
cmd.Process.Wait()
close(onExit)
}()

// send a signal to the server process to terminate it
t.Log("Send a signal to the server process to terminate it")
if err := cmd.Process.Signal(os.Interrupt); err != nil {
t.Fatal(err)
}

// wait for the server to exit
// TODO: use synctest when available
select {
case <-time.After(5 * time.Second):
t.Fatal("server did not exit after SIGINT")
case <-onExit:
t.Log("Closing client session so server can exit immediately")
session.Close()

t.Log("Wait for process to terminate after interrupt signal")
_, err = cmd.Process.Wait()
if err == nil {
t.Errorf("unexpected error: %v", err)
}
}

Expand Down
25 changes: 8 additions & 17 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,6 @@ func TestEndToEnd(t *testing.T) {
t.Errorf("after connection, Clients() has length %d, want 1", len(got))
}

// Wait for the server to exit after the client closes its connection.
var clientWG sync.WaitGroup
clientWG.Add(1)
go func() {
if err := ss.Wait(); err != nil {
t.Errorf("server failed: %v", err)
}
clientWG.Done()
}()

loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging
opts := &ClientOptions{
CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) {
Expand Down Expand Up @@ -518,7 +508,9 @@ func TestEndToEnd(t *testing.T) {

// Disconnect.
cs.Close()
clientWG.Wait()
if err := ss.Wait(); err != nil {
t.Errorf("server failed: %v", err)
}

// After disconnecting, neither client nor server should have any
// connections.
Expand Down Expand Up @@ -626,6 +618,7 @@ func basicClientServerConnection(t *testing.T, client *Client, server *Server, c
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = ss.Close() })

if client == nil {
client = NewClient(testImpl, nil)
Expand All @@ -634,6 +627,8 @@ func basicClientServerConnection(t *testing.T, client *Client, server *Server, c
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = cs.Close() })

return cs, ss, func() {
cs.Close()
ss.Wait()
Expand Down Expand Up @@ -750,11 +745,7 @@ func TestMiddleware(t *testing.T) {
t.Fatal(err)
}
// Wait for the server to exit after the client closes its connection.
defer func() {
if err := ss.Wait(); err != nil {
t.Errorf("server failed: %v", err)
}
}()
t.Cleanup(func() { _ = ss.Close() })

var sbuf, cbuf bytes.Buffer
sbuf.WriteByte('\n')
Expand All @@ -773,7 +764,7 @@ func TestMiddleware(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer cs.Close()
t.Cleanup(func() { _ = cs.Close() })

if _, err := cs.ListTools(ctx, nil); err != nil {
t.Fatal(err)
Expand Down
1 change: 1 addition & 0 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,7 @@ func (s *Server) Run(ctx context.Context, t Transport) error {
select {
case <-ctx.Done():
ss.Close()
<-ssClosed // wait until waiting go routine above actually completes
s.opts.Logger.Error("server run cancelled", "error", ctx.Err())
return ctx.Err()
case err := <-ssClosed:
Expand Down
3 changes: 2 additions & 1 deletion mcp/streamable_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

// !+streamablehandler

// TODO: Until we have a way to clean up abandoned sessions, this test will leak goroutines (see #499)
func ExampleStreamableHTTPHandler() {
// Create a new streamable handler, using the same MCP server for every request.
//
Expand Down Expand Up @@ -45,7 +46,7 @@ func ExampleStreamableHTTPHandler_middleware() {
server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil)
handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server {
return server
}, nil)
}, &mcp.StreamableHTTPOptions{Stateless: true})
loggingHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// Example debugging; you could also capture the response.
body, err := io.ReadAll(req.Body)
Expand Down
46 changes: 32 additions & 14 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ func TestStreamableServerShutdown(t *testing.T) {
// network failure and receive replayed messages (if replay is configured). It
// uses a proxy that is killed and restarted to simulate a recoverable network
// outage.
//
// TODO: Until we have a way to clean up abandoned sessions, this test will leak goroutines (see #499)
func TestClientReplay(t *testing.T) {
for _, test := range []clientReplayTest{
{"default", 0, true},
Expand Down Expand Up @@ -318,7 +320,10 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
})

realServer := httptest.NewServer(mustNotPanic(t, NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)))
defer realServer.Close()
t.Cleanup(func() {
t.Log("Closing real HTTP server")
realServer.Close()
})
realServerURL, err := url.Parse(realServer.URL)
if err != nil {
t.Fatalf("Failed to parse real server URL: %v", err)
Expand All @@ -345,21 +350,20 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer clientSession.Close()
t.Cleanup(func() {
t.Log("Closing clientSession")
clientSession.Close()
})

var (
wg sync.WaitGroup
callErr error
)
wg.Add(1)
toolCallResult := make(chan error, 1)
go func() {
defer wg.Done()
_, callErr = clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"})
_, callErr := clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"})
toolCallResult <- callErr
}()

select {
case <-serverReadyToKillProxy:
// Server has sent the first two messages and is paused.
t.Log("Server has sent the first two messages and is paused.")
case <-ctx.Done():
t.Fatalf("Context timed out before server was ready to kill proxy")
}
Expand Down Expand Up @@ -387,9 +391,9 @@ func testClientReplay(t *testing.T, test clientReplayTest) {

restartedProxy := &http.Server{Handler: proxyHandler}
go restartedProxy.Serve(listener)
defer restartedProxy.Close()
t.Cleanup(func() { restartedProxy.Close() })

wg.Wait()
callErr := <-toolCallResult

if test.wantRecovered {
// If we've recovered, we should get all 4 notifications and the tool call
Expand Down Expand Up @@ -463,14 +467,15 @@ func TestServerTransportCleanup(t *testing.T) {
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer clientSession.Close()
t.Cleanup(func() { _ = clientSession.Close() })
}

for _, ch := range chans {
select {
case <-ctx.Done():
t.Errorf("did not capture transport deletion event from all session in 10 seconds")
case <-ch: // Received transport deletion signal of this session
case <-ch:
t.Log("Received session transport deletion signal")
}
}

Expand Down Expand Up @@ -1256,6 +1261,7 @@ func TestStreamableStateless(t *testing.T) {
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { cs.Close() })
res, err := cs.CallTool(ctx, &CallToolParams{Name: "greet", Arguments: hiParams{Name: "bar"}})
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -1434,6 +1440,18 @@ func TestStreamableGET(t *testing.T) {
if got, want := resp.StatusCode, http.StatusOK; got != want {
t.Errorf("GET with session ID: got status %d, want %d", got, want)
}

t.Log("Sending final DELETE request to close session and release resources")
del := newReq("DELETE", nil)
del.Header.Set(sessionIDHeader, sessionID)
resp, err = http.DefaultClient.Do(del)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if got, want := resp.StatusCode, http.StatusNoContent; got != want {
t.Errorf("DELETE with session ID: got status %d, want %d", got, want)
}
}

func TestStreamableClientContextPropagation(t *testing.T) {
Expand Down
9 changes: 8 additions & 1 deletion mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,14 @@ func (r rwc) Write(p []byte) (n int, err error) {
}

func (r rwc) Close() error {
return errors.Join(r.rc.Close(), r.wc.Close())
rcErr := r.rc.Close()

var wcErr error
if r.wc != nil { // we only allow a nil writer in unit tests
wcErr = r.wc.Close()
}

return errors.Join(rcErr, wcErr)
}

// An ioConn is a transport that delimits messages with newlines across
Expand Down
2 changes: 1 addition & 1 deletion mcp/transport_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func ExampleLoggingTransport() {
if err != nil {
log.Fatal(err)
}
defer serverSession.Wait()
defer serverSession.Close()

client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil)
var b bytes.Buffer
Expand Down
3 changes: 2 additions & 1 deletion mcp/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestBatchFraming(t *testing.T) {
r, w := io.Pipe()
tport := newIOConn(rwc{r, w})
tport.outgoingBatch = make([]jsonrpc.Message, 0, 2)
defer tport.Close()
t.Cleanup(func() { tport.Close() })

// Read the two messages into a channel, for easy testing later.
read := make(chan jsonrpc.Message)
Expand Down Expand Up @@ -101,6 +101,7 @@ func TestIOConnRead(t *testing.T) {
tr := newIOConn(rwc{
rc: io.NopCloser(strings.NewReader(tt.input)),
})
t.Cleanup(func() { tr.Close() })
if tt.protocolVersion != "" {
tr.sessionUpdated(ServerSessionState{
InitializeParams: &InitializeParams{
Expand Down