Skip to content

Commit fc869cb

Browse files
committed
test(context): test the way context can be requested and added, refactor test so client server reusable.
1 parent 72e7533 commit fc869cb

File tree

7 files changed

+872
-134
lines changed

7 files changed

+872
-134
lines changed

crates/rmcp/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ tracing-subscriber = { version = "0.3", features = [
8080
"std",
8181
"fmt",
8282
] }
83+
async-trait = "0.1"
8384
[[test]]
8485
name = "test_tool_macros"
8586
required-features = ["server"]
@@ -105,3 +106,8 @@ name = "test_logging"
105106
required-features = ["server", "client"]
106107
path = "tests/test_logging.rs"
107108

109+
[[test]]
110+
name = "test_message_protocol"
111+
required-features = ["client"]
112+
path = "tests/test_message_protocol.rs"
113+

crates/rmcp/src/handler/client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ pub trait ClientHandler: Sized + Send + Sync + 'static {
8484
McpError::method_not_found::<CreateMessageRequestMethod>(),
8585
))
8686
}
87+
8788
fn list_roots(
8889
&self,
8990
context: RequestContext<RoleClient>,

crates/rmcp/src/model.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,16 @@ pub struct SamplingMessage {
659659
pub content: Content,
660660
}
661661

662+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
663+
pub enum ContextInclusion {
664+
#[serde(rename = "allServers")]
665+
AllServers,
666+
#[serde(rename = "none")]
667+
None,
668+
#[serde(rename = "thisServer")]
669+
ThisServer,
670+
}
671+
662672
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
663673
#[serde(rename_all = "camelCase")]
664674
pub struct CreateMessageRequestParam {
@@ -668,7 +678,7 @@ pub struct CreateMessageRequestParam {
668678
#[serde(skip_serializing_if = "Option::is_none")]
669679
pub system_prompt: Option<String>,
670680
#[serde(skip_serializing_if = "Option::is_none")]
671-
pub include_context: Option<String>,
681+
pub include_context: Option<ContextInclusion>,
672682
#[serde(skip_serializing_if = "Option::is_none")]
673683
pub temperature: Option<f32>,
674684
pub max_tokens: u32,
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
use async_trait::async_trait;
2+
use rmcp::{
3+
ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler,
4+
model::*,
5+
service::{Peer, RequestContext, Service},
6+
};
7+
use serde_json::json;
8+
use std::{
9+
future::Future,
10+
sync::{Arc, Mutex},
11+
};
12+
use tokio::sync::Notify;
13+
14+
#[derive(Clone)]
15+
pub struct TestClientHandler {
16+
pub peer: Option<Peer<RoleClient>>,
17+
pub honor_this_server: bool,
18+
pub honor_all_servers: bool,
19+
pub receive_signal: Arc<Notify>,
20+
pub received_messages: Arc<Mutex<Vec<LoggingMessageNotificationParam>>>,
21+
}
22+
23+
impl TestClientHandler {
24+
pub fn new(honor_this_server: bool, honor_all_servers: bool) -> Self {
25+
Self {
26+
peer: None,
27+
honor_this_server,
28+
honor_all_servers,
29+
receive_signal: Arc::new(Notify::new()),
30+
received_messages: Arc::new(Mutex::new(Vec::new())),
31+
}
32+
}
33+
34+
pub fn with_notification(
35+
honor_this_server: bool,
36+
honor_all_servers: bool,
37+
receive_signal: Arc<Notify>,
38+
received_messages: Arc<Mutex<Vec<LoggingMessageNotificationParam>>>,
39+
) -> Self {
40+
Self {
41+
peer: None,
42+
honor_this_server,
43+
honor_all_servers,
44+
receive_signal,
45+
received_messages,
46+
}
47+
}
48+
}
49+
50+
impl ClientHandler for TestClientHandler {
51+
fn get_peer(&self) -> Option<Peer<RoleClient>> {
52+
self.peer.clone()
53+
}
54+
55+
fn set_peer(&mut self, peer: Peer<RoleClient>) {
56+
self.peer = Some(peer);
57+
}
58+
59+
fn create_message(
60+
&self,
61+
params: CreateMessageRequestParam,
62+
_context: RequestContext<RoleClient>,
63+
) -> impl Future<Output = Result<CreateMessageResult, McpError>> + Send + '_ {
64+
async move {
65+
// First validate that there's at least one User message
66+
if !params.messages.iter().any(|msg| msg.role == Role::User) {
67+
return Err(McpError::invalid_request(
68+
"Message sequence must contain at least one user message",
69+
Some(json!({"messages": params.messages})),
70+
));
71+
}
72+
73+
// Create response based on context inclusion
74+
let response = match params.include_context {
75+
Some(ContextInclusion::ThisServer) if self.honor_this_server => {
76+
"Test response with context: test context"
77+
}
78+
Some(ContextInclusion::AllServers) if self.honor_all_servers => {
79+
"Test response with context: test context"
80+
}
81+
_ => "Test response without context",
82+
};
83+
84+
Ok(CreateMessageResult {
85+
message: SamplingMessage {
86+
role: Role::Assistant,
87+
content: Content::text(response.to_string()),
88+
},
89+
model: "test-model".to_string(),
90+
stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()),
91+
})
92+
}
93+
}
94+
95+
fn on_logging_message(
96+
&self,
97+
params: LoggingMessageNotificationParam,
98+
) -> impl Future<Output = ()> + Send + '_ {
99+
let receive_signal = self.receive_signal.clone();
100+
let received_messages = self.received_messages.clone();
101+
102+
async move {
103+
println!("Client: Received log message: {:?}", params);
104+
let mut messages = received_messages.lock().unwrap();
105+
messages.push(params);
106+
receive_signal.notify_one();
107+
}
108+
}
109+
}
110+
111+
pub struct TestServer {}
112+
113+
impl TestServer {
114+
pub fn new() -> Self {
115+
Self {}
116+
}
117+
}
118+
119+
impl ServerHandler for TestServer {
120+
fn get_info(&self) -> ServerInfo {
121+
ServerInfo {
122+
capabilities: ServerCapabilities::builder().enable_logging().build(),
123+
..Default::default()
124+
}
125+
}
126+
127+
fn set_level(
128+
&self,
129+
request: SetLevelRequestParam,
130+
context: RequestContext<RoleServer>,
131+
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
132+
let peer = context.peer;
133+
async move {
134+
let (data, logger) = match request.level {
135+
LoggingLevel::Error => (
136+
serde_json::json!({
137+
"message": "Failed to process request",
138+
"error_code": "E1001",
139+
"error_details": "Connection timeout",
140+
"timestamp": chrono::Utc::now().to_rfc3339(),
141+
}),
142+
Some("error_handler".to_string()),
143+
),
144+
LoggingLevel::Debug => (
145+
serde_json::json!({
146+
"message": "Processing request",
147+
"function": "handle_request",
148+
"line": 42,
149+
"context": {
150+
"request_id": "req-123",
151+
"user_id": "user-456"
152+
},
153+
"timestamp": chrono::Utc::now().to_rfc3339(),
154+
}),
155+
Some("debug_logger".to_string()),
156+
),
157+
LoggingLevel::Info => (
158+
serde_json::json!({
159+
"message": "System status update",
160+
"status": "healthy",
161+
"metrics": {
162+
"requests_per_second": 150,
163+
"average_latency_ms": 45,
164+
"error_rate": 0.01
165+
},
166+
"timestamp": chrono::Utc::now().to_rfc3339(),
167+
}),
168+
Some("monitoring".to_string()),
169+
),
170+
_ => (
171+
serde_json::json!({
172+
"message": format!("Message at level {:?}", request.level),
173+
"timestamp": chrono::Utc::now().to_rfc3339(),
174+
}),
175+
None,
176+
),
177+
};
178+
179+
if let Err(e) = peer
180+
.notify_logging_message(LoggingMessageNotificationParam {
181+
level: request.level,
182+
data,
183+
logger,
184+
})
185+
.await
186+
{
187+
panic!("Failed to send notification: {}", e);
188+
}
189+
Ok(())
190+
}
191+
}
192+
}

crates/rmcp/tests/common/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
pub mod calculator;
2+
pub mod handlers;

0 commit comments

Comments
 (0)