Skip to content

Commit 40c8a68

Browse files
committed
refactor(server): standardize URL path handling with normalizeURLPath
Replace manual path manipulation with a dedicated normalizeURLPath function that properly handles path joining while ensuring consistent formatting. The function: - Always starts paths with a leading slash - Never ends paths with a trailing slash (except for root path "/") - Uses path.Join internally for proper path normalization - Handles edge cases like empty segments, double slashes, and parent references This eliminates duplicated code and creates a more consistent approach to URL path handling throughout the SSE server implementation. Comprehensive tests were added to validate the function's behavior.
1 parent 1999773 commit 40c8a68

File tree

2 files changed

+129
-16
lines changed

2 files changed

+129
-16
lines changed

server/sse.go

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99
"net/http/httptest"
1010
"net/url"
11+
"path"
1112
"strings"
1213
"sync"
1314
"sync/atomic"
@@ -109,11 +110,7 @@ func WithBaseURL(baseURL string) SSEOption {
109110
// WithBasePath adds a new option for setting a static base path
110111
func WithBasePath(basePath string) SSEOption {
111112
return func(s *SSEServer) {
112-
// Ensure the path starts with / and doesn't end with /
113-
if !strings.HasPrefix(basePath, "/") {
114-
basePath = "/" + basePath
115-
}
116-
s.basePath = strings.TrimSuffix(basePath, "/")
113+
s.basePath = normalizeURLPath(basePath)
117114
}
118115
}
119116

@@ -126,10 +123,7 @@ func WithDynamicBasePath(fn DynamicBasePathFunc) SSEOption {
126123
if fn != nil {
127124
s.dynamicBasePathFunc = func(r *http.Request, sid string) string {
128125
bp := fn(r, sid)
129-
if !strings.HasPrefix(bp, "/") {
130-
bp = "/" + bp
131-
}
132-
return strings.TrimSuffix(bp, "/")
126+
return normalizeURLPath(bp)
133127
}
134128
}
135129
}
@@ -388,7 +382,7 @@ func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID strin
388382
basePath = s.dynamicBasePathFunc(r, sessionID)
389383
}
390384

391-
endpointPath := basePath + s.messageEndpoint
385+
endpointPath := normalizeURLPath(basePath, s.messageEndpoint)
392386
if s.useFullURLForMessageEndpoint && s.baseURL != "" {
393387
endpointPath = s.baseURL + endpointPath
394388
}
@@ -515,17 +509,19 @@ func (s *SSEServer) CompleteSseEndpoint() (string, error) {
515509
if s.dynamicBasePathFunc != nil {
516510
return "", &ErrDynamicPathConfig{Method: "CompleteSseEndpoint"}
517511
}
518-
return s.baseURL + s.basePath + s.sseEndpoint, nil
512+
513+
path := normalizeURLPath(s.basePath, s.sseEndpoint)
514+
return s.baseURL + path, nil
519515
}
520516

521517
func (s *SSEServer) CompleteSsePath() string {
522518
path, err := s.CompleteSseEndpoint()
523519
if err != nil {
524-
return s.basePath + s.sseEndpoint
520+
return normalizeURLPath(s.basePath, s.sseEndpoint)
525521
}
526522
urlPath, err := s.GetUrlPath(path)
527523
if err != nil {
528-
return s.basePath + s.sseEndpoint
524+
return normalizeURLPath(s.basePath, s.sseEndpoint)
529525
}
530526
return urlPath
531527
}
@@ -534,17 +530,18 @@ func (s *SSEServer) CompleteMessageEndpoint() (string, error) {
534530
if s.dynamicBasePathFunc != nil {
535531
return "", &ErrDynamicPathConfig{Method: "CompleteMessageEndpoint"}
536532
}
537-
return s.baseURL + s.basePath + s.messageEndpoint, nil
533+
path := normalizeURLPath(s.basePath, s.messageEndpoint)
534+
return s.baseURL + path, nil
538535
}
539536

540537
func (s *SSEServer) CompleteMessagePath() string {
541538
path, err := s.CompleteMessageEndpoint()
542539
if err != nil {
543-
return s.basePath + s.messageEndpoint
540+
return normalizeURLPath(s.basePath, s.messageEndpoint)
544541
}
545542
urlPath, err := s.GetUrlPath(path)
546543
if err != nil {
547-
return s.basePath + s.messageEndpoint
544+
return normalizeURLPath(s.basePath, s.messageEndpoint)
548545
}
549546
return urlPath
550547
}
@@ -628,3 +625,21 @@ func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
628625

