Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions client/transport/stdio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"os/exec"
"path/filepath"
"runtime"
"sync"
"testing"
"time"
Expand All @@ -31,6 +32,10 @@ func compileTestServer(outputPath string) error {
func TestStdio(t *testing.T) {
// Compile mock server
mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server")
// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
mockServerPath += ".exe"
}
if err := compileTestServer(mockServerPath); err != nil {
t.Fatalf("Failed to compile mock server: %v", err)
}
Expand Down Expand Up @@ -302,16 +307,19 @@ func TestStdioErrors(t *testing.T) {
})

t.Run("RequestBeforeStart", func(t *testing.T) {
// 创建一个新的 Stdio 实例但不调用 Start 方法
mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server")
// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
mockServerPath += ".exe"
}
if err := compileTestServer(mockServerPath); err != nil {
t.Fatalf("Failed to compile mock server: %v", err)
}
defer os.Remove(mockServerPath)

uninitiatedStdio := NewStdio(mockServerPath, nil)

// 准备一个请求
// Prepare a request
request := JSONRPCRequest{
JSONRPC: "2.0",
ID: 99,
Expand All @@ -331,6 +339,10 @@ func TestStdioErrors(t *testing.T) {
t.Run("RequestAfterClose", func(t *testing.T) {
// Compile mock server
mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server")
// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
mockServerPath += ".exe"
}
if err := compileTestServer(mockServerPath); err != nil {
t.Fatalf("Failed to compile mock server: %v", err)
}
Expand Down
13 changes: 10 additions & 3 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ type SSEServer struct {

keepAlive bool
keepAliveInterval time.Duration

mu sync.RWMutex
Comment on lines 67 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Inconsistent protection pattern for the sessions map

While you've added mutex protection for the sessions map in the handleMessage method, the same protection is not applied in other methods that access this map (e.g., handleSSE, SendEventToSession). Either remove the mutex usage for sessions.Load (as sync.Map is already thread-safe) or consistently apply the protection pattern across all methods.

Also applies to: 212-219


🏁 Script executed:

#!/bin/bash
# Find all places where sessions map is accessed without mutex protection

echo "Checking for access to sessions map without mutex protection..."
rg -n "sessions\.(Load|Store|Delete|Range)" --type go server/sse.go | grep -v "mu\.(RLock|Lock)"

Length of output: 502


Ensure consistent synchronization for the sessions map

We’ve found multiple s.sessions accesses that aren’t wrapped by your mu lock, yet you’ve also added locking in only some methods. Since you’re using a sync.Map (which is already safe for concurrent use), you should either:

• Remove the mu entirely and drop the extra locking around handleMessage, relying solely on sync.Map.
• Or keep mu and wrap all sessions.Load/Store/Delete/Range calls under mu.RLock/mu.Lock.

Unprotected accesses in server/sse.go:

  • Line 213: s.sessions.Range(...)
  • Line 217: s.sessions.Delete(key)
  • Line 255: s.sessions.Store(sessionID, session)
  • Line 256: defer s.sessions.Delete(sessionID)
  • Line 349: sessionI, ok := s.sessions.Load(sessionID)
  • Line 416: sessionI, ok := s.sessions.Load(sessionID)

Please pick one synchronization pattern and apply it consistently to avoid races.

}

// SSEOption defines a function type for configuring SSEServer
Expand Down Expand Up @@ -189,18 +191,24 @@ func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server {
// Start begins serving SSE connections on the specified address.
// It sets up HTTP handlers for SSE and message endpoints.
func (s *SSEServer) Start(addr string) error {
s.mu.Lock()
s.srv = &http.Server{
Addr: addr,
Handler: s,
}
s.mu.Unlock()

return s.srv.ListenAndServe()
}

// Shutdown gracefully stops the SSE server, closing all active sessions
// and shutting down the HTTP server.
func (s *SSEServer) Shutdown(ctx context.Context) error {
if s.srv != nil {
s.mu.RLock()
srv := s.srv
s.mu.RUnlock()

if srv != nil {
s.sessions.Range(func(key, value interface{}) bool {
if session, ok := value.(*sseSession); ok {
close(session.done)
Expand All @@ -209,7 +217,7 @@ func (s *SSEServer) Shutdown(ctx context.Context) error {
return true
})

return s.srv.Shutdown(ctx)
return srv.Shutdown(ctx)
}
return nil
}
Expand Down Expand Up @@ -335,7 +343,6 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId")
return
}

sessionI, ok := s.sessions.Load(sessionID)
if !ok {
s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID")
Expand Down