From 588c39c771628ef7fc07ce3a5d50cb90a6fe0953 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Thu, 29 May 2025 02:47:56 +0800 Subject: [PATCH 1/2] feat: provide more context information 1. provide context for notification 2. allow extract more info from tool call 3. inject http request part for streamable http server --- crates/rmcp/src/handler/client.rs | 36 ++++++++---- crates/rmcp/src/handler/server.rs | 23 +++++--- crates/rmcp/src/handler/server/tool.rs | 36 ++++++++++++ crates/rmcp/src/model/extension.rs | 2 +- crates/rmcp/src/service.rs | 58 +++++++++++++++---- crates/rmcp/src/service/server.rs | 7 ++- crates/rmcp/src/service/tower.rs | 2 + .../transport/streamable_http_server/tower.rs | 25 +++++--- crates/rmcp/tests/common/handlers.rs | 6 +- crates/rmcp/tests/test_notification.rs | 6 +- 10 files changed, 161 insertions(+), 40 deletions(-) diff --git a/crates/rmcp/src/handler/client.rs b/crates/rmcp/src/handler/client.rs index 005661ca..f6d4ef6e 100644 --- a/crates/rmcp/src/handler/client.rs +++ b/crates/rmcp/src/handler/client.rs @@ -1,7 +1,7 @@ use crate::{ error::Error as McpError, model::*, - service::{RequestContext, RoleClient, Service, ServiceRole}, + service::{NotificationContext, RequestContext, RoleClient, Service, ServiceRole}, }; impl Service for H { @@ -26,28 +26,29 @@ impl Service for H { async fn handle_notification( &self, notification: ::PeerNot, + context: NotificationContext, ) -> Result<(), McpError> { match notification { ServerNotification::CancelledNotification(notification) => { - self.on_cancelled(notification.params).await + self.on_cancelled(notification.params, context).await } ServerNotification::ProgressNotification(notification) => { - self.on_progress(notification.params).await + self.on_progress(notification.params, context).await } ServerNotification::LoggingMessageNotification(notification) => { - self.on_logging_message(notification.params).await + self.on_logging_message(notification.params, context).await } ServerNotification::ResourceUpdatedNotification(notification) => { - self.on_resource_updated(notification.params).await + self.on_resource_updated(notification.params, context).await } ServerNotification::ResourceListChangedNotification(_notification_no_param) => { - self.on_resource_list_changed().await + self.on_resource_list_changed(context).await } ServerNotification::ToolListChangedNotification(_notification_no_param) => { - self.on_tool_list_changed().await + self.on_tool_list_changed(context).await } ServerNotification::PromptListChangedNotification(_notification_no_param) => { - self.on_prompt_list_changed().await + self.on_prompt_list_changed(context).await } }; Ok(()) @@ -87,34 +88,47 @@ pub trait ClientHandler: Sized + Send + Sync + 'static { fn on_cancelled( &self, params: CancelledNotificationParam, + context: NotificationContext, ) -> impl Future + Send + '_ { std::future::ready(()) } fn on_progress( &self, params: ProgressNotificationParam, + context: NotificationContext, ) -> impl Future + Send + '_ { std::future::ready(()) } fn on_logging_message( &self, params: LoggingMessageNotificationParam, + context: NotificationContext, ) -> impl Future + Send + '_ { std::future::ready(()) } fn on_resource_updated( &self, params: ResourceUpdatedNotificationParam, + context: NotificationContext, ) -> impl Future + Send + '_ { std::future::ready(()) } - fn on_resource_list_changed(&self) -> impl Future + Send + '_ { + fn on_resource_list_changed( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { std::future::ready(()) } - fn on_tool_list_changed(&self) -> impl Future + Send + '_ { + fn on_tool_list_changed( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { std::future::ready(()) } - fn on_prompt_list_changed(&self) -> impl Future + Send + '_ { + fn on_prompt_list_changed( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { std::future::ready(()) } diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 07681fed..52b63832 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -1,7 +1,7 @@ use crate::{ error::Error as McpError, model::*, - service::{RequestContext, RoleServer, Service, ServiceRole}, + service::{NotificationContext, RequestContext, RoleServer, Service, ServiceRole}, }; mod resource; @@ -71,19 +71,20 @@ impl Service for H { async fn handle_notification( &self, notification: ::PeerNot, + context: NotificationContext, ) -> Result<(), McpError> { match notification { ClientNotification::CancelledNotification(notification) => { - self.on_cancelled(notification.params).await + self.on_cancelled(notification.params, context).await } ClientNotification::ProgressNotification(notification) => { - self.on_progress(notification.params).await + self.on_progress(notification.params, context).await } ClientNotification::InitializedNotification(_notification) => { - self.on_initialized().await + self.on_initialized(context).await } ClientNotification::RootsListChangedNotification(_notification) => { - self.on_roots_list_changed().await + self.on_roots_list_changed(context).await } }; Ok(()) @@ -196,20 +197,28 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { fn on_cancelled( &self, notification: CancelledNotificationParam, + context: NotificationContext, ) -> impl Future + Send + '_ { std::future::ready(()) } fn on_progress( &self, notification: ProgressNotificationParam, + context: NotificationContext, ) -> impl Future + Send + '_ { std::future::ready(()) } - fn on_initialized(&self) -> impl Future + Send + '_ { + fn on_initialized( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { tracing::info!("client initialized"); std::future::ready(()) } - fn on_roots_list_changed(&self) -> impl Future + Send + '_ { + fn on_roots_list_changed( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { std::future::ready(()) } diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index 30d88727..acb44fd5 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -320,6 +320,42 @@ where } } +impl<'a, S> FromToolCallContextPart<'a, S> for crate::Peer { + fn from_tool_call_context_part( + context: ToolCallContext<'a, S>, + ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + let peer = context.request_context.peer.clone(); + Ok((peer, context)) + } +} + +impl<'a, S> FromToolCallContextPart<'a, S> for crate::model::Meta { + fn from_tool_call_context_part( + mut context: ToolCallContext<'a, S>, + ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + let mut meta = crate::model::Meta::default(); + std::mem::swap(&mut meta, &mut context.request_context.meta); + Ok((meta, context)) + } +} + +pub struct RequestId(pub crate::model::RequestId); +impl<'a, S> FromToolCallContextPart<'a, S> for RequestId { + fn from_tool_call_context_part( + context: ToolCallContext<'a, S>, + ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + Ok((RequestId(context.request_context.id.clone()), context)) + } +} + +impl<'a, S> FromToolCallContextPart<'a, S> for RequestContext { + fn from_tool_call_context_part( + context: ToolCallContext<'a, S>, + ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + Ok((context.request_context.clone(), context)) + } +} + impl<'s, S> ToolCallContext<'s, S> { pub fn invoke(self, h: H) -> H::Fut where diff --git a/crates/rmcp/src/model/extension.rs b/crates/rmcp/src/model/extension.rs index a9d78d4c..039fdf2e 100644 --- a/crates/rmcp/src/model/extension.rs +++ b/crates/rmcp/src/model/extension.rs @@ -49,7 +49,7 @@ pub struct Extensions { impl Extensions { /// Create an empty `Extensions`. #[inline] - pub fn new() -> Extensions { + pub const fn new() -> Extensions { Extensions { map: None } } diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 8fc05511..57d1089d 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -76,7 +76,9 @@ pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone { type PeerResp: TransferObject; type PeerNot: TryInto + From - + TransferObject; + + TransferObject + + GetMeta + + GetExtensions; type InitializeError; const IS_CLIENT: bool; type Info: TransferObject; @@ -100,6 +102,7 @@ pub trait Service: Send + Sync + 'static { fn handle_notification( &self, notification: R::PeerNot, + context: NotificationContext, ) -> impl Future> + Send + '_; fn get_info(&self) -> R::Info; } @@ -145,8 +148,9 @@ impl Service for Box> { fn handle_notification( &self, notification: R::PeerNot, + context: NotificationContext, ) -> impl Future> + Send + '_ { - DynService::handle_notification(self.as_ref(), notification) + DynService::handle_notification(self.as_ref(), notification, context) } fn get_info(&self) -> R::Info { @@ -160,7 +164,11 @@ pub trait DynService: Send + Sync { request: R::PeerReq, context: RequestContext, ) -> BoxFuture>; - fn handle_notification(&self, notification: R::PeerNot) -> BoxFuture>; + fn handle_notification( + &self, + notification: R::PeerNot, + context: NotificationContext, + ) -> BoxFuture>; fn get_info(&self) -> R::Info; } @@ -172,8 +180,12 @@ impl> DynService for S { ) -> BoxFuture> { Box::pin(self.handle_request(request, context)) } - fn handle_notification(&self, notification: R::PeerNot) -> BoxFuture> { - Box::pin(self.handle_notification(notification)) + fn handle_notification( + &self, + notification: R::PeerNot, + context: NotificationContext, + ) -> BoxFuture> { + Box::pin(self.handle_notification(notification, context)) } fn get_info(&self) -> R::Info { self.get_info() @@ -487,6 +499,15 @@ pub struct RequestContext { pub peer: Peer, } +/// Request execution context +#[derive(Debug, Clone)] +pub struct NotificationContext { + pub meta: Meta, + pub extensions: Extensions, + /// An interface to fetch the remote client or server + pub peer: Peer, +} + /// Use this function to skip initialization process pub fn serve_directly( service: S, @@ -710,7 +731,9 @@ where })); } Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest { - id, request, .. + id, + mut request, + .. })) => { tracing::debug!(%id, ?request, "received request"); { @@ -719,12 +742,17 @@ where let request_ct = serve_loop_ct.child_token(); let context_ct = request_ct.child_token(); local_ct_pool.insert(id.clone(), request_ct); + let mut extensions = Extensions::new(); + let mut meta = Meta::new(); + // avoid clone + std::mem::swap(&mut extensions, request.extensions_mut()); + std::mem::swap(&mut meta, request.get_meta_mut()); let context = RequestContext { ct: context_ct, id: id.clone(), peer: peer.clone(), - meta: request.get_meta().clone(), - extensions: request.extensions().clone(), + meta, + extensions, }; tokio::spawn(async move { let result = service.handle_request(request, context).await; @@ -748,7 +776,7 @@ where })) => { tracing::info!(?notification, "received notification"); // catch cancelled notification - let notification = match notification.try_into() { + let mut notification = match notification.try_into() { Ok::(cancelled) => { if let Some(ct) = local_ct_pool.remove(&cancelled.params.request_id) { tracing::info!(id = %cancelled.params.request_id, reason = cancelled.params.reason, "cancelled"); @@ -760,8 +788,18 @@ where }; { let service = shared_service.clone(); + let mut extensions = Extensions::new(); + let mut meta = Meta::new(); + // avoid clone + std::mem::swap(&mut extensions, notification.extensions_mut()); + std::mem::swap(&mut meta, notification.get_meta_mut()); + let context = NotificationContext { + peer: peer.clone(), + meta, + extensions, + }; tokio::spawn(async move { - let result = service.handle_notification(notification).await; + let result = service.handle_notification(notification, context).await; if let Err(error) = result { tracing::warn!(%error, "Error sending notification"); } diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index f2c5ca67..6f585a61 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -210,7 +210,12 @@ where Some(ClientJsonRpcMessage::notification(notification)), )); }; - let _ = service.handle_notification(notification).await; + let context = NotificationContext { + meta: notification.get_meta().clone(), + extensions: notification.extensions().clone(), + peer: peer.clone(), + }; + let _ = service.handle_notification(notification, context).await; // Continue processing service Ok(serve_inner(service, transport, peer, peer_rx, ct)) } diff --git a/crates/rmcp/src/service/tower.rs b/crates/rmcp/src/service/tower.rs index da30f15d..454a9779 100644 --- a/crates/rmcp/src/service/tower.rs +++ b/crates/rmcp/src/service/tower.rs @@ -2,6 +2,7 @@ use std::{future::poll_fn, marker::PhantomData}; use tower_service::Service as TowerService; +use super::NotificationContext; use crate::service::{RequestContext, Service, ServiceRole}; pub struct TowerHandler { @@ -42,6 +43,7 @@ where fn handle_notification( &self, _notification: R::PeerNot, + _context: NotificationContext, ) -> impl Future> + Send + '_ { std::future::ready(Ok(())) } diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index f9c0a69f..4bcfe0a6 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -10,7 +10,7 @@ use tokio_stream::wrappers::ReceiverStream; use super::session::SessionManager; use crate::{ RoleServer, - model::ClientJsonRpcMessage, + model::{ClientJsonRpcMessage, GetExtensions}, serve_server, service::serve_directly, transport::{ @@ -243,7 +243,7 @@ where // json deserialize request body let (part, body) = request.into_parts(); - let message = match expect_json(body).await { + let mut message = match expect_json(body).await { Ok(message) => message, Err(response) => return Ok(response), }; @@ -271,6 +271,20 @@ where ) .expect("valid response")); } + + // inject request part to extensions + match &mut message { + ClientJsonRpcMessage::Request(req) => { + req.request.extensions_mut().insert(part); + } + ClientJsonRpcMessage::Notification(not) => { + not.notification.extensions_mut().insert(part); + } + _ => { + // skip + } + } + match message { ClientJsonRpcMessage::Request(_) => { let stream = self @@ -379,11 +393,8 @@ where self.config.sse_keep_alive, )) } - ClientJsonRpcMessage::Notification(notification) => { - service - .handle_notification(notification.notification) - .await - .map_err(internal_error_response("handle notification"))?; + ClientJsonRpcMessage::Notification(_notification) => { + // ignore Ok(accepted_response()) } ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()), diff --git a/crates/rmcp/tests/common/handlers.rs b/crates/rmcp/tests/common/handlers.rs index c769565f..23bb5114 100644 --- a/crates/rmcp/tests/common/handlers.rs +++ b/crates/rmcp/tests/common/handlers.rs @@ -4,8 +4,9 @@ use std::{ }; use rmcp::{ - ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler, model::*, - service::RequestContext, + ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler, + model::*, + service::{NotificationContext, RequestContext}, }; use serde_json::json; use tokio::sync::Notify; @@ -83,6 +84,7 @@ impl ClientHandler for TestClientHandler { fn on_logging_message( &self, params: LoggingMessageNotificationParam, + _context: NotificationContext, ) -> impl Future + Send + '_ { let receive_signal = self.receive_signal.clone(); let received_messages = self.received_messages.clone(); diff --git a/crates/rmcp/tests/test_notification.rs b/crates/rmcp/tests/test_notification.rs index 09dd5e56..09501816 100644 --- a/crates/rmcp/tests/test_notification.rs +++ b/crates/rmcp/tests/test_notification.rs @@ -52,7 +52,11 @@ pub struct Client { } impl ClientHandler for Client { - async fn on_resource_updated(&self, params: rmcp::model::ResourceUpdatedNotificationParam) { + async fn on_resource_updated( + &self, + params: rmcp::model::ResourceUpdatedNotificationParam, + _context: rmcp::service::NotificationContext, + ) { let uri = params.uri; tracing::info!("Resource updated: {}", uri); self.receive_signal.notify_one(); From cefeb830cadf29be764e2957b76b5e3c6ce2337c Mon Sep 17 00:00:00 2001 From: 4t145 Date: Thu, 29 May 2025 11:35:09 +0800 Subject: [PATCH 2/2] docs: add document for getting peer from context --- crates/rmcp/README.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/crates/rmcp/README.md b/crates/rmcp/README.md index b3f67d12..6afcf13a 100644 --- a/crates/rmcp/README.md +++ b/crates/rmcp/README.md @@ -146,6 +146,39 @@ let service = client.serve(transport).await?; +## Access with peer interface when handling message + +You can get the [`Peer`](crate::service::Peer) struct from [`NotificationContext`](crate::service::NotificationContext) and [`RequestContext`](crate::service::RequestContext). + +```rust, ignore +# use rmcp::{ +# ServerHandler, +# model::{LoggingLevel, LoggingMessageNotificationParam, ProgressNotificationParam}, +# service::{NotificationContext, RoleServer}, +# }; +# pub struct Handler; + +impl ServerHandler for Handler { + async fn on_progress( + &self, + notification: ProgressNotificationParam, + context: NotificationContext, + ) { + let peer = context.peer; + let _ = peer + .notify_logging_message(LoggingMessageNotificationParam { + level: LoggingLevel::Info, + logger: None, + data: serde_json::json!({ + "message": format!("Progress: {}", notification.progress), + }), + }) + .await; + } +} +``` + + ## Manage Multi Services For many cases you need to manage several service in a collection, you can call `into_dyn` to convert services into the same type.