From fb60fc4cdf259ab954d3e375a47713ab5bce1bc5 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Fri, 4 Apr 2025 01:52:23 +0800 Subject: [PATCH] feat: revision-2025-03-26 without streamable http 1. Suppot revision 2025-03-26 data types 2. Support meta, progress tokne, and extensions in request/notification, 3. Remove `Message`, use `JsonRpcMessage` directly --- crates/rmcp-macros/src/tool.rs | 9 +- crates/rmcp/src/handler/server.rs | 8 +- crates/rmcp/src/handler/server/tool.rs | 2 +- crates/rmcp/src/model.rs | 252 +++++++++++------- crates/rmcp/src/model/annotated.rs | 5 +- crates/rmcp/src/model/capabilities.rs | 15 +- crates/rmcp/src/model/content.rs | 10 + crates/rmcp/src/model/extension.rs | 337 +++++++++++++++++++++++++ crates/rmcp/src/model/meta.rs | 150 +++++++++++ crates/rmcp/src/model/serde_impl.rs | 267 ++++++++++++++++++++ crates/rmcp/src/model/tool.rs | 104 ++++++++ crates/rmcp/src/service.rs | 219 +++++++++++----- crates/rmcp/src/service/client.rs | 75 ++++-- crates/rmcp/src/service/server.rs | 50 ++-- examples/servers/src/common/counter.rs | 6 +- 15 files changed, 1280 insertions(+), 229 deletions(-) create mode 100644 crates/rmcp/src/model/extension.rs create mode 100644 crates/rmcp/src/model/meta.rs create mode 100644 crates/rmcp/src/model/serde_impl.rs diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index a59e4b3f..d092fb63 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -349,6 +349,7 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu name: #name.into(), description: Some(#description.into()), input_schema: #schema.into(), + annotations: None } } } @@ -413,10 +414,10 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu Some(line) } }); - let trivial_argrextraction_part = quote! { + let trivial_arg_extraction_part = quote! { #(#trivial_args)* }; - let processed_argrextraction_part = match &mut tool_macro_attrs.params { + let processed_arg_extraction_part = match &mut tool_macro_attrs.params { ToolParams::Aggregated { rust_type } => { let PatType { pat, ty, .. } = rust_type; quote! { @@ -487,8 +488,8 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu #raw_fn_vis async fn #tool_call_fn_ident(context: rmcp::handler::server::tool::ToolCallContext<'_, Self>) -> std::result::Result { use rmcp::handler::server::tool::*; - #trivial_argrextraction_part - #processed_argrextraction_part + #trivial_arg_extraction_part + #processed_arg_extraction_part #call } } diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 669564f8..808c9604 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -141,21 +141,21 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { } fn list_prompts( &self, - request: PaginatedRequestParam, + request: Option, context: RequestContext, ) -> impl Future> + Send + '_ { std::future::ready(Ok(ListPromptsResult::default())) } fn list_resources( &self, - request: PaginatedRequestParam, + request: Option, context: RequestContext, ) -> impl Future> + Send + '_ { std::future::ready(Ok(ListResourcesResult::default())) } fn list_resource_templates( &self, - request: PaginatedRequestParam, + request: Option, context: RequestContext, ) -> impl Future> + Send + '_ { std::future::ready(Ok(ListResourceTemplatesResult::default())) @@ -192,7 +192,7 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { } fn list_tools( &self, - request: PaginatedRequestParam, + request: Option, context: RequestContext, ) -> impl Future> + Send + '_ { std::future::ready(Ok(ListToolsResult::default())) diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index 5057420c..e931cbb7 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -453,7 +453,7 @@ macro_rules! tool_box { (@derive $tool_box:ident) => { async fn list_tools( &self, - _: $crate::model::PaginatedRequestParam, + _: Option<$crate::model::PaginatedRequestParam>, _: $crate::service::RequestContext<$crate::service::RoleServer>, ) -> Result<$crate::model::ListToolsResult, $crate::Error> { Ok($crate::model::ListToolsResult { diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index 50adc909..71bd661e 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -2,18 +2,24 @@ use std::{borrow::Cow, sync::Arc}; mod annotated; mod capabilities; mod content; +mod extension; +mod meta; mod prompt; mod resource; +mod serde_impl; mod tool; - pub use annotated::*; pub use capabilities::*; pub use content::*; +pub use extension::*; +pub use meta::*; pub use prompt::*; pub use resource::*; use serde::{Deserialize, Serialize}; use serde_json::Value; pub use tool::*; + +/// You can use [`crate::object!`] or [`crate::model::object`] to create a json object quickly. pub type JsonObject = serde_json::Map; /// unwrap the JsonObject under [`serde_json::Value`] @@ -28,6 +34,7 @@ pub fn object(value: serde_json::Value) -> JsonObject { } } +/// Use this macro just like [`serde_json::json!`] #[cfg(feature = "macros")] #[macro_export] macro_rules! object { @@ -85,7 +92,7 @@ macro_rules! const_string { const_string!(JsonRpcVersion2_0 = "2.0"); -#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd)] pub struct ProtocolVersion(Cow<'static, str>); impl Default for ProtocolVersion { @@ -94,8 +101,9 @@ impl Default for ProtocolVersion { } } impl ProtocolVersion { - pub const LATEST: Self = Self(Cow::Borrowed("2024-11-05")); - pub const V_2024_11_05: Self = Self::LATEST; + pub const V_2025_03_26: Self = Self(Cow::Borrowed("2025-03-26")); + pub const V_2024_11_05: Self = Self(Cow::Borrowed("2024-11-05")); + pub const LATEST: Self = Self::V_2025_03_26; } impl Serialize for ProtocolVersion { @@ -116,17 +124,28 @@ impl<'de> Deserialize<'de> for ProtocolVersion { #[allow(clippy::single_match)] match s.as_str() { "2024-11-05" => return Ok(ProtocolVersion::V_2024_11_05), + "2025-03-26" => return Ok(ProtocolVersion::V_2025_03_26), _ => {} } Ok(ProtocolVersion(Cow::Owned(s))) } } + #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub enum NumberOrString { Number(u32), String(Arc), } +impl NumberOrString { + pub fn into_json_value(self) -> Value { + match self { + NumberOrString::Number(n) => Value::Number(serde_json::Number::from(n)), + NumberOrString::String(s) => Value::String(s.to_string()), + } + } +} + impl std::fmt::Display for NumberOrString { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -166,41 +185,58 @@ impl<'de> Deserialize<'de> for NumberOrString { } pub type RequestId = NumberOrString; -pub type ProgressToken = NumberOrString; -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct WithMeta

{ - #[serde(skip_serializing_if = "Option::is_none")] - pub _meta: Option, - #[serde(flatten)] - pub inner: P, -} #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(rename_all = "camelCase")] -pub struct RequestMeta { - progress_token: ProgressToken, +#[serde(transparent)] +pub struct ProgressToken(pub NumberOrString); +#[derive(Debug, Clone)] +pub struct Request { + pub method: M, + // #[serde(skip_serializing_if = "Option::is_none")] + pub params: P, + /// extensions will carry anything possible in the context, including [`Meta`] + /// + /// this is similar with the Extensions in `http` crate + pub extensions: Extensions, } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct Request>> { +#[derive(Debug, Clone)] +pub struct RequestOptionalParam { pub method: M, // #[serde(skip_serializing_if = "Option::is_none")] - pub params: P, + pub params: Option

, + /// extensions will carry anything possible in the context, including [`Meta`] + /// + /// this is similar with the Extensions in `http` crate + pub extensions: Extensions, } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] + +#[derive(Debug, Clone)] pub struct RequestNoParam { pub method: M, + /// extensions will carry anything possible in the context, including [`Meta`] + /// + /// this is similar with the Extensions in `http` crate + pub extensions: Extensions, } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct Notification>> { +#[derive(Debug, Clone)] +pub struct Notification { pub method: M, pub params: P, + /// extensions will carry anything possible in the context, including [`Meta`] + /// + /// this is similar with the Extensions in `http` crate + pub extensions: Extensions, } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct NotificationNoParam { pub method: M, + /// extensions will carry anything possible in the context, including [`Meta`] + /// + /// this is similar with the Extensions in `http` crate + pub extensions: Extensions, } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] @@ -210,9 +246,9 @@ pub struct JsonRpcRequest { #[serde(flatten)] pub request: R, } -type DefaultResponse = WithMeta; +type DefaultResponse = JsonObject; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct JsonRpcResponse { +pub struct JsonRpcResponse { pub jsonrpc: JsonRpcVersion2_0, pub id: RequestId, pub result: R, @@ -295,95 +331,111 @@ impl ErrorData { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(untagged)] -pub enum JsonRpcMessage { +pub enum JsonRpcBatchRequestItem { Request(JsonRpcRequest), + Notification(JsonRpcNotification), +} + +impl JsonRpcBatchRequestItem { + pub fn into_non_batch_message(self) -> JsonRpcMessage { + match self { + JsonRpcBatchRequestItem::Request(r) => JsonRpcMessage::Request(r), + JsonRpcBatchRequestItem::Notification(n) => JsonRpcMessage::Notification(n), + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(untagged)] +pub enum JsonRpcBatchResponseItem { Response(JsonRpcResponse), - Notification(JsonRpcNotification), Error(JsonRpcError), } -impl JsonRpcMessage { - pub fn into_message(self) -> Message { +impl JsonRpcBatchResponseItem { + pub fn into_non_batch_message(self) -> JsonRpcMessage { match self { - JsonRpcMessage::Request(JsonRpcRequest { id, request, .. }) => { - Message::Request(request, id) - } - JsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => { - Message::Response(result, id) - } - JsonRpcMessage::Notification(JsonRpcNotification { notification, .. }) => { - Message::Notification(notification) - } - JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => Message::Error(error, id), + JsonRpcBatchResponseItem::Response(r) => JsonRpcMessage::Response(r), + JsonRpcBatchResponseItem::Error(e) => JsonRpcMessage::Error(e), } } } -#[derive(Debug, Clone, PartialEq)] -pub enum Message { - Request(Req, RequestId), - Response(Resp, RequestId), - Error(ErrorData, RequestId), - Notification(Noti), +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(untagged)] +pub enum JsonRpcMessage { + Request(JsonRpcRequest), + Response(JsonRpcResponse), + Notification(JsonRpcNotification), + BatchRequest(Vec>), + BatchResponse(Vec>), + Error(JsonRpcError), } -impl Message { - pub fn into_notification(self) -> Option { +impl JsonRpcMessage { + #[inline] + pub const fn request(request: Req, id: RequestId) -> Self { + JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: JsonRpcVersion2_0, + id, + request, + }) + } + #[inline] + pub const fn response(response: Resp, id: RequestId) -> Self { + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: JsonRpcVersion2_0, + id, + result: response, + }) + } + #[inline] + pub const fn error(error: ErrorData, id: RequestId) -> Self { + JsonRpcMessage::Error(JsonRpcError { + jsonrpc: JsonRpcVersion2_0, + id, + error, + }) + } + #[inline] + pub const fn notification(notification: Not) -> Self { + JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: JsonRpcVersion2_0, + notification, + }) + } + pub fn into_request(self) -> Option<(Req, RequestId)> { match self { - Message::Notification(notification) => Some(notification), + JsonRpcMessage::Request(r) => Some((r.request, r.id)), _ => None, } } pub fn into_response(self) -> Option<(Resp, RequestId)> { match self { - Message::Response(result, id) => Some((result, id)), + JsonRpcMessage::Response(r) => Some((r.result, r.id)), _ => None, } } - pub fn into_request(self) -> Option<(Req, RequestId)> { + pub fn into_notification(self) -> Option { match self { - Message::Request(request, id) => Some((request, id)), + JsonRpcMessage::Notification(n) => Some(n.notification), _ => None, } } pub fn into_error(self) -> Option<(ErrorData, RequestId)> { match self { - Message::Error(error, id) => Some((error, id)), + JsonRpcMessage::Error(e) => Some((e.error, e.id)), _ => None, } } pub fn into_result(self) -> Option<(Result, RequestId)> { match self { - Message::Response(result, id) => Some((Ok(result), id)), - Message::Error(error, id) => Some((Err(error), id)), + JsonRpcMessage::Response(r) => Some((Ok(r.result), r.id)), + JsonRpcMessage::Error(e) => Some((Err(e.error), e.id)), + _ => None, } } - pub fn into_json_rpc_message(self) -> JsonRpcMessage { - match self { - Message::Request(request, id) => JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: JsonRpcVersion2_0, - id, - request, - }), - Message::Response(result, id) => JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion2_0, - id, - result, - }), - Message::Error(error, id) => JsonRpcMessage::Error(JsonRpcError { - jsonrpc: JsonRpcVersion2_0, - id, - error, - }), - Message::Notification(notification) => { - JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: JsonRpcVersion2_0, - notification, - }) - } - } - } } /// # Empty result @@ -493,12 +545,10 @@ impl Implementation { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] #[serde(rename_all = "camelCase")] -pub struct PaginatedRequestParamInner { +pub struct PaginatedRequestParam { #[serde(skip_serializing_if = "Option::is_none")] pub cursor: Option, } - -pub type PaginatedRequestParam = Option; const_string!(PingRequestMethod = "ping"); pub type PingRequest = RequestNoParam; @@ -512,6 +562,9 @@ pub struct ProgressNotificationParam { /// Total number of items to process (or total progress required), if known #[serde(skip_serializing_if = "Option::is_none")] pub total: Option, + /// An optional message describing the current progress. + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, } pub type ProgressNotification = Notification; @@ -533,14 +586,15 @@ macro_rules! paginated_result { } const_string!(ListResourcesRequestMethod = "resources/list"); -pub type ListResourcesRequest = Request; +pub type ListResourcesRequest = + RequestOptionalParam; paginated_result!(ListResourcesResult { resources: Vec }); const_string!(ListResourceTemplatesRequestMethod = "resources/templates/list"); pub type ListResourceTemplatesRequest = - Request; + RequestOptionalParam; paginated_result!(ListResourceTemplatesResult { resource_templates: Vec }); @@ -589,7 +643,7 @@ pub type ResourceUpdatedNotification = Notification; const_string!(ListPromptsRequestMethod = "prompts/list"); -pub type ListPromptsRequest = Request; +pub type ListPromptsRequest = RequestOptionalParam; paginated_result!(ListPromptsResult { prompts: Vec }); @@ -793,7 +847,7 @@ impl CallToolResult { } const_string!(ListToolsRequestMethod = "tools/list"); -pub type ListToolsRequest = Request; +pub type ListToolsRequest = RequestOptionalParam; paginated_result!( ListToolsResult { tools: Vec @@ -840,7 +894,7 @@ macro_rules! ts_union { export type $U: ident = $(|)?$($V: ident)|*; ) => { - #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] + #[derive(Debug, Serialize, Deserialize, Clone)] #[serde(untagged)] pub enum $U { $($V($V),)* @@ -884,7 +938,6 @@ impl ClientResult { } pub type ClientJsonRpcMessage = JsonRpcMessage; -pub type ClientMessage = Message; ts_union!( export type ServerRequest = @@ -926,7 +979,6 @@ impl ServerResult { } pub type ServerJsonRpcMessage = JsonRpcMessage; -pub type ServerMessage = Message; impl TryInto for ServerNotification { type Error = ServerNotification; @@ -975,12 +1027,14 @@ mod tests { }); let message: ClientJsonRpcMessage = serde_json::from_value(raw.clone()).expect("invalid notification"); - let message = message.into_message(); match &message { - ClientMessage::Notification(ClientNotification::InitializedNotification(_n)) => {} + ClientJsonRpcMessage::Notification(JsonRpcNotification { + notification: ClientNotification::InitializedNotification(_n), + .. + }) => {} _ => panic!("Expected Notification"), } - let json = serde_json::to_value(message.into_json_rpc_message()).expect("valid json"); + let json = serde_json::to_value(message).expect("valid json"); assert_eq!(json, raw); } @@ -999,7 +1053,7 @@ mod tests { assert_eq!(r.id, RequestId::Number(1)); assert_eq!(r.request.method, "request"); assert_eq!( - &r.request.params.as_ref().unwrap().inner, + &r.request.params, json!({"key": "value"}) .as_object() .expect("should be an object") @@ -1057,10 +1111,7 @@ mod tests { }); let request: ClientJsonRpcMessage = serde_json::from_value(request.clone()).expect("invalid request"); - let (request, id) = request - .into_message() - .into_request() - .expect("expect request"); + let (request, id) = request.into_request().expect("should be a request"); assert_eq!(id, RequestId::Number(1)); match request { ClientRequest::InitializeRequest(Request { @@ -1071,6 +1122,7 @@ mod tests { capabilities, client_info, }, + .. }) => { assert_eq!(capabilities.roots.unwrap().list_changed, Some(true)); assert_eq!(capabilities.sampling.unwrap().len(), 0); @@ -1083,7 +1135,6 @@ mod tests { serde_json::from_value(raw_response_json.clone()).expect("invalid response"); let (response, id) = server_response .clone() - .into_message() .into_response() .expect("expect response"); assert_eq!(id, RequestId::Number(1)); @@ -1113,4 +1164,11 @@ mod tests { assert_eq!(server_response_json, raw_response_json); } + + #[test] + fn test_protocol_version_order() { + let v1 = ProtocolVersion::V_2024_11_05; + let v2 = ProtocolVersion::V_2025_03_26; + assert!(v1 < v2); + } } diff --git a/crates/rmcp/src/model/annotated.rs b/crates/rmcp/src/model/annotated.rs index 4401ae04..31ea65db 100644 --- a/crates/rmcp/src/model/annotated.rs +++ b/crates/rmcp/src/model/annotated.rs @@ -4,8 +4,8 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use super::{ - RawContent, RawEmbeddedResource, RawImageContent, RawResource, RawResourceTemplate, - RawTextContent, Role, + RawAudioContent, RawContent, RawEmbeddedResource, RawImageContent, RawResource, + RawResourceTemplate, RawTextContent, Role, }; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] @@ -162,6 +162,7 @@ macro_rules! annotate { annotate!(RawContent); annotate!(RawTextContent); annotate!(RawImageContent); +annotate!(RawAudioContent); annotate!(RawEmbeddedResource); annotate!(RawResource); annotate!(RawResourceTemplate); diff --git a/crates/rmcp/src/model/capabilities.rs b/crates/rmcp/src/model/capabilities.rs index 087e4058..7e853402 100644 --- a/crates/rmcp/src/model/capabilities.rs +++ b/crates/rmcp/src/model/capabilities.rs @@ -76,6 +76,8 @@ pub struct ServerCapabilities { #[serde(skip_serializing_if = "Option::is_none")] pub logging: Option, #[serde(skip_serializing_if = "Option::is_none")] + pub completions: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub prompts: Option, #[serde(skip_serializing_if = "Option::is_none")] pub resources: Option, @@ -192,14 +194,15 @@ builder! { ServerCapabilities { experimental: ExperimentalCapabilities, logging: JsonObject, + completions: JsonObject, prompts: PromptsCapability, resources: ResourcesCapability, tools: ToolsCapability } } -impl - ServerCapabilitiesBuilder> +impl + ServerCapabilitiesBuilder> { pub fn enable_tool_list_changed(mut self) -> Self { if let Some(c) = self.tools.as_mut() { @@ -209,8 +212,8 @@ impl } } -impl - ServerCapabilitiesBuilder> +impl + ServerCapabilitiesBuilder> { pub fn enable_prompts_list_changed(mut self) -> Self { if let Some(c) = self.prompts.as_mut() { @@ -220,8 +223,8 @@ impl } } -impl - ServerCapabilitiesBuilder> +impl + ServerCapabilitiesBuilder> { pub fn enable_resources_list_changed(mut self) -> Self { if let Some(c) = self.resources.as_mut() { diff --git a/crates/rmcp/src/model/content.rs b/crates/rmcp/src/model/content.rs index 15cf0507..a7774d8c 100644 --- a/crates/rmcp/src/model/content.rs +++ b/crates/rmcp/src/model/content.rs @@ -37,12 +37,22 @@ impl EmbeddedResource { } } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RawAudioContent { + pub data: String, + pub mime_type: String, +} + +pub type AudioContent = Annotated; + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "camelCase")] pub enum RawContent { Text(RawTextContent), Image(RawImageContent), Resource(RawEmbeddedResource), + Audio(AudioContent), } pub type Content = Annotated; diff --git a/crates/rmcp/src/model/extension.rs b/crates/rmcp/src/model/extension.rs new file mode 100644 index 00000000..a9d78d4c --- /dev/null +++ b/crates/rmcp/src/model/extension.rs @@ -0,0 +1,337 @@ +//! A container for those extra data could be carried on request or notification. +//! +//! This file is copied and modified from crate [http](https://github.com/hyperium/http). +//! +//! - Original code license: +//! - Original code: +use std::{ + any::{Any, TypeId}, + collections::HashMap, + fmt, + hash::{BuildHasherDefault, Hasher}, +}; + +type AnyMap = HashMap, BuildHasherDefault>; + +// With TypeIds as keys, there's no need to hash them. They are already hashes +// themselves, coming from the compiler. The IdHasher just holds the u64 of +// the TypeId, and then returns it, instead of doing any bit fiddling. +#[derive(Default)] +struct IdHasher(u64); + +impl Hasher for IdHasher { + fn write(&mut self, _: &[u8]) { + unreachable!("TypeId calls write_u64"); + } + + #[inline] + fn write_u64(&mut self, id: u64) { + self.0 = id; + } + + #[inline] + fn finish(&self) -> u64 { + self.0 + } +} + +/// A type map of protocol extensions. +/// +/// `Extensions` can be used by `Request` `Notification` and `Response` to store +/// extra data derived from the underlying protocol. +#[derive(Clone, Default)] +pub struct Extensions { + // If extensions are never used, no need to carry around an empty HashMap. + // That's 3 words. Instead, this is only 1 word. + map: Option>, +} + +impl Extensions { + /// Create an empty `Extensions`. + #[inline] + pub fn new() -> Extensions { + Extensions { map: None } + } + + /// Insert a type into this `Extensions`. + /// + /// If a extension of this type already existed, it will + /// be returned and replaced with the new one. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// assert!(ext.insert(5i32).is_none()); + /// assert!(ext.insert(4u8).is_none()); + /// assert_eq!(ext.insert(9i32), Some(5i32)); + /// ``` + pub fn insert(&mut self, val: T) -> Option { + self.map + .get_or_insert_with(Box::default) + .insert(TypeId::of::(), Box::new(val)) + .and_then(|boxed| boxed.into_any().downcast().ok().map(|boxed| *boxed)) + } + + /// Get a reference to a type previously inserted on this `Extensions`. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// assert!(ext.get::().is_none()); + /// ext.insert(5i32); + /// + /// assert_eq!(ext.get::(), Some(&5i32)); + /// ``` + pub fn get(&self) -> Option<&T> { + self.map + .as_ref() + .and_then(|map| map.get(&TypeId::of::())) + .and_then(|boxed| (**boxed).as_any().downcast_ref()) + } + + /// Get a mutable reference to a type previously inserted on this `Extensions`. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// ext.insert(String::from("Hello")); + /// ext.get_mut::().unwrap().push_str(" World"); + /// + /// assert_eq!(ext.get::().unwrap(), "Hello World"); + /// ``` + pub fn get_mut(&mut self) -> Option<&mut T> { + self.map + .as_mut() + .and_then(|map| map.get_mut(&TypeId::of::())) + .and_then(|boxed| (**boxed).as_any_mut().downcast_mut()) + } + + /// Get a mutable reference to a type, inserting `value` if not already present on this + /// `Extensions`. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// *ext.get_or_insert(1i32) += 2; + /// + /// assert_eq!(*ext.get::().unwrap(), 3); + /// ``` + pub fn get_or_insert(&mut self, value: T) -> &mut T { + self.get_or_insert_with(|| value) + } + + /// Get a mutable reference to a type, inserting the value created by `f` if not already present + /// on this `Extensions`. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// *ext.get_or_insert_with(|| 1i32) += 2; + /// + /// assert_eq!(*ext.get::().unwrap(), 3); + /// ``` + pub fn get_or_insert_with T>( + &mut self, + f: F, + ) -> &mut T { + let out = self + .map + .get_or_insert_with(Box::default) + .entry(TypeId::of::()) + .or_insert_with(|| Box::new(f())); + (**out).as_any_mut().downcast_mut().unwrap() + } + + /// Get a mutable reference to a type, inserting the type's default value if not already present + /// on this `Extensions`. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// *ext.get_or_insert_default::() += 2; + /// + /// assert_eq!(*ext.get::().unwrap(), 2); + /// ``` + pub fn get_or_insert_default(&mut self) -> &mut T { + self.get_or_insert_with(T::default) + } + + /// Remove a type from this `Extensions`. + /// + /// If a extension of this type existed, it will be returned. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// ext.insert(5i32); + /// assert_eq!(ext.remove::(), Some(5i32)); + /// assert!(ext.get::().is_none()); + /// ``` + pub fn remove(&mut self) -> Option { + self.map + .as_mut() + .and_then(|map| map.remove(&TypeId::of::())) + .and_then(|boxed| boxed.into_any().downcast().ok().map(|boxed| *boxed)) + } + + /// Clear the `Extensions` of all inserted extensions. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// ext.insert(5i32); + /// ext.clear(); + /// + /// assert!(ext.get::().is_none()); + /// ``` + #[inline] + pub fn clear(&mut self) { + if let Some(ref mut map) = self.map { + map.clear(); + } + } + + /// Check whether the extension set is empty or not. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// assert!(ext.is_empty()); + /// ext.insert(5i32); + /// assert!(!ext.is_empty()); + /// ``` + #[inline] + pub fn is_empty(&self) -> bool { + self.map.as_ref().is_none_or(|map| map.is_empty()) + } + + /// Get the number of extensions available. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext = Extensions::new(); + /// assert_eq!(ext.len(), 0); + /// ext.insert(5i32); + /// assert_eq!(ext.len(), 1); + /// ``` + #[inline] + pub fn len(&self) -> usize { + self.map.as_ref().map_or(0, |map| map.len()) + } + + /// Extends `self` with another `Extensions`. + /// + /// If an instance of a specific type exists in both, the one in `self` is overwritten with the + /// one from `other`. + /// + /// # Example + /// + /// ``` + /// # use rmcp::model::Extensions; + /// let mut ext_a = Extensions::new(); + /// ext_a.insert(8u8); + /// ext_a.insert(16u16); + /// + /// let mut ext_b = Extensions::new(); + /// ext_b.insert(4u8); + /// ext_b.insert("hello"); + /// + /// ext_a.extend(ext_b); + /// assert_eq!(ext_a.len(), 3); + /// assert_eq!(ext_a.get::(), Some(&4u8)); + /// assert_eq!(ext_a.get::(), Some(&16u16)); + /// assert_eq!(ext_a.get::<&'static str>().copied(), Some("hello")); + /// ``` + pub fn extend(&mut self, other: Self) { + if let Some(other) = other.map { + if let Some(map) = &mut self.map { + map.extend(*other); + } else { + self.map = Some(other); + } + } + } +} + +impl fmt::Debug for Extensions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Extensions").finish() + } +} + +trait AnyClone: Any { + fn clone_box(&self) -> Box; + fn as_any(&self) -> &dyn Any; + fn as_any_mut(&mut self) -> &mut dyn Any; + fn into_any(self: Box) -> Box; +} + +impl AnyClone for T { + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn into_any(self: Box) -> Box { + self + } +} + +impl Clone for Box { + fn clone(&self) -> Self { + (**self).clone_box() + } +} + +#[test] +fn test_extensions() { + #[derive(Clone, Debug, PartialEq)] + struct MyType(i32); + + let mut extensions = Extensions::new(); + + extensions.insert(5i32); + extensions.insert(MyType(10)); + + assert_eq!(extensions.get(), Some(&5i32)); + assert_eq!(extensions.get_mut(), Some(&mut 5i32)); + + let ext2 = extensions.clone(); + + assert_eq!(extensions.remove::(), Some(5i32)); + assert!(extensions.get::().is_none()); + + // clone still has it + assert_eq!(ext2.get(), Some(&5i32)); + assert_eq!(ext2.get(), Some(&MyType(10))); + + assert_eq!(extensions.get::(), None); + assert_eq!(extensions.get(), Some(&MyType(10))); +} diff --git a/crates/rmcp/src/model/meta.rs b/crates/rmcp/src/model/meta.rs new file mode 100644 index 00000000..5c978896 --- /dev/null +++ b/crates/rmcp/src/model/meta.rs @@ -0,0 +1,150 @@ +use std::ops::{Deref, DerefMut}; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use super::{ + ClientNotification, ClientRequest, Extensions, JsonObject, NumberOrString, ProgressToken, + ServerNotification, ServerRequest, +}; + +pub trait GetMeta { + fn get_meta_mut(&mut self) -> &mut Meta; + fn get_meta(&self) -> &Meta; +} + +macro_rules! variant_extension { + ( + $Enum: ident { + $($variant: ident)* + } + ) => { + impl $Enum { + pub fn extensions(&self) -> &Extensions { + match self { + $( + $Enum::$variant(v) => &v.extensions, + )* + } + } + pub fn extensions_mut(&mut self) -> &mut Extensions { + match self { + $( + $Enum::$variant(v) => &mut v.extensions, + )* + } + } + } + impl GetMeta for $Enum { + fn get_meta_mut(&mut self) -> &mut Meta { + self.extensions_mut().get_or_insert_default() + } + fn get_meta(&self) -> &Meta { + self.extensions().get::().unwrap_or(Meta::static_empty()) + } + } + }; +} + +variant_extension! { + ClientRequest { + PingRequest + InitializeRequest + CompleteRequest + SetLevelRequest + GetPromptRequest + ListPromptsRequest + ListResourcesRequest + ListResourceTemplatesRequest + ReadResourceRequest + SubscribeRequest + UnsubscribeRequest + CallToolRequest + ListToolsRequest + } +} + +variant_extension! { + ServerRequest { + PingRequest + CreateMessageRequest + ListRootsRequest + } +} + +variant_extension! { + ClientNotification { + CancelledNotification + ProgressNotification + InitializedNotification + RootsListChangedNotification + } +} + +variant_extension! { + ServerNotification { + CancelledNotification + ProgressNotification + LoggingMessageNotification + ResourceUpdatedNotification + ResourceListChangedNotification + ToolListChangedNotification + PromptListChangedNotification + } +} +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +#[serde(transparent)] +pub struct Meta(pub JsonObject); +const PROGRESS_TOKEN_FIELD: &str = "progressToken"; +impl Meta { + pub fn new() -> Self { + Self(JsonObject::new()) + } + + pub(crate) fn static_empty() -> &'static Self { + static EMPTY: std::sync::OnceLock = std::sync::OnceLock::new(); + EMPTY.get_or_init(Default::default) + } + + pub fn get_progress_token(&self) -> Option { + self.0.get(PROGRESS_TOKEN_FIELD).and_then(|v| match v { + Value::String(s) => Some(ProgressToken(NumberOrString::String(s.to_string().into()))), + Value::Number(n) => n + .as_u64() + .map(|n| ProgressToken(NumberOrString::Number(n as u32))), + _ => None, + }) + } + + pub fn set_progress_token(&mut self, token: ProgressToken) { + match token.0 { + NumberOrString::String(ref s) => self.0.insert( + PROGRESS_TOKEN_FIELD.to_string(), + Value::String(s.to_string()), + ), + NumberOrString::Number(n) => self + .0 + .insert(PROGRESS_TOKEN_FIELD.to_string(), Value::Number(n.into())), + }; + } + + pub fn extend(&mut self, other: Meta) { + for (k, v) in other.0.into_iter() { + self.0.insert(k, v); + } + } +} + +impl Deref for Meta { + type Target = JsonObject; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Meta { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/crates/rmcp/src/model/serde_impl.rs b/crates/rmcp/src/model/serde_impl.rs new file mode 100644 index 00000000..09222d52 --- /dev/null +++ b/crates/rmcp/src/model/serde_impl.rs @@ -0,0 +1,267 @@ +use std::borrow::Cow; + +use serde::{Deserialize, Serialize}; + +use super::{ + Extensions, Meta, Notification, NotificationNoParam, Request, RequestNoParam, + RequestOptionalParam, +}; +#[derive(Serialize, Deserialize)] +struct WithMeta<'a, P> { + #[serde(skip_serializing_if = "Option::is_none")] + _meta: Option>, + #[serde(flatten)] + _rest: P, +} + +#[derive(Serialize, Deserialize)] +struct Proxy<'a, M, P> { + method: M, + params: WithMeta<'a, P>, +} + +#[derive(Serialize, Deserialize)] +struct ProxyOptionalParam<'a, M, P> { + method: M, + params: Option>, +} + +#[derive(Serialize, Deserialize)] +struct ProxyNoParam { + method: M, +} + +impl Serialize for Request +where + M: Serialize, + R: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions = &self.extensions; + let _meta = extensions.get::().map(Cow::Borrowed); + Proxy::serialize( + &Proxy { + method: &self.method, + params: WithMeta { + _rest: &self.params, + _meta, + }, + }, + serializer, + ) + } +} + +impl<'de, M, R> Deserialize<'de> for Request +where + M: Deserialize<'de>, + R: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let body = Proxy::deserialize(deserializer)?; + let _meta = body.params._meta.map(|m| m.into_owned()); + let mut extensions = Extensions::new(); + if let Some(meta) = _meta { + extensions.insert(meta); + } + Ok(Request { + extensions, + method: body.method, + params: body.params._rest, + }) + } +} + +impl Serialize for RequestOptionalParam +where + M: Serialize, + R: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions = &self.extensions; + let _meta = extensions.get::().map(Cow::Borrowed); + Proxy::serialize( + &Proxy { + method: &self.method, + params: WithMeta { + _rest: &self.params, + _meta, + }, + }, + serializer, + ) + } +} + +impl<'de, M, R> Deserialize<'de> for RequestOptionalParam +where + M: Deserialize<'de>, + R: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let body = ProxyOptionalParam::<'_, _, Option>::deserialize(deserializer)?; + let mut params = None; + let mut _meta = None; + if let Some(body_params) = body.params { + params = body_params._rest; + _meta = body_params._meta.map(|m| m.into_owned()); + } + let mut extensions = Extensions::new(); + if let Some(meta) = _meta { + extensions.insert(meta); + } + Ok(RequestOptionalParam { + extensions, + method: body.method, + params, + }) + } +} + +impl Serialize for RequestNoParam +where + M: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions = &self.extensions; + let _meta = extensions.get::().map(Cow::Borrowed); + ProxyNoParam::serialize( + &ProxyNoParam { + method: &self.method, + }, + serializer, + ) + } +} + +impl<'de, M> Deserialize<'de> for RequestNoParam +where + M: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let body = ProxyNoParam::<_>::deserialize(deserializer)?; + let extensions = Extensions::new(); + Ok(RequestNoParam { + extensions, + method: body.method, + }) + } +} + +impl Serialize for Notification +where + M: Serialize, + R: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions = &self.extensions; + let _meta = extensions.get::().map(Cow::Borrowed); + Proxy::serialize( + &Proxy { + method: &self.method, + params: WithMeta { + _rest: &self.params, + _meta, + }, + }, + serializer, + ) + } +} + +impl<'de, M, R> Deserialize<'de> for Notification +where + M: Deserialize<'de>, + R: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let body = Proxy::deserialize(deserializer)?; + let _meta = body.params._meta.map(|m| m.into_owned()); + let mut extensions = Extensions::new(); + if let Some(meta) = _meta { + extensions.insert(meta); + } + Ok(Notification { + extensions, + method: body.method, + params: body.params._rest, + }) + } +} + +impl Serialize for NotificationNoParam +where + M: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions = &self.extensions; + let _meta = extensions.get::().map(Cow::Borrowed); + ProxyNoParam::serialize( + &ProxyNoParam { + method: &self.method, + }, + serializer, + ) + } +} + +impl<'de, M> Deserialize<'de> for NotificationNoParam +where + M: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let body = ProxyNoParam::<_>::deserialize(deserializer)?; + let extensions = Extensions::new(); + Ok(NotificationNoParam { + extensions, + method: body.method, + }) + } +} + +#[cfg(test)] +mod test { + use serde_json::json; + + use crate::model::ListToolsRequest; + + #[test] + fn test_deserialize_lost_tools_request() { + let _req: ListToolsRequest = serde_json::from_value(json!( + { + "method": "tools/list", + } + )) + .unwrap(); + } +} diff --git a/crates/rmcp/src/model/tool.rs b/crates/rmcp/src/model/tool.rs index cb3e359b..3d973eb5 100644 --- a/crates/rmcp/src/model/tool.rs +++ b/crates/rmcp/src/model/tool.rs @@ -18,6 +18,102 @@ pub struct Tool { pub description: Option>, /// A JSON Schema object defining the expected parameters for the tool pub input_schema: Arc, + /// Optional additional tool information. + pub annotations: Option, +} + +/// Additional properties describing a Tool to clients. +/// +/// NOTE: all properties in ToolAnnotations are **hints**. +/// They are not guaranteed to provide a faithful description of +/// tool behavior (including descriptive properties like `title`). +/// +/// Clients should never make tool use decisions based on ToolAnnotations +/// received from untrusted servers. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct ToolAnnotations { + /// A human-readable title for the tool. + pub title: Option, + + /// If true, the tool does not modify its environment. + /// + /// Default: false + pub read_only_hint: Option, + + /// If true, the tool may perform destructive updates to its environment. + /// If false, the tool performs only additive updates. + /// + /// (This property is meaningful only when `readOnlyHint == false`) + /// + /// Default: true + /// A human-readable description of the tool's purpose. + pub destructive_hint: Option, + + /// If true, calling the tool repeatedly with the same arguments + /// will have no additional effect on the its environment. + /// + /// (This property is meaningful only when `readOnlyHint == false`) + /// + /// Default: false. + pub idempotent_hint: Option, + + /// If true, this tool may interact with an "open world" of external + /// entities. If false, the tool's domain of interaction is closed. + /// For example, the world of a web search tool is open, whereas that + /// of a memory tool is not. + /// + /// Default: true + pub open_world_hint: Option, +} + +impl ToolAnnotations { + pub fn new() -> Self { + Self::default() + } + pub fn with_title(title: T) -> Self + where + T: Into, + { + ToolAnnotations { + title: Some(title.into()), + ..Self::default() + } + } + pub fn read_only(self, read_only: bool) -> Self { + ToolAnnotations { + read_only_hint: Some(read_only), + ..self + } + } + pub fn destructive(self, destructive: bool) -> Self { + ToolAnnotations { + destructive_hint: Some(destructive), + ..self + } + } + pub fn idempotent(self, idempotent: bool) -> Self { + ToolAnnotations { + idempotent_hint: Some(idempotent), + ..self + } + } + pub fn open_world(self, open_world: bool) -> Self { + ToolAnnotations { + open_world_hint: Some(open_world), + ..self + } + } + + /// If not set, defaults to true. + pub fn is_destructive(&self) -> bool { + self.destructive_hint.unwrap_or(true) + } + + /// If not set, defaults to false. + pub fn is_idempotent(&self) -> bool { + self.idempotent_hint.unwrap_or(false) + } } impl Tool { @@ -32,6 +128,14 @@ impl Tool { name: name.into(), description: Some(description.into()), input_schema: input_schema.into(), + annotations: None, + } + } + + pub fn annotate(self, annotations: ToolAnnotations) -> Self { + Tool { + annotations: Some(annotations), + ..self } } diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 9ddb37e7..b23c61fa 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -4,7 +4,10 @@ use thiserror::Error; use crate::{ error::Error as McpError, model::{ - CancelledNotification, CancelledNotificationParam, JsonRpcMessage, Message, RequestId, + CancelledNotification, CancelledNotificationParam, GetMeta, JsonRpcBatchRequestItem, + JsonRpcBatchResponseItem, JsonRpcError, JsonRpcMessage, JsonRpcNotification, + JsonRpcRequest, JsonRpcResponse, Meta, NumberOrString, ProgressToken, RequestId, + ServerJsonRpcMessage, }, transport::IntoTransport, }; @@ -56,12 +59,12 @@ impl TransferObject for T where #[allow(private_bounds, reason = "there's no the third implementation")] pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone { - type Req: TransferObject; + type Req: TransferObject + GetMeta; type Resp: TransferObject; type Not: TryInto + From + TransferObject; - type PeerReq: TransferObject; + type PeerReq: TransferObject + GetMeta; type PeerResp: TransferObject; type PeerNot: TryInto + From @@ -79,11 +82,6 @@ pub type RxJsonRpcMessage = JsonRpcMessage< ::PeerNot, >; -pub type TxMessage = - Message<::Req, ::Resp, ::Not>; -pub type RxMessage = - Message<::PeerReq, ::PeerResp, ::PeerNot>; - pub trait Service: Send + Sync + 'static { fn handle_request( &self, @@ -192,7 +190,7 @@ impl> DynService for S { } use std::{ - collections::HashMap, + collections::{HashMap, VecDeque}, ops::Deref, sync::{Arc, atomic::AtomicU32}, time::Duration, @@ -204,17 +202,32 @@ pub trait RequestIdProvider: Send + Sync + 'static { fn next_request_id(&self) -> RequestId; } +pub trait ProgressTokenProvider: Send + Sync + 'static { + fn next_progress_token(&self) -> ProgressToken; +} + +pub type AtomicU32RequestIdProvider = AtomicU32Provider; +pub type AtomicU32ProgressTokenProvider = AtomicU32Provider; + #[derive(Debug, Default)] -pub struct AtomicU32RequestIdProvider { +pub struct AtomicU32Provider { id: AtomicU32, } -impl RequestIdProvider for AtomicU32RequestIdProvider { +impl RequestIdProvider for AtomicU32Provider { fn next_request_id(&self) -> RequestId { RequestId::Number(self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst)) } } +impl ProgressTokenProvider for AtomicU32Provider { + fn next_progress_token(&self) -> ProgressToken { + ProgressToken(NumberOrString::Number( + self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst), + )) + } +} + type Responder = tokio::sync::oneshot::Sender; /// A handle to a remote request @@ -228,6 +241,7 @@ pub struct RequestHandle { pub options: PeerRequestOptions, pub peer: Peer, pub id: RequestId, + pub progress_token: ProgressToken, } impl RequestHandle { @@ -251,6 +265,7 @@ impl RequestHandle { reason: Some(Self::REQUEST_TIMEOUT_REASON.to_owned()), }, method: crate::model::CancelledNotificationMethod, + extensions: Default::default(), }; let _ = self.peer.send_notification(notification.into()).await; error @@ -271,6 +286,7 @@ impl RequestHandle { reason, }, method: crate::model::CancelledNotificationMethod, + extensions: Default::default(), }; self.peer.send_notification(notification.into()).await?; Ok(()) @@ -278,24 +294,28 @@ impl RequestHandle { } #[derive(Debug)] -pub enum PeerSinkMessage { - Request( - R::Req, - RequestId, - Responder>, - ), - Notification(R::Not, Responder>), +pub(crate) enum PeerSinkMessage { + Request { + request: R::Req, + id: RequestId, + responder: Responder>, + }, + Notification { + notification: R::Not, + responder: Responder>, + }, } /// An interface to fetch the remote client or server /// /// For general purpose, call [`Peer::send_request`] or [`Peer::send_notification`] to send message to remote peer. /// -/// To create a cancellable request, call [`Peer::send_cancellable_request`]. +/// To create a cancellable request, call [`Peer::send_request_with_option`]. #[derive(Clone)] pub struct Peer { tx: mpsc::Sender>, request_id_provider: Arc, + progress_token_provider: Arc, info: Arc, } @@ -312,7 +332,8 @@ type ProxyOutbound = mpsc::Receiver>; #[derive(Debug, Default)] pub struct PeerRequestOptions { - timeout: Option, + pub timeout: Option, + pub meta: Option, } impl PeerRequestOptions { @@ -323,7 +344,7 @@ impl PeerRequestOptions { impl Peer { const CLIENT_CHANNEL_BUFFER_SIZE: usize = 1024; - pub fn new( + pub(crate) fn new( request_id_provider: Arc, peer_info: R::PeerInfo, ) -> (Peer, ProxyOutbound) { @@ -332,6 +353,7 @@ impl Peer { Self { tx, request_id_provider, + progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()), info: peer_info.into(), }, rx, @@ -340,7 +362,10 @@ impl Peer { pub async fn send_notification(&self, notification: R::Not) -> Result<(), ServiceError> { let (responder, receiver) = tokio::sync::oneshot::channel(); self.tx - .send(PeerSinkMessage::Notification(notification, responder)) + .send(PeerSinkMessage::Notification { + notification, + responder, + }) .await .map_err(|_m| { ServiceError::Transport(std::io::Error::other("disconnected: receiver dropped")) @@ -350,25 +375,46 @@ impl Peer { })? } pub async fn send_request(&self, request: R::Req) -> Result { - self.send_cancellable_request(request, PeerRequestOptions::no_options()) + self.send_request_with_option(request, PeerRequestOptions::no_options()) .await? .await_response() .await } + pub async fn send_cancellable_request( &self, request: R::Req, options: PeerRequestOptions, + ) -> Result, ServiceError> { + self.send_request_with_option(request, options).await + } + + pub async fn send_request_with_option( + &self, + mut request: R::Req, + options: PeerRequestOptions, ) -> Result, ServiceError> { let id = self.request_id_provider.next_request_id(); + let progress_token = self.progress_token_provider.next_progress_token(); + request + .get_meta_mut() + .set_progress_token(progress_token.clone()); + if let Some(meta) = options.meta.clone() { + request.get_meta_mut().extend(meta); + } let (responder, receiver) = tokio::sync::oneshot::channel(); self.tx - .send(PeerSinkMessage::Request(request, id.clone(), responder)) + .send(PeerSinkMessage::Request { + request, + id: id.clone(), + responder, + }) .await .map_err(|_m| ServiceError::Transport(std::io::Error::other("disconnected")))?; Ok(RequestHandle { id, rx: receiver, + progress_token, options, peer: self.clone(), }) @@ -424,6 +470,7 @@ pub struct RequestContext { /// this token will be cancelled when the [`CancelledNotification`] is received. pub ct: CancellationToken, pub id: RequestId, + pub meta: Meta, /// An interface to fetch the remote client or server pub peer: Peer, } @@ -464,7 +511,7 @@ async fn serve_inner( mut service: S, transport: T, peer_info: R::PeerInfo, - id_provider: Arc, + id_provider: Arc, ct: CancellationToken, ) -> Result, E> where @@ -475,9 +522,8 @@ where { use futures::{SinkExt, StreamExt}; const SINK_PROXY_BUFFER_SIZE: usize = 64; - let (sink_proxy_tx, mut sink_proxy_rx) = tokio::sync::mpsc::channel::< - Message<::Req, ::Resp, ::Not>, - >(SINK_PROXY_BUFFER_SIZE); + let (sink_proxy_tx, mut sink_proxy_rx) = + tokio::sync::mpsc::channel::>(SINK_PROXY_BUFFER_SIZE); if R::IS_CLIENT { tracing::info!(?peer_info, "Service initialized as client"); @@ -501,6 +547,7 @@ where let (mut sink, mut stream) = transport.into_transport(); let mut sink = std::pin::pin!(sink); let mut stream = std::pin::pin!(stream); + let mut batch_messages = VecDeque::>::new(); #[derive(Debug)] enum Event { ProxyMessage(P), @@ -508,57 +555,66 @@ where ToSink(T), } let quit_reason = loop { - let evt = tokio::select! { - m = sink_proxy_rx.recv() => { - if let Some(m) = m { - Event::ToSink(m) - } else { - continue + let evt = if let Some(m) = batch_messages.pop_front() { + Event::PeerMessage(m) + } else { + tokio::select! { + m = sink_proxy_rx.recv() => { + if let Some(m) = m { + Event::ToSink(m) + } else { + continue + } } - } - m = stream.next() => { - if let Some(m) = m { - Event::PeerMessage(m.into_message()) - } else { - // input stream closed - tracing::info!("input stream terminated"); - break QuitReason::Closed + m = stream.next() => { + if let Some(m) = m { + Event::PeerMessage(m) + } else { + // input stream closed + tracing::info!("input stream terminated"); + break QuitReason::Closed + } } - } - m = peer_proxy.recv() => { - if let Some(m) = m { - Event::ProxyMessage(m) - } else { - continue + m = peer_proxy.recv() => { + if let Some(m) = m { + Event::ProxyMessage(m) + } else { + continue + } + } + _ = serve_loop_ct.cancelled() => { + tracing::info!("task cancelled"); + break QuitReason::Cancelled } - } - _ = serve_loop_ct.cancelled() => { - tracing::info!("task cancelled"); - break QuitReason::Cancelled } }; + tracing::debug!(?evt, "new event"); match evt { // response and error - Event::ToSink(e) => { - if let Some(id) = match &e { - Message::Response(_, id) => Some(id), - Message::Error(_, id) => Some(id), + Event::ToSink(m) => { + if let Some(id) = match &m { + JsonRpcMessage::Response(response) => Some(&response.id), + JsonRpcMessage::Error(error) => Some(&error.id), _ => None, } { if let Some(ct) = local_ct_pool.remove(id) { ct.cancel(); } - let send_result = sink.send(e.into_json_rpc_message()).await; + let send_result = sink.send(m).await; if let Err(error) = send_result { tracing::error!(%error, "fail to response message"); } } } - Event::ProxyMessage(PeerSinkMessage::Request(request, id, responder)) => { + Event::ProxyMessage(PeerSinkMessage::Request { + request, + id, + responder, + }) => { local_responder_pool.insert(id.clone(), responder); let send_result = sink - .send(Message::Request(request, id.clone()).into_json_rpc_message()) + .send(JsonRpcMessage::request(request, id.clone())) .await; if let Err(e) = send_result { if let Some(responder) = local_responder_pool.remove(&id) { @@ -567,7 +623,10 @@ where } } } - Event::ProxyMessage(PeerSinkMessage::Notification(notification, responder)) => { + Event::ProxyMessage(PeerSinkMessage::Notification { + notification, + responder, + }) => { // catch cancellation notification let mut cancellation_param = None; let notification = match notification.try_into() { @@ -577,9 +636,7 @@ where } Err(notification) => notification, }; - let send_result = sink - .send(Message::Notification(notification).into_json_rpc_message()) - .await; + let send_result = sink.send(JsonRpcMessage::notification(notification)).await; let response = if let Err(e) = send_result { Err(ServiceError::Transport(std::io::Error::other(e))) } else { @@ -595,7 +652,9 @@ where } } } - Event::PeerMessage(Message::Request(request, id)) => { + Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest { + id, request, .. + })) => { tracing::info!(%id, ?request, "received request"); { let service = shared_service.clone(); @@ -607,24 +666,28 @@ where ct: context_ct, id: id.clone(), peer: peer.clone(), + meta: request.get_meta().clone(), }; tokio::spawn(async move { let result = service.handle_request(request, context).await; let response = match result { Ok(result) => { tracing::info!(%id, ?result, "response message"); - Message::Response(result, id) + JsonRpcMessage::response(result, id) } Err(error) => { tracing::warn!(%id, ?error, "response error"); - Message::Error(error, id) + JsonRpcMessage::error(error, id) } }; let _send_result = sink.send(response).await; }); } } - Event::PeerMessage(Message::Notification(notification)) => { + Event::PeerMessage(JsonRpcMessage::Notification(JsonRpcNotification { + notification, + .. + })) => { tracing::info!(?notification, "received notification"); // catch cancelled notification let notification = match notification.try_into() { @@ -647,7 +710,11 @@ where }); } } - Event::PeerMessage(Message::Response(result, id)) => { + Event::PeerMessage(JsonRpcMessage::Response(JsonRpcResponse { + result, + id, + .. + })) => { if let Some(responder) = local_responder_pool.remove(&id) { let response_result = responder.send(Ok(result)); if let Err(_error) = response_result { @@ -655,7 +722,7 @@ where } } } - Event::PeerMessage(Message::Error(error, id)) => { + Event::PeerMessage(JsonRpcMessage::Error(JsonRpcError { error, id, .. })) => { if let Some(responder) = local_responder_pool.remove(&id) { let _response_result = responder.send(Err(ServiceError::McpError(error))); if let Err(_error) = _response_result { @@ -663,6 +730,20 @@ where } } } + Event::PeerMessage(JsonRpcMessage::BatchRequest(batch)) => { + batch_messages.extend( + batch + .into_iter() + .map(JsonRpcBatchRequestItem::into_non_batch_message), + ); + } + Event::PeerMessage(JsonRpcMessage::BatchResponse(batch)) => { + batch_messages.extend( + batch + .into_iter() + .map(JsonRpcBatchResponseItem::into_non_batch_message), + ); + } } }; let sink_close_result = sink.close().await; diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index b2e30e0c..19ff38c6 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -4,16 +4,16 @@ use thiserror::Error; use super::*; use crate::model::{ CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification, - CancelledNotificationParam, ClientInfo, ClientMessage, ClientNotification, ClientRequest, - ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult, GetPromptRequest, - GetPromptRequestParam, GetPromptResult, InitializeRequest, InitializedNotification, - JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest, - ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest, - ListToolsResult, PaginatedRequestParam, PaginatedRequestParamInner, ProgressNotification, - ProgressNotificationParam, ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult, - RequestId, RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification, - ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParam, SubscribeRequest, - SubscribeRequestParam, UnsubscribeRequest, UnsubscribeRequestParam, + CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, ClientNotification, + ClientRequest, ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult, + GetPromptRequest, GetPromptRequestParam, GetPromptResult, InitializeRequest, + InitializedNotification, JsonRpcResponse, ListPromptsRequest, ListPromptsResult, + ListResourceTemplatesRequest, ListResourceTemplatesResult, ListResourcesRequest, + ListResourcesResult, ListToolsRequest, ListToolsResult, PaginatedRequestParam, + ProgressNotification, ProgressNotificationParam, ReadResourceRequest, ReadResourceRequestParam, + ReadResourceResult, RequestId, RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, + ServerNotification, ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParam, + SubscribeRequest, SubscribeRequestParam, UnsubscribeRequest, UnsubscribeRequestParam, }; /// It represents the error that may occur when serving the client. @@ -141,11 +141,12 @@ where let init_request = InitializeRequest { method: Default::default(), params: service.get_info(), + extensions: Default::default(), }; - sink.send( - ClientMessage::Request(ClientRequest::InitializeRequest(init_request), id.clone()) - .into_json_rpc_message(), - ) + sink.send(ClientJsonRpcMessage::request( + ClientRequest::InitializeRequest(init_request), + id.clone(), + )) .await?; let (response, response_id) = expect_response(&mut stream, "initialize response") @@ -166,12 +167,13 @@ where }; // send notification - let notification = ClientMessage::Notification(ClientNotification::InitializedNotification( - InitializedNotification { + let notification = ClientJsonRpcMessage::notification( + ClientNotification::InitializedNotification(InitializedNotification { method: Default::default(), - }, - )); - sink.send(notification.into_json_rpc_message()).await?; + extensions: Default::default(), + }), + ); + sink.send(notification).await?; serve_inner(service, (sink, stream), initialize_result, id_provider, ct).await } @@ -195,6 +197,22 @@ macro_rules! method { .send_request(ClientRequest::$Req($Req { method: Default::default(), params, + extensions: Default::default(), + })) + .await?; + match result { + ServerResult::$Resp(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), + } + } + }; + (peer_req $method:ident $Req:ident($Param: ident)? => $Resp: ident ) => { + pub async fn $method(&self, params: Option<$Param>) -> Result<$Resp, ServiceError> { + let result = self + .send_request(ClientRequest::$Req($Req { + method: Default::default(), + params, + extensions: Default::default(), })) .await?; match result { @@ -209,6 +227,7 @@ macro_rules! method { .send_request(ClientRequest::$Req($Req { method: Default::default(), params, + extensions: Default::default(), })) .await?; match result { @@ -223,6 +242,7 @@ macro_rules! method { self.send_notification(ClientNotification::$Not($Not { method: Default::default(), params, + extensions: Default::default(), })) .await?; Ok(()) @@ -232,6 +252,7 @@ macro_rules! method { pub async fn $method(&self) -> Result<(), ServiceError> { self.send_notification(ClientNotification::$Not($Not { method: Default::default(), + extensions: Default::default(), })) .await?; Ok(()) @@ -243,14 +264,14 @@ impl Peer { method!(peer_req complete CompleteRequest(CompleteRequestParam) => CompleteResult); method!(peer_req set_level SetLevelRequest(SetLevelRequestParam)); method!(peer_req get_prompt GetPromptRequest(GetPromptRequestParam) => GetPromptResult); - method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParam) => ListPromptsResult); - method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParam) => ListResourcesResult); - method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParam) => ListResourceTemplatesResult); + method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParam)? => ListPromptsResult); + method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParam)? => ListResourcesResult); + method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParam)? => ListResourceTemplatesResult); method!(peer_req read_resource ReadResourceRequest(ReadResourceRequestParam) => ReadResourceResult); method!(peer_req subscribe SubscribeRequest(SubscribeRequestParam) ); method!(peer_req unsubscribe UnsubscribeRequest(UnsubscribeRequestParam)); method!(peer_req call_tool CallToolRequest(CallToolRequestParam) => CallToolResult); - method!(peer_req list_tools ListToolsRequest(PaginatedRequestParam) => ListToolsResult); + method!(peer_req list_tools ListToolsRequest(PaginatedRequestParam)? => ListToolsResult); method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam)); method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam)); @@ -267,7 +288,7 @@ impl Peer { let mut cursor = None; loop { let result = self - .list_tools(Some(PaginatedRequestParamInner { cursor })) + .list_tools(Some(PaginatedRequestParam { cursor })) .await?; tools.extend(result.tools); cursor = result.next_cursor; @@ -286,7 +307,7 @@ impl Peer { let mut cursor = None; loop { let result = self - .list_prompts(Some(PaginatedRequestParamInner { cursor })) + .list_prompts(Some(PaginatedRequestParam { cursor })) .await?; prompts.extend(result.prompts); cursor = result.next_cursor; @@ -305,7 +326,7 @@ impl Peer { let mut cursor = None; loop { let result = self - .list_resources(Some(PaginatedRequestParamInner { cursor })) + .list_resources(Some(PaginatedRequestParam { cursor })) .await?; resources.extend(result.resources); cursor = result.next_cursor; @@ -326,7 +347,7 @@ impl Peer { let mut cursor = None; loop { let result = self - .list_resource_templates(Some(PaginatedRequestParamInner { cursor })) + .list_resource_templates(Some(PaginatedRequestParam { cursor })) .await?; resource_templates.extend(result.resource_templates); cursor = result.next_cursor; diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index 9270488b..120e2741 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -4,12 +4,12 @@ use thiserror::Error; use super::*; use crate::model::{ CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, - ClientMessage, ClientNotification, ClientRequest, ClientResult, CreateMessageRequest, + ClientNotification, ClientRequest, ClientResult, CreateMessageRequest, CreateMessageRequestParam, CreateMessageResult, ListRootsRequest, ListRootsResult, LoggingMessageNotification, LoggingMessageNotificationParam, ProgressNotification, ProgressNotificationParam, PromptListChangedNotification, ResourceListChangedNotification, - ResourceUpdatedNotification, ResourceUpdatedNotificationParam, ServerInfo, ServerMessage, - ServerNotification, ServerRequest, ServerResult, ToolListChangedNotification, + ResourceUpdatedNotification, ResourceUpdatedNotificationParam, ServerInfo, ServerNotification, + ServerRequest, ServerResult, ToolListChangedNotification, }; #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] @@ -33,10 +33,10 @@ impl ServiceRole for RoleServer { #[derive(Error, Debug)] pub enum ServerError { #[error("expect initialized request, but received: {0:?}")] - ExpectedInitRequest(Option), + ExpectedInitRequest(Option), #[error("expect initialized notification, but received: {0:?}")] - ExpectedInitNotification(Option), + ExpectedInitNotification(Option), #[error("connection closed: {0}")] ConnectionClosed(String), @@ -75,15 +75,17 @@ where } /// Helper function to get the next message from the stream -async fn expect_next_message(stream: &mut S, context: &str) -> Result +async fn expect_next_message( + stream: &mut S, + context: &str, +) -> Result where S: StreamExt + Unpin, { - Ok(stream + stream .next() .await - .ok_or_else(|| ServerError::ConnectionClosed(context.to_string()))? - .into_message()) + .ok_or_else(|| ServerError::ConnectionClosed(context.to_string())) } /// Helper function to expect a request from the stream @@ -144,16 +146,28 @@ where let ClientRequest::InitializeRequest(peer_info) = request else { return Err(handle_server_error(ServerError::ExpectedInitRequest(Some( - ClientMessage::Request(request, id), + ClientJsonRpcMessage::request(request, id), )))); }; // Send initialize response - let init_response = service.get_info(); - sink.send( - ServerMessage::Response(ServerResult::InitializeResult(init_response), id) - .into_json_rpc_message(), - ) + let mut init_response = service.get_info(); + let protocol_version = match peer_info + .params + .protocol_version + .partial_cmp(&init_response.protocol_version) + .ok_or(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unsupported protocol version", + ))? { + std::cmp::Ordering::Less => peer_info.params.protocol_version.clone(), + _ => init_response.protocol_version, + }; + init_response.protocol_version = protocol_version; + sink.send(ServerJsonRpcMessage::response( + ServerResult::InitializeResult(init_response), + id, + )) .await?; // Wait for initialize notification @@ -163,7 +177,7 @@ where let ClientNotification::InitializedNotification(_) = notification else { return Err(handle_server_error(ServerError::ExpectedInitNotification( - Some(ClientMessage::Notification(notification)), + Some(ClientJsonRpcMessage::notification(notification)), ))); }; @@ -177,6 +191,7 @@ macro_rules! method { let result = self .send_request(ServerRequest::$Req($Req { method: Default::default(), + extensions: Default::default(), })) .await?; match result { @@ -191,6 +206,7 @@ macro_rules! method { .send_request(ServerRequest::$Req($Req { method: Default::default(), params, + extensions: Default::default(), })) .await?; match result { @@ -224,6 +240,7 @@ macro_rules! method { self.send_notification(ServerNotification::$Not($Not { method: Default::default(), params, + extensions: Default::default(), })) .await?; Ok(()) @@ -233,6 +250,7 @@ macro_rules! method { pub async fn $method(&self) -> Result<(), ServiceError> { self.send_notification(ServerNotification::$Not($Not { method: Default::default(), + extensions: Default::default(), })) .await?; Ok(()) diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index 81151205..4a29204a 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -98,7 +98,7 @@ impl ServerHandler for Counter { async fn list_resources( &self, - _request: PaginatedRequestParam, + _request: Option, _: RequestContext, ) -> Result { Ok(ListResourcesResult { @@ -139,7 +139,7 @@ impl ServerHandler for Counter { async fn list_prompts( &self, - _request: PaginatedRequestParam, + _request: Option, _: RequestContext, ) -> Result { Ok(ListPromptsResult { @@ -185,7 +185,7 @@ impl ServerHandler for Counter { async fn list_resource_templates( &self, - _request: PaginatedRequestParam, + _request: Option, _: RequestContext, ) -> Result { Ok(ListResourceTemplatesResult {