Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ tracing-subscriber = { version = "0.3", features = [
"std",
"fmt",
] }
async-trait = "0.1"
[[test]]
name = "test_tool_macros"
required-features = ["server"]
Expand All @@ -105,3 +106,8 @@ name = "test_logging"
required-features = ["server", "client"]
path = "tests/test_logging.rs"

[[test]]
name = "test_message_protocol"
required-features = ["client"]
path = "tests/test_message_protocol.rs"

1 change: 1 addition & 0 deletions crates/rmcp/src/handler/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ pub trait ClientHandler: Sized + Send + Sync + 'static {
McpError::method_not_found::<CreateMessageRequestMethod>(),
))
}

fn list_roots(
&self,
context: RequestContext<RoleClient>,
Expand Down
12 changes: 11 additions & 1 deletion crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,16 @@ pub struct SamplingMessage {
pub content: Content,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub enum ContextInclusion {
#[serde(rename = "allServers")]
AllServers,
#[serde(rename = "none")]
None,
#[serde(rename = "thisServer")]
ThisServer,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct CreateMessageRequestParam {
Expand All @@ -668,7 +678,7 @@ pub struct CreateMessageRequestParam {
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_context: Option<String>,
pub include_context: Option<ContextInclusion>,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be breaking change users are not following specification

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is OK, user who really use this feature should found we didn't implement this correctlly

#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
pub max_tokens: u32,
Expand Down
193 changes: 193 additions & 0 deletions crates/rmcp/tests/common/handlers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
use std::{
future::Future,
sync::{Arc, Mutex},
};

use rmcp::{
ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler,
model::*,
service::{Peer, RequestContext},
};
use serde_json::json;
use tokio::sync::Notify;

#[derive(Clone)]
pub struct TestClientHandler {
pub peer: Option<Peer<RoleClient>>,
pub honor_this_server: bool,
pub honor_all_servers: bool,
pub receive_signal: Arc<Notify>,
pub received_messages: Arc<Mutex<Vec<LoggingMessageNotificationParam>>>,
}

impl TestClientHandler {
#[allow(dead_code)]
pub fn new(honor_this_server: bool, honor_all_servers: bool) -> Self {
Self {
peer: None,
honor_this_server,
honor_all_servers,
receive_signal: Arc::new(Notify::new()),
received_messages: Arc::new(Mutex::new(Vec::new())),
}
}

#[allow(dead_code)]
pub fn with_notification(
honor_this_server: bool,
honor_all_servers: bool,
receive_signal: Arc<Notify>,
received_messages: Arc<Mutex<Vec<LoggingMessageNotificationParam>>>,
) -> Self {
Self {
peer: None,
honor_this_server,
honor_all_servers,
receive_signal,
received_messages,
}
}
}

impl ClientHandler for TestClientHandler {
fn get_peer(&self) -> Option<Peer<RoleClient>> {
self.peer.clone()
}

fn set_peer(&mut self, peer: Peer<RoleClient>) {
self.peer = Some(peer);
}

async fn create_message(
&self,
params: CreateMessageRequestParam,
_context: RequestContext<RoleClient>,
) -> Result<CreateMessageResult, McpError> {
// First validate that there's at least one User message
if !params.messages.iter().any(|msg| msg.role == Role::User) {
return Err(McpError::invalid_request(
"Message sequence must contain at least one user message",
Some(json!({"messages": params.messages})),
));
}

// Create response based on context inclusion
let response = match params.include_context {
Some(ContextInclusion::ThisServer) if self.honor_this_server => {
"Test response with context: test context"
}
Some(ContextInclusion::AllServers) if self.honor_all_servers => {
"Test response with context: test context"
}
_ => "Test response without context",
};

Ok(CreateMessageResult {
message: SamplingMessage {
role: Role::Assistant,
content: Content::text(response.to_string()),
},
model: "test-model".to_string(),
stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()),
})
}

fn on_logging_message(
&self,
params: LoggingMessageNotificationParam,
) -> impl Future<Output = ()> + Send + '_ {
let receive_signal = self.receive_signal.clone();
let received_messages = self.received_messages.clone();

async move {
println!("Client: Received log message: {:?}", params);
let mut messages = received_messages.lock().unwrap();
messages.push(params);
receive_signal.notify_one();
}
}
}

pub struct TestServer {}

impl TestServer {
#[allow(dead_code)]
pub fn new() -> Self {
Self {}
}
}

impl ServerHandler for TestServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
capabilities: ServerCapabilities::builder().enable_logging().build(),
..Default::default()
}
}

fn set_level(
&self,
request: SetLevelRequestParam,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
let peer = context.peer;
async move {
let (data, logger) = match request.level {
LoggingLevel::Error => (
serde_json::json!({
"message": "Failed to process request",
"error_code": "E1001",
"error_details": "Connection timeout",
"timestamp": chrono::Utc::now().to_rfc3339(),
}),
Some("error_handler".to_string()),
),
LoggingLevel::Debug => (
serde_json::json!({
"message": "Processing request",
"function": "handle_request",
"line": 42,
"context": {
"request_id": "req-123",
"user_id": "user-456"
},
"timestamp": chrono::Utc::now().to_rfc3339(),
}),
Some("debug_logger".to_string()),
),
LoggingLevel::Info => (
serde_json::json!({
"message": "System status update",
"status": "healthy",
"metrics": {
"requests_per_second": 150,
"average_latency_ms": 45,
"error_rate": 0.01
},
"timestamp": chrono::Utc::now().to_rfc3339(),
}),
Some("monitoring".to_string()),
),
_ => (
serde_json::json!({
"message": format!("Message at level {:?}", request.level),
"timestamp": chrono::Utc::now().to_rfc3339(),
}),
None,
),
};

if let Err(e) = peer
.notify_logging_message(LoggingMessageNotificationParam {
level: request.level,
data,
logger,
})
.await
{
panic!("Failed to send notification: {}", e);
}
Ok(())
}
}
}
1 change: 1 addition & 0 deletions crates/rmcp/tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod calculator;
pub mod handlers;
Loading
Loading