Skip to content

Commit f009e8e

Browse files
committed
Address review comments
1 parent 72448c8 commit f009e8e

File tree

1 file changed

+66
-21
lines changed

1 file changed

+66
-21
lines changed

crates/rmcp/src/transport/sse_client.rs

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,10 @@ impl<C: SseClient> SseClientTransport<C> {
121121

122122
let mut sse_stream = client.get_stream(sse_endpoint.clone(), None, None).await?;
123123
let message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() {
124-
endpoint.parse::<http::Uri>()?
124+
let ep = endpoint.parse::<http::Uri>()?;
125+
let mut sse_endpoint_parts = sse_endpoint.clone().into_parts();
126+
sse_endpoint_parts.path_and_query = ep.into_parts().path_and_query;
127+
Uri::from_parts(sse_endpoint_parts)?
125128
} else {
126129
// wait the endpoint event
127130
loop {
@@ -133,29 +136,11 @@ impl<C: SseClient> SseClientTransport<C> {
133136
continue;
134137
};
135138
let ep = sse.data.unwrap_or_default();
136-
// Join the result and
137-
let sse_endpoint = if ep.starts_with("/") {
138-
// Absolute path, take as-is
139-
ep
140-
} else {
141-
// Relative path, merge with base
142-
sse_endpoint
143-
.path_and_query()
144-
.map(|p| p.path())
145-
.unwrap_or_default()
146-
.to_string()
147-
+ ep.as_str()
148-
};
149-
break sse_endpoint.parse::<http::Uri>()?;
139+
140+
break message_endpoint(sse_endpoint.clone(), ep)?;
150141
}
151142
};
152143

153-
// sse: <authority><sse_pq> -> <authority><message_pq>
154-
let message_endpoint = {
155-
let mut sse_endpoint_parts = sse_endpoint.clone().into_parts();
156-
sse_endpoint_parts.path_and_query = message_endpoint.into_parts().path_and_query;
157-
Uri::from_parts(sse_endpoint_parts)?
158-
};
159144
let stream = Box::pin(SseAutoReconnectStream::new(
160145
sse_stream,
161146
SseClientReconnect {
@@ -173,6 +158,36 @@ impl<C: SseClient> SseClientTransport<C> {
173158
}
174159
}
175160

161+
fn message_endpoint(base: http::Uri, endpoint: String) -> Result<http::Uri, http::uri::InvalidUri> {
162+
// If endpoint is a full URL, parse and return it directly
163+
if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
164+
return endpoint.parse::<http::Uri>();
165+
}
166+
167+
let mut base_parts = base.into_parts();
168+
let endpoint_clone = endpoint.clone();
169+
170+
if endpoint.starts_with("?") {
171+
// Query only - keep base path and append query
172+
if let Some(base_path_and_query) = &base_parts.path_and_query {
173+
let base_path = base_path_and_query.path();
174+
base_parts.path_and_query = Some(format!("{}{}", base_path, endpoint).parse()?);
175+
} else {
176+
base_parts.path_and_query = Some(format!("/{}", endpoint).parse()?);
177+
}
178+
} else {
179+
// Path (with optional query) - replace entire path_and_query
180+
let path_to_use = if endpoint.starts_with("/") {
181+
endpoint // Use absolute path as-is
182+
} else {
183+
format!("/{}", endpoint) // Make relative path absolute
184+
};
185+
base_parts.path_and_query = Some(path_to_use.parse()?);
186+
}
187+
188+
http::Uri::from_parts(base_parts).map_err(|_| endpoint_clone.parse::<http::Uri>().unwrap_err())
189+
}
190+
176191
#[derive(Debug, Clone)]
177192
pub struct SseClientConfig {
178193
/// client sse endpoint
@@ -201,3 +216,33 @@ impl Default for SseClientConfig {
201216
}
202217
}
203218
}
219+
220+
#[cfg(test)]
221+
mod tests {
222+
use super::*;
223+
224+
#[test]
225+
fn test_message_endpoint() {
226+
let base_url = "https://localhost/sse".parse::<http::Uri>().unwrap();
227+
228+
// Query only
229+
let result = message_endpoint(base_url.clone(), "?sessionId=x".to_string()).unwrap();
230+
assert_eq!(result.to_string(), "https://localhost/sse?sessionId=x");
231+
232+
// Relative path with query
233+
let result = message_endpoint(base_url.clone(), "mypath?sessionId=x".to_string()).unwrap();
234+
assert_eq!(result.to_string(), "https://localhost/mypath?sessionId=x");
235+
236+
// Absolute path with query
237+
let result = message_endpoint(base_url.clone(), "/xxx?sessionId=x".to_string()).unwrap();
238+
assert_eq!(result.to_string(), "https://localhost/xxx?sessionId=x");
239+
240+
// Full URL
241+
let result = message_endpoint(
242+
base_url.clone(),
243+
"http://example.com/xxx?sessionId=x".to_string(),
244+
)
245+
.unwrap();
246+
assert_eq!(result.to_string(), "http://example.com/xxx?sessionId=x");
247+
}
248+
}

0 commit comments

Comments
 (0)