diff --git a/README.md b/README.md index 54ff970f..3cb91ceb 100644 --- a/README.md +++ b/README.md @@ -124,5 +124,10 @@ See [oauth_support](docs/OAUTH_SUPPORT.md) for details. ## Related Projects - [containerd-mcp-server](https://github.com/jokemanfire/mcp-containerd) - A containerd-based MCP server implementation -## Development with Dev Container -See [docs/DEVCONTAINER.md](docs/DEVCONTAINER.md) for instructions on using Dev Container for development. +## Development + +### Tips for Contributors +See [docs/CONTRIBUTE.MD](docs/CONTRIBUTE.MD) to get some tips for contributing. + +### Using Dev Container +If you want to use dev container, see [docs/DEVCONTAINER.md](docs/DEVCONTAINER.md) for instructions on using Dev Container for development. \ No newline at end of file diff --git a/crates/rmcp-macros/Cargo.toml b/crates/rmcp-macros/Cargo.toml index 9afa8b69..c7f85684 100644 --- a/crates/rmcp-macros/Cargo.toml +++ b/crates/rmcp-macros/Cargo.toml @@ -19,7 +19,7 @@ syn = {version = "2", features = ["full"]} quote = "1" proc-macro2 = "1" serde_json = "1.0" - +darling = { version = "0.20" } [features] [dev-dependencies] \ No newline at end of file diff --git a/crates/rmcp-macros/README.md b/crates/rmcp-macros/README.md index 223e164f..03ef93cd 100644 --- a/crates/rmcp-macros/README.md +++ b/crates/rmcp-macros/README.md @@ -10,41 +10,149 @@ This library primarily provides the following macros: ## Usage -### Tool Macro +### tool -Mark a function as a tool: +This macro is used to mark a function as a tool handler. -```rust ignore -#[tool] -fn calculator(&self, #[tool(param)] a: i32, #[tool(param)] b: i32) -> Result { - // Implement tool functionality - Ok(CallToolResult::success(vec![Content::text((a + b).to_string())])) -} +This will generate a function that return the attribute of this tool, with type `rmcp::model::Tool`. + +#### Usage + +| feied | type | usage | +| :- | :- | :- | +| `name` | `String` | The name of the tool. If not provided, it defaults to the function name. | +| `description` | `String` | A description of the tool. The document of this function will be used. | +| `input_schema` | `Expr` | A JSON Schema object defining the expected parameters for the tool. If not provide, if will use the json schema of its argument with type `Parameters` | +| `annotations` | `ToolAnnotationsAttribute` | Additional tool information. Defaults to `None`. | +#### Example + +```rust +#[tool(name = "my_tool", description = "This is my tool", annotations(title = "我的工具", read_only_hint = true))] +pub async fn my_tool(param: Parameters) { + // handling tool request +} ``` -Use on an impl block to automatically register multiple tools: +### tool_router + +This macro is used to generate a tool router based on functions marked with `#[rmcp::tool]` in an implementation block. + +It creates a function that returns a `ToolRouter` instance. + +In most case, you need to add a field for handler to store the router information and initialize it when creating handler, or store it with a static variable. + +#### Usage + +| feied | type | usage | +| :- | :- | :- | +| `router` | `Ident` | The name of the router function to be generated. Defaults to `tool_router`. | +| `vis` | `Visibility` | The visibility of the generated router function. Defaults to empty. | -```rust ignore -#[tool(tool_box)] -impl MyHandler { +#### Example + +```rust +#[tool_router] +impl MyToolHandler { #[tool] - fn tool1(&self) -> Result { - // Tool 1 implementation + pub fn my_tool() { + } - - #[tool] - fn tool2(&self) -> Result { - // Tool 2 implementation + + pub fn new() -> Self { + Self { + // the default name of tool router will be `tool_router` + tool_router: Self::tool_router(), + } } } ``` +Or specify the visibility and router name, which would be helpful when you want to combine multiple routers into one: + +```rust +mod a { + #[tool_router(router = tool_router_a, vis = pub)] + impl MyToolHandler { + #[tool] + fn my_tool_a() { + + } + } +} + +mod b { + #[tool_router(router = tool_router_b, vis = pub)] + impl MyToolHandler { + #[tool] + fn my_tool_b() { + + } + } +} + +impl MyToolHandler { + fn new() -> Self { + Self { + tool_router: self::tool_router_a() + self::tool_router_b(), + } + } +} + + +### tool_handler + +This macro will generate the handler for `tool_call` and `list_tools` methods in the implementation block, by using an existing `ToolRouter` instance. + +#### Usage + +| field | type | usage | +| :- | :- | :- | +| `router` | `Expr` | The expression to access the `ToolRouter` instance. Defaults to `self.tool_router`. | + +#### Example +```rust +#[tool_handler] +impl ServerHandler for MyToolHandler { + // ...implement other handler +} +``` + +or using a custom router expression: +```rust +#[tool_handler(router = self.get_router().await)] +impl ServerHandler for MyToolHandler { + // ...implement other handler +} +``` + +#### Explained +This macro will be expended to something like this: +```rust +impl ServerHandler for MyToolHandler { + async fn call_tool( + &self, + request: CallToolRequestParam, + context: RequestContext, + ) -> Result { + let tcc = ToolCallContext::new(self, request, context); + self.tool_router.call(tcc).await + } + + async fn list_tools( + &self, + _request: Option, + _context: RequestContext, + ) -> Result { + let items = self.tool_router.list_all(); + Ok(ListToolsResult::with_all_items(items)) + } +} +``` ## Advanced Features -- Support for parameter aggregation (`#[tool(aggr)]`) - Support for custom tool names and descriptions - Automatic generation of tool descriptions from documentation comments - JSON Schema generation for tool parameters diff --git a/crates/rmcp-macros/src/lib.rs b/crates/rmcp-macros/src/lib.rs index ffe44ec5..ddd1e7d5 100644 --- a/crates/rmcp-macros/src/lib.rs +++ b/crates/rmcp-macros/src/lib.rs @@ -2,10 +2,161 @@ use proc_macro::TokenStream; mod tool; - +mod tool_handler; +mod tool_router; +/// # tool +/// +/// This macro is used to mark a function as a tool handler. +/// +/// This will generate a function that return the attribute of this tool, with type `rmcp::model::Tool`. +/// +/// ## Usage +/// +/// | feied | type | usage | +/// | :- | :- | :- | +/// | `name` | `String` | The name of the tool. If not provided, it defaults to the function name. | +/// | `description` | `String` | A description of the tool. The document of this function will be used. | +/// | `input_schema` | `Expr` | A JSON Schema object defining the expected parameters for the tool. If not provide, if will use the json schema of its argument with type `Parameters` | +/// | `annotations` | `ToolAnnotationsAttribute` | Additional tool information. Defaults to `None`. | +/// +/// ## Example +/// +/// ```rust,ignore +/// #[tool(name = "my_tool", description = "This is my tool", annotations(title = "我的工具", read_only_hint = true))] +/// pub async fn my_tool(param: Parameters) { +/// // handling tool request +/// } +/// ``` #[proc_macro_attribute] pub fn tool(attr: TokenStream, input: TokenStream) -> TokenStream { tool::tool(attr.into(), input.into()) .unwrap_or_else(|err| err.to_compile_error()) .into() } + +/// # tool_router +/// +/// This macro is used to generate a tool router based on functions marked with `#[rmcp::tool]` in an implementation block. +/// +/// It creates a function that returns a `ToolRouter` instance. +/// +/// In most case, you need to add a field for handler to store the router information and initialize it when creating handler, or store it with a static variable. +/// ## Usage +/// +/// | feied | type | usage | +/// | :- | :- | :- | +/// | `router` | `Ident` | The name of the router function to be generated. Defaults to `tool_router`. | +/// | `vis` | `Visibility` | The visibility of the generated router function. Defaults to empty. | +/// +/// ## Example +/// +/// ```rust,ignore +/// #[tool_router] +/// impl MyToolHandler { +/// #[tool] +/// pub fn my_tool() { +/// +/// } +/// +/// pub fn new() -> Self { +/// Self { +/// // the default name of tool router will be `tool_router` +/// tool_router: Self::tool_router(), +/// } +/// } +/// } +/// ``` +/// +/// Or specify the visibility and router name, which would be helpful when you want to combine multiple routers into one: +/// +/// ```rust,ignore +/// mod a { +/// #[tool_router(router = tool_router_a, vis = pub)] +/// impl MyToolHandler { +/// #[tool] +/// fn my_tool_a() { +/// +/// } +/// } +/// } +/// +/// mod b { +/// #[tool_router(router = tool_router_b, vis = pub)] +/// impl MyToolHandler { +/// #[tool] +/// fn my_tool_b() { +/// +/// } +/// } +/// } +/// +/// impl MyToolHandler { +/// fn new() -> Self { +/// Self { +/// tool_router: self::tool_router_a() + self::tool_router_b(), +/// } +/// } +/// } +/// ``` +#[proc_macro_attribute] +pub fn tool_router(attr: TokenStream, input: TokenStream) -> TokenStream { + tool_router::tool_router(attr.into(), input.into()) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} + +/// # tool_handler +/// +/// This macro will generate the handler for `tool_call` and `list_tools` methods in the implementation block, by using an existing `ToolRouter` instance. +/// +/// ## Usage +/// +/// | field | type | usage | +/// | :- | :- | :- | +/// | `router` | `Expr` | The expression to access the `ToolRouter` instance. Defaults to `self.tool_router`. | +/// ## Example +/// ```rust,ignore +/// #[tool_handler] +/// impl ServerHandler for MyToolHandler { +/// // ...implement other handler +/// } +/// ``` +/// +/// or using a custom router expression: +/// ```rust,ignore +/// #[tool_handler(router = self.get_router().await)] +/// impl ServerHandler for MyToolHandler { +/// // ...implement other handler +/// } +/// ``` +/// +/// ## Explain +/// +/// This macro will be expended to something like this: +/// ```rust,ignore +/// impl ServerHandler for MyToolHandler { +/// async fn call_tool( +/// &self, +/// request: CallToolRequestParam, +/// context: RequestContext, +/// ) -> Result { +/// let tcc = ToolCallContext::new(self, request, context); +/// self.tool_router.call(tcc).await +/// } +/// +/// async fn list_tools( +/// &self, +/// _request: Option, +/// _context: RequestContext, +/// ) -> Result { +/// let items = self.tool_router.list_all(); +/// Ok(ListToolsResult::with_all_items(items)) +/// } +/// } +/// ``` +#[proc_macro_attribute] +pub fn tool_handler(attr: TokenStream, input: TokenStream) -> TokenStream { + tool_handler::tool_handler(attr.into(), input.into()) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index 956ab320..8d1fb495 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -1,348 +1,96 @@ -use std::collections::HashSet; - +use darling::{FromMeta, ast::NestedMeta}; use proc_macro2::TokenStream; -use quote::{ToTokens, quote}; -use serde_json::json; -use syn::{ - Expr, FnArg, Ident, ItemFn, ItemImpl, Lit, MetaList, PatType, Token, Type, Visibility, - parse::Parse, parse_quote, spanned::Spanned, -}; - -/// Stores tool annotation attributes -#[derive(Default, Clone)] -struct ToolAnnotationAttrs(pub serde_json::Map); - -impl Parse for ToolAnnotationAttrs { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let mut attrs = serde_json::Map::new(); - - while !input.is_empty() { - let key: Ident = input.parse()?; - input.parse::()?; - let value: Lit = input.parse()?; - let value = match value { - Lit::Str(s) => json!(s.value()), - Lit::Bool(b) => json!(b.value), - _ => { - return Err(syn::Error::new( - key.span(), - "annotations must be string or boolean literals", - )); - } - }; - attrs.insert(key.to_string(), value); - if input.is_empty() { - break; - } - input.parse::()?; - } - - Ok(ToolAnnotationAttrs(attrs)) - } -} - -#[derive(Default)] -struct ToolImplItemAttrs { - tool_box: Option>, -} - -impl Parse for ToolImplItemAttrs { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let mut tool_box = None; - while !input.is_empty() { - let key: Ident = input.parse()?; - match key.to_string().as_str() { - "tool_box" => { - tool_box = Some(None); - if input.lookahead1().peek(Token![=]) { - input.parse::()?; - let value: Ident = input.parse()?; - tool_box = Some(Some(value)); - } - } - _ => { - return Err(syn::Error::new(key.span(), "unknown attribute")); - } - } - if input.is_empty() { - break; - } - input.parse::()?; - } - - Ok(ToolImplItemAttrs { tool_box }) - } +use quote::{ToTokens, format_ident, quote}; +use syn::{Expr, Ident, ImplItemFn, ReturnType}; +#[derive(FromMeta, Default, Debug)] +#[darling(default)] +pub struct ToolAttribute { + /// The name of the tool + pub name: Option, + pub description: Option, + /// A JSON Schema object defining the expected parameters for the tool + pub input_schema: Option, + /// Optional additional tool information. + pub annotations: Option, } -#[derive(Default)] -struct ToolFnItemAttrs { - name: Option, - description: Option, - vis: Option, - annotations: Option, +pub struct ResolvedToolAttribute { + pub name: String, + pub description: Option, + pub input_schema: Expr, + pub annotations: Expr, } -impl Parse for ToolFnItemAttrs { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let mut name = None; - let mut description = None; - let mut vis = None; - let mut annotations = None; - - while !input.is_empty() { - let key: Ident = input.parse()?; - input.parse::()?; - match key.to_string().as_str() { - "name" => { - let value: Expr = input.parse()?; - name = Some(value); - } - "description" => { - let value: Expr = input.parse()?; - description = Some(value); - } - "vis" => { - let value: Visibility = input.parse()?; - vis = Some(value); - } - "annotations" => { - // Parse the annotations as a nested structure - let content; - syn::braced!(content in input); - let value = content.parse()?; - annotations = Some(value); - } - _ => { - return Err(syn::Error::new(key.span(), "unknown attribute")); - } - } - if input.is_empty() { - break; - } - input.parse::()?; - } - - Ok(ToolFnItemAttrs { +impl ResolvedToolAttribute { + pub fn into_fn(self, fn_ident: Ident) -> syn::Result { + let Self { name, description, - vis, + input_schema, annotations, - }) - } -} - -struct ToolFnParamAttrs { - serde_meta: Vec, - schemars_meta: Vec, - ident: Ident, - rust_type: Box, -} - -impl ToTokens for ToolFnParamAttrs { - fn to_tokens(&self, tokens: &mut TokenStream) { - let ident = &self.ident; - let rust_type = &self.rust_type; - let serde_meta = &self.serde_meta; - let schemars_meta = &self.schemars_meta; - tokens.extend(quote! { - #(#[#serde_meta])* - #(#[#schemars_meta])* - pub #ident: #rust_type, - }); - } -} - -#[derive(Default)] - -enum ToolParams { - Aggregated { - rust_type: PatType, - }, - Params { - attrs: Vec, - }, - #[default] - NoParam, -} - -#[derive(Default)] -struct ToolAttrs { - fn_item: ToolFnItemAttrs, - params: ToolParams, -} -const TOOL_IDENT: &str = "tool"; -const SERDE_IDENT: &str = "serde"; -const SCHEMARS_IDENT: &str = "schemars"; -const PARAM_IDENT: &str = "param"; -const AGGREGATED_IDENT: &str = "aggr"; -const REQ_IDENT: &str = "req"; - -pub enum ParamMarker { - Param, - Aggregated, -} - -impl Parse for ParamMarker { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let ident: Ident = input.parse()?; - match ident.to_string().as_str() { - PARAM_IDENT => Ok(ParamMarker::Param), - AGGREGATED_IDENT | REQ_IDENT => Ok(ParamMarker::Aggregated), - _ => Err(syn::Error::new(ident.span(), "unknown attribute")), - } - } -} - -pub enum ToolItem { - Fn(ItemFn), - Impl(ItemImpl), -} - -impl Parse for ToolItem { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let lookahead = input.lookahead1(); - if lookahead.peek(Token![impl]) { - let item = input.parse::()?; - Ok(ToolItem::Impl(item)) + } = self; + let description = if let Some(description) = description { + quote! { Some(#description.into()) } } else { - let item = input.parse::()?; - Ok(ToolItem::Fn(item)) - } - } -} - -// dispatch impl function item and impl block item -pub(crate) fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { - let tool_item = syn::parse2::(input)?; - match tool_item { - ToolItem::Fn(item) => tool_fn_item(attr, item), - ToolItem::Impl(item) => tool_impl_item(attr, item), - } -} - -pub(crate) fn tool_impl_item(attr: TokenStream, mut input: ItemImpl) -> syn::Result { - let tool_impl_attr: ToolImplItemAttrs = syn::parse2(attr)?; - let tool_box_ident = tool_impl_attr.tool_box; - - // get all tool function ident - let mut tool_fn_idents = Vec::new(); - for item in &input.items { - if let syn::ImplItem::Fn(method) = item { - for attr in &method.attrs { - if attr.path().is_ident(TOOL_IDENT) { - tool_fn_idents.push(method.sig.ident.clone()); + quote! { None } + }; + let tokens = quote! { + pub fn #fn_ident() -> rmcp::model::Tool { + rmcp::model::Tool { + name: #name.into(), + description: #description, + input_schema: #input_schema, + annotations: #annotations, } } - } + }; + syn::parse2::(tokens) } +} - // handle different cases - if input.trait_.is_some() { - if let Some(ident) = tool_box_ident { - // check if there are generic parameters - if !input.generics.params.is_empty() { - // for trait implementation with generic parameters, directly use the already generated *_inner method - - // generate call_tool method - input.items.push(parse_quote! { - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - self.call_tool_inner(request, context).await - } - }); - - // generate list_tools method - input.items.push(parse_quote! { - async fn list_tools( - &self, - request: Option, - context: rmcp::service::RequestContext, - ) -> Result { - self.list_tools_inner(request, context).await - } - }); - } else { - // if there are no generic parameters, add tool box derive - input.items.push(parse_quote!( - rmcp::tool_box!(@derive #ident); - )); - } - } else { - return Err(syn::Error::new( - proc_macro2::Span::call_site(), - "tool_box attribute is required for trait implementation", - )); - } - } else if let Some(ident) = tool_box_ident { - // if it is a normal impl block - if !input.generics.params.is_empty() { - // if there are generic parameters, not use tool_box! macro, but generate code directly - - // create call code for each tool function - let match_arms = tool_fn_idents.iter().map(|ident| { - let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span()); - let call_fn = Ident::new(&format!("{}_tool_call", ident), ident.span()); - quote! { - name if name == Self::#attr_fn().name => { - Self::#call_fn(tcc).await - } - } - }); - - let tool_attrs = tool_fn_idents.iter().map(|ident| { - let attr_fn = Ident::new(&format!("{}_tool_attr", ident), ident.span()); - quote! { Self::#attr_fn() } - }); - - // implement call_tool method - input.items.push(parse_quote! { - async fn call_tool_inner( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); - match tcc.name() { - #(#match_arms,)* - _ => Err(rmcp::Error::invalid_params("tool not found", None)), - } - } - }); - - // implement list_tools method - input.items.push(parse_quote! { - async fn list_tools_inner( - &self, - _: Option, - _: rmcp::service::RequestContext, - ) -> Result { - Ok(rmcp::model::ListToolsResult { - next_cursor: None, - tools: vec![#(#tool_attrs),*], - }) - } - }); - } else { - // if there are no generic parameters, use the original tool_box! macro - let this_type_ident = &input.self_ty; - input.items.push(parse_quote!( - rmcp::tool_box!(#this_type_ident { - #(#tool_fn_idents),* - } #ident); - )); - } - } +#[derive(FromMeta, Debug, Default)] +#[darling(default)] +pub struct ToolAnnotationsAttribute { + /// A human-readable title for the tool. + pub title: Option, + + /// If true, the tool does not modify its environment. + /// + /// Default: false + pub read_only_hint: Option, + + /// If true, the tool may perform destructive updates to its environment. + /// If false, the tool performs only additive updates. + /// + /// (This property is meaningful only when `readOnlyHint == false`) + /// + /// Default: true + /// A human-readable description of the tool's purpose. + pub destructive_hint: Option, + + /// If true, calling the tool repeatedly with the same arguments + /// will have no additional effect on the its environment. + /// + /// (This property is meaningful only when `readOnlyHint == false`) + /// + /// Default: false. + pub idempotent_hint: Option, + + /// If true, this tool may interact with an "open world" of external + /// entities. If false, the tool's domain of interaction is closed. + /// For example, the world of a web search tool is open, whereas that + /// of a memory tool is not. + /// + /// Default: true + pub open_world_hint: Option, +} - Ok(quote! { - #input - }) +fn none_expr() -> Expr { + syn::parse2::(quote! { None }).unwrap() } // extract doc line from attribute -fn extract_doc_line(attr: &syn::Attribute) -> Option { +fn extract_doc_line(existing_docs: Option, attr: &syn::Attribute) -> Option { if !attr.path().is_ident("doc") { return None; } @@ -360,391 +108,147 @@ fn extract_doc_line(attr: &syn::Attribute) -> Option { }; let content = lit_str.value().trim().to_string(); - - (!content.is_empty()).then_some(content) -} - -pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Result { - let mut tool_macro_attrs = ToolAttrs::default(); - let args: ToolFnItemAttrs = syn::parse2(attr)?; - tool_macro_attrs.fn_item = args; - // let mut fommated_fn_args: Punctuated = Punctuated::new(); - let mut unextractable_args_indexes = HashSet::new(); - for (index, mut fn_arg) in input_fn.sig.inputs.iter_mut().enumerate() { - enum Caught { - Param(ToolFnParamAttrs), - Aggregated(PatType), - } - let mut caught = None; - match &mut fn_arg { - FnArg::Receiver(_) => { - continue; - } - FnArg::Typed(pat_type) => { - let mut serde_metas = Vec::new(); - let mut schemars_metas = Vec::new(); - let mut arg_ident = match pat_type.pat.as_ref() { - syn::Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()), - _ => None, - }; - let raw_attrs: Vec<_> = pat_type.attrs.drain(..).collect(); - for attr in raw_attrs { - match &attr.meta { - syn::Meta::List(meta_list) => { - if meta_list.path.is_ident(TOOL_IDENT) { - let pat_type = pat_type.clone(); - let marker = meta_list.parse_args::()?; - match marker { - ParamMarker::Param => { - let Some(arg_ident) = arg_ident.take() else { - return Err(syn::Error::new( - proc_macro2::Span::call_site(), - "input param must have an ident as name", - )); - }; - caught.replace(Caught::Param(ToolFnParamAttrs { - serde_meta: Vec::new(), - schemars_meta: Vec::new(), - ident: arg_ident, - rust_type: pat_type.ty.clone(), - })); - } - ParamMarker::Aggregated => { - caught.replace(Caught::Aggregated(pat_type.clone())); - } - } - } else if meta_list.path.is_ident(SERDE_IDENT) { - serde_metas.push(meta_list.clone()); - } else if meta_list.path.is_ident(SCHEMARS_IDENT) { - schemars_metas.push(meta_list.clone()); - } else { - pat_type.attrs.push(attr); - } - } - _ => { - pat_type.attrs.push(attr); - } - } - } - match caught { - Some(Caught::Param(mut param)) => { - param.serde_meta = serde_metas; - param.schemars_meta = schemars_metas; - match &mut tool_macro_attrs.params { - ToolParams::Params { attrs } => { - attrs.push(param); - } - _ => { - tool_macro_attrs.params = ToolParams::Params { attrs: vec![param] }; - } - } - unextractable_args_indexes.insert(index); - } - Some(Caught::Aggregated(rust_type)) => { - if let ToolParams::Params { .. } = tool_macro_attrs.params { - return Err(syn::Error::new( - rust_type.span(), - "cannot mix aggregated and individual parameters", - )); - } - tool_macro_attrs.params = ToolParams::Aggregated { rust_type }; - unextractable_args_indexes.insert(index); - } - None => {} - } - } + match (existing_docs, content) { + (Some(mut existing_docs), content) if !content.is_empty() => { + existing_docs.push('\n'); + existing_docs.push_str(&content); + Some(existing_docs) } + (Some(existing_docs), _) => Some(existing_docs), + (None, content) if !content.is_empty() => Some(content), + _ => None, } +} - // input_fn.sig.inputs = fommated_fn_args; - let name = if let Some(expr) = tool_macro_attrs.fn_item.name { - expr +pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { + let attribute = if attr.is_empty() { + Default::default() } else { - let fn_name = &input_fn.sig.ident; - parse_quote! { - stringify!(#fn_name) - } + let attr_args = NestedMeta::parse_meta_list(attr)?; + ToolAttribute::from_list(&attr_args)? }; - let tool_attr_fn_ident = Ident::new( - &format!("{}_tool_attr", input_fn.sig.ident), - proc_macro2::Span::call_site(), - ); + let mut fn_item = syn::parse2::(input.clone())?; + let fn_ident = &fn_item.sig.ident; - // generate get tool attr function - let tool_attr_fn = { - let description = if let Some(expr) = tool_macro_attrs.fn_item.description { - // Use explicitly provided description if available - expr - } else { - // Try to extract documentation comments - let doc_content = input_fn - .attrs - .iter() - .filter_map(extract_doc_line) - .collect::>() - .join("\n"); - - parse_quote! { - #doc_content.trim().to_string() - } - }; - let schema = match &tool_macro_attrs.params { - ToolParams::Aggregated { rust_type } => { - let ty = &rust_type.ty; - let schema = quote! { - rmcp::handler::server::tool::cached_schema_for_type::<#ty>() - }; - schema - } - ToolParams::Params { attrs, .. } => { - let (param_type, temp_param_type_name) = - create_request_type(attrs, input_fn.sig.ident.to_string()); - let schema = quote! { + let tool_attr_fn_ident = format_ident!("{}_tool_attr", fn_ident); + let input_schema_expr = if let Some(input_schema) = attribute.input_schema { + input_schema + } else { + // try to find some parameters wrapper in the function + let params_ty = fn_item.sig.inputs.iter().find_map(|input| { + if let syn::FnArg::Typed(pat_type) = input { + if let syn::Type::Path(type_path) = &*pat_type.ty { + if type_path + .path + .segments + .last() + .is_some_and(|type_name| type_name.ident == "Parameters") { - #param_type - rmcp::handler::server::tool::cached_schema_for_type::<#temp_param_type_name>() + return Some(pat_type.ty.clone()); } - }; - schema - } - ToolParams::NoParam => { - quote! { - rmcp::handler::server::tool::cached_schema_for_type::() } } - }; - let input_fn_attrs = &input_fn.attrs; - let input_fn_vis = &input_fn.vis; - - let annotations_code = if let Some(annotations) = &tool_macro_attrs.fn_item.annotations { - let annotations = - serde_json::to_string(&annotations.0).expect("failed to serialize annotations"); - quote! { - Some(serde_json::from_str::(&#annotations).expect("Could not parse tool annotations")) - } + None + }); + if let Some(params_ty) = params_ty { + // if found, use the Parameters schema + syn::parse2::(quote! { + rmcp::handler::server::tool::cached_schema_for_type::<#params_ty>() + })? } else { - quote! { None } - }; - - quote! { - #(#input_fn_attrs)* - #input_fn_vis fn #tool_attr_fn_ident() -> rmcp::model::Tool { - rmcp::model::Tool { - name: #name.into(), - description: Some(#description.into()), - input_schema: #schema.into(), - annotations: #annotations_code, - } - } + // if not found, use the default EmptyObject schema + syn::parse2::(quote! { + rmcp::handler::server::tool::cached_schema_for_type::() + })? } }; - - // generate wrapped tool function - let tool_call_fn = { - // wrapper function have the same sig: - // async fn #tool_tool_call(context: rmcp::handler::server::tool::ToolCallContext<'_, Self>) - // -> std::result::Result - // - // and the block part should be like: - // { - // use rmcp::handler::server::tool::*; - // let (t0, context) = ::from_tool_call_context_part(context)?; - // let (t1, context) = ::from_tool_call_context_part(context)?; - // ... - // let (tn, context) = ::from_tool_call_context_part(context)?; - // // for params - // ... expand helper types here - // let (__rmcp_tool_req, context) = rmcp::model::JsonObject::from_tool_call_context_part(context)?; - // let __#TOOL_ToolCallParam { param_0, param_1, param_2, .. } = parse_json_object(__rmcp_tool_req)?; - // // for aggr - // let (Parameters(aggr), context) = >::from_tool_call_context_part(context)?; - // Self::#tool_ident(to, param_0, t1, param_1, ..., param_2, tn, aggr).await.into_call_tool_result() - // - // } - // - // - // - - // for receiver type, name it as __rmcp_tool_receiver - let is_async = input_fn.sig.asyncness.is_some(); - let receiver_ident = || Ident::new("__rmcp_tool_receiver", proc_macro2::Span::call_site()); - // generate the extraction part for trivial args - let trivial_args = input_fn - .sig - .inputs - .iter() - .enumerate() - .filter_map(|(index, arg)| { - if unextractable_args_indexes.contains(&index) { - None - } else { - // get ident/type pair - let line = match arg { - FnArg::Typed(pat_type) => { - let pat = &pat_type.pat; - let ty = &pat_type.ty; - quote! { - let (#pat, context) = <#ty>::from_tool_call_context_part(context)?; - } - } - FnArg::Receiver(r) => { - let ty = r.ty.clone(); - let pat = receiver_ident(); - quote! { - let (#pat, context) = <#ty>::from_tool_call_context_part(context)?; - } - } - }; - Some(line) - } - }); - let trivial_arg_extraction_part = quote! { - #(#trivial_args)* - }; - let processed_arg_extraction_part = match &mut tool_macro_attrs.params { - ToolParams::Aggregated { rust_type } => { - let PatType { pat, ty, .. } = rust_type; - quote! { - let (Parameters(#pat), context) = >::from_tool_call_context_part(context)?; - } - } - ToolParams::Params { attrs } => { - let (param_type, temp_param_type_name) = - create_request_type(attrs, input_fn.sig.ident.to_string()); - - let params_ident = attrs.iter().map(|attr| &attr.ident).collect::>(); - quote! { - #param_type - let (__rmcp_tool_req, context) = rmcp::model::JsonObject::from_tool_call_context_part(context)?; - let #temp_param_type_name { - #(#params_ident,)* - } = parse_json_object(__rmcp_tool_req)?; - } - } - ToolParams::NoParam => { - quote! {} - } + let annotations_expr = if let Some(annotations) = attribute.annotations { + let ToolAnnotationsAttribute { + title, + read_only_hint, + destructive_hint, + idempotent_hint, + open_world_hint, + } = annotations; + fn wrap_option(x: Option) -> TokenStream { + x.map(|x| quote! {Some(#x.into())}) + .unwrap_or(quote! { None }) + } + let title = wrap_option(title); + let read_only_hint = wrap_option(read_only_hint); + let destructive_hint = wrap_option(destructive_hint); + let idempotent_hint = wrap_option(idempotent_hint); + let open_world_hint = wrap_option(open_world_hint); + let token_stream = quote! { + Some(rmcp::model::ToolAnnotations { + title: #title, + read_only_hint: #read_only_hint, + destructive_hint: #destructive_hint, + idempotent_hint: #idempotent_hint, + open_world_hint: #open_world_hint, + }) }; - // generate the execution part - // has receiver? - let params = &input_fn - .sig - .inputs - .iter() - .map(|fn_arg| match fn_arg { - FnArg::Receiver(_) => { - let pat = receiver_ident(); - quote! { #pat } + syn::parse2::(token_stream)? + } else { + none_expr() + }; + let resolved_tool_attr = ResolvedToolAttribute { + name: attribute.name.unwrap_or_else(|| fn_ident.to_string()), + description: attribute + .description + .or_else(|| fn_item.attrs.iter().fold(None, extract_doc_line)), + input_schema: input_schema_expr, + annotations: annotations_expr, + }; + let tool_attr_fn = resolved_tool_attr.into_fn(tool_attr_fn_ident)?; + // modify the the input function + if fn_item.sig.asyncness.is_some() { + // 1. remove asyncness from sig + // 2. make return type: `std::pin::Pin + Send + '_>>` + // 3. make body: { Box::pin(async move { #body }) } + let new_output = syn::parse2::({ + match &fn_item.sig.output { + syn::ReturnType::Default => { + quote! { -> std::pin::Pin + Send + '_>> } } - FnArg::Typed(pat_type) => { - let pat = &pat_type.pat.clone(); - quote! { #pat } + syn::ReturnType::Type(_, ty) => { + quote! { -> std::pin::Pin + Send + '_>> } } - }) - .collect::>(); - let raw_fn_ident = &input_fn.sig.ident; - let call = if is_async { - quote! { - Self::#raw_fn_ident(#(#params),*).await.into_call_tool_result() - } - } else { - quote! { - Self::#raw_fn_ident(#(#params),*).into_call_tool_result() - } - }; - // assemble the whole function - let tool_call_fn_ident = Ident::new( - &format!("{}_tool_call", input_fn.sig.ident), - proc_macro2::Span::call_site(), - ); - let raw_fn_vis = tool_macro_attrs - .fn_item - .vis - .as_ref() - .unwrap_or(&input_fn.vis); - let raw_fn_attr = &input_fn - .attrs - .iter() - .filter(|attr| !attr.path().is_ident(TOOL_IDENT)) - .collect::>(); - quote! { - #(#raw_fn_attr)* - #raw_fn_vis async fn #tool_call_fn_ident(context: rmcp::handler::server::tool::ToolCallContext<'_, Self>) - -> std::result::Result { - use rmcp::handler::server::tool::*; - #trivial_arg_extraction_part - #processed_arg_extraction_part - #call } - } - }; + })?; + let prev_block = &fn_item.block; + let new_block = syn::parse2::(quote! { + { Box::pin(async move #prev_block ) } + })?; + fn_item.sig.asyncness = None; + fn_item.sig.output = new_output; + fn_item.block = new_block; + } Ok(quote! { #tool_attr_fn - #tool_call_fn - #input_fn + #fn_item }) } -fn create_request_type(attrs: &[ToolFnParamAttrs], tool_name: String) -> (TokenStream, Ident) { - let pascal_case_tool_name = tool_name.to_ascii_uppercase(); - let temp_param_type_name = Ident::new( - &format!("__{pascal_case_tool_name}ToolCallParam",), - proc_macro2::Span::call_site(), - ); - ( - quote! { - use rmcp::{serde, schemars}; - #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] - pub struct #temp_param_type_name { - #(#attrs)* - } - }, - temp_param_type_name, - ) -} - #[cfg(test)] mod test { use super::*; - #[test] - fn test_tool_sync_macro() -> syn::Result<()> { - let attr = quote! { - name = "test_tool", - description = "test tool", - vis = - }; - let input = quote! { - fn sum(&self, #[tool(aggr)] req: StructRequest) -> Result { - Ok(CallToolResult::success(vec![Content::text((req.a + req.b).to_string())])) - } - }; - let input = tool(attr, input)?; - - println!("input: {:#}", input); - Ok(()) - } - #[test] fn test_trait_tool_macro() -> syn::Result<()> { let attr = quote! { - tool_box = Calculator + name = "direct-annotated-tool", + annotations(title = "Annotated Tool", read_only_hint = true) }; let input = quote! { - impl ServerHandler for Calculator { - #[tool] - fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - ..Default::default() - } - } + async fn async_method(&self, Parameters(Request { fields }): Parameters) { + drop(fields) } }; - let input = tool(attr, input)?; + let _input = tool(attr, input)?; - println!("input: {:#}", input); Ok(()) } + #[test] fn test_doc_comment_description() -> syn::Result<()> { let attr = quote! {}; // No explicit description @@ -764,6 +268,7 @@ mod test { Ok(()) } + #[test] fn test_explicit_description_priority() -> syn::Result<()> { let attr = quote! { diff --git a/crates/rmcp-macros/src/tool_handler.rs b/crates/rmcp-macros/src/tool_handler.rs new file mode 100644 index 00000000..c6b168e7 --- /dev/null +++ b/crates/rmcp-macros/src/tool_handler.rs @@ -0,0 +1,51 @@ +use darling::{FromMeta, ast::NestedMeta}; +use proc_macro2::TokenStream; +use quote::{ToTokens, quote}; +use syn::{Expr, ImplItem, ItemImpl}; + +#[derive(FromMeta)] +#[darling(default)] +pub struct ToolHandlerAttribute { + pub router: Expr, +} + +impl Default for ToolHandlerAttribute { + fn default() -> Self { + Self { + router: syn::parse2(quote! { + self.tool_router + }) + .unwrap(), + } + } +} + +pub fn tool_handler(attr: TokenStream, input: TokenStream) -> syn::Result { + let attr_args = NestedMeta::parse_meta_list(attr)?; + let ToolHandlerAttribute { router } = ToolHandlerAttribute::from_list(&attr_args)?; + let mut item_impl = syn::parse2::(input.clone())?; + let tool_call_fn = quote! { + async fn call_tool( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); + #router.call(tcc).await + } + }; + let tool_list_fn = quote! { + async fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> Result { + Ok(rmcp::model::ListToolsResult::with_all_items(#router.list_all())) + } + }; + let tool_call_fn = syn::parse2::(tool_call_fn)?; + let tool_list_fn = syn::parse2::(tool_list_fn)?; + item_impl.items.push(tool_call_fn); + item_impl.items.push(tool_list_fn); + Ok(item_impl.into_token_stream()) +} diff --git a/crates/rmcp-macros/src/tool_router.rs b/crates/rmcp-macros/src/tool_router.rs new file mode 100644 index 00000000..f6bc0edd --- /dev/null +++ b/crates/rmcp-macros/src/tool_router.rs @@ -0,0 +1,90 @@ +//! ```ignore +//! #[rmcp::tool_router(router)] +//! impl Handler { +//! +//! } +//! ``` +//! + +use darling::{FromMeta, ast::NestedMeta}; +use proc_macro2::TokenStream; +use quote::{ToTokens, format_ident, quote}; +use syn::{Ident, ImplItem, ItemImpl, Visibility}; + +#[derive(FromMeta)] +#[darling(default)] +pub struct ToolRouterAttribute { + pub router: Ident, + pub vis: Option, +} + +impl Default for ToolRouterAttribute { + fn default() -> Self { + Self { + router: format_ident!("tool_router"), + vis: None, + } + } +} + +pub fn tool_router(attr: TokenStream, input: TokenStream) -> syn::Result { + let attr_args = NestedMeta::parse_meta_list(attr)?; + let ToolRouterAttribute { router, vis } = ToolRouterAttribute::from_list(&attr_args)?; + let mut item_impl = syn::parse2::(input.clone())?; + // find all function marked with `#[rmcp::tool]` + let tool_attr_fns: Vec<_> = item_impl + .items + .iter() + .filter_map(|item| { + if let syn::ImplItem::Fn(fn_item) = item { + fn_item + .attrs + .iter() + .any(|attr| { + attr.path() + .segments + .last() + .is_some_and(|seg| seg.ident == "tool") + }) + .then_some(&fn_item.sig.ident) + } else { + None + } + }) + .collect(); + let mut routers = vec![]; + for handler in tool_attr_fns { + let tool_attr_fn_ident = format_ident!("{handler}_tool_attr"); + routers.push(quote! { + .with_route((Self::#tool_attr_fn_ident(), Self::#handler)) + }) + } + let router_fn = syn::parse2::(quote! { + #vis fn #router() -> rmcp::handler::server::router::tool::ToolRouter { + rmcp::handler::server::router::tool::ToolRouter::::new() + #(#routers)* + } + })?; + item_impl.items.push(router_fn); + Ok(item_impl.into_token_stream()) +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_router_attr() -> Result<(), Box> { + let attr = quote! { + router = test_router, + }; + let attr_args = NestedMeta::parse_meta_list(attr)?; + let ToolRouterAttribute { router, vis } = ToolRouterAttribute::from_list(&attr_args)?; + println!("router: {}", router); + if let Some(vis) = vis { + println!("visibility: {}", vis.to_token_stream()); + } else { + println!("visibility: None"); + } + Ok(()) + } +} diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 9754f12f..999d4ee5 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -63,7 +63,6 @@ http-body-util = { version = "0.1", optional = true } bytes = { version = "1", optional = true } # macro rmcp-macros = { version = "0.1", workspace = true, optional = true } - [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] chrono = { version = "0.4.38", features = ["serde"] } diff --git a/crates/rmcp/README.md b/crates/rmcp/README.md index b57c9814..7c93dd24 100644 --- a/crates/rmcp/README.md +++ b/crates/rmcp/README.md @@ -15,20 +15,22 @@ wait for the first release. Creating a server with tools is simple using the `#[tool]` macro: ```rust, ignore -use rmcp::{Error as McpError, ServiceExt, model::*, tool, transport::stdio}; +use rmcp::{Error as McpError, ServiceExt, model::*, tool, tool_router, transport::stdio, handler::server::tool::ToolCallContext, handler::server::router::tool::ToolRouter}; use std::sync::Arc; use tokio::sync::Mutex; #[derive(Clone)] pub struct Counter { counter: Arc>, + tool_router: ToolRouter, } -#[tool(tool_box)] +#[tool_router] impl Counter { fn new() -> Self { Self { counter: Arc::new(Mutex::new(0)), + tool_router: Self::tool_router(), } } @@ -51,7 +53,7 @@ impl Counter { } // Implement the server handler -#[tool(tool_box)] +#[tool_handler] impl rmcp::ServerHandler for Counter { fn get_info(&self) -> ServerInfo { ServerInfo { diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 52b63832..ed74a807 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -5,6 +5,7 @@ use crate::{ }; mod resource; +pub mod router; pub mod tool; pub mod wrapper; impl Service for H { diff --git a/crates/rmcp/src/handler/server/router.rs b/crates/rmcp/src/handler/server/router.rs new file mode 100644 index 00000000..ea972f59 --- /dev/null +++ b/crates/rmcp/src/handler/server/router.rs @@ -0,0 +1,96 @@ +use std::sync::Arc; + +use tool::{IntoToolRoute, ToolRoute}; + +use super::ServerHandler; +use crate::{ + RoleServer, Service, + model::{ClientRequest, ListToolsResult, ServerResult}, + service::NotificationContext, +}; + +pub mod tool; + +pub struct Router { + pub tool_router: tool::ToolRouter, + pub service: Arc, +} + +impl Router +where + S: ServerHandler, +{ + pub fn new(service: S) -> Self { + Self { + tool_router: tool::ToolRouter::new(), + service: Arc::new(service), + } + } + + pub fn with_tool(mut self, route: R) -> Self + where + R: IntoToolRoute, + { + self.tool_router.add_route(route.into_tool_route()); + self + } + + pub fn with_tools(mut self, routes: impl IntoIterator>) -> Self { + for route in routes { + self.tool_router.add_route(route); + } + self + } +} + +impl Service for Router +where + S: ServerHandler, +{ + async fn handle_notification( + &self, + notification: ::PeerNot, + context: NotificationContext, + ) -> Result<(), crate::Error> { + self.service + .handle_notification(notification, context) + .await + } + async fn handle_request( + &self, + request: ::PeerReq, + context: crate::service::RequestContext, + ) -> Result<::Resp, crate::Error> { + match request { + ClientRequest::CallToolRequest(request) => { + if self.tool_router.has_route(request.params.name.as_ref()) + || !self.tool_router.transparent_when_not_found + { + let tool_call_context = crate::handler::server::tool::ToolCallContext::new( + self.service.as_ref(), + request.params, + context, + ); + let result = self.tool_router.call(tool_call_context).await?; + Ok(ServerResult::CallToolResult(result)) + } else { + self.service + .handle_request(ClientRequest::CallToolRequest(request), context) + .await + } + } + ClientRequest::ListToolsRequest(_) => { + let tools = self.tool_router.list_all(); + Ok(ServerResult::ListToolsResult(ListToolsResult { + tools, + next_cursor: None, + })) + } + rest => self.service.handle_request(rest, context).await, + } + } + + fn get_info(&self) -> ::Info { + self.service.get_info() + } +} diff --git a/crates/rmcp/src/handler/server/router/prompt.rs b/crates/rmcp/src/handler/server/router/prompt.rs new file mode 100644 index 00000000..e69de29b diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs new file mode 100644 index 00000000..bb5fda21 --- /dev/null +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -0,0 +1,272 @@ +use std::{borrow::Cow, sync::Arc}; + +use futures::{FutureExt, future::BoxFuture}; +use schemars::JsonSchema; + +use crate::{ + handler::server::tool::{ + CallToolHandler, DynCallToolHandler, ToolCallContext, schema_for_type, + }, + model::{CallToolResult, Tool, ToolAnnotations}, +}; + +pub struct ToolRoute { + #[allow(clippy::type_complexity)] + pub call: Arc>, + pub attr: crate::model::Tool, +} + +impl std::fmt::Debug for ToolRoute { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ToolRoute") + .field("name", &self.attr.name) + .field("description", &self.attr.description) + .field("input_schema", &self.attr.input_schema) + .finish() + } +} + +impl Clone for ToolRoute { + fn clone(&self) -> Self { + Self { + call: self.call.clone(), + attr: self.attr.clone(), + } + } +} + +impl ToolRoute { + pub fn new(attr: impl Into, call: C) -> Self + where + C: CallToolHandler + Send + Sync + Clone + 'static, + { + Self { + call: Arc::new(move |context: ToolCallContext| { + let call = call.clone(); + context.invoke(call).boxed() + }), + attr: attr.into(), + } + } + pub fn new_dyn(attr: impl Into, call: C) -> Self + where + C: for<'a> Fn( + ToolCallContext<'a, S>, + ) -> BoxFuture<'a, Result> + + Send + + Sync + + 'static, + { + Self { + call: Arc::new(call), + attr: attr.into(), + } + } + pub fn name(&self) -> &str { + &self.attr.name + } +} + +pub trait IntoToolRoute { + fn into_tool_route(self) -> ToolRoute; +} + +impl IntoToolRoute for (T, C) +where + S: Send + Sync + 'static, + C: CallToolHandler + Send + Sync + Clone + 'static, + T: Into, +{ + fn into_tool_route(self) -> ToolRoute { + ToolRoute::new(self.0.into(), self.1) + } +} + +impl IntoToolRoute for ToolRoute +where + S: Send + Sync + 'static, +{ + fn into_tool_route(self) -> ToolRoute { + self + } +} + +pub struct ToolAttrGenerateFunctionAdapter; +impl IntoToolRoute for F +where + S: Send + Sync + 'static, + F: Fn() -> ToolRoute, +{ + fn into_tool_route(self) -> ToolRoute { + (self)() + } +} + +pub trait CallToolHandlerExt: Sized +where + Self: CallToolHandler + Send + Sync + Clone + 'static, +{ + fn name(self, name: impl Into>) -> WithToolAttr; +} + +impl CallToolHandlerExt for C +where + C: CallToolHandler + Send + Sync + Clone + 'static, +{ + fn name(self, name: impl Into>) -> WithToolAttr { + WithToolAttr { + attr: Tool::new( + name.into(), + "", + schema_for_type::(), + ), + call: self, + _marker: std::marker::PhantomData, + } + } +} + +pub struct WithToolAttr +where + C: CallToolHandler + Send + Sync + Clone + 'static, +{ + pub attr: crate::model::Tool, + pub call: C, + pub _marker: std::marker::PhantomData, +} + +impl IntoToolRoute for WithToolAttr +where + C: CallToolHandler + Send + Sync + Clone + 'static, + S: Send + Sync + 'static, +{ + fn into_tool_route(self) -> ToolRoute { + ToolRoute::new(self.attr, self.call) + } +} + +impl WithToolAttr +where + C: CallToolHandler + Send + Sync + Clone + 'static, +{ + pub fn description(mut self, description: impl Into>) -> Self { + self.attr.description = Some(description.into()); + self + } + pub fn parameters(mut self) -> Self { + self.attr.input_schema = schema_for_type::().into(); + self + } + pub fn parameters_value(mut self, schema: serde_json::Value) -> Self { + self.attr.input_schema = crate::model::object(schema).into(); + self + } + pub fn annotation(mut self, annotation: impl Into) -> Self { + self.attr.annotations = Some(annotation.into()); + self + } +} +#[derive(Debug)] +pub struct ToolRouter { + #[allow(clippy::type_complexity)] + pub map: std::collections::HashMap, ToolRoute>, + + pub transparent_when_not_found: bool, +} + +impl Default for ToolRouter { + fn default() -> Self { + Self { + map: std::collections::HashMap::new(), + transparent_when_not_found: false, + } + } +} +impl Clone for ToolRouter { + fn clone(&self) -> Self { + Self { + map: self.map.clone(), + transparent_when_not_found: self.transparent_when_not_found, + } + } +} + +impl IntoIterator for ToolRouter { + type Item = ToolRoute; + type IntoIter = std::collections::hash_map::IntoValues, ToolRoute>; + + fn into_iter(self) -> Self::IntoIter { + self.map.into_values() + } +} + +impl ToolRouter +where + S: Send + Sync + 'static, +{ + pub fn new() -> Self { + Self { + map: std::collections::HashMap::new(), + transparent_when_not_found: false, + } + } + pub fn with_route(mut self, route: R) -> Self + where + R: IntoToolRoute, + { + self.add_route(route.into_tool_route()); + self + } + + pub fn add_route(&mut self, item: ToolRoute) { + self.map.insert(item.attr.name.clone(), item); + } + + pub fn merge(&mut self, other: ToolRouter) { + for item in other.map.into_values() { + self.add_route(item); + } + } + + pub fn remove_route(&mut self, name: &str) { + self.map.remove(name); + } + pub fn has_route(&self, name: &str) -> bool { + self.map.contains_key(name) + } + pub async fn call( + &self, + context: ToolCallContext<'_, S>, + ) -> Result { + let item = self + .map + .get(context.name()) + .ok_or_else(|| crate::Error::invalid_params("tool not found", None))?; + (item.call)(context).await + } + + pub fn list_all(&self) -> Vec { + self.map.values().map(|item| item.attr.clone()).collect() + } +} + +impl std::ops::Add> for ToolRouter +where + S: Send + Sync + 'static, +{ + type Output = Self; + + fn add(mut self, other: ToolRouter) -> Self::Output { + self.merge(other); + self + } +} + +impl std::ops::AddAssign> for ToolRouter +where + S: Send + Sync + 'static, +{ + fn add_assign(&mut self, other: ToolRouter) { + self.merge(other); + } +} diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index ec0f38e4..9fbe5b94 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -2,16 +2,18 @@ use std::{ any::TypeId, borrow::Cow, collections::HashMap, future::Ready, marker::PhantomData, sync::Arc, }; -use futures::future::BoxFuture; +use futures::future::{BoxFuture, FutureExt}; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use tokio_util::sync::CancellationToken; +pub use super::router::tool::{ToolRoute, ToolRouter}; use crate::{ RoleServer, - model::{CallToolRequestParam, CallToolResult, ConstString, IntoContents, JsonObject}, + model::{CallToolRequestParam, CallToolResult, IntoContents, JsonObject}, service::RequestContext, }; + /// A shortcut for generating a JSON schema for a type. pub fn schema_for_type() -> JsonObject { // explicitly to align json schema version to official specifications. @@ -62,16 +64,16 @@ pub fn parse_json_object(input: JsonObject) -> Result { - request_context: RequestContext, - service: &'service S, - name: Cow<'static, str>, - arguments: Option, +pub struct ToolCallContext<'s, S> { + pub request_context: RequestContext, + pub service: &'s S, + pub name: Cow<'static, str>, + pub arguments: Option, } -impl<'service, S> ToolCallContext<'service, S> { +impl<'s, S> ToolCallContext<'s, S> { pub fn new( - service: &'service S, + service: &'s S, CallToolRequestParam { name, arguments }: CallToolRequestParam, request_context: RequestContext, ) -> Self { @@ -90,10 +92,8 @@ impl<'service, S> ToolCallContext<'service, S> { } } -pub trait FromToolCallContextPart<'a, S>: Sized { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error>; +pub trait FromToolCallContextPart: Sized { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result; } pub trait IntoCallToolResult { @@ -161,16 +161,16 @@ impl IntoCallToolResult for Result { } } -pub trait CallToolHandler<'a, S, A> { - type Fut: Future> + Send + 'a; - fn call(self, context: ToolCallContext<'a, S>) -> Self::Fut; +pub trait CallToolHandler { + fn call( + self, + context: ToolCallContext<'_, S>, + ) -> BoxFuture<'_, Result>; } -pub type DynCallToolHandler = dyn Fn(ToolCallContext<'_, S>) -> BoxFuture<'_, Result> +pub type DynCallToolHandler = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result> + Send + Sync; -/// Parameter Extractor -pub struct Parameter(pub K, pub V); /// Parameter Extractor /// @@ -188,83 +188,25 @@ impl JsonSchema for Parameters

