diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index be401a93..738f2446 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -80,6 +80,7 @@ tracing-subscriber = { version = "0.3", features = [ "std", "fmt", ] } +async-trait = "0.1" [[test]] name = "test_tool_macros" required-features = ["server"] @@ -105,3 +106,8 @@ name = "test_logging" required-features = ["server", "client"] path = "tests/test_logging.rs" +[[test]] +name = "test_message_protocol" +required-features = ["client"] +path = "tests/test_message_protocol.rs" + diff --git a/crates/rmcp/src/handler/client.rs b/crates/rmcp/src/handler/client.rs index 3443e635..13bd9677 100644 --- a/crates/rmcp/src/handler/client.rs +++ b/crates/rmcp/src/handler/client.rs @@ -84,6 +84,7 @@ pub trait ClientHandler: Sized + Send + Sync + 'static { McpError::method_not_found::(), )) } + fn list_roots( &self, context: RequestContext, diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index eb44a420..33980464 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -659,6 +659,16 @@ pub struct SamplingMessage { pub content: Content, } +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub enum ContextInclusion { + #[serde(rename = "allServers")] + AllServers, + #[serde(rename = "none")] + None, + #[serde(rename = "thisServer")] + ThisServer, +} + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] pub struct CreateMessageRequestParam { @@ -668,7 +678,7 @@ pub struct CreateMessageRequestParam { #[serde(skip_serializing_if = "Option::is_none")] pub system_prompt: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub include_context: Option, + pub include_context: Option, #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, pub max_tokens: u32, diff --git a/crates/rmcp/tests/common/handlers.rs b/crates/rmcp/tests/common/handlers.rs new file mode 100644 index 00000000..d2212b63 --- /dev/null +++ b/crates/rmcp/tests/common/handlers.rs @@ -0,0 +1,193 @@ +use std::{ + future::Future, + sync::{Arc, Mutex}, +}; + +use rmcp::{ + ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler, + model::*, + service::{Peer, RequestContext}, +}; +use serde_json::json; +use tokio::sync::Notify; + +#[derive(Clone)] +pub struct TestClientHandler { + pub peer: Option>, + pub honor_this_server: bool, + pub honor_all_servers: bool, + pub receive_signal: Arc, + pub received_messages: Arc>>, +} + +impl TestClientHandler { + #[allow(dead_code)] + pub fn new(honor_this_server: bool, honor_all_servers: bool) -> Self { + Self { + peer: None, + honor_this_server, + honor_all_servers, + receive_signal: Arc::new(Notify::new()), + received_messages: Arc::new(Mutex::new(Vec::new())), + } + } + + #[allow(dead_code)] + pub fn with_notification( + honor_this_server: bool, + honor_all_servers: bool, + receive_signal: Arc, + received_messages: Arc>>, + ) -> Self { + Self { + peer: None, + honor_this_server, + honor_all_servers, + receive_signal, + received_messages, + } + } +} + +impl ClientHandler for TestClientHandler { + fn get_peer(&self) -> Option> { + self.peer.clone() + } + + fn set_peer(&mut self, peer: Peer) { + self.peer = Some(peer); + } + + async fn create_message( + &self, + params: CreateMessageRequestParam, + _context: RequestContext, + ) -> Result { + // First validate that there's at least one User message + if !params.messages.iter().any(|msg| msg.role == Role::User) { + return Err(McpError::invalid_request( + "Message sequence must contain at least one user message", + Some(json!({"messages": params.messages})), + )); + } + + // Create response based on context inclusion + let response = match params.include_context { + Some(ContextInclusion::ThisServer) if self.honor_this_server => { + "Test response with context: test context" + } + Some(ContextInclusion::AllServers) if self.honor_all_servers => { + "Test response with context: test context" + } + _ => "Test response without context", + }; + + Ok(CreateMessageResult { + message: SamplingMessage { + role: Role::Assistant, + content: Content::text(response.to_string()), + }, + model: "test-model".to_string(), + stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), + }) + } + + fn on_logging_message( + &self, + params: LoggingMessageNotificationParam, + ) -> impl Future + Send + '_ { + let receive_signal = self.receive_signal.clone(); + let received_messages = self.received_messages.clone(); + + async move { + println!("Client: Received log message: {:?}", params); + let mut messages = received_messages.lock().unwrap(); + messages.push(params); + receive_signal.notify_one(); + } + } +} + +pub struct TestServer {} + +impl TestServer { + #[allow(dead_code)] + pub fn new() -> Self { + Self {} + } +} + +impl ServerHandler for TestServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + capabilities: ServerCapabilities::builder().enable_logging().build(), + ..Default::default() + } + } + + fn set_level( + &self, + request: SetLevelRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + let peer = context.peer; + async move { + let (data, logger) = match request.level { + LoggingLevel::Error => ( + serde_json::json!({ + "message": "Failed to process request", + "error_code": "E1001", + "error_details": "Connection timeout", + "timestamp": chrono::Utc::now().to_rfc3339(), + }), + Some("error_handler".to_string()), + ), + LoggingLevel::Debug => ( + serde_json::json!({ + "message": "Processing request", + "function": "handle_request", + "line": 42, + "context": { + "request_id": "req-123", + "user_id": "user-456" + }, + "timestamp": chrono::Utc::now().to_rfc3339(), + }), + Some("debug_logger".to_string()), + ), + LoggingLevel::Info => ( + serde_json::json!({ + "message": "System status update", + "status": "healthy", + "metrics": { + "requests_per_second": 150, + "average_latency_ms": 45, + "error_rate": 0.01 + }, + "timestamp": chrono::Utc::now().to_rfc3339(), + }), + Some("monitoring".to_string()), + ), + _ => ( + serde_json::json!({ + "message": format!("Message at level {:?}", request.level), + "timestamp": chrono::Utc::now().to_rfc3339(), + }), + None, + ), + }; + + if let Err(e) = peer + .notify_logging_message(LoggingMessageNotificationParam { + level: request.level, + data, + logger, + }) + .await + { + panic!("Failed to send notification: {}", e); + } + Ok(()) + } + } +} diff --git a/crates/rmcp/tests/common/mod.rs b/crates/rmcp/tests/common/mod.rs index 09bb58d2..49196065 100644 --- a/crates/rmcp/tests/common/mod.rs +++ b/crates/rmcp/tests/common/mod.rs @@ -1 +1,2 @@ pub mod calculator; +pub mod handlers; diff --git a/crates/rmcp/tests/test_logging.rs b/crates/rmcp/tests/test_logging.rs index b3d95677..eb63773f 100644 --- a/crates/rmcp/tests/test_logging.rs +++ b/crates/rmcp/tests/test_logging.rs @@ -1,133 +1,24 @@ // cargo test --features "server client" --package rmcp test_logging -use std::{ - future::Future, - sync::{Arc, Mutex}, -}; +mod common; + +use std::sync::{Arc, Mutex}; +use common::handlers::{TestClientHandler, TestServer}; use rmcp::{ - ClientHandler, Error as McpError, Peer, RoleClient, RoleServer, ServerHandler, ServiceExt, - model::{ - LoggingLevel, LoggingMessageNotificationParam, ServerCapabilities, ServerInfo, - SetLevelRequestParam, - }, - service::RequestContext, + ServiceExt, + model::{LoggingLevel, LoggingMessageNotificationParam, SetLevelRequestParam}, }; +use serde_json::json; use tokio::sync::Notify; -pub struct LoggingClient { - receive_signal: Arc, - received_messages: Arc>>, - peer: Option>, -} - -impl ClientHandler for LoggingClient { - async fn on_logging_message(&self, params: LoggingMessageNotificationParam) { - println!("Client: Received log message: {:?}", params); - let mut messages = self.received_messages.lock().unwrap(); - messages.push(params); - self.receive_signal.notify_one(); - } - - fn set_peer(&mut self, peer: Peer) { - self.peer.replace(peer); - } - - fn get_peer(&self) -> Option> { - self.peer.clone() - } -} - -pub struct TestServer {} - -impl TestServer { - fn new() -> Self { - Self {} - } -} - -impl ServerHandler for TestServer { - fn get_info(&self) -> ServerInfo { - ServerInfo { - capabilities: ServerCapabilities::builder().enable_logging().build(), - ..Default::default() - } - } - - fn set_level( - &self, - request: SetLevelRequestParam, - context: RequestContext, - ) -> impl Future> + Send + '_ { - let peer = context.peer; - async move { - let (data, logger) = match request.level { - LoggingLevel::Error => ( - serde_json::json!({ - "message": "Failed to process request", - "error_code": "E1001", - "error_details": "Connection timeout", - "timestamp": chrono::Utc::now().to_rfc3339(), - }), - Some("error_handler".to_string()), - ), - LoggingLevel::Debug => ( - serde_json::json!({ - "message": "Processing request", - "function": "handle_request", - "line": 42, - "context": { - "request_id": "req-123", - "user_id": "user-456" - }, - "timestamp": chrono::Utc::now().to_rfc3339(), - }), - Some("debug_logger".to_string()), - ), - LoggingLevel::Info => ( - serde_json::json!({ - "message": "System status update", - "status": "healthy", - "metrics": { - "requests_per_second": 150, - "average_latency_ms": 45, - "error_rate": 0.01 - }, - "timestamp": chrono::Utc::now().to_rfc3339(), - }), - Some("monitoring".to_string()), - ), - _ => ( - serde_json::json!({ - "message": format!("Message at level {:?}", request.level), - "timestamp": chrono::Utc::now().to_rfc3339(), - }), - None, - ), - }; - - if let Err(e) = peer - .notify_logging_message(LoggingMessageNotificationParam { - level: request.level, - data, - logger, - }) - .await - { - panic!("Failed to send notification: {}", e); - } - Ok(()) - } - } -} - #[tokio::test] async fn test_logging_spec_compliance() -> anyhow::Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); let receive_signal = Arc::new(Notify::new()); - let received_messages = Arc::new(Mutex::new(Vec::new())); + let received_messages = Arc::new(Mutex::new(Vec::::new())); - // Start server - tokio::spawn(async move { + // Start server in a separate task + let server_handle = tokio::spawn(async move { let server = TestServer::new().serve(server_transport).await?; // Test server can send messages before level is set @@ -147,15 +38,16 @@ async fn test_logging_spec_compliance() -> anyhow::Result<()> { anyhow::Ok(()) }); - let client = LoggingClient { - receive_signal: receive_signal.clone(), - received_messages: received_messages.clone(), - peer: None, - } + let client = TestClientHandler::with_notification( + true, + true, + receive_signal.clone(), + received_messages.clone(), + ) .serve(client_transport) .await?; - // Verify server-initiated message + // Wait for the initial server message receive_signal.notified().await; { let mut messages = received_messages.lock().unwrap(); @@ -173,6 +65,8 @@ async fn test_logging_spec_compliance() -> anyhow::Result<()> { .peer() .set_level(SetLevelRequestParam { level }) .await?; + + // Wait for each message response receive_signal.notified().await; let mut messages = received_messages.lock().unwrap(); @@ -194,7 +88,12 @@ async fn test_logging_spec_compliance() -> anyhow::Result<()> { messages.clear(); } + // Important: Cancel the client before ending the test client.cancel().await?; + + // Wait for server to complete + server_handle.await??; + Ok(()) } @@ -202,32 +101,31 @@ async fn test_logging_spec_compliance() -> anyhow::Result<()> { async fn test_logging_user_scenarios() -> anyhow::Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); let receive_signal = Arc::new(Notify::new()); - let received_messages = Arc::new(Mutex::new(Vec::new())); + let received_messages = Arc::new(Mutex::new(Vec::::new())); - // Start server - tokio::spawn(async move { + let server_handle = tokio::spawn(async move { let server = TestServer::new().serve(server_transport).await?; server.waiting().await?; anyhow::Ok(()) }); - let client = LoggingClient { - receive_signal: receive_signal.clone(), - received_messages: received_messages.clone(), - peer: None, - } + let client = TestClientHandler::with_notification( + true, + true, + receive_signal.clone(), + received_messages.clone(), + ) .serve(client_transport) .await?; // Test 1: Error reporting scenario - // User should see detailed error information client .peer() .set_level(SetLevelRequestParam { level: LoggingLevel::Error, }) .await?; - receive_signal.notified().await; + receive_signal.notified().await; // Wait for response { let messages = received_messages.lock().unwrap(); let msg = &messages[0]; @@ -247,14 +145,13 @@ async fn test_logging_user_scenarios() -> anyhow::Result<()> { } // Test 2: Debug scenario - // User debugging their application should see detailed information client .peer() .set_level(SetLevelRequestParam { level: LoggingLevel::Debug, }) .await?; - receive_signal.notified().await; + receive_signal.notified().await; // Wait for response { let messages = received_messages.lock().unwrap(); let msg = messages.last().unwrap(); @@ -271,14 +168,13 @@ async fn test_logging_user_scenarios() -> anyhow::Result<()> { } // Test 3: Production monitoring scenario - // User monitoring production should see important status updates client .peer() .set_level(SetLevelRequestParam { level: LoggingLevel::Info, }) .await?; - receive_signal.notified().await; + receive_signal.notified().await; // Wait for response { let messages = received_messages.lock().unwrap(); let msg = messages.last().unwrap(); @@ -287,7 +183,10 @@ async fn test_logging_user_scenarios() -> anyhow::Result<()> { assert!(data.contains_key("metrics"), "Should include metrics"); } + // Important: Cancel client and wait for server before ending client.cancel().await?; + server_handle.await??; + Ok(()) } @@ -327,3 +226,126 @@ fn test_logging_level_serialization() { ); } } + +#[tokio::test] +async fn test_logging_edge_cases() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let receive_signal = Arc::new(Notify::new()); + let received_messages = Arc::new(Mutex::new(Vec::::new())); + + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + let client = TestClientHandler::with_notification( + true, + true, + receive_signal.clone(), + received_messages.clone(), + ) + .serve(client_transport) + .await?; + + // Test all logging levels from spec + for level in [ + LoggingLevel::Alert, + LoggingLevel::Critical, + LoggingLevel::Notice, // These weren't tested before + ] { + client + .peer() + .set_level(SetLevelRequestParam { level }) + .await?; + receive_signal.notified().await; + + let messages = received_messages.lock().unwrap(); + let msg = messages.last().unwrap(); + assert_eq!(msg.level, level); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_logging_optional_fields() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let receive_signal = Arc::new(Notify::new()); + let received_messages = Arc::new(Mutex::new(Vec::::new())); + + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + + // Test message with and without optional logger field + for (level, has_logger) in [(LoggingLevel::Info, true), (LoggingLevel::Debug, false)] { + server + .peer() + .notify_logging_message(LoggingMessageNotificationParam { + level, + data: json!({"test": "data"}), + logger: has_logger.then(|| "test_logger".to_string()), + }) + .await?; + } + + server.waiting().await?; + anyhow::Ok(()) + }); + + let client = TestClientHandler::with_notification( + true, + true, + receive_signal.clone(), + received_messages.clone(), + ) + .serve(client_transport) + .await?; + + // Wait for the initial server message + receive_signal.notified().await; + { + let mut messages = received_messages.lock().unwrap(); + assert_eq!(messages.len(), 2, "Should receive two messages"); + messages.clear(); + } + + // Test level filtering and message format + for level in [LoggingLevel::Info, LoggingLevel::Debug] { + client + .peer() + .set_level(SetLevelRequestParam { level }) + .await?; + + // Wait for each message response + receive_signal.notified().await; + + let mut messages = received_messages.lock().unwrap(); + let msg = messages.last().unwrap(); + + // Verify required fields + assert_eq!(msg.level, level); + assert!(msg.data.is_object()); + + // Verify data format + let data = msg.data.as_object().unwrap(); + assert!(data.contains_key("message")); + assert!(data.contains_key("timestamp")); + + // Verify timestamp + let timestamp = data["timestamp"].as_str().unwrap(); + chrono::DateTime::parse_from_rfc3339(timestamp).expect("RFC3339 timestamp"); + + messages.clear(); + } + + // Important: Cancel the client before ending the test + client.cancel().await?; + + // Wait for server to complete + server_handle.await??; + + Ok(()) +} diff --git a/crates/rmcp/tests/test_message_protocol.rs b/crates/rmcp/tests/test_message_protocol.rs new file mode 100644 index 00000000..c75567f4 --- /dev/null +++ b/crates/rmcp/tests/test_message_protocol.rs @@ -0,0 +1,529 @@ +//cargo test --test test_message_protocol --features "client server" + +mod common; +use common::handlers::{TestClientHandler, TestServer}; +use rmcp::{ + ServiceExt, + model::*, + service::{RequestContext, Service}, +}; +use tokio_util::sync::CancellationToken; + +// Tests start here +#[tokio::test] +async fn test_message_roles() { + let messages = vec![ + SamplingMessage { + role: Role::User, + content: Content::text("user message"), + }, + SamplingMessage { + role: Role::Assistant, + content: Content::text("assistant message"), + }, + ]; + + // Verify all roles can be serialized/deserialized correctly + let json = serde_json::to_string(&messages).unwrap(); + let deserialized: Vec = serde_json::from_str(&json).unwrap(); + assert_eq!(messages, deserialized); +} + +#[tokio::test] +async fn test_context_inclusion_integration() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client that honors context requests + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + // Test ThisServer context inclusion + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test message"), + }], + include_context: Some(ContextInclusion::ThisServer), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + text.contains("test context"), + "Response should include context for ThisServer" + ); + } else { + panic!("Expected CreateMessageResult"); + } + + // Test AllServers context inclusion + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test message"), + }], + include_context: Some(ContextInclusion::AllServers), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(2), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + text.contains("test context"), + "Response should include context for AllServers" + ); + } else { + panic!("Expected CreateMessageResult"); + } + + // Test No context inclusion + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test message"), + }], + include_context: Some(ContextInclusion::None), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(3), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + !text.contains("test context"), + "Response should not include context for None" + ); + } else { + panic!("Expected CreateMessageResult"); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_context_inclusion_ignored_integration() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client that ignores context requests + let handler = TestClientHandler::new(false, false); + let client = handler.clone().serve(client_transport).await?; + + // Test that context requests are ignored + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test message"), + }], + include_context: Some(ContextInclusion::ThisServer), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + !text.contains("test context"), + "Context should be ignored when client chooses not to honor requests" + ); + } else { + panic!("Expected CreateMessageResult"); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_message_sequence_integration() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![ + SamplingMessage { + role: Role::User, + content: Content::text("first message"), + }, + SamplingMessage { + role: Role::Assistant, + content: Content::text("second message"), + }, + ], + include_context: Some(ContextInclusion::ThisServer), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + text.contains("test context"), + "Response should include context when ThisServer is specified" + ); + assert_eq!(result.model, "test-model"); + assert_eq!( + result.stop_reason, + Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()) + ); + } else { + panic!("Expected CreateMessageResult"); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + // Test valid sequence: User -> Assistant -> User + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![ + SamplingMessage { + role: Role::User, + content: Content::text("first user message"), + }, + SamplingMessage { + role: Role::Assistant, + content: Content::text("first assistant response"), + }, + SamplingMessage { + role: Role::User, + content: Content::text("second user message"), + }, + ], + include_context: None, + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + }, + ) + .await?; + + assert!(matches!(result, ClientResult::CreateMessageResult(_))); + + // Test invalid: No user message + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::Assistant, + content: Content::text("assistant message"), + }], + include_context: None, + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(2), + }, + ) + .await; + + assert!(result.is_err()); + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_selective_context_handling_integration() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Client that only honors ThisServer but ignores AllServers + let handler = TestClientHandler::new(true, false); + let client = handler.clone().serve(client_transport).await?; + + // Test ThisServer is honored + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test message"), + }], + include_context: Some(ContextInclusion::ThisServer), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + text.contains("test context"), + "ThisServer context request should be honored" + ); + } + + // Test AllServers is ignored + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test message"), + }], + include_context: Some(ContextInclusion::AllServers), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(2), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + !text.contains("test context"), + "AllServers context request should be ignored" + ); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_context_inclusion() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + // Test context handling + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("test"), + }], + include_context: Some(ContextInclusion::ThisServer), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + }, + ) + .await?; + + if let ClientResult::CreateMessageResult(result) = result { + let text = result.message.content.as_text().unwrap().text.as_str(); + assert!(text.contains("test context")); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +}