From e3a724ebcb7826ddcdbf1876cb4f7d9dac16ff16 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 16 Jun 2025 19:21:20 +0800 Subject: [PATCH 01/12] refactor: refactor tool macros and router implementation - Updated the `#[tool(tool_box)]` macro to `#[tool_router]` across various modules for consistency. - Enhanced the `Calculator`, `Counter`, and `GenericService` structs to utilize `ToolRouter` for handling tool calls. - Introduced `Parameters` struct for better parameter handling in tool functions. - Added new methods for listing tools and calling tools in server handlers. - Improved test cases to reflect changes in tool routing and parameter handling. - Updated documentation and examples to align with the new router structure. --- crates/rmcp-macros/Cargo.toml | 2 +- crates/rmcp-macros/README.md | 4 +- crates/rmcp-macros/src/lib.rs | 16 +- crates/rmcp-macros/src/tool.rs | 893 ++++-------------- crates/rmcp-macros/src/tool_inherite.rs | 785 +++++++++++++++ crates/rmcp-macros/src/tool_router.rs | 90 ++ crates/rmcp/Cargo.toml | 2 +- crates/rmcp/README.md | 24 +- crates/rmcp/src/handler/server.rs | 14 + crates/rmcp/src/handler/server/router.rs | 97 ++ .../rmcp/src/handler/server/router/promt.rs | 0 crates/rmcp/src/handler/server/router/tool.rs | 364 +++++++ crates/rmcp/src/handler/server/tool.rs | 405 +++----- crates/rmcp/src/lib.rs | 6 +- crates/rmcp/src/model.rs | 11 + crates/rmcp/tests/common/calculator.rs | 48 +- crates/rmcp/tests/test_complex_schema.rs | 10 +- .../rmcp/tests/test_tool_macro_annotations.rs | 18 +- crates/rmcp/tests/test_tool_macros.rs | 125 ++- crates/rmcp/tests/test_tool_routers.rs | 74 ++ crates/rmcp/tests/test_with_js.rs | 2 +- docs/readme/README.zh-cn.md | 4 +- examples/servers/src/common/calculator.rs | 55 +- examples/servers/src/common/counter.rs | 21 +- .../servers/src/common/generic_service.rs | 37 +- 25 files changed, 2016 insertions(+), 1091 deletions(-) create mode 100644 crates/rmcp-macros/src/tool_inherite.rs create mode 100644 crates/rmcp-macros/src/tool_router.rs create mode 100644 crates/rmcp/src/handler/server/router.rs create mode 100644 crates/rmcp/src/handler/server/router/promt.rs create mode 100644 crates/rmcp/src/handler/server/router/tool.rs create mode 100644 crates/rmcp/tests/test_tool_routers.rs diff --git a/crates/rmcp-macros/Cargo.toml b/crates/rmcp-macros/Cargo.toml index 9afa8b69..c7f85684 100644 --- a/crates/rmcp-macros/Cargo.toml +++ b/crates/rmcp-macros/Cargo.toml @@ -19,7 +19,7 @@ syn = {version = "2", features = ["full"]} quote = "1" proc-macro2 = "1" serde_json = "1.0" - +darling = { version = "0.20" } [features] [dev-dependencies] \ No newline at end of file diff --git a/crates/rmcp-macros/README.md b/crates/rmcp-macros/README.md index 223e164f..873b89d4 100644 --- a/crates/rmcp-macros/README.md +++ b/crates/rmcp-macros/README.md @@ -26,7 +26,7 @@ fn calculator(&self, #[tool(param)] a: i32, #[tool(param)] b: i32) -> Result Result { @@ -34,7 +34,7 @@ impl MyHandler { } #[tool] - fn tool2(&self) -> Result { + fn tool(&self) -> Result { // Tool 2 implementation } } diff --git a/crates/rmcp-macros/src/lib.rs b/crates/rmcp-macros/src/lib.rs index ffe44ec5..b2d6a7b0 100644 --- a/crates/rmcp-macros/src/lib.rs +++ b/crates/rmcp-macros/src/lib.rs @@ -1,11 +1,25 @@ #[allow(unused_imports)] use proc_macro::TokenStream; +// mod tool_inherite; mod tool; - +mod tool_router; +// #[proc_macro_attribute] +// pub fn tool(attr: TokenStream, input: TokenStream) -> TokenStream { +// tool_inherite::tool(attr.into(), input.into()) +// .unwrap_or_else(|err| err.to_compile_error()) +// .into() +// } #[proc_macro_attribute] pub fn tool(attr: TokenStream, input: TokenStream) -> TokenStream { tool::tool(attr.into(), input.into()) .unwrap_or_else(|err| err.to_compile_error()) .into() } + +#[proc_macro_attribute] +pub fn tool_router(attr: TokenStream, input: TokenStream) -> TokenStream { + tool_router::tool_router(attr.into(), input.into()) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} \ No newline at end of file diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index 956ab320..d32f8189 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -1,344 +1,93 @@ -use std::collections::HashSet; - +use darling::FromMeta; +use darling::ast::NestedMeta; use proc_macro2::TokenStream; -use quote::{ToTokens, quote}; -use serde_json::json; -use syn::{ - Expr, FnArg, Ident, ItemFn, ItemImpl, Lit, MetaList, PatType, Token, Type, Visibility, - parse::Parse, parse_quote, spanned::Spanned, -}; - -/// Stores tool annotation attributes -#[derive(Default, Clone)] -struct ToolAnnotationAttrs(pub serde_json::Map); - -impl Parse for ToolAnnotationAttrs { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let mut attrs = serde_json::Map::new(); - - while !input.is_empty() { - let key: Ident = input.parse()?; - input.parse::()?; - let value: Lit = input.parse()?; - let value = match value { - Lit::Str(s) => json!(s.value()), - Lit::Bool(b) => json!(b.value), - _ => { - return Err(syn::Error::new( - key.span(), - "annotations must be string or boolean literals", - )); - } - }; - attrs.insert(key.to_string(), value); - if input.is_empty() { - break; - } - input.parse::()?; - } - - Ok(ToolAnnotationAttrs(attrs)) - } -} - -#[derive(Default)] -struct ToolImplItemAttrs { - tool_box: Option>, -} - -impl Parse for ToolImplItemAttrs { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let mut tool_box = None; - while !input.is_empty() { - let key: Ident = input.parse()?; - match key.to_string().as_str() { - "tool_box" => { - tool_box = Some(None); - if input.lookahead1().peek(Token![=]) { - input.parse::()?; - let value: Ident = input.parse()?; - tool_box = Some(Some(value)); - } - } - _ => { - return Err(syn::Error::new(key.span(), "unknown attribute")); - } - } - if input.is_empty() { - break; - } - input.parse::()?; - } - - Ok(ToolImplItemAttrs { tool_box }) - } -} - -#[derive(Default)] -struct ToolFnItemAttrs { - name: Option, - description: Option, - vis: Option, - annotations: Option, -} - -impl Parse for ToolFnItemAttrs { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let mut name = None; - let mut description = None; - let mut vis = None; - let mut annotations = None; - - while !input.is_empty() { - let key: Ident = input.parse()?; - input.parse::()?; - match key.to_string().as_str() { - "name" => { - let value: Expr = input.parse()?; - name = Some(value); - } - "description" => { - let value: Expr = input.parse()?; - description = Some(value); - } - "vis" => { - let value: Visibility = input.parse()?; - vis = Some(value); - } - "annotations" => { - // Parse the annotations as a nested structure - let content; - syn::braced!(content in input); - let value = content.parse()?; - annotations = Some(value); - } - _ => { - return Err(syn::Error::new(key.span(), "unknown attribute")); - } - } - if input.is_empty() { - break; - } - input.parse::()?; - } - - Ok(ToolFnItemAttrs { +use quote::{ToTokens, format_ident, quote}; +use syn::{Expr, Ident, ImplItemFn, ReturnType}; +#[derive(FromMeta, Default, Debug)] +#[darling(default)] +pub struct ToolAttribute { + /// The name of the tool + pub name: Option, + pub description: Option, + /// A JSON Schema object defining the expected parameters for the tool + pub input_schema: Option, + /// Optional additional tool information. + pub annotations: Option, +} + +pub struct ResolvedToolAttribute { + pub name: String, + pub description: Option, + pub input_schema: Expr, + pub annotations: Expr, +} + +impl ResolvedToolAttribute { + pub fn into_fn(self, fn_ident: Ident) -> syn::Result { + let Self { name, description, - vis, + input_schema, annotations, - }) - } -} - -struct ToolFnParamAttrs { - serde_meta: Vec, - schemars_meta: Vec, - ident: Ident, - rust_type: Box, -} - -impl ToTokens for ToolFnParamAttrs { - fn to_tokens(&self, tokens: &mut TokenStream) { - let ident = &self.ident; - let rust_type = &self.rust_type; - let serde_meta = &self.serde_meta; - let schemars_meta = &self.schemars_meta; - tokens.extend(quote! { - #(#[#serde_meta])* - #(#[#schemars_meta])* - pub #ident: #rust_type, - }); - } -} - -#[derive(Default)] - -enum ToolParams { - Aggregated { - rust_type: PatType, - }, - Params { - attrs: Vec, - }, - #[default] - NoParam, -} - -#[derive(Default)] -struct ToolAttrs { - fn_item: ToolFnItemAttrs, - params: ToolParams, -} -const TOOL_IDENT: &str = "tool"; -const SERDE_IDENT: &str = "serde"; -const SCHEMARS_IDENT: &str = "schemars"; -const PARAM_IDENT: &str = "param"; -const AGGREGATED_IDENT: &str = "aggr"; -const REQ_IDENT: &str = "req"; - -pub enum ParamMarker { - Param, - Aggregated, -} - -impl Parse for ParamMarker { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let ident: Ident = input.parse()?; - match ident.to_string().as_str() { - PARAM_IDENT => Ok(ParamMarker::Param), - AGGREGATED_IDENT | REQ_IDENT => Ok(ParamMarker::Aggregated), - _ => Err(syn::Error::new(ident.span(), "unknown attribute")), - } - } -} - -pub enum ToolItem { - Fn(ItemFn), - Impl(ItemImpl), -} - -impl Parse for ToolItem { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let lookahead = input.lookahead1(); - if lookahead.peek(Token![impl]) { - let item = input.parse::()?; - Ok(ToolItem::Impl(item)) + } = self; + let description = if let Some(description) = description { + quote! { Some(#description.into()) } } else { - let item = input.parse::()?; - Ok(ToolItem::Fn(item)) - } - } -} - -// dispatch impl function item and impl block item -pub(crate) fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { - let tool_item = syn::parse2::(input)?; - match tool_item { - ToolItem::Fn(item) => tool_fn_item(attr, item), - ToolItem::Impl(item) => tool_impl_item(attr, item), - } -} - -pub(crate) fn tool_impl_item(attr: TokenStream, mut input: ItemImpl) -> syn::Result { - let tool_impl_attr: ToolImplItemAttrs = syn::parse2(attr)?; - let tool_box_ident = tool_impl_attr.tool_box; - - // get all tool function ident - let mut tool_fn_idents = Vec::new(); - for item in &input.items { - if let syn::ImplItem::Fn(method) = item { - for attr in &method.attrs { - if attr.path().is_ident(TOOL_IDENT) { - tool_fn_idents.push(method.sig.ident.clone()); + quote! { None } + }; + let tokens = quote! { + pub fn #fn_ident() -> rmcp::model::Tool { + rmcp::model::Tool { + name: #name.into(), + description: #description, + input_schema: #input_schema, + annotations: #annotations, } } - } + }; + syn::parse2::(tokens) } +} - // handle different cases - if input.trait_.is_some() { - if let Some(ident) = tool_box_ident { - // check if there are generic parameters - if !input.generics.params.is_empty() { - // for trait implementation with generic parameters, directly use the already generated *_inner method - - // generate call_tool method - input.items.push(parse_quote! { - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - self.call_tool_inner(request, context).await - } - }); +#[derive(FromMeta, Debug, Default)] +#[darling(default)] +pub struct ToolAnnotationsAttribute { + /// A human-readable title for the tool. + pub title: Option, - // generate list_tools method - input.items.push(parse_quote! { - async fn list_tools( - &self, - request: Option, - context: rmcp::service::RequestContext, - ) -> Result { - self.list_tools_inner(request, context).await - } - }); - } else { - // if there are no generic parameters, add tool box derive - input.items.push(parse_quote!( - rmcp::tool_box!(@derive #ident); - )); - } - } else { - return Err(syn::Error::new( - proc_macro2::Span::call_site(), - "tool_box attribute is required for trait implementation", - )); - } - } else if let Some(ident) = tool_box_ident { - // if it is a normal impl block - if !input.generics.params.is_empty() { - // if there are generic parameters, not use tool_box! macro, but generate code directly + /// If true, the tool does not modify its environment. + /// + /// Default: false + pub read_only_hint: Option, - // create call code for each tool function - let match_arms = tool_fn_idents.iter().map(|ident| { - let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span()); - let call_fn = Ident::new(&format!("{}_tool_call", ident), ident.span()); - quote! { - name if name == Self::#attr_fn().name => { - Self::#call_fn(tcc).await - } - } - }); + /// If true, the tool may perform destructive updates to its environment. + /// If false, the tool performs only additive updates. + /// + /// (This property is meaningful only when `readOnlyHint == false`) + /// + /// Default: true + /// A human-readable description of the tool's purpose. + pub destructive_hint: Option, - let tool_attrs = tool_fn_idents.iter().map(|ident| { - let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span()); - quote! { Self::#attr_fn() } - }); + /// If true, calling the tool repeatedly with the same arguments + /// will have no additional effect on the its environment. + /// + /// (This property is meaningful only when `readOnlyHint == false`) + /// + /// Default: false. + pub idempotent_hint: Option, - // implement call_tool method - input.items.push(parse_quote! { - async fn call_tool_inner( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); - match tcc.name() { - #(#match_arms,)* - _ => Err(rmcp::Error::invalid_params("tool not found", None)), - } - } - }); - - // implement list_tools method - input.items.push(parse_quote! { - async fn list_tools_inner( - &self, - _: Option, - _: rmcp::service::RequestContext, - ) -> Result { - Ok(rmcp::model::ListToolsResult { - next_cursor: None, - tools: vec![#(#tool_attrs),*], - }) - } - }); - } else { - // if there are no generic parameters, use the original tool_box! macro - let this_type_ident = &input.self_ty; - input.items.push(parse_quote!( - rmcp::tool_box!(#this_type_ident { - #(#tool_fn_idents),* - } #ident); - )); - } - } + /// If true, this tool may interact with an "open world" of external + /// entities. If false, the tool's domain of interaction is closed. + /// For example, the world of a web search tool is open, whereas that + /// of a memory tool is not. + /// + /// Default: true + pub open_world_hint: Option, +} - Ok(quote! { - #input - }) +fn none_expr() -> Expr { + syn::parse2::(quote! { None }).unwrap() } // extract doc line from attribute @@ -364,422 +113,142 @@ fn extract_doc_line(attr: &syn::Attribute) -> Option { (!content.is_empty()).then_some(content) } -pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Result { - let mut tool_macro_attrs = ToolAttrs::default(); - let args: ToolFnItemAttrs = syn::parse2(attr)?; - tool_macro_attrs.fn_item = args; - // let mut fommated_fn_args: Punctuated = Punctuated::new(); - let mut unextractable_args_indexes = HashSet::new(); - for (index, mut fn_arg) in input_fn.sig.inputs.iter_mut().enumerate() { - enum Caught { - Param(ToolFnParamAttrs), - Aggregated(PatType), - } - let mut caught = None; - match &mut fn_arg { - FnArg::Receiver(_) => { - continue; - } - FnArg::Typed(pat_type) => { - let mut serde_metas = Vec::new(); - let mut schemars_metas = Vec::new(); - let mut arg_ident = match pat_type.pat.as_ref() { - syn::Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()), - _ => None, - }; - let raw_attrs: Vec<_> = pat_type.attrs.drain(..).collect(); - for attr in raw_attrs { - match &attr.meta { - syn::Meta::List(meta_list) => { - if meta_list.path.is_ident(TOOL_IDENT) { - let pat_type = pat_type.clone(); - let marker = meta_list.parse_args::()?; - match marker { - ParamMarker::Param => { - let Some(arg_ident) = arg_ident.take() else { - return Err(syn::Error::new( - proc_macro2::Span::call_site(), - "input param must have an ident as name", - )); - }; - caught.replace(Caught::Param(ToolFnParamAttrs { - serde_meta: Vec::new(), - schemars_meta: Vec::new(), - ident: arg_ident, - rust_type: pat_type.ty.clone(), - })); - } - ParamMarker::Aggregated => { - caught.replace(Caught::Aggregated(pat_type.clone())); - } - } - } else if meta_list.path.is_ident(SERDE_IDENT) { - serde_metas.push(meta_list.clone()); - } else if meta_list.path.is_ident(SCHEMARS_IDENT) { - schemars_metas.push(meta_list.clone()); - } else { - pat_type.attrs.push(attr); - } - } - _ => { - pat_type.attrs.push(attr); - } - } - } - match caught { - Some(Caught::Param(mut param)) => { - param.serde_meta = serde_metas; - param.schemars_meta = schemars_metas; - match &mut tool_macro_attrs.params { - ToolParams::Params { attrs } => { - attrs.push(param); - } - _ => { - tool_macro_attrs.params = ToolParams::Params { attrs: vec![param] }; - } - } - unextractable_args_indexes.insert(index); - } - Some(Caught::Aggregated(rust_type)) => { - if let ToolParams::Params { .. } = tool_macro_attrs.params { - return Err(syn::Error::new( - rust_type.span(), - "cannot mix aggregated and individual parameters", - )); - } - tool_macro_attrs.params = ToolParams::Aggregated { rust_type }; - unextractable_args_indexes.insert(index); - } - None => {} - } - } - } - } - - // input_fn.sig.inputs = fommated_fn_args; - let name = if let Some(expr) = tool_macro_attrs.fn_item.name { - expr +pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { + let attribute = if attr.is_empty() { + Default::default() } else { - let fn_name = &input_fn.sig.ident; - parse_quote! { - stringify!(#fn_name) - } + let attr_args = NestedMeta::parse_meta_list(attr)?; + ToolAttribute::from_list(&attr_args)? }; - let tool_attr_fn_ident = Ident::new( - &format!("{}_tool_attr", input_fn.sig.ident), - proc_macro2::Span::call_site(), - ); - - // generate get tool attr function - let tool_attr_fn = { - let description = if let Some(expr) = tool_macro_attrs.fn_item.description { - // Use explicitly provided description if available - expr - } else { - // Try to extract documentation comments - let doc_content = input_fn - .attrs - .iter() - .filter_map(extract_doc_line) - .collect::>() - .join("\n"); + let mut fn_item = syn::parse2::(input.clone())?; + let fn_ident = &fn_item.sig.ident; - parse_quote! { - #doc_content.trim().to_string() - } - }; - let schema = match &tool_macro_attrs.params { - ToolParams::Aggregated { rust_type } => { - let ty = &rust_type.ty; - let schema = quote! { - rmcp::handler::server::tool::cached_schema_for_type::<#ty>() - }; - schema - } - ToolParams::Params { attrs, .. } => { - let (param_type, temp_param_type_name) = - create_request_type(attrs, input_fn.sig.ident.to_string()); - let schema = quote! { + let tool_attr_fn_ident = format_ident!("{}_tool_attr", fn_ident); + let input_schema_expr = if let Some(input_schema) = attribute.input_schema { + input_schema + } else { + // try to find some parameters wrapper in the function + let params_ty = fn_item.sig.inputs.iter().find_map(|input| { + if let syn::FnArg::Typed(pat_type) = input { + if let syn::Type::Path(type_path) = &*pat_type.ty { + if type_path + .path + .segments + .last() + .is_some_and(|type_name| type_name.ident == "Parameters") { - #param_type - rmcp::handler::server::tool::cached_schema_for_type::<#temp_param_type_name>() + return Some(pat_type.ty.clone()); } - }; - schema - } - ToolParams::NoParam => { - quote! { - rmcp::handler::server::tool::cached_schema_for_type::() } } - }; - let input_fn_attrs = &input_fn.attrs; - let input_fn_vis = &input_fn.vis; - - let annotations_code = if let Some(annotations) = &tool_macro_attrs.fn_item.annotations { - let annotations = - serde_json::to_string(&annotations.0).expect("failed to serialize annotations"); - quote! { - Some(serde_json::from_str::(&#annotations).expect("Could not parse tool annotations")) - } + None + }); + if let Some(params_ty) = params_ty { + // if found, use the Parameters schema + syn::parse2::(quote! { + rmcp::handler::server::tool::cached_schema_for_type::<#params_ty>() + })? } else { - quote! { None } - }; - - quote! { - #(#input_fn_attrs)* - #input_fn_vis fn #tool_attr_fn_ident() -> rmcp::model::Tool { - rmcp::model::Tool { - name: #name.into(), - description: Some(#description.into()), - input_schema: #schema.into(), - annotations: #annotations_code, - } - } + // if not found, use the default EmptyObject schema + syn::parse2::(quote! { + rmcp::handler::server::tool::cached_schema_for_type::() + })? } }; - - // generate wrapped tool function - let tool_call_fn = { - // wrapper function have the same sig: - // async fn #tool_tool_call(context: rmcp::handler::server::tool::ToolCallContext<'_, Self>) - // -> std::result::Result - // - // and the block part should be like: - // { - // use rmcp::handler::server::tool::*; - // let (t0, context) = ::from_tool_call_context_part(context)?; - // let (t1, context) = ::from_tool_call_context_part(context)?; - // ... - // let (tn, context) = ::from_tool_call_context_part(context)?; - // // for params - // ... expand helper types here - // let (__rmcp_tool_req, context) = rmcp::model::JsonObject::from_tool_call_context_part(context)?; - // let __#TOOL_ToolCallParam { param_0, param_1, param_2, .. } = parse_json_object(__rmcp_tool_req)?; - // // for aggr - // let (Parameters(aggr), context) = >::from_tool_call_context_part(context)?; - // Self::#tool_ident(to, param_0, t1, param_1, ..., param_2, tn, aggr).await.into_call_tool_result() - // - // } - // - // - // - - // for receiver type, name it as __rmcp_tool_receiver - let is_async = input_fn.sig.asyncness.is_some(); - let receiver_ident = || Ident::new("__rmcp_tool_receiver", proc_macro2::Span::call_site()); - // generate the extraction part for trivial args - let trivial_args = input_fn - .sig - .inputs - .iter() - .enumerate() - .filter_map(|(index, arg)| { - if unextractable_args_indexes.contains(&index) { - None - } else { - // get ident/type pair - let line = match arg { - FnArg::Typed(pat_type) => { - let pat = &pat_type.pat; - let ty = &pat_type.ty; - quote! { - let (#pat, context) = <#ty>::from_tool_call_context_part(context)?; - } - } - FnArg::Receiver(r) => { - let ty = r.ty.clone(); - let pat = receiver_ident(); - quote! { - let (#pat, context) = <#ty>::from_tool_call_context_part(context)?; - } - } - }; - Some(line) - } - }); - let trivial_arg_extraction_part = quote! { - #(#trivial_args)* - }; - let processed_arg_extraction_part = match &mut tool_macro_attrs.params { - ToolParams::Aggregated { rust_type } => { - let PatType { pat, ty, .. } = rust_type; - quote! { - let (Parameters(#pat), context) = >::from_tool_call_context_part(context)?; - } - } - ToolParams::Params { attrs } => { - let (param_type, temp_param_type_name) = - create_request_type(attrs, input_fn.sig.ident.to_string()); - - let params_ident = attrs.iter().map(|attr| &attr.ident).collect::>(); - quote! { - #param_type - let (__rmcp_tool_req, context) = rmcp::model::JsonObject::from_tool_call_context_part(context)?; - let #temp_param_type_name { - #(#params_ident,)* - } = parse_json_object(__rmcp_tool_req)?; - } - } - ToolParams::NoParam => { - quote! {} - } - }; - // generate the execution part - // has receiver? - let params = &input_fn - .sig - .inputs - .iter() - .map(|fn_arg| match fn_arg { - FnArg::Receiver(_) => { - let pat = receiver_ident(); - quote! { #pat } - } - FnArg::Typed(pat_type) => { - let pat = &pat_type.pat.clone(); - quote! { #pat } - } + let annotations_expr = if let Some(annotations) = attribute.annotations { + let ToolAnnotationsAttribute { + title, + read_only_hint, + destructive_hint, + idempotent_hint, + open_world_hint, + } = annotations; + fn wrap_option(x: Option) -> TokenStream { + x.map(|x| quote! {Some(#x.into())}) + .unwrap_or(quote! { None }) + } + let title = wrap_option(title); + let read_only_hint = wrap_option(read_only_hint); + let destructive_hint = wrap_option(destructive_hint); + let idempotent_hint = wrap_option(idempotent_hint); + let open_world_hint = wrap_option(open_world_hint); + let token_stream = quote! { + Some(rmcp::model::ToolAnnotations { + title: #title, + read_only_hint: #read_only_hint, + destructive_hint: #destructive_hint, + idempotent_hint: #idempotent_hint, + open_world_hint: #open_world_hint, }) - .collect::>(); - let raw_fn_ident = &input_fn.sig.ident; - let call = if is_async { - quote! { - Self::#raw_fn_ident(#(#params),*).await.into_call_tool_result() - } - } else { - quote! { - Self::#raw_fn_ident(#(#params),*).into_call_tool_result() - } }; - // assemble the whole function - let tool_call_fn_ident = Ident::new( - &format!("{}_tool_call", input_fn.sig.ident), - proc_macro2::Span::call_site(), - ); - let raw_fn_vis = tool_macro_attrs - .fn_item - .vis - .as_ref() - .unwrap_or(&input_fn.vis); - let raw_fn_attr = &input_fn - .attrs - .iter() - .filter(|attr| !attr.path().is_ident(TOOL_IDENT)) - .collect::>(); - quote! { - #(#raw_fn_attr)* - #raw_fn_vis async fn #tool_call_fn_ident(context: rmcp::handler::server::tool::ToolCallContext<'_, Self>) - -> std::result::Result { - use rmcp::handler::server::tool::*; - #trivial_arg_extraction_part - #processed_arg_extraction_part - #call + syn::parse2::(token_stream)? + } else { + none_expr() + }; + let resolved_tool_attr = ResolvedToolAttribute { + name: attribute.name.unwrap_or_else(|| fn_ident.to_string()), + description: attribute.description.or_else(|| { + let doc_content = fn_item + .attrs + .iter() + .filter_map(extract_doc_line) + .collect::>() + .join("\n"); + if doc_content.is_empty() { + None + } else { + Some(doc_content) } - } + }), + input_schema: input_schema_expr, + annotations: annotations_expr, }; + let tool_attr_fn = resolved_tool_attr.into_fn(tool_attr_fn_ident)?; + // modify the the input function + if fn_item.sig.asyncness.is_some() { + // 1. remove asyncness from sig + // 2. make return type: `std::pin::Pin + Send + '_>>` + // 3. make body: { Box::pin(async move { #body }) } + let new_output = syn::parse2::({ + match &fn_item.sig.output { + syn::ReturnType::Default => { + quote! { -> std::pin::Pin + Send + '_>> } + } + syn::ReturnType::Type(_, ty) => { + quote! { -> std::pin::Pin + Send + '_>> } + } + } + })?; + let prev_block = &fn_item.block; + let new_block = syn::parse2::(quote! { + { Box::pin(async move #prev_block ) } + })?; + fn_item.sig.asyncness = None; + fn_item.sig.output = new_output; + fn_item.block = new_block; + } Ok(quote! { #tool_attr_fn - #tool_call_fn - #input_fn + #fn_item }) } -fn create_request_type(attrs: &[ToolFnParamAttrs], tool_name: String) -> (TokenStream, Ident) { - let pascal_case_tool_name = tool_name.to_ascii_uppercase(); - let temp_param_type_name = Ident::new( - &format!("__{pascal_case_tool_name}ToolCallParam",), - proc_macro2::Span::call_site(), - ); - ( - quote! { - use rmcp::{serde, schemars}; - #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] - pub struct #temp_param_type_name { - #(#attrs)* - } - }, - temp_param_type_name, - ) -} - #[cfg(test)] mod test { use super::*; - #[test] - fn test_tool_sync_macro() -> syn::Result<()> { - let attr = quote! { - name = "test_tool", - description = "test tool", - vis = - }; - let input = quote! { - fn sum(&self, #[tool(aggr)] req: StructRequest) -> Result { - Ok(CallToolResult::success(vec![Content::text((req.a + req.b).to_string())])) - } - }; - let input = tool(attr, input)?; - - println!("input: {:#}", input); - Ok(()) - } - #[test] fn test_trait_tool_macro() -> syn::Result<()> { let attr = quote! { - tool_box = Calculator + name = "direct-annotated-tool", + annotations(title = "Annotated Tool", read_only_hint = true) }; let input = quote! { - impl ServerHandler for Calculator { - #[tool] - fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - ..Default::default() - } - } + async fn async_method(&self, Parameters(Request { fields }): Parameters) { + drop(fields) } }; let input = tool(attr, input)?; - println!("input: {:#}", input); - Ok(()) - } - #[test] - fn test_doc_comment_description() -> syn::Result<()> { - let attr = quote! {}; // No explicit description - let input = quote! { - /// This is a test description from doc comments - /// with multiple lines - fn test_function(&self) -> Result<(), Error> { - Ok(()) - } - }; - let result = tool(attr, input)?; - - // The output should contain the description from doc comments - let result_str = result.to_string(); - assert!(result_str.contains("This is a test description from doc comments")); - assert!(result_str.contains("with multiple lines")); - - Ok(()) - } - #[test] - fn test_explicit_description_priority() -> syn::Result<()> { - let attr = quote! { - description = "Explicit description has priority" - }; - let input = quote! { - /// Doc comment description that should be ignored - fn test_function(&self) -> Result<(), Error> { - Ok(()) - } - }; - let result = tool(attr, input)?; - - // The output should contain the explicit description - let result_str = result.to_string(); - assert!(result_str.contains("Explicit description has priority")); Ok(()) } } diff --git a/crates/rmcp-macros/src/tool_inherite.rs b/crates/rmcp-macros/src/tool_inherite.rs new file mode 100644 index 00000000..4c292ed0 --- /dev/null +++ b/crates/rmcp-macros/src/tool_inherite.rs @@ -0,0 +1,785 @@ +use std::collections::HashSet; + +use proc_macro2::TokenStream; +use quote::{ToTokens, quote}; +use serde_json::json; +use syn::{ + Expr, FnArg, Ident, ItemFn, ItemImpl, Lit, MetaList, PatType, Token, Type, Visibility, + parse::Parse, parse_quote, spanned::Spanned, +}; + +/// Stores tool annotation attributes +#[derive(Default, Clone)] +struct ToolAnnotationAttrs(pub serde_json::Map); + +impl Parse for ToolAnnotationAttrs { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut attrs = serde_json::Map::new(); + + while !input.is_empty() { + let key: Ident = input.parse()?; + input.parse::()?; + let value: Lit = input.parse()?; + let value = match value { + Lit::Str(s) => json!(s.value()), + Lit::Bool(b) => json!(b.value), + _ => { + return Err(syn::Error::new( + key.span(), + "annotations must be string or boolean literals", + )); + } + }; + attrs.insert(key.to_string(), value); + if input.is_empty() { + break; + } + input.parse::()?; + } + + Ok(ToolAnnotationAttrs(attrs)) + } +} + +#[derive(Default)] +struct ToolImplItemAttrs { + tool_box: Option>, +} + +impl Parse for ToolImplItemAttrs { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut tool_box = None; + while !input.is_empty() { + let key: Ident = input.parse()?; + match key.to_string().as_str() { + "tool_box" => { + tool_box = Some(None); + if input.lookahead1().peek(Token![=]) { + input.parse::()?; + let value: Ident = input.parse()?; + tool_box = Some(Some(value)); + } + } + _ => { + return Err(syn::Error::new(key.span(), "unknown attribute")); + } + } + if input.is_empty() { + break; + } + input.parse::()?; + } + + Ok(ToolImplItemAttrs { tool_box }) + } +} + +#[derive(Default)] +struct ToolFnItemAttrs { + name: Option, + description: Option, + vis: Option, + annotations: Option, +} + +impl Parse for ToolFnItemAttrs { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut name = None; + let mut description = None; + let mut vis = None; + let mut annotations = None; + + while !input.is_empty() { + let key: Ident = input.parse()?; + input.parse::()?; + match key.to_string().as_str() { + "name" => { + let value: Expr = input.parse()?; + name = Some(value); + } + "description" => { + let value: Expr = input.parse()?; + description = Some(value); + } + "vis" => { + let value: Visibility = input.parse()?; + vis = Some(value); + } + "annotations" => { + // Parse the annotations as a nested structure + let content; + syn::braced!(content in input); + let value = content.parse()?; + annotations = Some(value); + } + _ => { + return Err(syn::Error::new(key.span(), "unknown attribute")); + } + } + if input.is_empty() { + break; + } + input.parse::()?; + } + + Ok(ToolFnItemAttrs { + name, + description, + vis, + annotations, + }) + } +} + +struct ToolFnParamAttrs { + serde_meta: Vec, + schemars_meta: Vec, + ident: Ident, + rust_type: Box, +} + +impl ToTokens for ToolFnParamAttrs { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = &self.ident; + let rust_type = &self.rust_type; + let serde_meta = &self.serde_meta; + let schemars_meta = &self.schemars_meta; + tokens.extend(quote! { + #(#[#serde_meta])* + #(#[#schemars_meta])* + pub #ident: #rust_type, + }); + } +} + +#[derive(Default)] + +enum ToolParams { + Aggregated { + rust_type: PatType, + }, + Params { + attrs: Vec, + }, + #[default] + NoParam, +} + +#[derive(Default)] +struct ToolAttrs { + fn_item: ToolFnItemAttrs, + params: ToolParams, +} +const TOOL_IDENT: &str = "tool"; +const SERDE_IDENT: &str = "serde"; +const SCHEMARS_IDENT: &str = "schemars"; +const PARAM_IDENT: &str = "param"; +const AGGREGATED_IDENT: &str = "aggr"; +const REQ_IDENT: &str = "req"; + +pub enum ParamMarker { + Param, + Aggregated, +} + +impl Parse for ParamMarker { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let ident: Ident = input.parse()?; + match ident.to_string().as_str() { + PARAM_IDENT => Ok(ParamMarker::Param), + AGGREGATED_IDENT | REQ_IDENT => Ok(ParamMarker::Aggregated), + _ => Err(syn::Error::new(ident.span(), "unknown attribute")), + } + } +} + +pub enum ToolItem { + Fn(ItemFn), + Impl(ItemImpl), +} + +impl Parse for ToolItem { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + if lookahead.peek(Token![impl]) { + let item = input.parse::()?; + Ok(ToolItem::Impl(item)) + } else { + let item = input.parse::()?; + Ok(ToolItem::Fn(item)) + } + } +} + +// dispatch impl function item and impl block item +pub(crate) fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { + let tool_item = syn::parse2::(input)?; + match tool_item { + ToolItem::Fn(item) => tool_fn_item(attr, item), + ToolItem::Impl(item) => tool_impl_item(attr, item), + } +} + +pub(crate) fn tool_impl_item(attr: TokenStream, mut input: ItemImpl) -> syn::Result { + let tool_impl_attr: ToolImplItemAttrs = syn::parse2(attr)?; + let tool_box_ident = tool_impl_attr.tool_box; + + // get all tool function ident + let mut tool_fn_idents = Vec::new(); + for item in &input.items { + if let syn::ImplItem::Fn(method) = item { + for attr in &method.attrs { + if attr.path().is_ident(TOOL_IDENT) { + tool_fn_idents.push(method.sig.ident.clone()); + } + } + } + } + + // handle different cases + if input.trait_.is_some() { + if let Some(ident) = tool_box_ident { + // check if there are generic parameters + if !input.generics.params.is_empty() { + // for trait implementation with generic parameters, directly use the already generated *_inner method + + // generate call_tool method + input.items.push(parse_quote! { + async fn call_tool( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + self.call_tool_inner(request, context).await + } + }); + + // generate list_tools method + input.items.push(parse_quote! { + async fn list_tools( + &self, + request: Option, + context: rmcp::service::RequestContext, + ) -> Result { + self.list_tools_inner(request, context).await + } + }); + } else { + // if there are no generic parameters, add tool box derive + input.items.push(parse_quote!( + rmcp::tool_box!(@derive #ident); + )); + } + } else { + return Err(syn::Error::new( + proc_macro2::Span::call_site(), + "tool_box attribute is required for trait implementation", + )); + } + } else if let Some(ident) = tool_box_ident { + // if it is a normal impl block + if !input.generics.params.is_empty() { + // if there are generic parameters, not use tool_box! macro, but generate code directly + + // create call code for each tool function + let match_arms = tool_fn_idents.iter().map(|ident| { + let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span()); + let call_fn = Ident::new(&format!("{}_tool_call", ident), ident.span()); + quote! { + name if name == Self::#attr_fn().name => { + Self::#call_fn(tcc).await + } + } + }); + + let tool_attrs = tool_fn_idents.iter().map(|ident| { + let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span()); + quote! { Self::#attr_fn() } + }); + + // implement call_tool method + input.items.push(parse_quote! { + async fn call_tool_inner( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); + match tcc.name() { + #(#match_arms,)* + _ => Err(rmcp::Error::invalid_params("tool not found", None)), + } + } + }); + + // implement list_tools method + input.items.push(parse_quote! { + async fn list_tools_inner( + &self, + _: Option, + _: rmcp::service::RequestContext, + ) -> Result { + Ok(rmcp::model::ListToolsResult { + next_cursor: None, + tools: vec![#(#tool_attrs),*], + }) + } + }); + } else { + // if there are no generic parameters, use the original tool_box! macro + let this_type_ident = &input.self_ty; + input.items.push(parse_quote!( + rmcp::tool_box!(#this_type_ident { + #(#tool_fn_idents),* + } #ident); + )); + } + } + + Ok(quote! { + #input + }) +} + +// extract doc line from attribute +fn extract_doc_line(attr: &syn::Attribute) -> Option { + if !attr.path().is_ident("doc") { + return None; + } + + let syn::Meta::NameValue(name_value) = &attr.meta else { + return None; + }; + + let syn::Expr::Lit(expr_lit) = &name_value.value else { + return None; + }; + + let syn::Lit::Str(lit_str) = &expr_lit.lit else { + return None; + }; + + let content = lit_str.value().trim().to_string(); + + (!content.is_empty()).then_some(content) +} + +pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Result { + let mut tool_macro_attrs = ToolAttrs::default(); + let args: ToolFnItemAttrs = syn::parse2(attr)?; + tool_macro_attrs.fn_item = args; + // let mut fommated_fn_args: Punctuated = Punctuated::new(); + let mut unextractable_args_indexes = HashSet::new(); + for (index, mut fn_arg) in input_fn.sig.inputs.iter_mut().enumerate() { + enum Caught { + Param(ToolFnParamAttrs), + Aggregated(PatType), + } + let mut caught = None; + match &mut fn_arg { + FnArg::Receiver(_) => { + continue; + } + FnArg::Typed(pat_type) => { + let mut serde_metas = Vec::new(); + let mut schemars_metas = Vec::new(); + let mut arg_ident = match pat_type.pat.as_ref() { + syn::Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()), + _ => None, + }; + let raw_attrs: Vec<_> = pat_type.attrs.drain(..).collect(); + for attr in raw_attrs { + match &attr.meta { + syn::Meta::List(meta_list) => { + if meta_list.path.is_ident(TOOL_IDENT) { + let pat_type = pat_type.clone(); + let marker = meta_list.parse_args::()?; + match marker { + ParamMarker::Param => { + let Some(arg_ident) = arg_ident.take() else { + return Err(syn::Error::new( + proc_macro2::Span::call_site(), + "input param must have an ident as name", + )); + }; + caught.replace(Caught::Param(ToolFnParamAttrs { + serde_meta: Vec::new(), + schemars_meta: Vec::new(), + ident: arg_ident, + rust_type: pat_type.ty.clone(), + })); + } + ParamMarker::Aggregated => { + caught.replace(Caught::Aggregated(pat_type.clone())); + } + } + } else if meta_list.path.is_ident(SERDE_IDENT) { + serde_metas.push(meta_list.clone()); + } else if meta_list.path.is_ident(SCHEMARS_IDENT) { + schemars_metas.push(meta_list.clone()); + } else { + pat_type.attrs.push(attr); + } + } + _ => { + pat_type.attrs.push(attr); + } + } + } + match caught { + Some(Caught::Param(mut param)) => { + param.serde_meta = serde_metas; + param.schemars_meta = schemars_metas; + match &mut tool_macro_attrs.params { + ToolParams::Params { attrs } => { + attrs.push(param); + } + _ => { + tool_macro_attrs.params = ToolParams::Params { attrs: vec![param] }; + } + } + unextractable_args_indexes.insert(index); + } + Some(Caught::Aggregated(rust_type)) => { + if let ToolParams::Params { .. } = tool_macro_attrs.params { + return Err(syn::Error::new( + rust_type.span(), + "cannot mix aggregated and individual parameters", + )); + } + tool_macro_attrs.params = ToolParams::Aggregated { rust_type }; + unextractable_args_indexes.insert(index); + } + None => {} + } + } + } + } + + // input_fn.sig.inputs = fommated_fn_args; + let name = if let Some(expr) = tool_macro_attrs.fn_item.name { + expr + } else { + let fn_name = &input_fn.sig.ident; + parse_quote! { + stringify!(#fn_name) + } + }; + let tool_attr_fn_ident = Ident::new( + &format!("{}_tool_attr", input_fn.sig.ident), + proc_macro2::Span::call_site(), + ); + + // generate get tool attr function + let tool_attr_fn = { + let description = if let Some(expr) = tool_macro_attrs.fn_item.description { + // Use explicitly provided description if available + expr + } else { + // Try to extract documentation comments + let doc_content = input_fn + .attrs + .iter() + .filter_map(extract_doc_line) + .collect::>() + .join("\n"); + + parse_quote! { + #doc_content.trim().to_string() + } + }; + let schema = match &tool_macro_attrs.params { + ToolParams::Aggregated { rust_type } => { + let ty = &rust_type.ty; + let schema = quote! { + rmcp::handler::server::tool::cached_schema_for_type::<#ty>() + }; + schema + } + ToolParams::Params { attrs, .. } => { + let (param_type, temp_param_type_name) = + create_request_type(attrs, input_fn.sig.ident.to_string()); + let schema = quote! { + { + #param_type + rmcp::handler::server::tool::cached_schema_for_type::<#temp_param_type_name>() + } + }; + schema + } + ToolParams::NoParam => { + quote! { + rmcp::handler::server::tool::cached_schema_for_type::() + } + } + }; + let input_fn_attrs = &input_fn.attrs; + let input_fn_vis = &input_fn.vis; + + let annotations_code = if let Some(annotations) = &tool_macro_attrs.fn_item.annotations { + let annotations = + serde_json::to_string(&annotations.0).expect("failed to serialize annotations"); + quote! { + Some(serde_json::from_str::(&#annotations).expect("Could not parse tool annotations")) + } + } else { + quote! { None } + }; + + quote! { + #(#input_fn_attrs)* + #input_fn_vis fn #tool_attr_fn_ident() -> rmcp::model::Tool { + rmcp::model::Tool { + name: #name.into(), + description: Some(#description.into()), + input_schema: #schema.into(), + annotations: #annotations_code, + } + } + } + }; + + // generate wrapped tool function + let tool_call_fn = { + // wrapper function have the same sig: + // async fn #tool_tool_call(context: rmcp::handler::server::tool::ToolCallContext) + // -> std::result::Result + // + // and the block part should be like: + // { + // use rmcp::handler::server::tool::*; + // let t0 = ::from_tool_call_context_part(&mut context)?; + // let t1 = ::from_tool_call_context_part(&mut context)?; + // ... + // let tn = ::from_tool_call_context_part(&mut context)?; + // // for params + // ... expand helper types here + // let __rmcp_tool_req = rmcp::model::JsonObject::from_tool_call_context_part(&mut context)?; + // let __#TOOL_ToolCallParam { param_0, param_1, param_2, .. } = parse_json_object(__rmcp_tool_req)?; + // // for aggr + // let Parameters(aggr) = >::from_tool_call_context_part(&mut context)?; + // Self::#tool_ident(to, param_0, t1, param_1, ..., param_2, tn, aggr).await.into_call_tool_result() + // + // } + // + // + // + + // for receiver type, name it as __rmcp_tool_receiver + let is_async = input_fn.sig.asyncness.is_some(); + let receiver_ident = || Ident::new("__rmcp_tool_receiver", proc_macro2::Span::call_site()); + // generate the extraction part for trivial args + let trivial_args = input_fn + .sig + .inputs + .iter() + .enumerate() + .filter_map(|(index, arg)| { + if unextractable_args_indexes.contains(&index) { + None + } else { + // get ident/type pair + let line = match arg { + FnArg::Typed(pat_type) => { + let pat = &pat_type.pat; + let ty = &pat_type.ty; + quote! { + let #pat = <#ty>::from_tool_call_context_part(&mut context)?; + } + } + FnArg::Receiver(r) => { + let ty = r.ty.clone(); + let pat = receiver_ident(); + quote! { + let #pat = <#ty>::from_tool_call_context_part(&mut context)?; + } + } + }; + Some(line) + } + }); + let trivial_arg_extraction_part = quote! { + #(#trivial_args)* + }; + let processed_arg_extraction_part = match &mut tool_macro_attrs.params { + ToolParams::Aggregated { rust_type } => { + let PatType { pat, ty, .. } = rust_type; + quote! { + let Parameters(#pat) = >::from_tool_call_context_part(&mut context)?; + } + } + ToolParams::Params { attrs } => { + let (param_type, temp_param_type_name) = + create_request_type(attrs, input_fn.sig.ident.to_string()); + + let params_ident = attrs.iter().map(|attr| &attr.ident).collect::>(); + quote! { + #param_type + let __rmcp_tool_req = rmcp::model::JsonObject::from_tool_call_context_part(&mut context)?; + let #temp_param_type_name { + #(#params_ident,)* + } = parse_json_object(__rmcp_tool_req)?; + } + } + ToolParams::NoParam => { + quote! {} + } + }; + // generate the execution part + // has receiver? + let params = &input_fn + .sig + .inputs + .iter() + .map(|fn_arg| match fn_arg { + FnArg::Receiver(_) => { + let pat = receiver_ident(); + quote! { #pat } + } + FnArg::Typed(pat_type) => { + let pat = &pat_type.pat.clone(); + quote! { #pat } + } + }) + .collect::>(); + let raw_fn_ident = &input_fn.sig.ident; + let call = if is_async { + quote! { + Self::#raw_fn_ident(#(#params),*).await.into_call_tool_result() + } + } else { + quote! { + Self::#raw_fn_ident(#(#params),*).into_call_tool_result() + } + }; + // assemble the whole function + let tool_call_fn_ident = Ident::new( + &format!("{}_tool_call", input_fn.sig.ident), + proc_macro2::Span::call_site(), + ); + let raw_fn_vis = tool_macro_attrs + .fn_item + .vis + .as_ref() + .unwrap_or(&input_fn.vis); + let raw_fn_attr = &input_fn + .attrs + .iter() + .filter(|attr| !attr.path().is_ident(TOOL_IDENT)) + .collect::>(); + quote! { + #(#raw_fn_attr)* + #raw_fn_vis async fn #tool_call_fn_ident(context: rmcp::handler::server::tool::ToolCallContext) + -> std::result::Result { + use rmcp::handler::server::tool::*; + #trivial_arg_extraction_part + #processed_arg_extraction_part + #call + } + } + }; + Ok(quote! { + #tool_attr_fn + #tool_call_fn + #input_fn + }) +} + +fn create_request_type(attrs: &[ToolFnParamAttrs], tool_name: String) -> (TokenStream, Ident) { + let pascal_case_tool_name = tool_name.to_ascii_uppercase(); + let temp_param_type_name = Ident::new( + &format!("__{pascal_case_tool_name}ToolCallParam",), + proc_macro2::Span::call_site(), + ); + ( + quote! { + use rmcp::{serde, schemars}; + #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] + pub struct #temp_param_type_name { + #(#attrs)* + } + }, + temp_param_type_name, + ) +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_tool_sync_macro() -> syn::Result<()> { + let attr = quote! { + name = "test_tool", + description = "test tool", + vis = + }; + let input = quote! { + fn sum(&self, #[tool(aggr)] req: StructRequest) -> Result { + Ok(CallToolResult::success(vec![Content::text((req.a + req.b).to_string())])) + } + }; + let input = tool(attr, input)?; + + println!("input: {:#}", input); + Ok(()) + } + + #[test] + fn test_trait_tool_macro() -> syn::Result<()> { + let attr = quote! { + tool_box = Calculator + }; + let input = quote! { + impl ServerHandler for Calculator { + #[tool] + fn get_info(&self) -> ServerInfo { + ServerInfo { + instructions: Some("A simple calculator".into()), + ..Default::default() + } + } + } + }; + let input = tool(attr, input)?; + + println!("input: {:#}", input); + Ok(()) + } + #[test] + fn test_doc_comment_description() -> syn::Result<()> { + let attr = quote! {}; // No explicit description + let input = quote! { + /// This is a test description from doc comments + /// with multiple lines + fn test_function(&self) -> Result<(), Error> { + Ok(()) + } + }; + let result = tool(attr, input)?; + + // The output should contain the description from doc comments + let result_str = result.to_string(); + assert!(result_str.contains("This is a test description from doc comments")); + assert!(result_str.contains("with multiple lines")); + + Ok(()) + } + #[test] + fn test_explicit_description_priority() -> syn::Result<()> { + let attr = quote! { + description = "Explicit description has priority" + }; + let input = quote! { + /// Doc comment description that should be ignored + fn test_function(&self) -> Result<(), Error> { + Ok(()) + assert!(result_str.contains("Explicit description has priority")); + } + }; + let result = tool(attr, input)?; + + // The output should contain the explicit description + let result_str = result.to_string(); + Ok(()) + } +} diff --git a/crates/rmcp-macros/src/tool_router.rs b/crates/rmcp-macros/src/tool_router.rs new file mode 100644 index 00000000..337cf00c --- /dev/null +++ b/crates/rmcp-macros/src/tool_router.rs @@ -0,0 +1,90 @@ +//! ```ignore +//! #[rmcp::tool_router(router)] +//! impl Handler { +//! +//! } +//! ``` +//! + +use darling::{FromMeta, ast::NestedMeta}; +use proc_macro2::TokenStream; +use quote::{ToTokens, format_ident, quote}; +use syn::{Ident, ImplItem, ItemImpl, Visibility}; + +#[derive(FromMeta)] +#[darling(default)] +pub struct ToolRouterAttribute { + pub router: Ident, + pub vis: Option, +} + +impl Default for ToolRouterAttribute { + fn default() -> Self { + Self { + router: format_ident!("tool_router"), + vis: None, + } + } +} + +pub fn tool_router(attr: TokenStream, input: TokenStream) -> syn::Result { + let attr_args = NestedMeta::parse_meta_list(attr)?; + let ToolRouterAttribute { router, vis } = ToolRouterAttribute::from_list(&attr_args)?; + let mut item_impl = syn::parse2::(input.clone())?; + // find all function marked with `#[rmcp::tool]` + let tool_attr_fns: Vec<_> = item_impl + .items + .iter() + .filter_map(|item| { + if let syn::ImplItem::Fn(fn_item) = item { + fn_item + .attrs + .iter() + .any(|attr| { + attr.path() + .segments + .last() + .is_some_and(|seg| seg.ident == "tool") + }) + .then_some(&fn_item.sig.ident) + } else { + None + } + }) + .collect(); + let mut routers = vec![]; + for handler in tool_attr_fns { + let tool_attr_fn_ident = format_ident!("{handler}_tool_attr"); + routers.push(quote! { + .with_route(Self::#tool_attr_fn_ident(), Self::#handler) + }) + } + let router_fn = syn::parse2::(quote! { + #vis fn #router() -> rmcp::handler::server::router::tool::ToolRouter { + rmcp::handler::server::router::tool::ToolRouter::::new() + #(#routers)* + } + })?; + item_impl.items.push(router_fn); + Ok(item_impl.into_token_stream()) +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_router_attr() -> Result<(), Box> { + let attr = quote! { + router = test_router, + }; + let attr_args = NestedMeta::parse_meta_list(attr)?; + let ToolRouterAttribute { router, vis } = ToolRouterAttribute::from_list(&attr_args)?; + println!("router: {}", router); + if let Some(vis) = vis { + println!("visibility: {}", vis.to_token_stream()); + } else { + println!("visibility: None"); + } + Ok(()) + } +} diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 9754f12f..cb144376 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -63,7 +63,7 @@ http-body-util = { version = "0.1", optional = true } bytes = { version = "1", optional = true } # macro rmcp-macros = { version = "0.1", workspace = true, optional = true } - +inventory = "0.3" [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] chrono = { version = "0.4.38", features = ["serde"] } diff --git a/crates/rmcp/README.md b/crates/rmcp/README.md index b57c9814..3130f0da 100644 --- a/crates/rmcp/README.md +++ b/crates/rmcp/README.md @@ -15,20 +15,22 @@ wait for the first release. Creating a server with tools is simple using the `#[tool]` macro: ```rust, ignore -use rmcp::{Error as McpError, ServiceExt, model::*, tool, transport::stdio}; +use rmcp::{Error as McpError, ServiceExt, model::*, tool, tool_router, transport::stdio, handler::server::tool::ToolCallContext, handler::server::router::tool::ToolRouter}; use std::sync::Arc; use tokio::sync::Mutex; #[derive(Clone)] pub struct Counter { counter: Arc>, + tool_router: ToolRouter, } -#[tool(tool_box)] +#[tool_router] impl Counter { fn new() -> Self { Self { counter: Arc::new(Mutex::new(0)), + tool_router: Self::tool_router(), } } @@ -51,7 +53,7 @@ impl Counter { } // Implement the server handler -#[tool(tool_box)] +#[tool_router] impl rmcp::ServerHandler for Counter { fn get_info(&self) -> ServerInfo { ServerInfo { @@ -60,6 +62,22 @@ impl rmcp::ServerHandler for Counter { ..Default::default() } } + async fn call_tool( + &self, + request: CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + let tcc = ToolCallContext::new(self, request, context); + self.tool_router.call(tcc).await + } + + async fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> Result { + Ok(ListToolsResult::with_all_items(self.tool_router.list_all())) + } } // Run the server diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 52b63832..e532927e 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -7,6 +7,7 @@ use crate::{ mod resource; pub mod tool; pub mod wrapper; +pub mod router; impl Service for H { async fn handle_request( &self, @@ -184,6 +185,19 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { request: CallToolRequestParam, context: RequestContext, ) -> impl Future> + Send + '_ { + // async move { + // let router = router::tool::GlobalStaticRouters::get::().await; + // router.call(tool::ToolCallContext { + // request_context: context, + // service: todo!(), + // name: todo!(), + // arguments: todo!(), + // }).await.map_err(|e| { + // tracing::error!("call tool error: {}", e); + // e + // }) + + // }; std::future::ready(Err(McpError::method_not_found::())) } fn list_tools( diff --git a/crates/rmcp/src/handler/server/router.rs b/crates/rmcp/src/handler/server/router.rs new file mode 100644 index 00000000..926c7eb7 --- /dev/null +++ b/crates/rmcp/src/handler/server/router.rs @@ -0,0 +1,97 @@ +use std::sync::Arc; + +use tool::{IntoToolRoute, ToolRoute}; + +use crate::{ + RoleServer, Service, + model::{ClientRequest, ListToolsResult, ServerResult}, + service::NotificationContext, +}; + +use super::ServerHandler; + +pub mod tool; + +pub struct Router { + pub tool_router: tool::ToolRouter, + pub service: Arc, +} + +impl Router +where + S: ServerHandler, +{ + pub fn new(service: S) -> Self { + Self { + tool_router: tool::ToolRouter::new(), + service: Arc::new(service), + } + } + + pub fn with_tool(mut self, route: R) -> Self + where + R: IntoToolRoute, + { + self.tool_router.add_route(route.into_tool_route()); + self + } + + pub fn with_tools(mut self, routes: impl IntoIterator>) -> Self { + for route in routes { + self.tool_router.add_route(route); + } + self + } +} + +impl Service for Router +where + S: ServerHandler, +{ + async fn handle_notification( + &self, + notification: ::PeerNot, + context: NotificationContext, + ) -> Result<(), crate::Error> { + self.service + .handle_notification(notification, context) + .await + } + async fn handle_request( + &self, + request: ::PeerReq, + context: crate::service::RequestContext, + ) -> Result<::Resp, crate::Error> { + match request { + ClientRequest::CallToolRequest(request) => { + if self.tool_router.has_route(request.params.name.as_ref()) + || !self.tool_router.transparent_when_not_found + { + let tool_call_context = crate::handler::server::tool::ToolCallContext::new( + self.service.as_ref(), + request.params, + context, + ); + let result = self.tool_router.call(tool_call_context).await?; + Ok(ServerResult::CallToolResult(result)) + } else { + self.service + .handle_request(ClientRequest::CallToolRequest(request), context) + .await + } + } + ClientRequest::ListToolsRequest(_) => { + let tools = self.tool_router.list_all(); + Ok(ServerResult::ListToolsResult(ListToolsResult { + tools, + next_cursor: None, + })) + } + rest => self.service.handle_request(rest, context).await, + } + } + + fn get_info(&self) -> ::Info { + self.service.get_info() + } +} diff --git a/crates/rmcp/src/handler/server/router/promt.rs b/crates/rmcp/src/handler/server/router/promt.rs new file mode 100644 index 00000000..e69de29b diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs new file mode 100644 index 00000000..ed2602d7 --- /dev/null +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -0,0 +1,364 @@ +use std::any::{Any, TypeId}; +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; + +use futures::FutureExt; +use futures::future::BoxFuture; +use schemars::JsonSchema; + +use crate::model::{CallToolResult, Tool, ToolAnnotations}; + +use crate::handler::server::tool::{ + CallToolHandler, DynCallToolHandler, ToolCallContext, schema_for_type, +}; + +inventory::collect!(ToolRouteWithType); + +#[derive(Debug, Default)] +pub struct GlobalStaticRouters { + pub routers: + std::sync::OnceLock>>>, +} + +impl GlobalStaticRouters { + pub fn global() -> &'static Self { + static GLOBAL: GlobalStaticRouters = GlobalStaticRouters { + routers: std::sync::OnceLock::new(), + }; + &GLOBAL + } + pub async fn set(router: Arc>) -> Result<(), String> { + let routers = Self::global().routers.get_or_init(Default::default); + let mut routers_wg = routers.write().await; + if routers_wg.insert(TypeId::of::(), router).is_some() { + return Err("Router already exists".to_string()); + } + Ok(()) + } + pub async fn get() -> Arc> { + let routers = Self::global().routers.get_or_init(Default::default); + let routers_rg = routers.read().await; + if let Some(router) = routers_rg.get(&TypeId::of::()) { + return router + .clone() + .downcast::>() + .expect("Failed to downcast"); + } + { + drop(routers_rg); + } + let mut routers = routers.write().await; + match routers.entry(TypeId::of::()) { + std::collections::hash_map::Entry::Occupied(occupied) => occupied + .get() + .clone() + .downcast::>() + .expect("Failed to downcast"), + std::collections::hash_map::Entry::Vacant(vacant) => { + let mut router = ToolRouter::::default(); + for route in inventory::iter:: + .into_iter() + .filter(|r| r.type_id == TypeId::of::()) + { + if let Some(route) = route.downcast::() { + router.add_route(route.clone()); + } + } + let mut_ref = vacant.insert(Arc::new(router)); + mut_ref + .downcast_ref() + .cloned() + .expect("Failed to downcast after insert") + } + } + } +} + +pub struct ToolRouteWithType { + type_id: TypeId, + route: Box, +} + +impl ToolRouteWithType { + pub fn downcast(&self) -> Option<&ToolRoute> { + if self.type_id == TypeId::of::() { + self.route.downcast_ref::>() + } else { + None + } + } + pub fn from_tool_route(route: ToolRoute) -> Self { + Self { + type_id: TypeId::of::(), + route: Box::new(route), + } + } +} + +impl From> for ToolRouteWithType { + fn from(value: ToolRoute) -> Self { + Self::from_tool_route(value) + } +} + +pub struct ToolRoute { + #[allow(clippy::type_complexity)] + pub call: Arc>, + pub attr: crate::model::Tool, +} + +impl std::fmt::Debug for ToolRoute { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ToolRoute") + .field("name", &self.attr.name) + .field("description", &self.attr.description) + .field("input_schema", &self.attr.input_schema) + .finish() + } +} + +impl Clone for ToolRoute { + fn clone(&self) -> Self { + Self { + call: self.call.clone(), + attr: self.attr.clone(), + } + } +} + +impl ToolRoute { + pub fn new(attr: impl Into, call: C) -> Self + where + C: CallToolHandler + Send + Sync + Clone + 'static, + { + Self { + call: Arc::new(move |context: ToolCallContext| { + let call = call.clone(); + context.invoke(call).boxed() + }), + attr: attr.into(), + } + } + pub fn new_dyn(attr: impl Into, call: C) -> Self + where + C: for<'a> Fn( + ToolCallContext<'a, S>, + ) -> BoxFuture<'a, Result> + + Send + + Sync + + 'static, + { + Self { + call: Arc::new(call), + attr: attr.into(), + } + } + pub fn name(&self) -> &str { + &self.attr.name + } +} + +pub trait IntoToolRoute { + fn into_tool_route(self) -> ToolRoute; +} + +impl IntoToolRoute for (T, C) +where + S: Send + Sync + 'static, + C: CallToolHandler + Send + Sync + Clone + 'static, + T: Into, +{ + fn into_tool_route(self) -> ToolRoute { + ToolRoute::new(self.0.into(), self.1) + } +} + +impl IntoToolRoute for ToolRoute +where + S: Send + Sync + 'static, +{ + fn into_tool_route(self) -> ToolRoute { + self + } +} + +pub struct ToolAttrGenerateFunctionAdapter; +impl IntoToolRoute for F +where + S: Send + Sync + 'static, + F: Fn() -> ToolRoute, +{ + fn into_tool_route(self) -> ToolRoute { + (self)() + } +} + +pub trait CallToolHandlerExt: Sized +where + Self: CallToolHandler + Send + Sync + Clone + 'static, +{ + fn name(self, name: impl Into>) -> WithToolAttr; +} + +impl CallToolHandlerExt for C +where + C: CallToolHandler + Send + Sync + Clone + 'static, +{ + fn name(self, name: impl Into>) -> WithToolAttr { + WithToolAttr { + attr: Tool::new( + name.into(), + "", + schema_for_type::(), + ), + call: self, + _marker: std::marker::PhantomData, + } + } +} + +pub struct WithToolAttr +where + C: CallToolHandler + Send + Sync + Clone + 'static, +{ + pub attr: crate::model::Tool, + pub call: C, + pub _marker: std::marker::PhantomData, +} + +impl IntoToolRoute for WithToolAttr +where + C: CallToolHandler + Send + Sync + Clone + 'static, + S: Send + Sync + 'static, +{ + fn into_tool_route(self) -> ToolRoute { + ToolRoute::new(self.attr, self.call) + } +} + +impl WithToolAttr +where + C: CallToolHandler + Send + Sync + Clone + 'static, +{ + pub fn description(mut self, description: impl Into>) -> Self { + self.attr.description = Some(description.into()); + self + } + pub fn parameters(mut self) -> Self { + self.attr.input_schema = schema_for_type::().into(); + self + } + pub fn parameters_value(mut self, schema: serde_json::Value) -> Self { + self.attr.input_schema = crate::model::object(schema).into(); + self + } + pub fn annotation(mut self, annotation: impl Into) -> Self { + self.attr.annotations = Some(annotation.into()); + self + } +} +#[derive(Debug)] +pub struct ToolRouter { + #[allow(clippy::type_complexity)] + pub map: std::collections::HashMap, ToolRoute>, + + pub transparent_when_not_found: bool, +} + +impl Default for ToolRouter { + fn default() -> Self { + Self { + map: std::collections::HashMap::new(), + transparent_when_not_found: false, + } + } +} +impl Clone for ToolRouter { + fn clone(&self) -> Self { + Self { + map: self.map.clone(), + transparent_when_not_found: self.transparent_when_not_found, + } + } +} + +impl IntoIterator for ToolRouter { + type Item = ToolRoute; + type IntoIter = std::collections::hash_map::IntoValues, ToolRoute>; + + fn into_iter(self) -> Self::IntoIter { + self.map.into_values() + } +} + +impl ToolRouter +where + S: Send + Sync + 'static, +{ + pub fn new() -> Self { + Self { + map: std::collections::HashMap::new(), + transparent_when_not_found: false, + } + } + pub fn with_route(mut self, attr: crate::model::Tool, call: C) -> Self + where + C: CallToolHandler + Send + Sync + Clone + 'static, + { + self.add_route(ToolRoute::new(attr, call)); + self + } + + pub fn add_route(&mut self, item: ToolRoute) { + self.map.insert(item.attr.name.clone(), item); + } + + pub fn merge(&mut self, other: ToolRouter) { + for item in other.map.into_values() { + self.add_route(item); + } + } + + pub fn remove_route(&mut self, name: &str) { + self.map.remove(name); + } + pub fn has_route(&self, name: &str) -> bool { + self.map.contains_key(name) + } + pub async fn call( + &self, + context: ToolCallContext<'_, S>, + ) -> Result { + let item = self + .map + .get(context.name()) + .ok_or_else(|| crate::Error::invalid_params("tool not found", None))?; + (item.call)(context).await + } + + pub fn list_all(&self) -> Vec { + self.map.values().map(|item| item.attr.clone()).collect() + } +} + +impl std::ops::Add> for ToolRouter +where + S: Send + Sync + 'static, +{ + type Output = Self; + + fn add(mut self, other: ToolRouter) -> Self::Output { + self.merge(other); + self + } +} + +impl std::ops::AddAssign> for ToolRouter +where + S: Send + Sync + 'static, +{ + fn add_assign(&mut self, other: ToolRouter) { + self.merge(other); + } +} diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index acb44fd5..5dbf85d4 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -2,16 +2,17 @@ use std::{ any::TypeId, borrow::Cow, collections::HashMap, future::Ready, marker::PhantomData, sync::Arc, }; -use futures::future::BoxFuture; +use futures::future::{BoxFuture, FutureExt}; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use tokio_util::sync::CancellationToken; use crate::{ RoleServer, - model::{CallToolRequestParam, CallToolResult, ConstString, IntoContents, JsonObject}, + model::{CallToolRequestParam, CallToolResult, IntoContents, JsonObject}, service::RequestContext, }; + /// A shortcut for generating a JSON schema for a type. pub fn schema_for_type() -> JsonObject { let mut settings = schemars::r#gen::SchemaSettings::default(); @@ -63,16 +64,16 @@ pub fn parse_json_object(input: JsonObject) -> Result { - request_context: RequestContext, - service: &'service S, - name: Cow<'static, str>, - arguments: Option, +pub struct ToolCallContext<'s, S> { + pub request_context: RequestContext, + pub service: &'s S, + pub name: Cow<'static, str>, + pub arguments: Option, } -impl<'service, S> ToolCallContext<'service, S> { +impl<'s, S> ToolCallContext<'s, S> { pub fn new( - service: &'service S, + service: &'s S, CallToolRequestParam { name, arguments }: CallToolRequestParam, request_context: RequestContext, ) -> Self { @@ -91,10 +92,8 @@ impl<'service, S> ToolCallContext<'service, S> { } } -pub trait FromToolCallContextPart<'a, S>: Sized { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error>; +pub trait FromToolCallContextPart: Sized { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result; } pub trait IntoCallToolResult { @@ -162,16 +161,16 @@ impl IntoCallToolResult for Result { } } -pub trait CallToolHandler<'a, S, A> { - type Fut: Future> + Send + 'a; - fn call(self, context: ToolCallContext<'a, S>) -> Self::Fut; +pub trait CallToolHandler { + fn call( + self, + context: ToolCallContext<'_, S>, + ) -> BoxFuture<'_, Result>; } -pub type DynCallToolHandler = dyn Fn(ToolCallContext<'_, S>) -> BoxFuture<'_, Result> +pub type DynCallToolHandler = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result> + Send + Sync; -/// Parameter Extractor -pub struct Parameter(pub K, pub V); /// Parameter Extractor /// @@ -189,83 +188,25 @@ impl JsonSchema for Parameters

{ } } -/// Callee Extractor -pub struct Callee<'a, S>(pub &'a S); - -impl<'a, S> FromToolCallContextPart<'a, S> for CancellationToken { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((context.request_context.ct.clone(), context)) - } -} - -impl<'a, S> FromToolCallContextPart<'a, S> for Callee<'a, S> { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((Callee(context.service), context)) +impl FromToolCallContextPart for CancellationToken { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { + Ok(context.request_context.ct.clone()) } } pub struct ToolName(pub Cow<'static, str>); -impl<'a, S> FromToolCallContextPart<'a, S> for ToolName { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((Self(context.name.clone()), context)) +impl FromToolCallContextPart for ToolName { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { + Ok(Self(context.name.clone())) } } -impl<'a, S> FromToolCallContextPart<'a, S> for &'a S { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((context.service, context)) - } -} - -impl<'a, S, K, V> FromToolCallContextPart<'a, S> for Parameter -where - K: ConstString, - V: DeserializeOwned, -{ - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - let arguments = context - .arguments - .as_ref() - .ok_or(crate::Error::invalid_params( - format!("missing parameter {field}", field = K::VALUE), - None, - ))?; - let value = arguments.get(K::VALUE).ok_or(crate::Error::invalid_params( - format!("missing parameter {field}", field = K::VALUE), - None, - ))?; - let value: V = serde_json::from_value(value.clone()).map_err(|e| { - crate::Error::invalid_params( - format!( - "failed to deserialize parameter {field}: {error}", - field = K::VALUE, - error = e - ), - None, - ) - })?; - Ok((Parameter(K::default(), value), context)) - } -} - -impl<'a, S, P> FromToolCallContextPart<'a, S> for Parameters

+impl FromToolCallContextPart for Parameters

where P: DeserializeOwned, { - fn from_tool_call_context_part( - mut context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let arguments = context.arguments.take().unwrap_or_default(); let value: P = serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| { @@ -274,37 +215,31 @@ where None, ) })?; - Ok((Parameters(value), context)) + Ok(Parameters(value)) } } -impl<'a, S> FromToolCallContextPart<'a, S> for JsonObject { - fn from_tool_call_context_part( - mut context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { +impl FromToolCallContextPart for JsonObject { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let object = context.arguments.take().unwrap_or_default(); - Ok((object, context)) + Ok(object) } } -impl<'a, S> FromToolCallContextPart<'a, S> for crate::model::Extensions { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { +impl FromToolCallContextPart for crate::model::Extensions { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let extensions = context.request_context.extensions.clone(); - Ok((extensions, context)) + Ok(extensions) } } pub struct Extension(pub T); -impl<'a, S, T> FromToolCallContextPart<'a, S> for Extension +impl FromToolCallContextPart for Extension where T: Send + Sync + 'static + Clone, { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let extension = context .request_context .extensions @@ -316,58 +251,52 @@ where None, ) })?; - Ok((Extension(extension), context)) + Ok(Extension(extension)) } } -impl<'a, S> FromToolCallContextPart<'a, S> for crate::Peer { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { +impl FromToolCallContextPart for crate::Peer { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let peer = context.request_context.peer.clone(); - Ok((peer, context)) + Ok(peer) } } -impl<'a, S> FromToolCallContextPart<'a, S> for crate::model::Meta { - fn from_tool_call_context_part( - mut context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { +impl FromToolCallContextPart for crate::model::Meta { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let mut meta = crate::model::Meta::default(); std::mem::swap(&mut meta, &mut context.request_context.meta); - Ok((meta, context)) + Ok(meta) } } pub struct RequestId(pub crate::model::RequestId); -impl<'a, S> FromToolCallContextPart<'a, S> for RequestId { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((RequestId(context.request_context.id.clone()), context)) +impl FromToolCallContextPart for RequestId { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { + Ok(RequestId(context.request_context.id.clone())) } } -impl<'a, S> FromToolCallContextPart<'a, S> for RequestContext { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((context.request_context.clone(), context)) +impl FromToolCallContextPart for RequestContext { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { + Ok(context.request_context.clone()) } } impl<'s, S> ToolCallContext<'s, S> { - pub fn invoke(self, h: H) -> H::Fut + pub fn invoke(self, h: H) -> BoxFuture<'s, Result> where - H: CallToolHandler<'s, S, A>, + H: CallToolHandler, { h.call(self) } } - #[allow(clippy::type_complexity)] -pub struct AsyncAdapter(PhantomData<(fn(P) -> Fut, fn(Fut) -> R)>); +pub struct AsyncAdapter(PhantomData fn(Fut) -> R>); pub struct SyncAdapter(PhantomData R>); +// #[allow(clippy::type_complexity)] +pub struct AsyncMethodAdapter(PhantomData R>); +pub struct SyncMethodAdapter(PhantomData R>); macro_rules! impl_for { ($($T: ident)*) => { @@ -382,174 +311,118 @@ macro_rules! impl_for { impl_for!([$($Tn)* $Tn_1] [$($Rest)*]); }; (@impl $($Tn: ident)*) => { - impl<'s, $($Tn,)* S, F, Fut, R> CallToolHandler<'s, S, AsyncAdapter<($($Tn,)*), Fut, R>> for F + impl<$($Tn,)* S, F, R> CallToolHandler> for F where $( - $Tn: FromToolCallContextPart<'s, S> + 's, + $Tn: FromToolCallContextPart , )* - F: FnOnce($($Tn,)*) -> Fut + Send + 's, - Fut: Future + Send + 's, - R: IntoCallToolResult + Send + 's, - S: Send + Sync, + F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R>, + + // Need RTN support here(I guess), https://github.com/rust-lang/rust/pull/138424 + // Fut: Future + Send + 'a, + R: IntoCallToolResult + Send + 'static, + S: Send + Sync + 'static, { - type Fut = IntoCallToolResultFut; - #[allow(unused_variables, non_snake_case)] + #[allow(unused_variables, non_snake_case, unused_mut)] fn call( self, - context: ToolCallContext<'s, S>, - ) -> Self::Fut { + mut context: ToolCallContext<'_, S>, + ) -> BoxFuture<'_, Result>{ $( - let result = $Tn::from_tool_call_context_part(context); - let ($Tn, context) = match result { - Ok((value, context)) => (value, context), - Err(e) => return IntoCallToolResultFut::Ready { - result: std::future::ready(Err(e)), - }, + let result = $Tn::from_tool_call_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), }; )* - IntoCallToolResultFut::Pending { - fut: self($($Tn,)*), - _marker: PhantomData - } + let service = context.service; + let fut = self(service, $($Tn,)*); + async move { + let result = fut.await; + result.into_call_tool_result() + }.boxed() } } - impl<'s, $($Tn,)* S, F, R> CallToolHandler<'s, S, SyncAdapter<($($Tn,)*), R>> for F + impl<$($Tn,)* S, F, Fut, R> CallToolHandler> for F where $( - $Tn: FromToolCallContextPart<'s, S> + 's, + $Tn: FromToolCallContextPart , )* - F: FnOnce($($Tn,)*) -> R + Send + 's, - R: IntoCallToolResult + Send + 's, + F: FnOnce($($Tn,)*) -> Fut + Send + , + Fut: Future + Send + 'static, + R: IntoCallToolResult + Send + 'static, S: Send + Sync, { - type Fut = Ready>; - #[allow(unused_variables, non_snake_case)] + #[allow(unused_variables, non_snake_case, unused_mut)] fn call( self, - context: ToolCallContext<'s, S>, - ) -> Self::Fut { + mut context: ToolCallContext, + ) -> BoxFuture<'static, Result>{ $( - let result = $Tn::from_tool_call_context_part(context); - let ($Tn, context) = match result { - Ok((value, context)) => (value, context), - Err(e) => return std::future::ready(Err(e)), + let result = $Tn::from_tool_call_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), }; )* - std::future::ready(self($($Tn,)*).into_call_tool_result()) + let fut = self($($Tn,)*); + async move { + let result = fut.await; + result.into_call_tool_result() + }.boxed() } } - }; -} -impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); -pub struct ToolBoxItem { - #[allow(clippy::type_complexity)] - pub call: Box>, - pub attr: crate::model::Tool, -} -impl ToolBoxItem { - pub fn new(attr: crate::model::Tool, call: C) -> Self - where - C: Fn(ToolCallContext<'_, S>) -> BoxFuture<'_, Result> - + Send - + Sync - + 'static, - { - Self { - call: Box::new(call), - attr, - } - } - pub fn name(&self) -> &str { - &self.attr.name - } -} - -#[derive(Default)] -pub struct ToolBox { - #[allow(clippy::type_complexity)] - pub map: std::collections::HashMap, ToolBoxItem>, -} - -impl ToolBox { - pub fn new() -> Self { - Self { - map: std::collections::HashMap::new(), + impl<$($Tn,)* S, F, R> CallToolHandler> for F + where + $( + $Tn: FromToolCallContextPart + , + )* + F: FnOnce(&S, $($Tn,)*) -> R + Send + , + R: IntoCallToolResult + Send + , + S: Send + Sync, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn call( + self, + mut context: ToolCallContext, + ) -> BoxFuture<'static, Result> { + $( + let result = $Tn::from_tool_call_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + std::future::ready(self(context.service, $($Tn,)*).into_call_tool_result()).boxed() + } } - } - pub fn add(&mut self, item: ToolBoxItem) { - self.map.insert(item.attr.name.clone(), item); - } - pub fn remove(&mut self, name: &str) { - self.map.remove(name); - } - - pub async fn call( - &self, - context: ToolCallContext<'_, S>, - ) -> Result { - let item = self - .map - .get(context.name()) - .ok_or_else(|| crate::Error::invalid_params("tool not found", None))?; - (item.call)(context).await - } - - pub fn list(&self) -> Vec { - self.map.values().map(|item| item.attr.clone()).collect() - } -} - -#[cfg(feature = "macros")] -#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] -#[macro_export] -macro_rules! tool_box { - (@pin_add $callee: ident, $attr: expr, $f: expr) => { - $callee.add(ToolBoxItem::new($attr, |context| Box::pin($f(context)))); - }; - ($server: ident { $($tool: ident),* $(,)?} ) => { - $crate::tool_box!($server { $($tool),* } tool_box); - }; - ($server: ident { $($tool: ident),* $(,)?} $tool_box: ident) => { - fn $tool_box() -> &'static $crate::handler::server::tool::ToolBox<$server> { - use $crate::handler::server::tool::{ToolBox, ToolBoxItem}; - static TOOL_BOX: std::sync::OnceLock> = std::sync::OnceLock::new(); - TOOL_BOX.get_or_init(|| { - let mut tool_box = ToolBox::new(); - $crate::paste!{ - $( - $crate::tool_box!(@pin_add tool_box, $server::[< $tool _tool_attr>](), $server::[<$tool _tool_call>]); - )* - } - tool_box - }) + impl<$($Tn,)* S, F, R> CallToolHandler> for F + where + $( + $Tn: FromToolCallContextPart + , + )* + F: FnOnce($($Tn,)*) -> R + Send + , + R: IntoCallToolResult + Send + , + S: Send + Sync, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn call( + self, + mut context: ToolCallContext, + ) -> BoxFuture<'static, Result> { + $( + let result = $Tn::from_tool_call_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + std::future::ready(self($($Tn,)*).into_call_tool_result()).boxed() + } } }; - (@derive) => { - $crate::tool_box!(@derive tool_box); - }; - - (@derive $tool_box:ident) => { - async fn list_tools( - &self, - _: Option<$crate::model::PaginatedRequestParam>, - _: $crate::service::RequestContext<$crate::service::RoleServer>, - ) -> Result<$crate::model::ListToolsResult, $crate::Error> { - Ok($crate::model::ListToolsResult { - next_cursor: None, - tools: Self::tool_box().list(), - }) - } - - async fn call_tool( - &self, - call_tool_request_param: $crate::model::CallToolRequestParam, - context: $crate::service::RequestContext<$crate::service::RoleServer>, - ) -> Result<$crate::model::CallToolResult, $crate::Error> { - let context = $crate::handler::server::tool::ToolCallContext::new(self, call_tool_request_param, context); - Self::$tool_box().call(context).await - } - } } +impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); \ No newline at end of file diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index e4096b48..7b5c655c 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -15,7 +15,7 @@ //! as Claude Desktop or the Cursor IDE. //! //! For example, to implement a server that has a tool that can count, you would -//! make an object for that tool and add an implementation with the `#[tool(tool_box)]` macro: +//! make an object for that tool and add an implementation with the `#[tool_router]` macro: //! //! ```rust //! use std::sync::Arc; @@ -27,7 +27,7 @@ //! counter: Arc>, //! } //! -//! #[tool(tool_box)] +//! #[tool_router] //! impl Counter { //! fn new() -> Self { //! Self { @@ -120,7 +120,7 @@ pub mod transport; pub use paste::paste; #[cfg(all(feature = "macros", feature = "server"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "macros", feature = "server"))))] -pub use rmcp_macros::tool; +pub use rmcp_macros::*; #[cfg(all(feature = "macros", feature = "server"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "macros", feature = "server"))))] pub use schemars; diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index c6407ac7..addc0bf0 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -679,6 +679,17 @@ macro_rules! paginated_result { pub next_cursor: Option, pub $i_item: $t_item, } + + impl $t { + pub fn with_all_items( + items: $t_item, + ) -> Self { + Self { + next_cursor: None, + $i_item: items, + } + } + } }; } diff --git a/crates/rmcp/tests/common/calculator.rs b/crates/rmcp/tests/common/calculator.rs index e179f258..4f4fccee 100644 --- a/crates/rmcp/tests/common/calculator.rs +++ b/crates/rmcp/tests/common/calculator.rs @@ -1,7 +1,9 @@ +#![allow(dead_code)] use rmcp::{ ServerHandler, + handler::server::{router::tool::ToolRouter, tool::Parameters}, model::{ServerCapabilities, ServerInfo}, - schemars, tool, + schemars, tool, tool_router, }; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct SumRequest { @@ -9,30 +11,46 @@ pub struct SumRequest { pub a: i32, pub b: i32, } -#[derive(Debug, Clone, Default)] -pub struct Calculator; -#[tool(tool_box)] + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct SubRequest { + #[schemars(description = "the left hand side number")] + pub a: i32, + #[schemars(description = "the right hand side number")] + pub b: i32, +} +#[derive(Debug, Clone)] +pub struct Calculator { + tool_router: ToolRouter, +} + +impl Calculator { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } +} + +impl Default for Calculator { + fn default() -> Self { + Self::new() + } +} + +#[tool_router] impl Calculator { #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + fn sum(&self, Parameters(SumRequest { a, b }): Parameters) -> String { (a + b).to_string() } #[tool(description = "Calculate the sub of two numbers")] - fn sub( - &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, - ) -> String { + fn sub(&self, Parameters(SubRequest { a, b }): Parameters) -> String { (a - b).to_string() } } -#[tool(tool_box)] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { diff --git a/crates/rmcp/tests/test_complex_schema.rs b/crates/rmcp/tests/test_complex_schema.rs index b9370fce..6a38a178 100644 --- a/crates/rmcp/tests/test_complex_schema.rs +++ b/crates/rmcp/tests/test_complex_schema.rs @@ -1,4 +1,6 @@ -use rmcp::{Error as McpError, model::*, schemars, tool}; +use rmcp::{ + Error as McpError, handler::server::tool::Parameters, model::*, schemars, tool_router, tool, +}; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] @@ -24,7 +26,7 @@ pub struct ChatRequest { #[derive(Clone, Default)] pub struct Demo; -#[tool(tool_box)] +#[tool_router] impl Demo { pub fn new() -> Self { Self @@ -33,9 +35,9 @@ impl Demo { #[tool(description = "LLM")] async fn chat( &self, - #[tool(aggr)] chat_request: ChatRequest, + chat_request: Parameters, ) -> Result { - let content = Content::json(chat_request)?; + let content = Content::json(chat_request.0)?; Ok(CallToolResult::success(vec![content])) } } diff --git a/crates/rmcp/tests/test_tool_macro_annotations.rs b/crates/rmcp/tests/test_tool_macro_annotations.rs index 0b57dffd..00368af6 100644 --- a/crates/rmcp/tests/test_tool_macro_annotations.rs +++ b/crates/rmcp/tests/test_tool_macro_annotations.rs @@ -1,9 +1,11 @@ #[cfg(test)] mod tests { - use rmcp::{ServerHandler, tool}; + use rmcp::{ServerHandler, handler::server::router::tool::ToolRouter, tool}; #[derive(Debug, Clone, Default)] - pub struct AnnotatedServer {} + pub struct AnnotatedServer { + tool_router: ToolRouter, + } impl AnnotatedServer { // Tool with inline comments for documentation @@ -11,12 +13,9 @@ mod tests { /// This is used to test tool annotations #[tool( name = "direct-annotated-tool", - annotations = { - title: "Annotated Tool", - readOnlyHint: true - } + annotations(title = "Annotated Tool", read_only_hint = true) )] - pub async fn direct_annotated_tool(&self, #[tool(param)] input: String) -> String { + pub async fn direct_annotated_tool(&self, input: String) -> String { format!("Direct: {}", input) } } @@ -28,10 +27,7 @@ mod tests { context: rmcp::service::RequestContext, ) -> Result { let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); - match tcc.name() { - "direct-annotated-tool" => Self::direct_annotated_tool_tool_call(tcc).await, - _ => Err(rmcp::Error::invalid_params("method not found", None)), - } + self.tool_router.call(tcc).await } } diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index 84bcac93..33cdaac5 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -1,12 +1,15 @@ //cargo test --test test_tool_macros --features "client server" - +#![allow(dead_code)] use std::sync::Arc; use rmcp::{ ClientHandler, ServerHandler, ServiceExt, - handler::server::tool::ToolCallContext, - model::{CallToolRequestParam, ClientInfo}, - tool, + handler::server::{ + router::tool::ToolRouter, + tool::{Parameters, ToolCallContext}, + }, + model::{CallToolRequestParam, ClientInfo, ListToolsResult}, + tool_router, tool, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -24,30 +27,39 @@ impl ServerHandler for Server { context: rmcp::service::RequestContext, ) -> Result { let tcc = ToolCallContext::new(self, request, context); - match tcc.name() { - "get-weather" => Self::get_weather_tool_call(tcc).await, - _ => Err(rmcp::Error::invalid_params("method not found", None)), - } + self.tool_router.call(tcc).await } } -#[derive(Debug, Clone, Default)] -pub struct Server {} +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct Server { + tool_router: ToolRouter, +} +impl Default for Server { + fn default() -> Self { + Self::new() + } +} + +#[tool_router] impl Server { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + /// This tool is used to get the weather of a city. - #[tool(name = "get-weather", description = "Get the weather of a city.", vis = )] - pub async fn get_weather(&self, #[tool(param)] city: String) -> String { + #[tool(name = "get-weather", description = "Get the weather of a city.")] + pub async fn get_weather(&self, city: Parameters) -> String { drop(city); "rain".to_string() } - #[tool(description = "Empty Parameter")] - async fn empty_param(&self) {} - #[tool(description = "Optional Parameter")] - async fn optional_param(&self, #[tool(param)] city: Option) -> String { - city.unwrap_or_default() - } + #[tool] + async fn empty_param(&self) {} } // define generic service trait @@ -68,13 +80,15 @@ impl DataService for MockDataService { #[derive(Debug, Clone)] pub struct GenericServer { data_service: Arc, + tool_router: ToolRouter, } -#[tool(tool_box)] +#[tool_router] impl GenericServer { pub fn new(data_service: DS) -> Self { Self { data_service: Arc::new(data_service), + tool_router: Self::tool_router(), } } @@ -83,16 +97,37 @@ impl GenericServer { self.data_service.get_data() } } -#[tool(tool_box)] -impl ServerHandler for GenericServer {} + +impl ServerHandler for GenericServer { + async fn call_tool( + &self, + request: CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + let tcc = ToolCallContext::new(self, request, context); + self.tool_router.call(tcc).await + } + async fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> Result { + Ok(ListToolsResult::with_all_items(self.tool_router.list_all())) + } +} #[tokio::test] async fn test_tool_macros() { - let server = Server::default(); + let server = Server::new(); let _attr = Server::get_weather_tool_attr(); - let _get_weather_call_fn = Server::get_weather_tool_call; + let _get_weather_tool_attr_fn = Server::get_weather_tool_attr; let _get_weather_fn = Server::get_weather; - server.get_weather("harbin".into()).await; + server + .get_weather(Parameters(GetWeatherRequest { + city: "Harbin".into(), + date: "Yesterday".into(), + })) + .await; } #[tokio::test] @@ -108,14 +143,14 @@ async fn test_tool_macros_with_generics() { let mock_service = MockDataService; let server = GenericServer::new(mock_service); let _attr = GenericServer::::get_data_tool_attr(); - let _get_data_call_fn = GenericServer::::get_data_tool_call; + let _get_data_call_fn = GenericServer::::get_data; let _get_data_fn = GenericServer::::get_data; assert_eq!(server.get_data().await, "mock data"); } #[tokio::test] async fn test_tool_macros_with_optional_param() { - let _attr = Server::optional_param_tool_attr(); + let _attr = Server::get_weather_tool_attr(); // println!("{_attr:?}"); let attr_type = _attr .input_schema @@ -147,20 +182,40 @@ pub struct OptionalI64TestSchema { } // Dummy struct to host the test tool method -#[derive(Debug, Clone, Default)] -pub struct OptionalSchemaTester {} +#[derive(Debug, Clone)] +pub struct OptionalSchemaTester { + router: ToolRouter, +} + +impl Default for OptionalSchemaTester { + fn default() -> Self { + Self::new() + } +} + +impl OptionalSchemaTester { + pub fn new() -> Self { + Self { + router: Self::tool_router(), + } + } +} +#[tool_router] impl OptionalSchemaTester { // Dummy tool function using the test schema as an aggregated parameter #[tool(description = "A tool to test optional schema generation")] - async fn test_optional_aggr(&self, #[tool(aggr)] _req: OptionalFieldTestSchema) { + async fn test_optional(&self, _req: Parameters) { // Implementation doesn't matter for schema testing // Return type changed to () to satisfy IntoCallToolResult } // Tool function to test optional i64 handling #[tool(description = "A tool to test optional i64 schema generation")] - async fn test_optional_i64_aggr(&self, #[tool(aggr)] req: OptionalI64TestSchema) -> String { + async fn test_optional_i64( + &self, + Parameters(req): Parameters, + ) -> String { match req.count { Some(c) => format!("Received count: {}", c), None => "Received null count".to_string(), @@ -176,11 +231,7 @@ impl ServerHandler for OptionalSchemaTester { context: rmcp::service::RequestContext, ) -> Result { let tcc = ToolCallContext::new(self, request, context); - match tcc.name() { - "test_optional_aggr" => Self::test_optional_aggr_tool_call(tcc).await, - "test_optional_i64_aggr" => Self::test_optional_i64_aggr_tool_call(tcc).await, - _ => Err(rmcp::Error::invalid_params("method not found", None)), - } + self.router.call(tcc).await } } @@ -189,7 +240,7 @@ fn test_optional_field_schema_generation_via_macro() { // tests https://github.com/modelcontextprotocol/rust-sdk/issues/135 // Get the attributes generated by the #[tool] macro helper - let tool_attr = OptionalSchemaTester::test_optional_aggr_tool_attr(); + let tool_attr = OptionalSchemaTester::test_optional_tool_attr(); // Print the actual generated schema for debugging println!( @@ -257,7 +308,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); // Server setup - let server = OptionalSchemaTester::default(); + let server = OptionalSchemaTester::new(); let server_handle = tokio::spawn(async move { server.serve(server_transport).await?.waiting().await?; anyhow::Ok(()) diff --git a/crates/rmcp/tests/test_tool_routers.rs b/crates/rmcp/tests/test_tool_routers.rs new file mode 100644 index 00000000..e08437c1 --- /dev/null +++ b/crates/rmcp/tests/test_tool_routers.rs @@ -0,0 +1,74 @@ +use std::collections::HashMap; + +use futures::future::BoxFuture; +use rmcp::{ + ServerHandler, + handler::server::{ + router::tool::ToolRouter, + tool::{CallToolHandler, Parameters}, + }, +}; + +#[derive(Debug, Default)] +pub struct TestHandler { + pub _marker: std::marker::PhantomData, +} + +impl ServerHandler for TestHandler {} +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct Request { + pub fields: HashMap, +} + +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct Sum { + pub a: i32, + pub b: i32, +} + +#[rmcp::tool_router(router = test_router_1)] +impl TestHandler { + #[rmcp::tool] + async fn async_method(&self, Parameters(Request { fields }): Parameters) { + drop(fields) + } +} + +#[rmcp::tool_router(router = test_router_2)] +impl TestHandler { + #[rmcp::tool] + fn sync_method(&self, Parameters(Request { fields }): Parameters) { + drop(fields) + } +} + +#[rmcp::tool] +async fn async_function( + _callee: &TestHandler, + Parameters(Request { fields }): Parameters, +) { + drop(fields) +} + +#[rmcp::tool] +fn async_function2(_callee: &TestHandler) -> BoxFuture<'_, ()> { + Box::pin(async move {}) +} + +#[test] +fn test_tool_router() { + let test_tool_router: ToolRouter> = ToolRouter::>::new() + .with_route(async_function_tool_attr(), async_function) + .with_route(async_function2_tool_attr(), async_function2) + + TestHandler::<()>::test_router_1() + + TestHandler::<()>::test_router_2(); + let tools = test_tool_router.list_all(); + assert_eq!(tools.len(), 4); + assert_handler(TestHandler::<()>::async_method); +} + +fn assert_handler(_handler: H) +where + H: CallToolHandler, +{ +} diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index b00752ad..3f2761cd 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -96,7 +96,7 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { let service: StreamableHttpService = StreamableHttpService::new( - || Ok(Calculator), + || Ok(Calculator::new()), Default::default(), StreamableHttpServerConfig { stateful_mode: true, diff --git a/docs/readme/README.zh-cn.md b/docs/readme/README.zh-cn.md index cc730c1d..6e6ff91d 100644 --- a/docs/readme/README.zh-cn.md +++ b/docs/readme/README.zh-cn.md @@ -97,7 +97,7 @@ pub struct SumRequest { pub struct Calculator; // create a static toolbox to store the tool attributes -#[tool(tool_box)] +#[tool_router] impl Calculator { // async function #[tool(description = "Calculate the sum of two numbers")] @@ -122,7 +122,7 @@ impl Calculator { } // impl call_tool and list_tool by querying static toolbox -#[tool(tool_box)] +#[tool_router] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { diff --git a/examples/servers/src/common/calculator.rs b/examples/servers/src/common/calculator.rs index 68beecc0..84a1cef9 100644 --- a/examples/servers/src/common/calculator.rs +++ b/examples/servers/src/common/calculator.rs @@ -1,8 +1,10 @@ +#![allow(dead_code)] + use rmcp::{ ServerHandler, - handler::server::wrapper::Json, - model::{ServerCapabilities, ServerInfo}, - schemars, tool, + handler::server::{router::tool::ToolRouter, tool::Parameters, wrapper::Json}, + model::{ListToolsResult, ServerCapabilities, ServerInfo}, + schemars, tool, tool_router, }; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] @@ -11,30 +13,33 @@ pub struct SumRequest { pub a: i32, pub b: i32, } + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct SubRequest { + #[schemars(description = "the left hand side number")] + pub a: i32, + #[schemars(description = "the right hand side number")] + pub b: i32, +} + #[derive(Debug, Clone)] -pub struct Calculator; -#[tool(tool_box)] +pub struct Calculator { + tool_router: ToolRouter, +} + +#[tool_router] impl Calculator { #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + fn sum(&self, Parameters(SumRequest { a, b }): Parameters) -> String { (a + b).to_string() } #[tool(description = "Calculate the difference of two numbers")] - fn sub( - &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, - ) -> Json { + fn sub(&self, Parameters(SubRequest { a, b }): Parameters) -> Json { Json(a - b) } } -#[tool(tool_box)] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { @@ -43,4 +48,22 @@ impl ServerHandler for Calculator { ..Default::default() } } + + async fn call_tool( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); + self.tool_router.call(tcc).await + } + + async fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> Result { + let items = self.tool_router.list_all(); + Ok(ListToolsResult::with_all_items(items)) + } } diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index 12aa8a4a..e4db5901 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -1,8 +1,9 @@ +#![allow(dead_code)] use std::sync::Arc; use rmcp::{ - Error as McpError, RoleServer, ServerHandler, const_string, model::*, schemars, - service::RequestContext, tool, + Error as McpError, RoleServer, ServerHandler, const_string, handler::server::tool::Parameters, + model::*, schemars, service::RequestContext, tool, tool_router, }; use serde_json::json; use tokio::sync::Mutex; @@ -18,7 +19,7 @@ pub struct Counter { counter: Arc>, } -#[tool(tool_box)] +#[tool_router] impl Counter { #[allow(dead_code)] pub fn new() -> Self { @@ -63,19 +64,16 @@ impl Counter { } #[tool(description = "Repeat what you say")] - fn echo( - &self, - #[tool(param)] - #[schemars(description = "Repeat what you say")] - saying: String, - ) -> Result { - Ok(CallToolResult::success(vec![Content::text(saying)])) + fn echo(&self, Parameters(object): Parameters) -> Result { + Ok(CallToolResult::success(vec![Content::text( + serde_json::Value::Object(object).to_string(), + )])) } #[tool(description = "Calculate the sum of two numbers")] fn sum( &self, - #[tool(aggr)] StructRequest { a, b }: StructRequest, + Parameters(StructRequest { a, b }): Parameters, ) -> Result { Ok(CallToolResult::success(vec![Content::text( (a + b).to_string(), @@ -83,7 +81,6 @@ impl Counter { } } const_string!(Echo = "echo"); -#[tool(tool_box)] impl ServerHandler for Counter { fn get_info(&self) -> ServerInfo { ServerInfo { diff --git a/examples/servers/src/common/generic_service.rs b/examples/servers/src/common/generic_service.rs index fc1d00ed..b5d1bd77 100644 --- a/examples/servers/src/common/generic_service.rs +++ b/examples/servers/src/common/generic_service.rs @@ -2,8 +2,9 @@ use std::sync::Arc; use rmcp::{ ServerHandler, + handler::server::{router::tool::ToolRouter, tool::Parameters}, model::{ServerCapabilities, ServerInfo}, - schemars, tool, + schemars, tool, tool_router, }; #[allow(dead_code)] @@ -40,14 +41,21 @@ impl DataService for MemoryDataService { pub struct GenericService { #[allow(dead_code)] data_service: Arc, + tool_router: ToolRouter, } -#[tool(tool_box)] +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct SetDataRequest { + pub data: String, +} + +#[tool_router] impl GenericService { #[allow(dead_code)] pub fn new(data_service: DS) -> Self { Self { data_service: Arc::new(data_service), + tool_router: Self::tool_router(), } } @@ -57,13 +65,15 @@ impl GenericService { } #[tool(description = "set memory to service")] - pub async fn set_data(&self, #[tool(param)] data: String) -> String { + pub async fn set_data( + &self, + Parameters(SetDataRequest { data }): Parameters, + ) -> String { let new_data = data.clone(); format!("Current memory: {}", new_data) } } -#[tool(tool_box)] impl ServerHandler for GenericService { fn get_info(&self) -> ServerInfo { ServerInfo { @@ -72,4 +82,23 @@ impl ServerHandler for GenericService { ..Default::default() } } + + async fn call_tool( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); + self.tool_router.call(tcc).await + } + + async fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> Result { + Ok(rmcp::model::ListToolsResult::with_all_items( + self.tool_router.list_all(), + )) + } } From ce87084fcdfd0a14cb23f0d5ca3872242b6c32e3 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 16 Jun 2025 19:48:28 +0800 Subject: [PATCH 02/12] fix: fix fmt and build error --- crates/rmcp-macros/src/lib.rs | 2 +- crates/rmcp-macros/src/tool.rs | 3 +- crates/rmcp/Cargo.toml | 1 - crates/rmcp/src/handler/server.rs | 2 +- crates/rmcp/src/handler/server/router.rs | 3 +- crates/rmcp/src/handler/server/router/tool.rs | 106 ++---------------- crates/rmcp/src/handler/server/tool.rs | 2 +- crates/rmcp/tests/test_complex_schema.rs | 2 +- crates/rmcp/tests/test_tool_macros.rs | 2 +- examples/servers/src/common/calculator.rs | 6 + examples/servers/src/common/counter.rs | 26 ++++- examples/transport/src/common/calculator.rs | 68 ++++++++--- examples/transport/src/http_upgrade.rs | 2 +- examples/transport/src/tcp.rs | 2 +- examples/transport/src/unix_socket.rs | 2 +- examples/transport/src/websocket.rs | 2 +- examples/wasi/src/calculator.rs | 67 +++++++---- examples/wasi/src/lib.rs | 5 +- 18 files changed, 151 insertions(+), 152 deletions(-) diff --git a/crates/rmcp-macros/src/lib.rs b/crates/rmcp-macros/src/lib.rs index b2d6a7b0..f1e2a9b9 100644 --- a/crates/rmcp-macros/src/lib.rs +++ b/crates/rmcp-macros/src/lib.rs @@ -22,4 +22,4 @@ pub fn tool_router(attr: TokenStream, input: TokenStream) -> TokenStream { tool_router::tool_router(attr.into(), input.into()) .unwrap_or_else(|err| err.to_compile_error()) .into() -} \ No newline at end of file +} diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index d32f8189..c27e5271 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -1,5 +1,4 @@ -use darling::FromMeta; -use darling::ast::NestedMeta; +use darling::{FromMeta, ast::NestedMeta}; use proc_macro2::TokenStream; use quote::{ToTokens, format_ident, quote}; use syn::{Expr, Ident, ImplItemFn, ReturnType}; diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index cb144376..999d4ee5 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -63,7 +63,6 @@ http-body-util = { version = "0.1", optional = true } bytes = { version = "1", optional = true } # macro rmcp-macros = { version = "0.1", workspace = true, optional = true } -inventory = "0.3" [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] chrono = { version = "0.4.38", features = ["serde"] } diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index e532927e..9773bd98 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -5,9 +5,9 @@ use crate::{ }; mod resource; +pub mod router; pub mod tool; pub mod wrapper; -pub mod router; impl Service for H { async fn handle_request( &self, diff --git a/crates/rmcp/src/handler/server/router.rs b/crates/rmcp/src/handler/server/router.rs index 926c7eb7..ea972f59 100644 --- a/crates/rmcp/src/handler/server/router.rs +++ b/crates/rmcp/src/handler/server/router.rs @@ -2,14 +2,13 @@ use std::sync::Arc; use tool::{IntoToolRoute, ToolRoute}; +use super::ServerHandler; use crate::{ RoleServer, Service, model::{ClientRequest, ListToolsResult, ServerResult}, service::NotificationContext, }; -use super::ServerHandler; - pub mod tool; pub struct Router { diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index ed2602d7..97db401b 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -1,107 +1,15 @@ -use std::any::{Any, TypeId}; -use std::borrow::Cow; -use std::collections::HashMap; -use std::sync::Arc; +use std::{borrow::Cow, sync::Arc}; -use futures::FutureExt; -use futures::future::BoxFuture; +use futures::{FutureExt, future::BoxFuture}; use schemars::JsonSchema; -use crate::model::{CallToolResult, Tool, ToolAnnotations}; - -use crate::handler::server::tool::{ - CallToolHandler, DynCallToolHandler, ToolCallContext, schema_for_type, +use crate::{ + handler::server::tool::{ + CallToolHandler, DynCallToolHandler, ToolCallContext, schema_for_type, + }, + model::{CallToolResult, Tool, ToolAnnotations}, }; -inventory::collect!(ToolRouteWithType); - -#[derive(Debug, Default)] -pub struct GlobalStaticRouters { - pub routers: - std::sync::OnceLock>>>, -} - -impl GlobalStaticRouters { - pub fn global() -> &'static Self { - static GLOBAL: GlobalStaticRouters = GlobalStaticRouters { - routers: std::sync::OnceLock::new(), - }; - &GLOBAL - } - pub async fn set(router: Arc>) -> Result<(), String> { - let routers = Self::global().routers.get_or_init(Default::default); - let mut routers_wg = routers.write().await; - if routers_wg.insert(TypeId::of::(), router).is_some() { - return Err("Router already exists".to_string()); - } - Ok(()) - } - pub async fn get() -> Arc> { - let routers = Self::global().routers.get_or_init(Default::default); - let routers_rg = routers.read().await; - if let Some(router) = routers_rg.get(&TypeId::of::()) { - return router - .clone() - .downcast::>() - .expect("Failed to downcast"); - } - { - drop(routers_rg); - } - let mut routers = routers.write().await; - match routers.entry(TypeId::of::()) { - std::collections::hash_map::Entry::Occupied(occupied) => occupied - .get() - .clone() - .downcast::>() - .expect("Failed to downcast"), - std::collections::hash_map::Entry::Vacant(vacant) => { - let mut router = ToolRouter::::default(); - for route in inventory::iter:: - .into_iter() - .filter(|r| r.type_id == TypeId::of::()) - { - if let Some(route) = route.downcast::() { - router.add_route(route.clone()); - } - } - let mut_ref = vacant.insert(Arc::new(router)); - mut_ref - .downcast_ref() - .cloned() - .expect("Failed to downcast after insert") - } - } - } -} - -pub struct ToolRouteWithType { - type_id: TypeId, - route: Box, -} - -impl ToolRouteWithType { - pub fn downcast(&self) -> Option<&ToolRoute> { - if self.type_id == TypeId::of::() { - self.route.downcast_ref::>() - } else { - None - } - } - pub fn from_tool_route(route: ToolRoute) -> Self { - Self { - type_id: TypeId::of::(), - route: Box::new(route), - } - } -} - -impl From> for ToolRouteWithType { - fn from(value: ToolRoute) -> Self { - Self::from_tool_route(value) - } -} - pub struct ToolRoute { #[allow(clippy::type_complexity)] pub call: Arc>, diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index 5dbf85d4..03206dc4 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -425,4 +425,4 @@ macro_rules! impl_for { } }; } -impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); \ No newline at end of file +impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); diff --git a/crates/rmcp/tests/test_complex_schema.rs b/crates/rmcp/tests/test_complex_schema.rs index 6a38a178..c199e7e6 100644 --- a/crates/rmcp/tests/test_complex_schema.rs +++ b/crates/rmcp/tests/test_complex_schema.rs @@ -1,5 +1,5 @@ use rmcp::{ - Error as McpError, handler::server::tool::Parameters, model::*, schemars, tool_router, tool, + Error as McpError, handler::server::tool::Parameters, model::*, schemars, tool, tool_router, }; use serde::{Deserialize, Serialize}; diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index 33cdaac5..606e5ae6 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -9,7 +9,7 @@ use rmcp::{ tool::{Parameters, ToolCallContext}, }, model::{CallToolRequestParam, ClientInfo, ListToolsResult}, - tool_router, tool, + tool, tool_router, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; diff --git a/examples/servers/src/common/calculator.rs b/examples/servers/src/common/calculator.rs index 84a1cef9..3629d133 100644 --- a/examples/servers/src/common/calculator.rs +++ b/examples/servers/src/common/calculator.rs @@ -29,6 +29,12 @@ pub struct Calculator { #[tool_router] impl Calculator { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + #[tool(description = "Calculate the sum of two numbers")] fn sum(&self, Parameters(SumRequest { a, b }): Parameters) -> String { (a + b).to_string() diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index e4db5901..7d3dd6c7 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -2,8 +2,12 @@ use std::sync::Arc; use rmcp::{ - Error as McpError, RoleServer, ServerHandler, const_string, handler::server::tool::Parameters, - model::*, schemars, service::RequestContext, tool, tool_router, + Error as McpError, RoleServer, ServerHandler, const_string, + handler::server::{router::tool::ToolRouter, tool::Parameters}, + model::*, + schemars, + service::RequestContext, + tool, tool_router, }; use serde_json::json; use tokio::sync::Mutex; @@ -17,6 +21,7 @@ pub struct StructRequest { #[derive(Clone)] pub struct Counter { counter: Arc>, + tool_router: ToolRouter, } #[tool_router] @@ -25,6 +30,7 @@ impl Counter { pub fn new() -> Self { Self { counter: Arc::new(Mutex::new(0)), + tool_router: Self::tool_router(), } } @@ -193,6 +199,22 @@ impl ServerHandler for Counter { }) } + async fn call_tool( + &self, + request: CallToolRequestParam, + context: RequestContext, + ) -> Result { + let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); + self.tool_router.call(tcc).await + } + + async fn list_tools( + &self, + _request: Option, + _context: RequestContext, + ) -> Result { + Ok(ListToolsResult::with_all_items(self.tool_router.list_all())) + } async fn initialize( &self, _request: InitializeRequestParam, diff --git a/examples/transport/src/common/calculator.rs b/examples/transport/src/common/calculator.rs index 99b7314a..3629d133 100644 --- a/examples/transport/src/common/calculator.rs +++ b/examples/transport/src/common/calculator.rs @@ -1,4 +1,11 @@ -use rmcp::{ServerHandler, model::ServerInfo, schemars, tool, tool_box}; +#![allow(dead_code)] + +use rmcp::{ + ServerHandler, + handler::server::{router::tool::ToolRouter, tool::Parameters, wrapper::Json}, + model::{ListToolsResult, ServerCapabilities, ServerInfo}, + schemars, tool, tool_router, +}; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct SumRequest { @@ -6,36 +13,63 @@ pub struct SumRequest { pub a: i32, pub b: i32, } + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct SubRequest { + #[schemars(description = "the left hand side number")] + pub a: i32, + #[schemars(description = "the right hand side number")] + pub b: i32, +} + #[derive(Debug, Clone)] -pub struct Calculator; +pub struct Calculator { + tool_router: ToolRouter, +} + +#[tool_router] impl Calculator { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + fn sum(&self, Parameters(SumRequest { a, b }): Parameters) -> String { (a + b).to_string() } - #[tool(description = "Calculate the sub of two numbers")] - fn sub( - &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, - ) -> String { - (a - b).to_string() + #[tool(description = "Calculate the difference of two numbers")] + fn sub(&self, Parameters(SubRequest { a, b }): Parameters) -> Json { + Json(a - b) } - - tool_box!(Calculator { sum, sub }); } impl ServerHandler for Calculator { - tool_box!(@derive); fn get_info(&self) -> ServerInfo { ServerInfo { instructions: Some("A simple calculator".into()), + capabilities: ServerCapabilities::builder().enable_tools().build(), ..Default::default() } } + + async fn call_tool( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); + self.tool_router.call(tcc).await + } + + async fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> Result { + let items = self.tool_router.list_all(); + Ok(ListToolsResult::with_all_items(items)) + } } diff --git a/examples/transport/src/http_upgrade.rs b/examples/transport/src/http_upgrade.rs index 1c0cf6ad..6a15add3 100644 --- a/examples/transport/src/http_upgrade.rs +++ b/examples/transport/src/http_upgrade.rs @@ -24,7 +24,7 @@ async fn main() -> anyhow::Result<()> { async fn http_server(req: Request) -> Result, hyper::Error> { tokio::spawn(async move { let upgraded = hyper::upgrade::on(req).await?; - let service = Calculator.serve(TokioIo::new(upgraded)).await?; + let service = Calculator::new().serve(TokioIo::new(upgraded)).await?; service.waiting().await?; anyhow::Result::<()>::Ok(()) }); diff --git a/examples/transport/src/tcp.rs b/examples/transport/src/tcp.rs index 72428fe6..683fb6cf 100644 --- a/examples/transport/src/tcp.rs +++ b/examples/transport/src/tcp.rs @@ -13,7 +13,7 @@ async fn server() -> anyhow::Result<()> { let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:8001").await?; while let Ok((stream, _)) = tcp_listener.accept().await { tokio::spawn(async move { - let server = serve_server(Calculator, stream).await?; + let server = serve_server(Calculator::new(), stream).await?; server.waiting().await?; anyhow::Ok(()) }); diff --git a/examples/transport/src/unix_socket.rs b/examples/transport/src/unix_socket.rs index a52b45a3..feeb2b87 100644 --- a/examples/transport/src/unix_socket.rs +++ b/examples/transport/src/unix_socket.rs @@ -14,7 +14,7 @@ async fn main() -> anyhow::Result<()> { while let Ok((stream, addr)) = unix_listener.accept().await { println!("Client connected: {:?}", addr); tokio::spawn(async move { - match serve_server(Calculator, stream).await { + match serve_server(Calculator::new(), stream).await { Ok(server) => { println!("Server initialized successfully"); if let Err(e) = server.waiting().await { diff --git a/examples/transport/src/websocket.rs b/examples/transport/src/websocket.rs index 0d0fec72..5ba23546 100644 --- a/examples/transport/src/websocket.rs +++ b/examples/transport/src/websocket.rs @@ -40,7 +40,7 @@ async fn start_server() -> anyhow::Result<()> { tokio::spawn(async move { let ws_stream = tokio_tungstenite::accept_async(stream).await?; let transport = WebsocketTransport::new_server(ws_stream); - let server = Calculator.serve(transport).await?; + let server = Calculator::new().serve(transport).await?; server.waiting().await?; Ok::<(), anyhow::Error>(()) }); diff --git a/examples/wasi/src/calculator.rs b/examples/wasi/src/calculator.rs index f1c35eea..3629d133 100644 --- a/examples/wasi/src/calculator.rs +++ b/examples/wasi/src/calculator.rs @@ -1,41 +1,52 @@ +#![allow(dead_code)] + use rmcp::{ ServerHandler, - model::{ServerCapabilities, ServerInfo}, - schemars, tool, tool_box, + handler::server::{router::tool::ToolRouter, tool::Parameters, wrapper::Json}, + model::{ListToolsResult, ServerCapabilities, ServerInfo}, + schemars, tool, tool_router, }; -#[derive(Debug, rmcp::serde::Deserialize, schemars::JsonSchema)] +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct SumRequest { #[schemars(description = "the left hand side number")] pub a: i32, pub b: i32, } + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct SubRequest { + #[schemars(description = "the left hand side number")] + pub a: i32, + #[schemars(description = "the right hand side number")] + pub b: i32, +} + #[derive(Debug, Clone)] -pub struct Calculator; +pub struct Calculator { + tool_router: ToolRouter, +} + +#[tool_router] impl Calculator { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + fn sum(&self, Parameters(SumRequest { a, b }): Parameters) -> String { (a + b).to_string() } - #[tool(description = "Calculate the sub of two numbers")] - fn sub( - &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, - ) -> String { - (a - b).to_string() + #[tool(description = "Calculate the difference of two numbers")] + fn sub(&self, Parameters(SubRequest { a, b }): Parameters) -> Json { + Json(a - b) } - - tool_box!(Calculator { sum, sub }); } impl ServerHandler for Calculator { - tool_box!(@derive); fn get_info(&self) -> ServerInfo { ServerInfo { instructions: Some("A simple calculator".into()), @@ -43,4 +54,22 @@ impl ServerHandler for Calculator { ..Default::default() } } + + async fn call_tool( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); + self.tool_router.call(tcc).await + } + + async fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> Result { + let items = self.tool_router.list_all(); + Ok(ListToolsResult::with_all_items(items)) + } } diff --git a/examples/wasi/src/lib.rs b/examples/wasi/src/lib.rs index 3b2904a8..2690cc73 100644 --- a/examples/wasi/src/lib.rs +++ b/examples/wasi/src/lib.rs @@ -112,7 +112,10 @@ impl wasi::exports::cli::run::Guest for TokioCliRunner { .with_writer(std::io::stderr) .with_ansi(false) .init(); - let server = calculator::Calculator.serve(wasi_io()).await.unwrap(); + let server = calculator::Calculator::new() + .serve(wasi_io()) + .await + .unwrap(); server.waiting().await.unwrap(); }); Ok(()) From fc35929dc7b9306136fd4481ff0000f87c22d826 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 16 Jun 2025 19:59:19 +0800 Subject: [PATCH 03/12] fix: fix test failure --- crates/rmcp/tests/test_tool_macros.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index 606e5ae6..5f982f7c 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -321,7 +321,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { // Test null case let result = client .call_tool(CallToolRequestParam { - name: "test_optional_i64_aggr".into(), + name: "test_optional_i64".into(), arguments: Some( serde_json::json!({ "count": null, @@ -349,7 +349,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { // Test Some case let some_result = client .call_tool(CallToolRequestParam { - name: "test_optional_i64_aggr".into(), + name: "test_optional_i64".into(), arguments: Some( serde_json::json!({ "count": 42, From f22164361a42f05c23739b64a14da2d6798e4c8b Mon Sep 17 00:00:00 2001 From: 4t145 Date: Tue, 17 Jun 2025 15:27:34 +0800 Subject: [PATCH 04/12] docs: documents for macros, fix ci --- crates/rmcp-macros/src/lib.rs | 106 ++++++++++++++++-- crates/rmcp-macros/src/tool.rs | 41 ++++++- crates/rmcp-macros/src/tool_handler.rs | 51 +++++++++ .../server/router/{promt.rs => prompt.rs} | 0 crates/rmcp/src/handler/server/tool.rs | 2 + crates/rmcp/src/lib.rs | 4 +- crates/rmcp/tests/test_tool_macros.rs | 24 +--- 7 files changed, 199 insertions(+), 29 deletions(-) create mode 100644 crates/rmcp-macros/src/tool_handler.rs rename crates/rmcp/src/handler/server/router/{promt.rs => prompt.rs} (100%) diff --git a/crates/rmcp-macros/src/lib.rs b/crates/rmcp-macros/src/lib.rs index f1e2a9b9..a5b5e979 100644 --- a/crates/rmcp-macros/src/lib.rs +++ b/crates/rmcp-macros/src/lib.rs @@ -1,15 +1,32 @@ #[allow(unused_imports)] use proc_macro::TokenStream; -// mod tool_inherite; mod tool; +mod tool_handler; mod tool_router; -// #[proc_macro_attribute] -// pub fn tool(attr: TokenStream, input: TokenStream) -> TokenStream { -// tool_inherite::tool(attr.into(), input.into()) -// .unwrap_or_else(|err| err.to_compile_error()) -// .into() -// } +/// # tool +/// +/// This macro is used to mark a function as a tool handler. +/// +/// This will generate a function that return the attribute of this tool, with type `rmcp::model::Tool`. +/// +/// ## Usage +/// +/// | feied | type | usage | +/// | :- | :- | :- | +/// | `name` | `String` | The name of the tool. If not provided, it defaults to the function name. | +/// | `description` | `String` | A description of the tool. The document of this function will be used. | +/// | `input_schema` | `Expr` | A JSON Schema object defining the expected parameters for the tool. If not provide, if will use the json schema of its argument with type `Parameters` | +/// | `annotations` | `ToolAnnotationsAttribute` | Additional tool information. Defaults to `None`. | +/// +/// ## Exmaple +/// +/// ```rust,ignore +/// #[tool(name = "my_tool", description = "This is my tool", annotations(title = "我的工具", read_only_hint = true))] +/// pub async fn my_tool(param: Parameters) { +/// // handling tool request +/// } +/// ``` #[proc_macro_attribute] pub fn tool(attr: TokenStream, input: TokenStream) -> TokenStream { tool::tool(attr.into(), input.into()) @@ -17,9 +34,84 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> TokenStream { .into() } +/// # tool_router +/// +/// This macro is used to generate a tool router based on functions marked with `#[rmcp::tool]` in an implementation block. +/// +/// It creates a function that returns a `ToolRouter` instance. +/// +/// ## Usage +/// +/// | feied | type | usage | +/// | :- | :- | :- | +/// | `router` | `Ident` | The name of the router function to be generated. Defaults to `tool_router`. | +/// | `vis` | `Visibility` | The visibility of the generated router function. Defaults to empty. | +/// +/// ## Example +/// +/// ```rust,ignore +/// #[tool_router] +/// impl MyToolHandler { +/// #[tool] +/// pub fn my_tool() { +/// +/// } +/// +/// pub fn new() -> Self { +/// Self { +/// // the default name of tool router will be `tool_router` +/// tool_router: Self::tool_router(), +/// } +/// } +/// } +/// ``` +/// +/// Or specify the visibility and router name: +/// +/// ```rust,ignore +/// #[tool_router(router = my_tool_router, vis = pub)] +/// impl MyToolHandler { +/// #[tool] +/// pub fn my_tool() { +/// +/// } +/// } +/// ``` #[proc_macro_attribute] pub fn tool_router(attr: TokenStream, input: TokenStream) -> TokenStream { tool_router::tool_router(attr.into(), input.into()) .unwrap_or_else(|err| err.to_compile_error()) .into() } + + +/// # tool_handler +/// +/// This macro will generate the handler for `tool_call` and `list_tools` methods in the implementation block, by using an exsisting `ToolRouter` instance. +/// +/// ## Usage +/// +/// | field | type | usage | +/// | :- | :- | :- | +/// | `router` | `Expr` | The expression to access the `ToolRouter` instance. Defaults to `self.tool_router`. | +/// ## Example +/// ```rust,ignore +/// #[tool_handler] +/// impl ServerHandler for MyToolHandler { +/// // ...implement other handler +/// } +/// ``` +/// +/// or using a custom router expression: +/// ```rust,ignore +/// #[tool_handler(router = self.get_router().await)] +/// impl ServerHandler for MyToolHandler { +/// // ...implement other handler +/// } +/// ``` +#[proc_macro_attribute] +pub fn tool_handler(attr: TokenStream, input: TokenStream) -> TokenStream { + tool_handler::tool_hanlder(attr.into(), input.into()) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index c27e5271..bf70d4df 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -246,8 +246,47 @@ mod test { drop(fields) } }; - let input = tool(attr, input)?; + let _input = tool(attr, input)?; Ok(()) } + + #[test] + fn test_doc_comment_description() -> syn::Result<()> { + let attr = quote! {}; // No explicit description + let input = quote! { + /// This is a test description from doc comments + /// with multiple lines + fn test_function(&self) -> Result<(), Error> { + Ok(()) + } + }; + let result = tool(attr, input)?; + + // The output should contain the description from doc comments + let result_str = result.to_string(); + assert!(result_str.contains("This is a test description from doc comments")); + assert!(result_str.contains("with multiple lines")); + + Ok(()) + } + + #[test] + fn test_explicit_description_priority() -> syn::Result<()> { + let attr = quote! { + description = "Explicit description has priority" + }; + let input = quote! { + /// Doc comment description that should be ignored + fn test_function(&self) -> Result<(), Error> { + Ok(()) + } + }; + let result = tool(attr, input)?; + + // The output should contain the explicit description + let result_str = result.to_string(); + assert!(result_str.contains("Explicit description has priority")); + Ok(()) + } } diff --git a/crates/rmcp-macros/src/tool_handler.rs b/crates/rmcp-macros/src/tool_handler.rs new file mode 100644 index 00000000..cf1bc6ff --- /dev/null +++ b/crates/rmcp-macros/src/tool_handler.rs @@ -0,0 +1,51 @@ +use darling::{FromMeta, ast::NestedMeta}; +use proc_macro2::TokenStream; +use quote::{ToTokens, quote}; +use syn::{Expr, ImplItem, ItemImpl}; + +#[derive(FromMeta)] +#[darling(default)] +pub struct ToolHanlderAttribute { + pub router: Expr, +} + +impl Default for ToolHanlderAttribute { + fn default() -> Self { + Self { + router: syn::parse2(quote! { + self.tool_router + }) + .unwrap(), + } + } +} + +pub fn tool_hanlder(attr: TokenStream, input: TokenStream) -> syn::Result { + let attr_args = NestedMeta::parse_meta_list(attr)?; + let ToolHanlderAttribute { router } = ToolHanlderAttribute::from_list(&attr_args)?; + let mut item_impl = syn::parse2::(input.clone())?; + let tool_call_fn = quote! { + async fn call_tool( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + let tcc = ToolCallContext::new(self, request, context); + #router.call(tcc).await + } + }; + let tool_list_fn = quote! { + async fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> Result { + Ok(ListToolsResult::with_all_items(#router.list_all())) + } + }; + let tool_call_fn = syn::parse2::(tool_call_fn)?; + let tool_list_fn = syn::parse2::(tool_list_fn)?; + item_impl.items.push(tool_call_fn); + item_impl.items.push(tool_list_fn); + Ok(item_impl.into_token_stream().into()) +} diff --git a/crates/rmcp/src/handler/server/router/promt.rs b/crates/rmcp/src/handler/server/router/prompt.rs similarity index 100% rename from crates/rmcp/src/handler/server/router/promt.rs rename to crates/rmcp/src/handler/server/router/prompt.rs diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index 03206dc4..2932c1c7 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -13,6 +13,8 @@ use crate::{ service::RequestContext, }; +pub use super::router::tool::{ToolRoute, ToolRouter}; + /// A shortcut for generating a JSON schema for a type. pub fn schema_for_type() -> JsonObject { let mut settings = schemars::r#gen::SchemaSettings::default(); diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index 7b5c655c..a3d1a414 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -19,12 +19,13 @@ //! //! ```rust //! use std::sync::Arc; -//! use rmcp::{Error as McpError, model::*, tool}; +//! use rmcp::{Error as McpError, model::*, tool, tool_router, handler::server::tool::ToolRouter}; //! use tokio::sync::Mutex; //! //! #[derive(Clone)] //! pub struct Counter { //! counter: Arc>, +//! tool_router: ToolRouter, //! } //! //! #[tool_router] @@ -32,6 +33,7 @@ //! fn new() -> Self { //! Self { //! counter: Arc::new(Mutex::new(0)), +//! tool_router: Self::tool_router(), //! } //! } //! diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index 5f982f7c..35574b9f 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -3,13 +3,10 @@ use std::sync::Arc; use rmcp::{ - ClientHandler, ServerHandler, ServiceExt, handler::server::{ router::tool::ToolRouter, tool::{Parameters, ToolCallContext}, - }, - model::{CallToolRequestParam, ClientInfo, ListToolsResult}, - tool, tool_router, + }, model::{CallToolRequestParam, ClientInfo, ListToolsResult}, tool, tool_handler, tool_router, ClientHandler, ServerHandler, ServiceExt }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -43,7 +40,7 @@ impl Default for Server { } } -#[tool_router] +#[tool_router(router = tool_router)] impl Server { pub fn new() -> Self { Self { @@ -98,22 +95,9 @@ impl GenericServer { } } +#[tool_handler] impl ServerHandler for GenericServer { - async fn call_tool( - &self, - request: CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = ToolCallContext::new(self, request, context); - self.tool_router.call(tcc).await - } - async fn list_tools( - &self, - _request: Option, - _context: rmcp::service::RequestContext, - ) -> Result { - Ok(ListToolsResult::with_all_items(self.tool_router.list_all())) - } + } #[tokio::test] From da6427522b4c351114c08013dda1af6ef0fa7b0c Mon Sep 17 00:00:00 2001 From: 4t145 Date: Tue, 17 Jun 2025 15:32:48 +0800 Subject: [PATCH 05/12] fix: fix ci --- crates/rmcp-macros/src/lib.rs | 23 +++++++++++------------ crates/rmcp-macros/src/tool_handler.rs | 10 +++++----- crates/rmcp/src/handler/server/tool.rs | 3 +-- crates/rmcp/tests/test_tool_macros.rs | 9 +++++---- examples/wasi/src/calculator.rs | 6 ++++++ 5 files changed, 28 insertions(+), 23 deletions(-) diff --git a/crates/rmcp-macros/src/lib.rs b/crates/rmcp-macros/src/lib.rs index a5b5e979..ccbe1fb0 100644 --- a/crates/rmcp-macros/src/lib.rs +++ b/crates/rmcp-macros/src/lib.rs @@ -19,7 +19,7 @@ mod tool_router; /// | `input_schema` | `Expr` | A JSON Schema object defining the expected parameters for the tool. If not provide, if will use the json schema of its argument with type `Parameters` | /// | `annotations` | `ToolAnnotationsAttribute` | Additional tool information. Defaults to `None`. | /// -/// ## Exmaple +/// ## Example /// /// ```rust,ignore /// #[tool(name = "my_tool", description = "This is my tool", annotations(title = "我的工具", read_only_hint = true))] @@ -48,7 +48,7 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> TokenStream { /// | `vis` | `Visibility` | The visibility of the generated router function. Defaults to empty. | /// /// ## Example -/// +/// /// ```rust,ignore /// #[tool_router] /// impl MyToolHandler { @@ -56,7 +56,7 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> TokenStream { /// pub fn my_tool() { /// /// } -/// +/// /// pub fn new() -> Self { /// Self { /// // the default name of tool router will be `tool_router` @@ -65,9 +65,9 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> TokenStream { /// } /// } /// ``` -/// +/// /// Or specify the visibility and router name: -/// +/// /// ```rust,ignore /// #[tool_router(router = my_tool_router, vis = pub)] /// impl MyToolHandler { @@ -84,13 +84,12 @@ pub fn tool_router(attr: TokenStream, input: TokenStream) -> TokenStream { .into() } - /// # tool_handler -/// -/// This macro will generate the handler for `tool_call` and `list_tools` methods in the implementation block, by using an exsisting `ToolRouter` instance. -/// +/// +/// This macro will generate the handler for `tool_call` and `list_tools` methods in the implementation block, by using an existing `ToolRouter` instance. +/// /// ## Usage -/// +/// /// | field | type | usage | /// | :- | :- | :- | /// | `router` | `Expr` | The expression to access the `ToolRouter` instance. Defaults to `self.tool_router`. | @@ -101,7 +100,7 @@ pub fn tool_router(attr: TokenStream, input: TokenStream) -> TokenStream { /// // ...implement other handler /// } /// ``` -/// +/// /// or using a custom router expression: /// ```rust,ignore /// #[tool_handler(router = self.get_router().await)] @@ -111,7 +110,7 @@ pub fn tool_router(attr: TokenStream, input: TokenStream) -> TokenStream { /// ``` #[proc_macro_attribute] pub fn tool_handler(attr: TokenStream, input: TokenStream) -> TokenStream { - tool_handler::tool_hanlder(attr.into(), input.into()) + tool_handler::tool_handler(attr.into(), input.into()) .unwrap_or_else(|err| err.to_compile_error()) .into() } diff --git a/crates/rmcp-macros/src/tool_handler.rs b/crates/rmcp-macros/src/tool_handler.rs index cf1bc6ff..d4bc2800 100644 --- a/crates/rmcp-macros/src/tool_handler.rs +++ b/crates/rmcp-macros/src/tool_handler.rs @@ -5,11 +5,11 @@ use syn::{Expr, ImplItem, ItemImpl}; #[derive(FromMeta)] #[darling(default)] -pub struct ToolHanlderAttribute { +pub struct ToolHandlerAttribute { pub router: Expr, } -impl Default for ToolHanlderAttribute { +impl Default for ToolHandlerAttribute { fn default() -> Self { Self { router: syn::parse2(quote! { @@ -20,9 +20,9 @@ impl Default for ToolHanlderAttribute { } } -pub fn tool_hanlder(attr: TokenStream, input: TokenStream) -> syn::Result { +pub fn tool_handler(attr: TokenStream, input: TokenStream) -> syn::Result { let attr_args = NestedMeta::parse_meta_list(attr)?; - let ToolHanlderAttribute { router } = ToolHanlderAttribute::from_list(&attr_args)?; + let ToolHandlerAttribute { router } = ToolHandlerAttribute::from_list(&attr_args)?; let mut item_impl = syn::parse2::(input.clone())?; let tool_call_fn = quote! { async fn call_tool( @@ -47,5 +47,5 @@ pub fn tool_hanlder(attr: TokenStream, input: TokenStream) -> syn::Result(tool_list_fn)?; item_impl.items.push(tool_call_fn); item_impl.items.push(tool_list_fn); - Ok(item_impl.into_token_stream().into()) + Ok(item_impl.into_token_stream()) } diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index 2932c1c7..544df523 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -7,14 +7,13 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use tokio_util::sync::CancellationToken; +pub use super::router::tool::{ToolRoute, ToolRouter}; use crate::{ RoleServer, model::{CallToolRequestParam, CallToolResult, IntoContents, JsonObject}, service::RequestContext, }; -pub use super::router::tool::{ToolRoute, ToolRouter}; - /// A shortcut for generating a JSON schema for a type. pub fn schema_for_type() -> JsonObject { let mut settings = schemars::r#gen::SchemaSettings::default(); diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index 35574b9f..9e7fde6a 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -3,10 +3,13 @@ use std::sync::Arc; use rmcp::{ + ClientHandler, ServerHandler, ServiceExt, handler::server::{ router::tool::ToolRouter, tool::{Parameters, ToolCallContext}, - }, model::{CallToolRequestParam, ClientInfo, ListToolsResult}, tool, tool_handler, tool_router, ClientHandler, ServerHandler, ServiceExt + }, + model::{CallToolRequestParam, ClientInfo, ListToolsResult}, + tool, tool_handler, tool_router, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -96,9 +99,7 @@ impl GenericServer { } #[tool_handler] -impl ServerHandler for GenericServer { - -} +impl ServerHandler for GenericServer {} #[tokio::test] async fn test_tool_macros() { diff --git a/examples/wasi/src/calculator.rs b/examples/wasi/src/calculator.rs index 3629d133..8c671c51 100644 --- a/examples/wasi/src/calculator.rs +++ b/examples/wasi/src/calculator.rs @@ -27,6 +27,12 @@ pub struct Calculator { tool_router: ToolRouter, } +impl Default for Calculator { + fn default() -> Self { + Self::new() + } +} + #[tool_router] impl Calculator { pub fn new() -> Self { From 86d5e50beb4dbc7494ae33b79a2928f1625ffa37 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Tue, 17 Jun 2025 15:36:32 +0800 Subject: [PATCH 06/12] fix: fix wrongly replaced documents --- crates/rmcp-macros/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/rmcp-macros/README.md b/crates/rmcp-macros/README.md index 873b89d4..98530c42 100644 --- a/crates/rmcp-macros/README.md +++ b/crates/rmcp-macros/README.md @@ -34,7 +34,7 @@ impl MyHandler { } #[tool] - fn tool(&self) -> Result { + fn tool2(&self) -> Result { // Tool 2 implementation } } From 080c3978c4780435b2c793e3cbbcf5b402d09fe0 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Tue, 17 Jun 2025 15:49:21 +0800 Subject: [PATCH 07/12] fix: remove useless file --- crates/rmcp-macros/src/tool_inherite.rs | 785 ------------------------ 1 file changed, 785 deletions(-) delete mode 100644 crates/rmcp-macros/src/tool_inherite.rs diff --git a/crates/rmcp-macros/src/tool_inherite.rs b/crates/rmcp-macros/src/tool_inherite.rs deleted file mode 100644 index 4c292ed0..00000000 --- a/crates/rmcp-macros/src/tool_inherite.rs +++ /dev/null @@ -1,785 +0,0 @@ -use std::collections::HashSet; - -use proc_macro2::TokenStream; -use quote::{ToTokens, quote}; -use serde_json::json; -use syn::{ - Expr, FnArg, Ident, ItemFn, ItemImpl, Lit, MetaList, PatType, Token, Type, Visibility, - parse::Parse, parse_quote, spanned::Spanned, -}; - -/// Stores tool annotation attributes -#[derive(Default, Clone)] -struct ToolAnnotationAttrs(pub serde_json::Map); - -impl Parse for ToolAnnotationAttrs { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let mut attrs = serde_json::Map::new(); - - while !input.is_empty() { - let key: Ident = input.parse()?; - input.parse::()?; - let value: Lit = input.parse()?; - let value = match value { - Lit::Str(s) => json!(s.value()), - Lit::Bool(b) => json!(b.value), - _ => { - return Err(syn::Error::new( - key.span(), - "annotations must be string or boolean literals", - )); - } - }; - attrs.insert(key.to_string(), value); - if input.is_empty() { - break; - } - input.parse::()?; - } - - Ok(ToolAnnotationAttrs(attrs)) - } -} - -#[derive(Default)] -struct ToolImplItemAttrs { - tool_box: Option>, -} - -impl Parse for ToolImplItemAttrs { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let mut tool_box = None; - while !input.is_empty() { - let key: Ident = input.parse()?; - match key.to_string().as_str() { - "tool_box" => { - tool_box = Some(None); - if input.lookahead1().peek(Token![=]) { - input.parse::()?; - let value: Ident = input.parse()?; - tool_box = Some(Some(value)); - } - } - _ => { - return Err(syn::Error::new(key.span(), "unknown attribute")); - } - } - if input.is_empty() { - break; - } - input.parse::()?; - } - - Ok(ToolImplItemAttrs { tool_box }) - } -} - -#[derive(Default)] -struct ToolFnItemAttrs { - name: Option, - description: Option, - vis: Option, - annotations: Option, -} - -impl Parse for ToolFnItemAttrs { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let mut name = None; - let mut description = None; - let mut vis = None; - let mut annotations = None; - - while !input.is_empty() { - let key: Ident = input.parse()?; - input.parse::()?; - match key.to_string().as_str() { - "name" => { - let value: Expr = input.parse()?; - name = Some(value); - } - "description" => { - let value: Expr = input.parse()?; - description = Some(value); - } - "vis" => { - let value: Visibility = input.parse()?; - vis = Some(value); - } - "annotations" => { - // Parse the annotations as a nested structure - let content; - syn::braced!(content in input); - let value = content.parse()?; - annotations = Some(value); - } - _ => { - return Err(syn::Error::new(key.span(), "unknown attribute")); - } - } - if input.is_empty() { - break; - } - input.parse::()?; - } - - Ok(ToolFnItemAttrs { - name, - description, - vis, - annotations, - }) - } -} - -struct ToolFnParamAttrs { - serde_meta: Vec, - schemars_meta: Vec, - ident: Ident, - rust_type: Box, -} - -impl ToTokens for ToolFnParamAttrs { - fn to_tokens(&self, tokens: &mut TokenStream) { - let ident = &self.ident; - let rust_type = &self.rust_type; - let serde_meta = &self.serde_meta; - let schemars_meta = &self.schemars_meta; - tokens.extend(quote! { - #(#[#serde_meta])* - #(#[#schemars_meta])* - pub #ident: #rust_type, - }); - } -} - -#[derive(Default)] - -enum ToolParams { - Aggregated { - rust_type: PatType, - }, - Params { - attrs: Vec, - }, - #[default] - NoParam, -} - -#[derive(Default)] -struct ToolAttrs { - fn_item: ToolFnItemAttrs, - params: ToolParams, -} -const TOOL_IDENT: &str = "tool"; -const SERDE_IDENT: &str = "serde"; -const SCHEMARS_IDENT: &str = "schemars"; -const PARAM_IDENT: &str = "param"; -const AGGREGATED_IDENT: &str = "aggr"; -const REQ_IDENT: &str = "req"; - -pub enum ParamMarker { - Param, - Aggregated, -} - -impl Parse for ParamMarker { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let ident: Ident = input.parse()?; - match ident.to_string().as_str() { - PARAM_IDENT => Ok(ParamMarker::Param), - AGGREGATED_IDENT | REQ_IDENT => Ok(ParamMarker::Aggregated), - _ => Err(syn::Error::new(ident.span(), "unknown attribute")), - } - } -} - -pub enum ToolItem { - Fn(ItemFn), - Impl(ItemImpl), -} - -impl Parse for ToolItem { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let lookahead = input.lookahead1(); - if lookahead.peek(Token![impl]) { - let item = input.parse::()?; - Ok(ToolItem::Impl(item)) - } else { - let item = input.parse::()?; - Ok(ToolItem::Fn(item)) - } - } -} - -// dispatch impl function item and impl block item -pub(crate) fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { - let tool_item = syn::parse2::(input)?; - match tool_item { - ToolItem::Fn(item) => tool_fn_item(attr, item), - ToolItem::Impl(item) => tool_impl_item(attr, item), - } -} - -pub(crate) fn tool_impl_item(attr: TokenStream, mut input: ItemImpl) -> syn::Result { - let tool_impl_attr: ToolImplItemAttrs = syn::parse2(attr)?; - let tool_box_ident = tool_impl_attr.tool_box; - - // get all tool function ident - let mut tool_fn_idents = Vec::new(); - for item in &input.items { - if let syn::ImplItem::Fn(method) = item { - for attr in &method.attrs { - if attr.path().is_ident(TOOL_IDENT) { - tool_fn_idents.push(method.sig.ident.clone()); - } - } - } - } - - // handle different cases - if input.trait_.is_some() { - if let Some(ident) = tool_box_ident { - // check if there are generic parameters - if !input.generics.params.is_empty() { - // for trait implementation with generic parameters, directly use the already generated *_inner method - - // generate call_tool method - input.items.push(parse_quote! { - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - self.call_tool_inner(request, context).await - } - }); - - // generate list_tools method - input.items.push(parse_quote! { - async fn list_tools( - &self, - request: Option, - context: rmcp::service::RequestContext, - ) -> Result { - self.list_tools_inner(request, context).await - } - }); - } else { - // if there are no generic parameters, add tool box derive - input.items.push(parse_quote!( - rmcp::tool_box!(@derive #ident); - )); - } - } else { - return Err(syn::Error::new( - proc_macro2::Span::call_site(), - "tool_box attribute is required for trait implementation", - )); - } - } else if let Some(ident) = tool_box_ident { - // if it is a normal impl block - if !input.generics.params.is_empty() { - // if there are generic parameters, not use tool_box! macro, but generate code directly - - // create call code for each tool function - let match_arms = tool_fn_idents.iter().map(|ident| { - let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span()); - let call_fn = Ident::new(&format!("{}_tool_call", ident), ident.span()); - quote! { - name if name == Self::#attr_fn().name => { - Self::#call_fn(tcc).await - } - } - }); - - let tool_attrs = tool_fn_idents.iter().map(|ident| { - let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span()); - quote! { Self::#attr_fn() } - }); - - // implement call_tool method - input.items.push(parse_quote! { - async fn call_tool_inner( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); - match tcc.name() { - #(#match_arms,)* - _ => Err(rmcp::Error::invalid_params("tool not found", None)), - } - } - }); - - // implement list_tools method - input.items.push(parse_quote! { - async fn list_tools_inner( - &self, - _: Option, - _: rmcp::service::RequestContext, - ) -> Result { - Ok(rmcp::model::ListToolsResult { - next_cursor: None, - tools: vec![#(#tool_attrs),*], - }) - } - }); - } else { - // if there are no generic parameters, use the original tool_box! macro - let this_type_ident = &input.self_ty; - input.items.push(parse_quote!( - rmcp::tool_box!(#this_type_ident { - #(#tool_fn_idents),* - } #ident); - )); - } - } - - Ok(quote! { - #input - }) -} - -// extract doc line from attribute -fn extract_doc_line(attr: &syn::Attribute) -> Option { - if !attr.path().is_ident("doc") { - return None; - } - - let syn::Meta::NameValue(name_value) = &attr.meta else { - return None; - }; - - let syn::Expr::Lit(expr_lit) = &name_value.value else { - return None; - }; - - let syn::Lit::Str(lit_str) = &expr_lit.lit else { - return None; - }; - - let content = lit_str.value().trim().to_string(); - - (!content.is_empty()).then_some(content) -} - -pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Result { - let mut tool_macro_attrs = ToolAttrs::default(); - let args: ToolFnItemAttrs = syn::parse2(attr)?; - tool_macro_attrs.fn_item = args; - // let mut fommated_fn_args: Punctuated = Punctuated::new(); - let mut unextractable_args_indexes = HashSet::new(); - for (index, mut fn_arg) in input_fn.sig.inputs.iter_mut().enumerate() { - enum Caught { - Param(ToolFnParamAttrs), - Aggregated(PatType), - } - let mut caught = None; - match &mut fn_arg { - FnArg::Receiver(_) => { - continue; - } - FnArg::Typed(pat_type) => { - let mut serde_metas = Vec::new(); - let mut schemars_metas = Vec::new(); - let mut arg_ident = match pat_type.pat.as_ref() { - syn::Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()), - _ => None, - }; - let raw_attrs: Vec<_> = pat_type.attrs.drain(..).collect(); - for attr in raw_attrs { - match &attr.meta { - syn::Meta::List(meta_list) => { - if meta_list.path.is_ident(TOOL_IDENT) { - let pat_type = pat_type.clone(); - let marker = meta_list.parse_args::()?; - match marker { - ParamMarker::Param => { - let Some(arg_ident) = arg_ident.take() else { - return Err(syn::Error::new( - proc_macro2::Span::call_site(), - "input param must have an ident as name", - )); - }; - caught.replace(Caught::Param(ToolFnParamAttrs { - serde_meta: Vec::new(), - schemars_meta: Vec::new(), - ident: arg_ident, - rust_type: pat_type.ty.clone(), - })); - } - ParamMarker::Aggregated => { - caught.replace(Caught::Aggregated(pat_type.clone())); - } - } - } else if meta_list.path.is_ident(SERDE_IDENT) { - serde_metas.push(meta_list.clone()); - } else if meta_list.path.is_ident(SCHEMARS_IDENT) { - schemars_metas.push(meta_list.clone()); - } else { - pat_type.attrs.push(attr); - } - } - _ => { - pat_type.attrs.push(attr); - } - } - } - match caught { - Some(Caught::Param(mut param)) => { - param.serde_meta = serde_metas; - param.schemars_meta = schemars_metas; - match &mut tool_macro_attrs.params { - ToolParams::Params { attrs } => { - attrs.push(param); - } - _ => { - tool_macro_attrs.params = ToolParams::Params { attrs: vec![param] }; - } - } - unextractable_args_indexes.insert(index); - } - Some(Caught::Aggregated(rust_type)) => { - if let ToolParams::Params { .. } = tool_macro_attrs.params { - return Err(syn::Error::new( - rust_type.span(), - "cannot mix aggregated and individual parameters", - )); - } - tool_macro_attrs.params = ToolParams::Aggregated { rust_type }; - unextractable_args_indexes.insert(index); - } - None => {} - } - } - } - } - - // input_fn.sig.inputs = fommated_fn_args; - let name = if let Some(expr) = tool_macro_attrs.fn_item.name { - expr - } else { - let fn_name = &input_fn.sig.ident; - parse_quote! { - stringify!(#fn_name) - } - }; - let tool_attr_fn_ident = Ident::new( - &format!("{}_tool_attr", input_fn.sig.ident), - proc_macro2::Span::call_site(), - ); - - // generate get tool attr function - let tool_attr_fn = { - let description = if let Some(expr) = tool_macro_attrs.fn_item.description { - // Use explicitly provided description if available - expr - } else { - // Try to extract documentation comments - let doc_content = input_fn - .attrs - .iter() - .filter_map(extract_doc_line) - .collect::>() - .join("\n"); - - parse_quote! { - #doc_content.trim().to_string() - } - }; - let schema = match &tool_macro_attrs.params { - ToolParams::Aggregated { rust_type } => { - let ty = &rust_type.ty; - let schema = quote! { - rmcp::handler::server::tool::cached_schema_for_type::<#ty>() - }; - schema - } - ToolParams::Params { attrs, .. } => { - let (param_type, temp_param_type_name) = - create_request_type(attrs, input_fn.sig.ident.to_string()); - let schema = quote! { - { - #param_type - rmcp::handler::server::tool::cached_schema_for_type::<#temp_param_type_name>() - } - }; - schema - } - ToolParams::NoParam => { - quote! { - rmcp::handler::server::tool::cached_schema_for_type::() - } - } - }; - let input_fn_attrs = &input_fn.attrs; - let input_fn_vis = &input_fn.vis; - - let annotations_code = if let Some(annotations) = &tool_macro_attrs.fn_item.annotations { - let annotations = - serde_json::to_string(&annotations.0).expect("failed to serialize annotations"); - quote! { - Some(serde_json::from_str::(&#annotations).expect("Could not parse tool annotations")) - } - } else { - quote! { None } - }; - - quote! { - #(#input_fn_attrs)* - #input_fn_vis fn #tool_attr_fn_ident() -> rmcp::model::Tool { - rmcp::model::Tool { - name: #name.into(), - description: Some(#description.into()), - input_schema: #schema.into(), - annotations: #annotations_code, - } - } - } - }; - - // generate wrapped tool function - let tool_call_fn = { - // wrapper function have the same sig: - // async fn #tool_tool_call(context: rmcp::handler::server::tool::ToolCallContext) - // -> std::result::Result - // - // and the block part should be like: - // { - // use rmcp::handler::server::tool::*; - // let t0 = ::from_tool_call_context_part(&mut context)?; - // let t1 = ::from_tool_call_context_part(&mut context)?; - // ... - // let tn = ::from_tool_call_context_part(&mut context)?; - // // for params - // ... expand helper types here - // let __rmcp_tool_req = rmcp::model::JsonObject::from_tool_call_context_part(&mut context)?; - // let __#TOOL_ToolCallParam { param_0, param_1, param_2, .. } = parse_json_object(__rmcp_tool_req)?; - // // for aggr - // let Parameters(aggr) = >::from_tool_call_context_part(&mut context)?; - // Self::#tool_ident(to, param_0, t1, param_1, ..., param_2, tn, aggr).await.into_call_tool_result() - // - // } - // - // - // - - // for receiver type, name it as __rmcp_tool_receiver - let is_async = input_fn.sig.asyncness.is_some(); - let receiver_ident = || Ident::new("__rmcp_tool_receiver", proc_macro2::Span::call_site()); - // generate the extraction part for trivial args - let trivial_args = input_fn - .sig - .inputs - .iter() - .enumerate() - .filter_map(|(index, arg)| { - if unextractable_args_indexes.contains(&index) { - None - } else { - // get ident/type pair - let line = match arg { - FnArg::Typed(pat_type) => { - let pat = &pat_type.pat; - let ty = &pat_type.ty; - quote! { - let #pat = <#ty>::from_tool_call_context_part(&mut context)?; - } - } - FnArg::Receiver(r) => { - let ty = r.ty.clone(); - let pat = receiver_ident(); - quote! { - let #pat = <#ty>::from_tool_call_context_part(&mut context)?; - } - } - }; - Some(line) - } - }); - let trivial_arg_extraction_part = quote! { - #(#trivial_args)* - }; - let processed_arg_extraction_part = match &mut tool_macro_attrs.params { - ToolParams::Aggregated { rust_type } => { - let PatType { pat, ty, .. } = rust_type; - quote! { - let Parameters(#pat) = >::from_tool_call_context_part(&mut context)?; - } - } - ToolParams::Params { attrs } => { - let (param_type, temp_param_type_name) = - create_request_type(attrs, input_fn.sig.ident.to_string()); - - let params_ident = attrs.iter().map(|attr| &attr.ident).collect::>(); - quote! { - #param_type - let __rmcp_tool_req = rmcp::model::JsonObject::from_tool_call_context_part(&mut context)?; - let #temp_param_type_name { - #(#params_ident,)* - } = parse_json_object(__rmcp_tool_req)?; - } - } - ToolParams::NoParam => { - quote! {} - } - }; - // generate the execution part - // has receiver? - let params = &input_fn - .sig - .inputs - .iter() - .map(|fn_arg| match fn_arg { - FnArg::Receiver(_) => { - let pat = receiver_ident(); - quote! { #pat } - } - FnArg::Typed(pat_type) => { - let pat = &pat_type.pat.clone(); - quote! { #pat } - } - }) - .collect::>(); - let raw_fn_ident = &input_fn.sig.ident; - let call = if is_async { - quote! { - Self::#raw_fn_ident(#(#params),*).await.into_call_tool_result() - } - } else { - quote! { - Self::#raw_fn_ident(#(#params),*).into_call_tool_result() - } - }; - // assemble the whole function - let tool_call_fn_ident = Ident::new( - &format!("{}_tool_call", input_fn.sig.ident), - proc_macro2::Span::call_site(), - ); - let raw_fn_vis = tool_macro_attrs - .fn_item - .vis - .as_ref() - .unwrap_or(&input_fn.vis); - let raw_fn_attr = &input_fn - .attrs - .iter() - .filter(|attr| !attr.path().is_ident(TOOL_IDENT)) - .collect::>(); - quote! { - #(#raw_fn_attr)* - #raw_fn_vis async fn #tool_call_fn_ident(context: rmcp::handler::server::tool::ToolCallContext) - -> std::result::Result { - use rmcp::handler::server::tool::*; - #trivial_arg_extraction_part - #processed_arg_extraction_part - #call - } - } - }; - Ok(quote! { - #tool_attr_fn - #tool_call_fn - #input_fn - }) -} - -fn create_request_type(attrs: &[ToolFnParamAttrs], tool_name: String) -> (TokenStream, Ident) { - let pascal_case_tool_name = tool_name.to_ascii_uppercase(); - let temp_param_type_name = Ident::new( - &format!("__{pascal_case_tool_name}ToolCallParam",), - proc_macro2::Span::call_site(), - ); - ( - quote! { - use rmcp::{serde, schemars}; - #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] - pub struct #temp_param_type_name { - #(#attrs)* - } - }, - temp_param_type_name, - ) -} - -#[cfg(test)] -mod test { - use super::*; - #[test] - fn test_tool_sync_macro() -> syn::Result<()> { - let attr = quote! { - name = "test_tool", - description = "test tool", - vis = - }; - let input = quote! { - fn sum(&self, #[tool(aggr)] req: StructRequest) -> Result { - Ok(CallToolResult::success(vec![Content::text((req.a + req.b).to_string())])) - } - }; - let input = tool(attr, input)?; - - println!("input: {:#}", input); - Ok(()) - } - - #[test] - fn test_trait_tool_macro() -> syn::Result<()> { - let attr = quote! { - tool_box = Calculator - }; - let input = quote! { - impl ServerHandler for Calculator { - #[tool] - fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - ..Default::default() - } - } - } - }; - let input = tool(attr, input)?; - - println!("input: {:#}", input); - Ok(()) - } - #[test] - fn test_doc_comment_description() -> syn::Result<()> { - let attr = quote! {}; // No explicit description - let input = quote! { - /// This is a test description from doc comments - /// with multiple lines - fn test_function(&self) -> Result<(), Error> { - Ok(()) - } - }; - let result = tool(attr, input)?; - - // The output should contain the description from doc comments - let result_str = result.to_string(); - assert!(result_str.contains("This is a test description from doc comments")); - assert!(result_str.contains("with multiple lines")); - - Ok(()) - } - #[test] - fn test_explicit_description_priority() -> syn::Result<()> { - let attr = quote! { - description = "Explicit description has priority" - }; - let input = quote! { - /// Doc comment description that should be ignored - fn test_function(&self) -> Result<(), Error> { - Ok(()) - assert!(result_str.contains("Explicit description has priority")); - } - }; - let result = tool(attr, input)?; - - // The output should contain the explicit description - let result_str = result.to_string(); - Ok(()) - } -} From 62c85f4ddf6da081fb56056896aca90e2436147f Mon Sep 17 00:00:00 2001 From: 4t145 Date: Wed, 18 Jun 2025 11:42:15 +0800 Subject: [PATCH 08/12] fix: change the parameter format for tool_router --- crates/rmcp-macros/src/tool_router.rs | 2 +- crates/rmcp/src/handler/server/router/tool.rs | 6 +++--- crates/rmcp/tests/test_tool_routers.rs | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/rmcp-macros/src/tool_router.rs b/crates/rmcp-macros/src/tool_router.rs index 337cf00c..f6bc0edd 100644 --- a/crates/rmcp-macros/src/tool_router.rs +++ b/crates/rmcp-macros/src/tool_router.rs @@ -56,7 +56,7 @@ pub fn tool_router(attr: TokenStream, input: TokenStream) -> syn::Result(quote! { diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index 97db401b..bb5fda21 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -210,11 +210,11 @@ where transparent_when_not_found: false, } } - pub fn with_route(mut self, attr: crate::model::Tool, call: C) -> Self + pub fn with_route(mut self, route: R) -> Self where - C: CallToolHandler + Send + Sync + Clone + 'static, + R: IntoToolRoute, { - self.add_route(ToolRoute::new(attr, call)); + self.add_route(route.into_tool_route()); self } diff --git a/crates/rmcp/tests/test_tool_routers.rs b/crates/rmcp/tests/test_tool_routers.rs index e08437c1..846fe424 100644 --- a/crates/rmcp/tests/test_tool_routers.rs +++ b/crates/rmcp/tests/test_tool_routers.rs @@ -58,8 +58,8 @@ fn async_function2(_callee: &TestHandler) -> BoxFuture<'_, ()> { #[test] fn test_tool_router() { let test_tool_router: ToolRouter> = ToolRouter::>::new() - .with_route(async_function_tool_attr(), async_function) - .with_route(async_function2_tool_attr(), async_function2) + .with_route((async_function_tool_attr(), async_function)) + .with_route((async_function2_tool_attr(), async_function2)) + TestHandler::<()>::test_router_1() + TestHandler::<()>::test_router_2(); let tools = test_tool_router.list_all(); From ec9bc2f3bf8206a40baa4256250cf4c52bddfca4 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Wed, 18 Jun 2025 12:29:03 +0800 Subject: [PATCH 09/12] fix: update extract_doc_line to handle existing documentation and clean up unused code in server handler --- crates/rmcp-macros/src/tool.rs | 30 ++++++++++++++---------------- crates/rmcp/src/handler/server.rs | 13 ------------- 2 files changed, 14 insertions(+), 29 deletions(-) diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index bf70d4df..8d1fb495 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -90,7 +90,7 @@ fn none_expr() -> Expr { } // extract doc line from attribute -fn extract_doc_line(attr: &syn::Attribute) -> Option { +fn extract_doc_line(existing_docs: Option, attr: &syn::Attribute) -> Option { if !attr.path().is_ident("doc") { return None; } @@ -108,8 +108,16 @@ fn extract_doc_line(attr: &syn::Attribute) -> Option { }; let content = lit_str.value().trim().to_string(); - - (!content.is_empty()).then_some(content) + match (existing_docs, content) { + (Some(mut existing_docs), content) if !content.is_empty() => { + existing_docs.push('\n'); + existing_docs.push_str(&content); + Some(existing_docs) + } + (Some(existing_docs), _) => Some(existing_docs), + (None, content) if !content.is_empty() => Some(content), + _ => None, + } } pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { @@ -186,19 +194,9 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { }; let resolved_tool_attr = ResolvedToolAttribute { name: attribute.name.unwrap_or_else(|| fn_ident.to_string()), - description: attribute.description.or_else(|| { - let doc_content = fn_item - .attrs - .iter() - .filter_map(extract_doc_line) - .collect::>() - .join("\n"); - if doc_content.is_empty() { - None - } else { - Some(doc_content) - } - }), + description: attribute + .description + .or_else(|| fn_item.attrs.iter().fold(None, extract_doc_line)), input_schema: input_schema_expr, annotations: annotations_expr, }; diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 9773bd98..ed74a807 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -185,19 +185,6 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { request: CallToolRequestParam, context: RequestContext, ) -> impl Future> + Send + '_ { - // async move { - // let router = router::tool::GlobalStaticRouters::get::().await; - // router.call(tool::ToolCallContext { - // request_context: context, - // service: todo!(), - // name: todo!(), - // arguments: todo!(), - // }).await.map_err(|e| { - // tracing::error!("call tool error: {}", e); - // e - // }) - - // }; std::future::ready(Err(McpError::method_not_found::())) } fn list_tools( From 481bd03287117661efbf86ebc032f95bf8aceb73 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 23 Jun 2025 11:50:01 +0800 Subject: [PATCH 10/12] doc: update document for macro and examples --- crates/rmcp-macros/README.md | 144 +++++++++++++++--- crates/rmcp-macros/src/lib.rs | 56 ++++++- crates/rmcp-macros/src/tool_handler.rs | 4 +- .../rmcp/tests/test_tool_macro_annotations.rs | 15 +- crates/rmcp/tests/test_tool_macros.rs | 36 +---- docs/readme/README.zh-cn.md | 2 +- examples/servers/src/common/calculator.rs | 23 +-- examples/servers/src/common/counter.rs | 22 +-- .../servers/src/common/generic_service.rs | 22 +-- examples/transport/src/common/calculator.rs | 24 +-- examples/wasi/src/calculator.rs | 24 +-- justfile | 8 + 12 files changed, 213 insertions(+), 167 deletions(-) create mode 100644 justfile diff --git a/crates/rmcp-macros/README.md b/crates/rmcp-macros/README.md index 98530c42..03ef93cd 100644 --- a/crates/rmcp-macros/README.md +++ b/crates/rmcp-macros/README.md @@ -10,41 +10,149 @@ This library primarily provides the following macros: ## Usage -### Tool Macro +### tool -Mark a function as a tool: +This macro is used to mark a function as a tool handler. -```rust ignore -#[tool] -fn calculator(&self, #[tool(param)] a: i32, #[tool(param)] b: i32) -> Result { - // Implement tool functionality - Ok(CallToolResult::success(vec![Content::text((a + b).to_string())])) -} +This will generate a function that return the attribute of this tool, with type `rmcp::model::Tool`. + +#### Usage + +| feied | type | usage | +| :- | :- | :- | +| `name` | `String` | The name of the tool. If not provided, it defaults to the function name. | +| `description` | `String` | A description of the tool. The document of this function will be used. | +| `input_schema` | `Expr` | A JSON Schema object defining the expected parameters for the tool. If not provide, if will use the json schema of its argument with type `Parameters` | +| `annotations` | `ToolAnnotationsAttribute` | Additional tool information. Defaults to `None`. | +#### Example + +```rust +#[tool(name = "my_tool", description = "This is my tool", annotations(title = "我的工具", read_only_hint = true))] +pub async fn my_tool(param: Parameters) { + // handling tool request +} ``` -Use on an impl block to automatically register multiple tools: +### tool_router + +This macro is used to generate a tool router based on functions marked with `#[rmcp::tool]` in an implementation block. + +It creates a function that returns a `ToolRouter` instance. + +In most case, you need to add a field for handler to store the router information and initialize it when creating handler, or store it with a static variable. + +#### Usage + +| feied | type | usage | +| :- | :- | :- | +| `router` | `Ident` | The name of the router function to be generated. Defaults to `tool_router`. | +| `vis` | `Visibility` | The visibility of the generated router function. Defaults to empty. | -```rust ignore +#### Example + +```rust #[tool_router] -impl MyHandler { +impl MyToolHandler { #[tool] - fn tool1(&self) -> Result { - // Tool 1 implementation + pub fn my_tool() { + } - - #[tool] - fn tool2(&self) -> Result { - // Tool 2 implementation + + pub fn new() -> Self { + Self { + // the default name of tool router will be `tool_router` + tool_router: Self::tool_router(), + } } } ``` +Or specify the visibility and router name, which would be helpful when you want to combine multiple routers into one: + +```rust +mod a { + #[tool_router(router = tool_router_a, vis = pub)] + impl MyToolHandler { + #[tool] + fn my_tool_a() { + + } + } +} + +mod b { + #[tool_router(router = tool_router_b, vis = pub)] + impl MyToolHandler { + #[tool] + fn my_tool_b() { + + } + } +} + +impl MyToolHandler { + fn new() -> Self { + Self { + tool_router: self::tool_router_a() + self::tool_router_b(), + } + } +} + + +### tool_handler + +This macro will generate the handler for `tool_call` and `list_tools` methods in the implementation block, by using an existing `ToolRouter` instance. + +#### Usage + +| field | type | usage | +| :- | :- | :- | +| `router` | `Expr` | The expression to access the `ToolRouter` instance. Defaults to `self.tool_router`. | + +#### Example +```rust +#[tool_handler] +impl ServerHandler for MyToolHandler { + // ...implement other handler +} +``` + +or using a custom router expression: +```rust +#[tool_handler(router = self.get_router().await)] +impl ServerHandler for MyToolHandler { + // ...implement other handler +} +``` + +#### Explained +This macro will be expended to something like this: +```rust +impl ServerHandler for MyToolHandler { + async fn call_tool( + &self, + request: CallToolRequestParam, + context: RequestContext, + ) -> Result { + let tcc = ToolCallContext::new(self, request, context); + self.tool_router.call(tcc).await + } + + async fn list_tools( + &self, + _request: Option, + _context: RequestContext, + ) -> Result { + let items = self.tool_router.list_all(); + Ok(ListToolsResult::with_all_items(items)) + } +} +``` ## Advanced Features -- Support for parameter aggregation (`#[tool(aggr)]`) - Support for custom tool names and descriptions - Automatic generation of tool descriptions from documentation comments - JSON Schema generation for tool parameters diff --git a/crates/rmcp-macros/src/lib.rs b/crates/rmcp-macros/src/lib.rs index ccbe1fb0..ddd1e7d5 100644 --- a/crates/rmcp-macros/src/lib.rs +++ b/crates/rmcp-macros/src/lib.rs @@ -40,6 +40,7 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> TokenStream { /// /// It creates a function that returns a `ToolRouter` instance. /// +/// In most case, you need to add a field for handler to store the router information and initialize it when creating handler, or store it with a static variable. /// ## Usage /// /// | feied | type | usage | @@ -66,14 +67,34 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> TokenStream { /// } /// ``` /// -/// Or specify the visibility and router name: +/// Or specify the visibility and router name, which would be helpful when you want to combine multiple routers into one: /// /// ```rust,ignore -/// #[tool_router(router = my_tool_router, vis = pub)] +/// mod a { +/// #[tool_router(router = tool_router_a, vis = pub)] +/// impl MyToolHandler { +/// #[tool] +/// fn my_tool_a() { +/// +/// } +/// } +/// } +/// +/// mod b { +/// #[tool_router(router = tool_router_b, vis = pub)] +/// impl MyToolHandler { +/// #[tool] +/// fn my_tool_b() { +/// +/// } +/// } +/// } +/// /// impl MyToolHandler { -/// #[tool] -/// pub fn my_tool() { -/// +/// fn new() -> Self { +/// Self { +/// tool_router: self::tool_router_a() + self::tool_router_b(), +/// } /// } /// } /// ``` @@ -108,6 +129,31 @@ pub fn tool_router(attr: TokenStream, input: TokenStream) -> TokenStream { /// // ...implement other handler /// } /// ``` +/// +/// ## Explain +/// +/// This macro will be expended to something like this: +/// ```rust,ignore +/// impl ServerHandler for MyToolHandler { +/// async fn call_tool( +/// &self, +/// request: CallToolRequestParam, +/// context: RequestContext, +/// ) -> Result { +/// let tcc = ToolCallContext::new(self, request, context); +/// self.tool_router.call(tcc).await +/// } +/// +/// async fn list_tools( +/// &self, +/// _request: Option, +/// _context: RequestContext, +/// ) -> Result { +/// let items = self.tool_router.list_all(); +/// Ok(ListToolsResult::with_all_items(items)) +/// } +/// } +/// ``` #[proc_macro_attribute] pub fn tool_handler(attr: TokenStream, input: TokenStream) -> TokenStream { tool_handler::tool_handler(attr.into(), input.into()) diff --git a/crates/rmcp-macros/src/tool_handler.rs b/crates/rmcp-macros/src/tool_handler.rs index d4bc2800..c6b168e7 100644 --- a/crates/rmcp-macros/src/tool_handler.rs +++ b/crates/rmcp-macros/src/tool_handler.rs @@ -30,7 +30,7 @@ pub fn tool_handler(attr: TokenStream, input: TokenStream) -> syn::Result, ) -> Result { - let tcc = ToolCallContext::new(self, request, context); + let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); #router.call(tcc).await } }; @@ -40,7 +40,7 @@ pub fn tool_handler(attr: TokenStream, input: TokenStream) -> syn::Result, _context: rmcp::service::RequestContext, ) -> Result { - Ok(ListToolsResult::with_all_items(#router.list_all())) + Ok(rmcp::model::ListToolsResult::with_all_items(#router.list_all())) } }; let tool_call_fn = syn::parse2::(tool_call_fn)?; diff --git a/crates/rmcp/tests/test_tool_macro_annotations.rs b/crates/rmcp/tests/test_tool_macro_annotations.rs index 00368af6..e945a10f 100644 --- a/crates/rmcp/tests/test_tool_macro_annotations.rs +++ b/crates/rmcp/tests/test_tool_macro_annotations.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use rmcp::{ServerHandler, handler::server::router::tool::ToolRouter, tool}; + use rmcp::{ServerHandler, handler::server::router::tool::ToolRouter, tool, tool_handler}; #[derive(Debug, Clone, Default)] pub struct AnnotatedServer { @@ -19,17 +19,8 @@ mod tests { format!("Direct: {}", input) } } - - impl ServerHandler for AnnotatedServer { - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); - self.tool_router.call(tcc).await - } - } + #[tool_handler] + impl ServerHandler for AnnotatedServer {} #[test] fn test_direct_tool_attributes() { diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index 9e7fde6a..b7631ed5 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -4,11 +4,8 @@ use std::sync::Arc; use rmcp::{ ClientHandler, ServerHandler, ServiceExt, - handler::server::{ - router::tool::ToolRouter, - tool::{Parameters, ToolCallContext}, - }, - model::{CallToolRequestParam, ClientInfo, ListToolsResult}, + handler::server::{router::tool::ToolRouter, tool::Parameters}, + model::{CallToolRequestParam, ClientInfo}, tool, tool_handler, tool_router, }; use schemars::JsonSchema; @@ -20,16 +17,8 @@ pub struct GetWeatherRequest { pub date: String, } -impl ServerHandler for Server { - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = ToolCallContext::new(self, request, context); - self.tool_router.call(tcc).await - } -} +#[tool_handler(router = self.tool_router)] +impl ServerHandler for Server {} #[derive(Debug, Clone)] #[allow(dead_code)] @@ -169,7 +158,7 @@ pub struct OptionalI64TestSchema { // Dummy struct to host the test tool method #[derive(Debug, Clone)] pub struct OptionalSchemaTester { - router: ToolRouter, + tool_router: ToolRouter, } impl Default for OptionalSchemaTester { @@ -181,7 +170,7 @@ impl Default for OptionalSchemaTester { impl OptionalSchemaTester { pub fn new() -> Self { Self { - router: Self::tool_router(), + tool_router: Self::tool_router(), } } } @@ -207,18 +196,9 @@ impl OptionalSchemaTester { } } } - +#[tool_handler] // Implement ServerHandler to route tool calls for OptionalSchemaTester -impl ServerHandler for OptionalSchemaTester { - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = ToolCallContext::new(self, request, context); - self.router.call(tcc).await - } -} +impl ServerHandler for OptionalSchemaTester {} #[test] fn test_optional_field_schema_generation_via_macro() { diff --git a/docs/readme/README.zh-cn.md b/docs/readme/README.zh-cn.md index 6e6ff91d..0ae1745c 100644 --- a/docs/readme/README.zh-cn.md +++ b/docs/readme/README.zh-cn.md @@ -122,7 +122,7 @@ impl Calculator { } // impl call_tool and list_tool by querying static toolbox -#[tool_router] +#[tool_handler] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { diff --git a/examples/servers/src/common/calculator.rs b/examples/servers/src/common/calculator.rs index 3629d133..3a3b3538 100644 --- a/examples/servers/src/common/calculator.rs +++ b/examples/servers/src/common/calculator.rs @@ -3,8 +3,8 @@ use rmcp::{ ServerHandler, handler::server::{router::tool::ToolRouter, tool::Parameters, wrapper::Json}, - model::{ListToolsResult, ServerCapabilities, ServerInfo}, - schemars, tool, tool_router, + model::{ServerCapabilities, ServerInfo}, + schemars, tool, tool_handler, tool_router, }; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] @@ -46,6 +46,7 @@ impl Calculator { } } +#[tool_handler] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { @@ -54,22 +55,4 @@ impl ServerHandler for Calculator { ..Default::default() } } - - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); - self.tool_router.call(tcc).await - } - - async fn list_tools( - &self, - _request: Option, - _context: rmcp::service::RequestContext, - ) -> Result { - let items = self.tool_router.list_all(); - Ok(ListToolsResult::with_all_items(items)) - } } diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index 7d3dd6c7..c205e293 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -2,12 +2,12 @@ use std::sync::Arc; use rmcp::{ - Error as McpError, RoleServer, ServerHandler, const_string, + Error as McpError, RoleServer, ServerHandler, handler::server::{router::tool::ToolRouter, tool::Parameters}, model::*, schemars, service::RequestContext, - tool, tool_router, + tool, tool_handler, tool_router, }; use serde_json::json; use tokio::sync::Mutex; @@ -86,7 +86,7 @@ impl Counter { )])) } } -const_string!(Echo = "echo"); +#[tool_handler] impl ServerHandler for Counter { fn get_info(&self) -> ServerInfo { ServerInfo { @@ -199,22 +199,6 @@ impl ServerHandler for Counter { }) } - async fn call_tool( - &self, - request: CallToolRequestParam, - context: RequestContext, - ) -> Result { - let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); - self.tool_router.call(tcc).await - } - - async fn list_tools( - &self, - _request: Option, - _context: RequestContext, - ) -> Result { - Ok(ListToolsResult::with_all_items(self.tool_router.list_all())) - } async fn initialize( &self, _request: InitializeRequestParam, diff --git a/examples/servers/src/common/generic_service.rs b/examples/servers/src/common/generic_service.rs index b5d1bd77..b629bc1c 100644 --- a/examples/servers/src/common/generic_service.rs +++ b/examples/servers/src/common/generic_service.rs @@ -4,7 +4,7 @@ use rmcp::{ ServerHandler, handler::server::{router::tool::ToolRouter, tool::Parameters}, model::{ServerCapabilities, ServerInfo}, - schemars, tool, tool_router, + schemars, tool, tool_handler, tool_router, }; #[allow(dead_code)] @@ -74,6 +74,7 @@ impl GenericService { } } +#[tool_handler] impl ServerHandler for GenericService { fn get_info(&self) -> ServerInfo { ServerInfo { @@ -82,23 +83,4 @@ impl ServerHandler for GenericService { ..Default::default() } } - - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); - self.tool_router.call(tcc).await - } - - async fn list_tools( - &self, - _request: Option, - _context: rmcp::service::RequestContext, - ) -> Result { - Ok(rmcp::model::ListToolsResult::with_all_items( - self.tool_router.list_all(), - )) - } } diff --git a/examples/transport/src/common/calculator.rs b/examples/transport/src/common/calculator.rs index 3629d133..8d7fdd62 100644 --- a/examples/transport/src/common/calculator.rs +++ b/examples/transport/src/common/calculator.rs @@ -3,8 +3,8 @@ use rmcp::{ ServerHandler, handler::server::{router::tool::ToolRouter, tool::Parameters, wrapper::Json}, - model::{ListToolsResult, ServerCapabilities, ServerInfo}, - schemars, tool, tool_router, + model::{ServerCapabilities, ServerInfo}, + schemars, tool, tool_handler, tool_router, }; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] @@ -45,7 +45,7 @@ impl Calculator { Json(a - b) } } - +#[tool_handler] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { @@ -54,22 +54,4 @@ impl ServerHandler for Calculator { ..Default::default() } } - - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); - self.tool_router.call(tcc).await - } - - async fn list_tools( - &self, - _request: Option, - _context: rmcp::service::RequestContext, - ) -> Result { - let items = self.tool_router.list_all(); - Ok(ListToolsResult::with_all_items(items)) - } } diff --git a/examples/wasi/src/calculator.rs b/examples/wasi/src/calculator.rs index 8c671c51..6f1d0a3f 100644 --- a/examples/wasi/src/calculator.rs +++ b/examples/wasi/src/calculator.rs @@ -3,8 +3,8 @@ use rmcp::{ ServerHandler, handler::server::{router::tool::ToolRouter, tool::Parameters, wrapper::Json}, - model::{ListToolsResult, ServerCapabilities, ServerInfo}, - schemars, tool, tool_router, + model::{ServerCapabilities, ServerInfo}, + schemars, tool, tool_handler, tool_router, }; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] @@ -51,7 +51,7 @@ impl Calculator { Json(a - b) } } - +#[tool_handler] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { @@ -60,22 +60,4 @@ impl ServerHandler for Calculator { ..Default::default() } } - - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); - self.tool_router.call(tcc).await - } - - async fn list_tools( - &self, - _request: Option, - _context: rmcp::service::RequestContext, - ) -> Result { - let items = self.tool_router.list_all(); - Ok(ListToolsResult::with_all_items(items)) - } } diff --git a/justfile b/justfile new file mode 100644 index 00000000..10f78083 --- /dev/null +++ b/justfile @@ -0,0 +1,8 @@ +fmt: + cargo +nightly fmt --all + +check: + cargo clippy --all-targets --all-features -- -D warnings + +fix: + cargo clippy --fix --all-targets --all-features --allow-staged \ No newline at end of file From edd324769858f78e1d3c447c847842c4b5ee2f80 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 23 Jun 2025 12:19:35 +0800 Subject: [PATCH 11/12] doc: update readme and add contribute guide --- README.md | 9 +++++++-- crates/rmcp/README.md | 18 +----------------- docs/CONTRIBUTE.MD | 13 +++++++++++++ justfile | 8 ++++++-- 4 files changed, 27 insertions(+), 21 deletions(-) create mode 100644 docs/CONTRIBUTE.MD diff --git a/README.md b/README.md index 54ff970f..70a5da0d 100644 --- a/README.md +++ b/README.md @@ -124,5 +124,10 @@ See [oauth_support](docs/OAUTH_SUPPORT.md) for details. ## Related Projects - [containerd-mcp-server](https://github.com/jokemanfire/mcp-containerd) - A containerd-based MCP server implementation -## Development with Dev Container -See [docs/DEVCONTAINER.md](docs/DEVCONTAINER.md) for instructions on using Dev Container for development. +## Development + +### Tips for Contributers +See [docs/CONTRIBUTE.MD](docs/CONTRIBUTE.MD) to get some tips for contributing. + +### Using Dev Container +If you want to use dev container, see [docs/DEVCONTAINER.md](docs/DEVCONTAINER.md) for instructions on using Dev Container for development. \ No newline at end of file diff --git a/crates/rmcp/README.md b/crates/rmcp/README.md index 3130f0da..7c93dd24 100644 --- a/crates/rmcp/README.md +++ b/crates/rmcp/README.md @@ -53,7 +53,7 @@ impl Counter { } // Implement the server handler -#[tool_router] +#[tool_handler] impl rmcp::ServerHandler for Counter { fn get_info(&self) -> ServerInfo { ServerInfo { @@ -62,22 +62,6 @@ impl rmcp::ServerHandler for Counter { ..Default::default() } } - async fn call_tool( - &self, - request: CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = ToolCallContext::new(self, request, context); - self.tool_router.call(tcc).await - } - - async fn list_tools( - &self, - _request: Option, - _context: rmcp::service::RequestContext, - ) -> Result { - Ok(ListToolsResult::with_all_items(self.tool_router.list_all())) - } } // Run the server diff --git a/docs/CONTRIBUTE.MD b/docs/CONTRIBUTE.MD new file mode 100644 index 00000000..95e3a834 --- /dev/null +++ b/docs/CONTRIBUTE.MD @@ -0,0 +1,13 @@ +# Discuss first +If you have a idea, make sure it is discussed before you make a PR. + +# Fmt And Clippy +You can use [just](https://github.com/casey/just) to help you fix your commit rapidly: +```shell +just fix +``` + +# How Can I Rewrite My Commit Message? +You can `git reset --soft upstream/main` and `git commit --forge`, this will merge your changes into one commit. + +Or you also can use git rebase. But we will still merge them into one commit when it is merged. \ No newline at end of file diff --git a/justfile b/justfile index 10f78083..970aa1a5 100644 --- a/justfile +++ b/justfile @@ -4,5 +4,9 @@ fmt: check: cargo clippy --all-targets --all-features -- -D warnings -fix: - cargo clippy --fix --all-targets --all-features --allow-staged \ No newline at end of file +fix: fmt + git add ./ + cargo clippy --fix --all-targets --all-features --allow-staged + +test: + cargo test --all-features \ No newline at end of file From 6cbfffb82cc053859c504fa269d75b6b2953c4a2 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 23 Jun 2025 14:10:54 +0800 Subject: [PATCH 12/12] fix: fix type --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 70a5da0d..3cb91ceb 100644 --- a/README.md +++ b/README.md @@ -126,7 +126,7 @@ See [oauth_support](docs/OAUTH_SUPPORT.md) for details. ## Development -### Tips for Contributers +### Tips for Contributors See [docs/CONTRIBUTE.MD](docs/CONTRIBUTE.MD) to get some tips for contributing. ### Using Dev Container