{ } } -/// Callee Extractor -pub struct Callee<'a, S>(pub &'a S); - -impl<'a, S> FromToolCallContextPart<'a, S> for CancellationToken { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((context.request_context.ct.clone(), context)) - } -} - -impl<'a, S> FromToolCallContextPart<'a, S> for Callee<'a, S> { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((Callee(context.service), context)) +impl FromToolCallContextPart for CancellationToken { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { + Ok(context.request_context.ct.clone()) } } pub struct ToolName(pub Cow<'static, str>); -impl<'a, S> FromToolCallContextPart<'a, S> for ToolName { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((Self(context.name.clone()), context)) - } -} - -impl<'a, S> FromToolCallContextPart<'a, S> for &'a S { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((context.service, context)) - } -} - -impl<'a, S, K, V> FromToolCallContextPart<'a, S> for Parameter -where - K: ConstString, - V: DeserializeOwned, -{ - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - let arguments = context - .arguments - .as_ref() - .ok_or(crate::Error::invalid_params( - format!("missing parameter {field}", field = K::VALUE), - None, - ))?; - let value = arguments.get(K::VALUE).ok_or(crate::Error::invalid_params( - format!("missing parameter {field}", field = K::VALUE), - None, - ))?; - let value: V = serde_json::from_value(value.clone()).map_err(|e| { - crate::Error::invalid_params( - format!( - "failed to deserialize parameter {field}: {error}", - field = K::VALUE, - error = e - ), - None, - ) - })?; - Ok((Parameter(K::default(), value), context)) +impl FromToolCallContextPart for ToolName { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { + Ok(Self(context.name.clone())) } } -impl<'a, S, P> FromToolCallContextPart<'a, S> for Parameters