629626
http.NotFound(w, r)
630627
}
628+
629+
// normalizeURLPath joins path elements like path.Join but ensures the
630+
// result always starts with a leading slash and never ends with a slash
631+
func normalizeURLPath(elem ...string) string {
632+
joined := path.Join(elem...)
633+
634+
// Ensure leading slash
635+
if !strings.HasPrefix(joined, "/") {
636+
joined = "/" + joined
637+
}
638+
639+
// Remove trailing slash if not just "/"
640+
if len(joined) > 1 && strings.HasSuffix(joined, "/") {
641+
joined = joined[:len(joined)-1]
642+
}
643+
644+
return joined
645+
}

server/sse_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,104 @@ func TestSSEServer(t *testing.T) {
10311031
messagePath := sseServer.CompleteMessagePath()
10321032
require.Equal(t, sseServer.basePath+sseServer.messageEndpoint, messagePath)
10331033
})
1034+
1035+
t.Run("TestNormalizeURLPath", func(t *testing.T) {
1036+
tests := []struct {
1037+
name string
1038+
inputs []string
1039+
expected string
1040+
}{
1041+
// Basic path joining
1042+
{
1043+
name: "empty inputs",
1044+
inputs: []string{"", ""},
1045+
expected: "/",
1046+
},
1047+
{
1048+
name: "single path segment",
1049+
inputs: []string{"mcp"},
1050+
expected: "/mcp",
1051+
},
1052+
{
1053+
name: "multiple path segments",
1054+
inputs: []string{"mcp", "api", "message"},
1055+
expected: "/mcp/api/message",
1056+
},
1057+
1058+
// Leading slash handling
1059+
{
1060+
name: "already has leading slash",
1061+
inputs: []string{"/mcp", "message"},
1062+
expected: "/mcp/message",
1063+
},
1064+
{
1065+
name: "mixed leading slashes",
1066+
inputs: []string{"/mcp", "/message"},
1067+
expected: "/mcp/message",
1068+
},
1069+
1070+
// Trailing slash handling
1071+
{
1072+
name: "with trailing slashes",
1073+
inputs: []string{"mcp/", "message/"},
1074+
expected: "/mcp/message",
1075+
},
1076+
{
1077+
name: "mixed trailing slashes",
1078+
inputs: []string{"mcp", "message/"},
1079+
expected: "/mcp/message",
1080+
},
1081+
{
1082+
name: "root path",
1083+
inputs: []string{"/"},
1084+
expected: "/",
1085+
},
1086+
1087+
// Path normalization
1088+
{
1089+
name: "normalize double slashes",
1090+
inputs: []string{"mcp//api", "//message"},
1091+
expected: "/mcp/api/message",
1092+
},
1093+
{
1094+
name: "normalize parent directory",
1095+
inputs: []string{"mcp/parent/../child", "message"},
1096+
expected: "/mcp/child/message",
1097+
},
1098+
{
1099+
name: "normalize current directory",
1100+
inputs: []string{"mcp/./api", "./message"},
1101+
expected: "/mcp/api/message",
1102+
},
1103+
1104+
// Complex cases
1105+
{
1106+
name: "complex mixed case",
1107+
inputs: []string{"/mcp/", "/api//", "message/"},
1108+
expected: "/mcp/api/message",
1109+
},
1110+
{
1111+
name: "absolute path in second segment",
1112+
inputs: []string{"tenant", "/message"},
1113+
expected: "/tenant/message",
1114+
},
1115+
{
1116+
name: "URL pattern with parameters",
1117+
inputs: []string{"/mcp/{tenant}", "message"},
1118+
expected: "/mcp/{tenant}/message",
1119+
},
1120+
}
1121+
1122+
for _, tt := range tests {
1123+
t.Run(tt.name, func(t *testing.T) {
1124+
result := normalizeURLPath(tt.inputs...)
1125+
if result != tt.expected {
1126+
t.Errorf("normalizeURLPath(%q) = %q, want %q",
1127+
tt.inputs, result, tt.expected)
1128+
}
1129+
})
1130+
}
1131+
})
10341132
}
10351133

10361134
func readSeeEvent(sseResp *http.Response) (string, error) {

0 commit comments

Comments
 (0)