+impl FromToolCallContextPart for Parameters

where P: DeserializeOwned, { - fn from_tool_call_context_part( - mut context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let arguments = context.arguments.take().unwrap_or_default(); let value: P = serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| { @@ -273,37 +215,31 @@ where None, ) })?; - Ok((Parameters(value), context)) + Ok(Parameters(value)) } } -impl<'a, S> FromToolCallContextPart<'a, S> for JsonObject { - fn from_tool_call_context_part( - mut context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { +impl FromToolCallContextPart for JsonObject { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let object = context.arguments.take().unwrap_or_default(); - Ok((object, context)) + Ok(object) } } -impl<'a, S> FromToolCallContextPart<'a, S> for crate::model::Extensions { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { +impl FromToolCallContextPart for crate::model::Extensions { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let extensions = context.request_context.extensions.clone(); - Ok((extensions, context)) + Ok(extensions) } } pub struct Extension(pub T); -impl<'a, S, T> FromToolCallContextPart<'a, S> for Extension +impl FromToolCallContextPart for Extension where T: Send + Sync + 'static + Clone, { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let extension = context .request_context .extensions @@ -315,58 +251,52 @@ where None, ) })?; - Ok((Extension(extension), context)) + Ok(Extension(extension)) } } -impl<'a, S> FromToolCallContextPart<'a, S> for crate::Peer { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { +impl FromToolCallContextPart for crate::Peer { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let peer = context.request_context.peer.clone(); - Ok((peer, context)) + Ok(peer) } } -impl<'a, S> FromToolCallContextPart<'a, S> for crate::model::Meta { - fn from_tool_call_context_part( - mut context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { +impl FromToolCallContextPart for crate::model::Meta { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let mut meta = crate::model::Meta::default(); std::mem::swap(&mut meta, &mut context.request_context.meta); - Ok((meta, context)) + Ok(meta) } } pub struct RequestId(pub crate::model::RequestId); -impl<'a, S> FromToolCallContextPart<'a, S> for RequestId { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((RequestId(context.request_context.id.clone()), context)) +impl FromToolCallContextPart for RequestId { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { + Ok(RequestId(context.request_context.id.clone())) } } -impl<'a, S> FromToolCallContextPart<'a, S> for RequestContext { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((context.request_context.clone(), context)) +impl FromToolCallContextPart for RequestContext { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { + Ok(context.request_context.clone()) } } impl<'s, S> ToolCallContext<'s, S> { - pub fn invoke(self, h: H) -> H::Fut + pub fn invoke(self, h: H) -> BoxFuture<'s, Result> where - H: CallToolHandler<'s, S, A>, + H: CallToolHandler, { h.call(self) } } - #[allow(clippy::type_complexity)] -pub struct AsyncAdapter(PhantomData<(fn(P) -> Fut, fn(Fut) -> R)>); +pub struct AsyncAdapter(PhantomData fn(Fut) -> R>); pub struct SyncAdapter(PhantomData R>); +// #[allow(clippy::type_complexity)] +pub struct AsyncMethodAdapter(PhantomData R>); +pub struct SyncMethodAdapter(PhantomData R>); macro_rules! impl_for { ($($T: ident)*) => { @@ -381,174 +311,118 @@ macro_rules! impl_for { impl_for!([$($Tn)* $Tn_1] [$($Rest)*]); }; (@impl $($Tn: ident)*) => { - impl<'s, $($Tn,)* S, F, Fut, R> CallToolHandler<'s, S, AsyncAdapter<($($Tn,)*), Fut, R>> for F + impl<$($Tn,)* S, F, R> CallToolHandler> for F where $( - $Tn: FromToolCallContextPart<'s, S> + 's, + $Tn: FromToolCallContextPart , )* - F: FnOnce($($Tn,)*) -> Fut + Send + 's, - Fut: Future + Send + 's, - R: IntoCallToolResult + Send + 's, - S: Send + Sync, + F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R>, + + // Need RTN support here(I guess), https://github.com/rust-lang/rust/pull/138424 + // Fut: Future + Send + 'a, + R: IntoCallToolResult + Send + 'static, + S: Send + Sync + 'static, { - type Fut = IntoCallToolResultFut; - #[allow(unused_variables, non_snake_case)] + #[allow(unused_variables, non_snake_case, unused_mut)] fn call( self, - context: ToolCallContext<'s, S>, - ) -> Self::Fut { + mut context: ToolCallContext<'_, S>, + ) -> BoxFuture<'_, Result>{ $( - let result = $Tn::from_tool_call_context_part(context); - let ($Tn, context) = match result { - Ok((value, context)) => (value, context), - Err(e) => return IntoCallToolResultFut::Ready { - result: std::future::ready(Err(e)), - }, + let result = $Tn::from_tool_call_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), }; )* - IntoCallToolResultFut::Pending { - fut: self($($Tn,)*), - _marker: PhantomData - } + let service = context.service; + let fut = self(service, $($Tn,)*); + async move { + let result = fut.await; + result.into_call_tool_result() + }.boxed() } } - impl<'s, $($Tn,)* S, F, R> CallToolHandler<'s, S, SyncAdapter<($($Tn,)*), R>> for F + impl<$($Tn,)* S, F, Fut, R> CallToolHandler> for F where $( - $Tn: FromToolCallContextPart<'s, S> + 's, + $Tn: FromToolCallContextPart , )* - F: FnOnce($($Tn,)*) -> R + Send + 's, - R: IntoCallToolResult + Send + 's, + F: FnOnce($($Tn,)*) -> Fut + Send + , + Fut: Future + Send + 'static, + R: IntoCallToolResult + Send + 'static, S: Send + Sync, { - type Fut = Ready>; - #[allow(unused_variables, non_snake_case)] + #[allow(unused_variables, non_snake_case, unused_mut)] fn call( self, - context: ToolCallContext<'s, S>, - ) -> Self::Fut { + mut context: ToolCallContext, + ) -> BoxFuture<'static, Result>{ $( - let result = $Tn::from_tool_call_context_part(context); - let ($Tn, context) = match result { - Ok((value, context)) => (value, context), - Err(e) => return std::future::ready(Err(e)), + let result = $Tn::from_tool_call_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), }; )* - std::future::ready(self($($Tn,)*).into_call_tool_result()) + let fut = self($($Tn,)*); + async move { + let result = fut.await; + result.into_call_tool_result() + }.boxed() } } - }; -} -impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); -pub struct ToolBoxItem { - #[allow(clippy::type_complexity)] - pub call: Box>, - pub attr: crate::model::Tool, -} -impl ToolBoxItem { - pub fn new(attr: crate::model::Tool, call: C) -> Self - where - C: Fn(ToolCallContext<'_, S>) -> BoxFuture<'_, Result> - + Send - + Sync - + 'static, - { - Self { - call: Box::new(call), - attr, - } - } - pub fn name(&self) -> &str { - &self.attr.name - } -} - -#[derive(Default)] -pub struct ToolBox { - #[allow(clippy::type_complexity)] - pub map: std::collections::HashMap, ToolBoxItem>, -} - -impl ToolBox { - pub fn new() -> Self { - Self { - map: std::collections::HashMap::new(), + impl<$($Tn,)* S, F, R> CallToolHandler> for F + where + $( + $Tn: FromToolCallContextPart + , + )* + F: FnOnce(&S, $($Tn,)*) -> R + Send + , + R: IntoCallToolResult + Send + , + S: Send + Sync, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn call( + self, + mut context: ToolCallContext, + ) -> BoxFuture<'static, Result> { + $( + let result = $Tn::from_tool_call_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + std::future::ready(self(context.service, $($Tn,)*).into_call_tool_result()).boxed() + } } - } - pub fn add(&mut self, item: ToolBoxItem) { - self.map.insert(item.attr.name.clone(), item); - } - pub fn remove(&mut self, name: &str) { - self.map.remove(name); - } - - pub async fn call( - &self, - context: ToolCallContext<'_, S>, - ) -> Result { - let item = self - .map - .get(context.name()) - .ok_or_else(|| crate::Error::invalid_params("tool not found", None))?; - (item.call)(context).await - } - - pub fn list(&self) -> Vec { - self.map.values().map(|item| item.attr.clone()).collect() - } -} - -#[cfg(feature = "macros")] -#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] -#[macro_export] -macro_rules! tool_box { - (@pin_add $callee: ident, $attr: expr, $f: expr) => { - $callee.add(ToolBoxItem::new($attr, |context| Box::pin($f(context)))); - }; - ($server: ident { $($tool: ident),* $(,)?} ) => { - $crate::tool_box!($server { $($tool),* } tool_box); - }; - ($server: ident { $($tool: ident),* $(,)?} $tool_box: ident) => { - fn $tool_box() -> &'static $crate::handler::server::tool::ToolBox<$server> { - use $crate::handler::server::tool::{ToolBox, ToolBoxItem}; - static TOOL_BOX: std::sync::OnceLock> = std::sync::OnceLock::new(); - TOOL_BOX.get_or_init(|| { - let mut tool_box = ToolBox::new(); - $crate::paste!{ - $( - $crate::tool_box!(@pin_add tool_box, $server::[< $tool _tool_attr>](), $server::[<$tool _tool_call>]); - )* - } - tool_box - }) + impl<$($Tn,)* S, F, R> CallToolHandler> for F + where + $( + $Tn: FromToolCallContextPart + , + )* + F: FnOnce($($Tn,)*) -> R + Send + , + R: IntoCallToolResult + Send + , + S: Send + Sync, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn call( + self, + mut context: ToolCallContext, + ) -> BoxFuture<'static, Result> { + $( + let result = $Tn::from_tool_call_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + std::future::ready(self($($Tn,)*).into_call_tool_result()).boxed() + } } }; - (@derive) => { - $crate::tool_box!(@derive tool_box); - }; - - (@derive $tool_box:ident) => { - async fn list_tools( - &self, - _: Option<$crate::model::PaginatedRequestParam>, - _: $crate::service::RequestContext<$crate::service::RoleServer>, - ) -> Result<$crate::model::ListToolsResult, $crate::Error> { - Ok($crate::model::ListToolsResult { - next_cursor: None, - tools: Self::tool_box().list(), - }) - } - - async fn call_tool( - &self, - call_tool_request_param: $crate::model::CallToolRequestParam, - context: $crate::service::RequestContext<$crate::service::RoleServer>, - ) -> Result<$crate::model::CallToolResult, $crate::Error> { - let context = $crate::handler::server::tool::ToolCallContext::new(self, call_tool_request_param, context); - Self::$tool_box().call(context).await - } - } } +impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index e4096b48..a3d1a414 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -15,23 +15,25 @@ //! as Claude Desktop or the Cursor IDE. //! //! For example, to implement a server that has a tool that can count, you would -//! make an object for that tool and add an implementation with the `#[tool(tool_box)]` macro: +//! make an object for that tool and add an implementation with the `#[tool_router]` macro: //! //! ```rust //! use std::sync::Arc; -//! use rmcp::{Error as McpError, model::*, tool}; +//! use rmcp::{Error as McpError, model::*, tool, tool_router, handler::server::tool::ToolRouter}; //! use tokio::sync::Mutex; //! //! #[derive(Clone)] //! pub struct Counter { //! counter: Arc>, +//! tool_router: ToolRouter, //! } //! -//! #[tool(tool_box)] +//! #[tool_router] //! impl Counter { //! fn new() -> Self { //! Self { //! counter: Arc::new(Mutex::new(0)), +//! tool_router: Self::tool_router(), //! } //! } //! @@ -120,7 +122,7 @@ pub mod transport; pub use paste::paste; #[cfg(all(feature = "macros", feature = "server"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "macros", feature = "server"))))] -pub use rmcp_macros::tool; +pub use rmcp_macros::*; #[cfg(all(feature = "macros", feature = "server"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "macros", feature = "server"))))] pub use schemars; diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index c6407ac7..addc0bf0 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -679,6 +679,17 @@ macro_rules! paginated_result { pub next_cursor: Option, pub $i_item: $t_item, } + + impl $t { + pub fn with_all_items( + items: $t_item, + ) -> Self { + Self { + next_cursor: None, + $i_item: items, + } + } + } }; } diff --git a/crates/rmcp/tests/common/calculator.rs b/crates/rmcp/tests/common/calculator.rs index e179f258..4f4fccee 100644 --- a/crates/rmcp/tests/common/calculator.rs +++ b/crates/rmcp/tests/common/calculator.rs @@ -1,7 +1,9 @@ +#![allow(dead_code)] use rmcp::{ ServerHandler, + handler::server::{router::tool::ToolRouter, tool::Parameters}, model::{ServerCapabilities, ServerInfo}, - schemars, tool, + schemars, tool, tool_router, }; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct SumRequest { @@ -9,30 +11,46 @@ pub struct SumRequest { pub a: i32, pub b: i32, } -#[derive(Debug, Clone, Default)] -pub struct Calculator; -#[tool(tool_box)] + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct SubRequest { + #[schemars(description = "the left hand side number")] + pub a: i32, + #[schemars(description = "the right hand side number")] + pub b: i32, +} +#[derive(Debug, Clone)] +pub struct Calculator { + tool_router: ToolRouter, +} + +impl Calculator { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } +} + +impl Default for Calculator { + fn default() -> Self { + Self::new() + } +} + +#[tool_router] impl Calculator { #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + fn sum(&self, Parameters(SumRequest { a, b }): Parameters) -> String { (a + b).to_string() } #[tool(description = "Calculate the sub of two numbers")] - fn sub( - &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, - ) -> String { + fn sub(&self, Parameters(SubRequest { a, b }): Parameters) -> String { (a - b).to_string() } } -#[tool(tool_box)] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { diff --git a/crates/rmcp/tests/test_complex_schema.rs b/crates/rmcp/tests/test_complex_schema.rs index b9370fce..c199e7e6 100644 --- a/crates/rmcp/tests/test_complex_schema.rs +++ b/crates/rmcp/tests/test_complex_schema.rs @@ -1,4 +1,6 @@ -use rmcp::{Error as McpError, model::*, schemars, tool}; +use rmcp::{ + Error as McpError, handler::server::tool::Parameters, model::*, schemars, tool, tool_router, +}; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] @@ -24,7 +26,7 @@ pub struct ChatRequest { #[derive(Clone, Default)] pub struct Demo; -#[tool(tool_box)] +#[tool_router] impl Demo { pub fn new() -> Self { Self @@ -33,9 +35,9 @@ impl Demo { #[tool(description = "LLM")] async fn chat( &self, - #[tool(aggr)] chat_request: ChatRequest, + chat_request: Parameters, ) -> Result { - let content = Content::json(chat_request)?; + let content = Content::json(chat_request.0)?; Ok(CallToolResult::success(vec![content])) } } diff --git a/crates/rmcp/tests/test_tool_macro_annotations.rs b/crates/rmcp/tests/test_tool_macro_annotations.rs index 0b57dffd..e945a10f 100644 --- a/crates/rmcp/tests/test_tool_macro_annotations.rs +++ b/crates/rmcp/tests/test_tool_macro_annotations.rs @@ -1,9 +1,11 @@ #[cfg(test)] mod tests { - use rmcp::{ServerHandler, tool}; + use rmcp::{ServerHandler, handler::server::router::tool::ToolRouter, tool, tool_handler}; #[derive(Debug, Clone, Default)] - pub struct AnnotatedServer {} + pub struct AnnotatedServer { + tool_router: ToolRouter, + } impl AnnotatedServer { // Tool with inline comments for documentation @@ -11,29 +13,14 @@ mod tests { /// This is used to test tool annotations #[tool( name = "direct-annotated-tool", - annotations = { - title: "Annotated Tool", - readOnlyHint: true - } + annotations(title = "Annotated Tool", read_only_hint = true) )] - pub async fn direct_annotated_tool(&self, #[tool(param)] input: String) -> String { + pub async fn direct_annotated_tool(&self, input: String) -> String { format!("Direct: {}", input) } } - - impl ServerHandler for AnnotatedServer { - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); - match tcc.name() { - "direct-annotated-tool" => Self::direct_annotated_tool_tool_call(tcc).await, - _ => Err(rmcp::Error::invalid_params("method not found", None)), - } - } - } + #[tool_handler] + impl ServerHandler for AnnotatedServer {} #[test] fn test_direct_tool_attributes() { diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index 84bcac93..b7631ed5 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -1,12 +1,12 @@ //cargo test --test test_tool_macros --features "client server" - +#![allow(dead_code)] use std::sync::Arc; use rmcp::{ ClientHandler, ServerHandler, ServiceExt, - handler::server::tool::ToolCallContext, + handler::server::{router::tool::ToolRouter, tool::Parameters}, model::{CallToolRequestParam, ClientInfo}, - tool, + tool, tool_handler, tool_router, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -17,37 +17,38 @@ pub struct GetWeatherRequest { pub date: String, } -impl ServerHandler for Server { - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = ToolCallContext::new(self, request, context); - match tcc.name() { - "get-weather" => Self::get_weather_tool_call(tcc).await, - _ => Err(rmcp::Error::invalid_params("method not found", None)), - } - } +#[tool_handler(router = self.tool_router)] +impl ServerHandler for Server {} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct Server { + tool_router: ToolRouter, } -#[derive(Debug, Clone, Default)] -pub struct Server {} +impl Default for Server { + fn default() -> Self { + Self::new() + } +} +#[tool_router(router = tool_router)] impl Server { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + /// This tool is used to get the weather of a city. - #[tool(name = "get-weather", description = "Get the weather of a city.", vis = )] - pub async fn get_weather(&self, #[tool(param)] city: String) -> String { + #[tool(name = "get-weather", description = "Get the weather of a city.")] + pub async fn get_weather(&self, city: Parameters) -> String { drop(city); "rain".to_string() } - #[tool(description = "Empty Parameter")] - async fn empty_param(&self) {} - #[tool(description = "Optional Parameter")] - async fn optional_param(&self, #[tool(param)] city: Option) -> String { - city.unwrap_or_default() - } + #[tool] + async fn empty_param(&self) {} } // define generic service trait @@ -68,13 +69,15 @@ impl DataService for MockDataService { #[derive(Debug, Clone)] pub struct GenericServer { data_service: Arc, + tool_router: ToolRouter, } -#[tool(tool_box)] +#[tool_router] impl GenericServer { pub fn new(data_service: DS) -> Self { Self { data_service: Arc::new(data_service), + tool_router: Self::tool_router(), } } @@ -83,16 +86,22 @@ impl GenericServer { self.data_service.get_data() } } -#[tool(tool_box)] + +#[tool_handler] impl ServerHandler for GenericServer {} #[tokio::test] async fn test_tool_macros() { - let server = Server::default(); + let server = Server::new(); let _attr = Server::get_weather_tool_attr(); - let _get_weather_call_fn = Server::get_weather_tool_call; + let _get_weather_tool_attr_fn = Server::get_weather_tool_attr; let _get_weather_fn = Server::get_weather; - server.get_weather("harbin".into()).await; + server + .get_weather(Parameters(GetWeatherRequest { + city: "Harbin".into(), + date: "Yesterday".into(), + })) + .await; } #[tokio::test] @@ -108,14 +117,14 @@ async fn test_tool_macros_with_generics() { let mock_service = MockDataService; let server = GenericServer::new(mock_service); let _attr = GenericServer::::get_data_tool_attr(); - let _get_data_call_fn = GenericServer::::get_data_tool_call; + let _get_data_call_fn = GenericServer::::get_data; let _get_data_fn = GenericServer::::get_data; assert_eq!(server.get_data().await, "mock data"); } #[tokio::test] async fn test_tool_macros_with_optional_param() { - let _attr = Server::optional_param_tool_attr(); + let _attr = Server::get_weather_tool_attr(); // println!("{_attr:?}"); let attr_type = _attr .input_schema @@ -147,49 +156,56 @@ pub struct OptionalI64TestSchema { } // Dummy struct to host the test tool method -#[derive(Debug, Clone, Default)] -pub struct OptionalSchemaTester {} +#[derive(Debug, Clone)] +pub struct OptionalSchemaTester { + tool_router: ToolRouter, +} + +impl Default for OptionalSchemaTester { + fn default() -> Self { + Self::new() + } +} + +impl OptionalSchemaTester { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } +} +#[tool_router] impl OptionalSchemaTester { // Dummy tool function using the test schema as an aggregated parameter #[tool(description = "A tool to test optional schema generation")] - async fn test_optional_aggr(&self, #[tool(aggr)] _req: OptionalFieldTestSchema) { + async fn test_optional(&self, _req: Parameters) { // Implementation doesn't matter for schema testing // Return type changed to () to satisfy IntoCallToolResult } // Tool function to test optional i64 handling #[tool(description = "A tool to test optional i64 schema generation")] - async fn test_optional_i64_aggr(&self, #[tool(aggr)] req: OptionalI64TestSchema) -> String { + async fn test_optional_i64( + &self, + Parameters(req): Parameters, + ) -> String { match req.count { Some(c) => format!("Received count: {}", c), None => "Received null count".to_string(), } } } - +#[tool_handler] // Implement ServerHandler to route tool calls for OptionalSchemaTester -impl ServerHandler for OptionalSchemaTester { - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = ToolCallContext::new(self, request, context); - match tcc.name() { - "test_optional_aggr" => Self::test_optional_aggr_tool_call(tcc).await, - "test_optional_i64_aggr" => Self::test_optional_i64_aggr_tool_call(tcc).await, - _ => Err(rmcp::Error::invalid_params("method not found", None)), - } - } -} +impl ServerHandler for OptionalSchemaTester {} #[test] fn test_optional_field_schema_generation_via_macro() { // tests https://github.com/modelcontextprotocol/rust-sdk/issues/135 // Get the attributes generated by the #[tool] macro helper - let tool_attr = OptionalSchemaTester::test_optional_aggr_tool_attr(); + let tool_attr = OptionalSchemaTester::test_optional_tool_attr(); // Print the actual generated schema for debugging println!( @@ -257,7 +273,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); // Server setup - let server = OptionalSchemaTester::default(); + let server = OptionalSchemaTester::new(); let server_handle = tokio::spawn(async move { server.serve(server_transport).await?.waiting().await?; anyhow::Ok(()) @@ -270,7 +286,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { // Test null case let result = client .call_tool(CallToolRequestParam { - name: "test_optional_i64_aggr".into(), + name: "test_optional_i64".into(), arguments: Some( serde_json::json!({ "count": null, @@ -298,7 +314,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { // Test Some case let some_result = client .call_tool(CallToolRequestParam { - name: "test_optional_i64_aggr".into(), + name: "test_optional_i64".into(), arguments: Some( serde_json::json!({ "count": 42, diff --git a/crates/rmcp/tests/test_tool_routers.rs b/crates/rmcp/tests/test_tool_routers.rs new file mode 100644 index 00000000..846fe424 --- /dev/null +++ b/crates/rmcp/tests/test_tool_routers.rs @@ -0,0 +1,74 @@ +use std::collections::HashMap; + +use futures::future::BoxFuture; +use rmcp::{ + ServerHandler, + handler::server::{ + router::tool::ToolRouter, + tool::{CallToolHandler, Parameters}, + }, +}; + +#[derive(Debug, Default)] +pub struct TestHandler { + pub _marker: std::marker::PhantomData, +} + +impl ServerHandler for TestHandler {} +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct Request { + pub fields: HashMap, +} + +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct Sum { + pub a: i32, + pub b: i32, +} + +#[rmcp::tool_router(router = test_router_1)] +impl TestHandler { + #[rmcp::tool] + async fn async_method(&self, Parameters(Request { fields }): Parameters) { + drop(fields) + } +} + +#[rmcp::tool_router(router = test_router_2)] +impl TestHandler { + #[rmcp::tool] + fn sync_method(&self, Parameters(Request { fields }): Parameters) { + drop(fields) + } +} + +#[rmcp::tool] +async fn async_function( + _callee: &TestHandler, + Parameters(Request { fields }): Parameters, +) { + drop(fields) +} + +#[rmcp::tool] +fn async_function2(_callee: &TestHandler) -> BoxFuture<'_, ()> { + Box::pin(async move {}) +} + +#[test] +fn test_tool_router() { + let test_tool_router: ToolRouter> = ToolRouter::>::new() + .with_route((async_function_tool_attr(), async_function)) + .with_route((async_function2_tool_attr(), async_function2)) + + TestHandler::<()>::test_router_1() + + TestHandler::<()>::test_router_2(); + let tools = test_tool_router.list_all(); + assert_eq!(tools.len(), 4); + assert_handler(TestHandler::<()>::async_method); +} + +fn assert_handler(_handler: H) +where + H: CallToolHandler, +{ +} diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index b00752ad..3f2761cd 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -96,7 +96,7 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { let service: StreamableHttpService = StreamableHttpService::new( - || Ok(Calculator), + || Ok(Calculator::new()), Default::default(), StreamableHttpServerConfig { stateful_mode: true, diff --git a/docs/CONTRIBUTE.MD b/docs/CONTRIBUTE.MD new file mode 100644 index 00000000..95e3a834 --- /dev/null +++ b/docs/CONTRIBUTE.MD @@ -0,0 +1,13 @@ +# Discuss first +If you have a idea, make sure it is discussed before you make a PR. + +# Fmt And Clippy +You can use [just](https://github.com/casey/just) to help you fix your commit rapidly: +```shell +just fix +``` + +# How Can I Rewrite My Commit Message? +You can `git reset --soft upstream/main` and `git commit --forge`, this will merge your changes into one commit. + +Or you also can use git rebase. But we will still merge them into one commit when it is merged. \ No newline at end of file diff --git a/docs/readme/README.zh-cn.md b/docs/readme/README.zh-cn.md index cc730c1d..0ae1745c 100644 --- a/docs/readme/README.zh-cn.md +++ b/docs/readme/README.zh-cn.md @@ -97,7 +97,7 @@ pub struct SumRequest { pub struct Calculator; // create a static toolbox to store the tool attributes -#[tool(tool_box)] +#[tool_router] impl Calculator { // async function #[tool(description = "Calculate the sum of two numbers")] @@ -122,7 +122,7 @@ impl Calculator { } // impl call_tool and list_tool by querying static toolbox -#[tool(tool_box)] +#[tool_handler] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { diff --git a/examples/servers/src/common/calculator.rs b/examples/servers/src/common/calculator.rs index 68beecc0..3a3b3538 100644 --- a/examples/servers/src/common/calculator.rs +++ b/examples/servers/src/common/calculator.rs @@ -1,8 +1,10 @@ +#![allow(dead_code)] + use rmcp::{ ServerHandler, - handler::server::wrapper::Json, + handler::server::{router::tool::ToolRouter, tool::Parameters, wrapper::Json}, model::{ServerCapabilities, ServerInfo}, - schemars, tool, + schemars, tool, tool_handler, tool_router, }; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] @@ -11,30 +13,40 @@ pub struct SumRequest { pub a: i32, pub b: i32, } + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct SubRequest { + #[schemars(description = "the left hand side number")] + pub a: i32, + #[schemars(description = "the right hand side number")] + pub b: i32, +} + #[derive(Debug, Clone)] -pub struct Calculator; -#[tool(tool_box)] +pub struct Calculator { + tool_router: ToolRouter, +} + +#[tool_router] impl Calculator { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + fn sum(&self, Parameters(SumRequest { a, b }): Parameters) -> String { (a + b).to_string() } #[tool(description = "Calculate the difference of two numbers")] - fn sub( - &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, - ) -> Json { + fn sub(&self, Parameters(SubRequest { a, b }): Parameters) -> Json { Json(a - b) } } -#[tool(tool_box)] +#[tool_handler] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { ServerInfo { diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index 12aa8a4a..c205e293 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -1,8 +1,13 @@ +#![allow(dead_code)] use std::sync::Arc; use rmcp::{ - Error as McpError, RoleServer, ServerHandler, const_string, model::*, schemars, - service::RequestContext, tool, + Error as McpError, RoleServer, ServerHandler, + handler::server::{router::tool::ToolRouter, tool::Parameters}, + model::*, + schemars, + service::RequestContext, + tool, tool_handler, tool_router, }; use serde_json::json; use tokio::sync::Mutex; @@ -16,14 +21,16 @@ pub struct StructRequest { #[derive(Clone)] pub struct Counter { counter: Arc>, + tool_router: ToolRouter, } -#[tool(tool_box)] +#[tool_router] impl Counter { #[allow(dead_code)] pub fn new() -> Self { Self { counter: Arc::new(Mutex::new(0)), + tool_router: Self::tool_router(), } } @@ -63,27 +70,23 @@ impl Counter { } #[tool(description = "Repeat what you say")] - fn echo( - &self, - #[tool(param)] - #[schemars(description = "Repeat what you say")] - saying: String, - ) -> Result { - Ok(CallToolResult::success(vec![Content::text(saying)])) + fn echo(&self, Parameters(object): Parameters) -> Result { + Ok(CallToolResult::success(vec![Content::text( + serde_json::Value::Object(object).to_string(), + )])) } #[tool(description = "Calculate the sum of two numbers")] fn sum( &self, - #[tool(aggr)] StructRequest { a, b }: StructRequest, + Parameters(StructRequest { a, b }): Parameters, ) -> Result { Ok(CallToolResult::success(vec![Content::text( (a + b).to_string(), )])) } } -const_string!(Echo = "echo"); -#[tool(tool_box)] +#[tool_handler] impl ServerHandler for Counter { fn get_info(&self) -> ServerInfo { ServerInfo { diff --git a/examples/servers/src/common/generic_service.rs b/examples/servers/src/common/generic_service.rs index fc1d00ed..b629bc1c 100644 --- a/examples/servers/src/common/generic_service.rs +++ b/examples/servers/src/common/generic_service.rs @@ -2,8 +2,9 @@ use std::sync::Arc; use rmcp::{ ServerHandler, + handler::server::{router::tool::ToolRouter, tool::Parameters}, model::{ServerCapabilities, ServerInfo}, - schemars, tool, + schemars, tool, tool_handler, tool_router, }; #[allow(dead_code)] @@ -40,14 +41,21 @@ impl DataService for MemoryDataService { pub struct GenericService { #[allow(dead_code)] data_service: Arc, + tool_router: ToolRouter, } -#[tool(tool_box)] +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct SetDataRequest { + pub data: String, +} + +#[tool_router] impl GenericService { #[allow(dead_code)] pub fn new(data_service: DS) -> Self { Self { data_service: Arc::new(data_service), + tool_router: Self::tool_router(), } } @@ -57,13 +65,16 @@ impl GenericService { } #[tool(description = "set memory to service")] - pub async fn set_data(&self, #[tool(param)] data: String) -> String { + pub async fn set_data( + &self, + Parameters(SetDataRequest { data }): Parameters, + ) -> String { let new_data = data.clone(); format!("Current memory: {}", new_data) } } -#[tool(tool_box)] +#[tool_handler] impl ServerHandler for GenericService { fn get_info(&self) -> ServerInfo { ServerInfo { diff --git a/examples/transport/src/common/calculator.rs b/examples/transport/src/common/calculator.rs index 99b7314a..8d7fdd62 100644 --- a/examples/transport/src/common/calculator.rs +++ b/examples/transport/src/common/calculator.rs @@ -1,4 +1,11 @@ -use rmcp::{ServerHandler, model::ServerInfo, schemars, tool, tool_box}; +#![allow(dead_code)] + +use rmcp::{ + ServerHandler, + handler::server::{router::tool::ToolRouter, tool::Parameters, wrapper::Json}, + model::{ServerCapabilities, ServerInfo}, + schemars, tool, tool_handler, tool_router, +}; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct SumRequest { @@ -6,35 +13,44 @@ pub struct SumRequest { pub a: i32, pub b: i32, } + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct SubRequest { + #[schemars(description = "the left hand side number")] + pub a: i32, + #[schemars(description = "the right hand side number")] + pub b: i32, +} + #[derive(Debug, Clone)] -pub struct Calculator; +pub struct Calculator { + tool_router: ToolRouter, +} + +#[tool_router] impl Calculator { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + fn sum(&self, Parameters(SumRequest { a, b }): Parameters) -> String { (a + b).to_string() } - #[tool(description = "Calculate the sub of two numbers")] - fn sub( - &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, - ) -> String { - (a - b).to_string() + #[tool(description = "Calculate the difference of two numbers")] + fn sub(&self, Parameters(SubRequest { a, b }): Parameters) -> Json { + Json(a - b) } - - tool_box!(Calculator { sum, sub }); } - +#[tool_handler] impl ServerHandler for Calculator { - tool_box!(@derive); fn get_info(&self) -> ServerInfo { ServerInfo { instructions: Some("A simple calculator".into()), + capabilities: ServerCapabilities::builder().enable_tools().build(), ..Default::default() } } diff --git a/examples/transport/src/http_upgrade.rs b/examples/transport/src/http_upgrade.rs index 1c0cf6ad..6a15add3 100644 --- a/examples/transport/src/http_upgrade.rs +++ b/examples/transport/src/http_upgrade.rs @@ -24,7 +24,7 @@ async fn main() -> anyhow::Result<()> { async fn http_server(req: Request) -> Result, hyper::Error> { tokio::spawn(async move { let upgraded = hyper::upgrade::on(req).await?; - let service = Calculator.serve(TokioIo::new(upgraded)).await?; + let service = Calculator::new().serve(TokioIo::new(upgraded)).await?; service.waiting().await?; anyhow::Result::<()>::Ok(()) }); diff --git a/examples/transport/src/tcp.rs b/examples/transport/src/tcp.rs index 72428fe6..683fb6cf 100644 --- a/examples/transport/src/tcp.rs +++ b/examples/transport/src/tcp.rs @@ -13,7 +13,7 @@ async fn server() -> anyhow::Result<()> { let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:8001").await?; while let Ok((stream, _)) = tcp_listener.accept().await { tokio::spawn(async move { - let server = serve_server(Calculator, stream).await?; + let server = serve_server(Calculator::new(), stream).await?; server.waiting().await?; anyhow::Ok(()) }); diff --git a/examples/transport/src/unix_socket.rs b/examples/transport/src/unix_socket.rs index a52b45a3..feeb2b87 100644 --- a/examples/transport/src/unix_socket.rs +++ b/examples/transport/src/unix_socket.rs @@ -14,7 +14,7 @@ async fn main() -> anyhow::Result<()> { while let Ok((stream, addr)) = unix_listener.accept().await { println!("Client connected: {:?}", addr); tokio::spawn(async move { - match serve_server(Calculator, stream).await { + match serve_server(Calculator::new(), stream).await { Ok(server) => { println!("Server initialized successfully"); if let Err(e) = server.waiting().await { diff --git a/examples/transport/src/websocket.rs b/examples/transport/src/websocket.rs index 0d0fec72..5ba23546 100644 --- a/examples/transport/src/websocket.rs +++ b/examples/transport/src/websocket.rs @@ -40,7 +40,7 @@ async fn start_server() -> anyhow::Result<()> { tokio::spawn(async move { let ws_stream = tokio_tungstenite::accept_async(stream).await?; let transport = WebsocketTransport::new_server(ws_stream); - let server = Calculator.serve(transport).await?; + let server = Calculator::new().serve(transport).await?; server.waiting().await?; Ok::<(), anyhow::Error>(()) }); diff --git a/examples/wasi/src/calculator.rs b/examples/wasi/src/calculator.rs index f1c35eea..6f1d0a3f 100644 --- a/examples/wasi/src/calculator.rs +++ b/examples/wasi/src/calculator.rs @@ -1,41 +1,58 @@ +#![allow(dead_code)] + use rmcp::{ ServerHandler, + handler::server::{router::tool::ToolRouter, tool::Parameters, wrapper::Json}, model::{ServerCapabilities, ServerInfo}, - schemars, tool, tool_box, + schemars, tool, tool_handler, tool_router, }; -#[derive(Debug, rmcp::serde::Deserialize, schemars::JsonSchema)] +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct SumRequest { #[schemars(description = "the left hand side number")] pub a: i32, pub b: i32, } + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct SubRequest { + #[schemars(description = "the left hand side number")] + pub a: i32, + #[schemars(description = "the right hand side number")] + pub b: i32, +} + #[derive(Debug, Clone)] -pub struct Calculator; +pub struct Calculator { + tool_router: ToolRouter, +} + +impl Default for Calculator { + fn default() -> Self { + Self::new() + } +} + +#[tool_router] impl Calculator { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + fn sum(&self, Parameters(SumRequest { a, b }): Parameters) -> String { (a + b).to_string() } - #[tool(description = "Calculate the sub of two numbers")] - fn sub( - &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, - ) -> String { - (a - b).to_string() + #[tool(description = "Calculate the difference of two numbers")] + fn sub(&self, Parameters(SubRequest { a, b }): Parameters) -> Json { + Json(a - b) } - - tool_box!(Calculator { sum, sub }); } - +#[tool_handler] impl ServerHandler for Calculator { - tool_box!(@derive); fn get_info(&self) -> ServerInfo { ServerInfo { instructions: Some("A simple calculator".into()), diff --git a/examples/wasi/src/lib.rs b/examples/wasi/src/lib.rs index 3b2904a8..2690cc73 100644 --- a/examples/wasi/src/lib.rs +++ b/examples/wasi/src/lib.rs @@ -112,7 +112,10 @@ impl wasi::exports::cli::run::Guest for TokioCliRunner { .with_writer(std::io::stderr) .with_ansi(false) .init(); - let server = calculator::Calculator.serve(wasi_io()).await.unwrap(); + let server = calculator::Calculator::new() + .serve(wasi_io()) + .await + .unwrap(); server.waiting().await.unwrap(); }); Ok(()) diff --git a/justfile b/justfile new file mode 100644 index 00000000..970aa1a5 --- /dev/null +++ b/justfile @@ -0,0 +1,12 @@ +fmt: + cargo +nightly fmt --all + +check: + cargo clippy --all-targets --all-features -- -D warnings + +fix: fmt + git add ./ + cargo clippy --fix --all-targets --all-features --allow-staged + +test: + cargo test --all-features \ No newline at end of file