From 92d9426895d409822f9ae259a630a7edd5353ca9 Mon Sep 17 00:00:00 2001 From: Stephan Lensky Date: Tue, 22 Apr 2025 18:00:47 -0400 Subject: [PATCH] Fix async callable object tools --- src/mcp/server/fastmcp/tools/base.py | 12 ++++- tests/server/fastmcp/test_tool_manager.py | 61 +++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index 92a216f56..b42ca1bb8 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import functools import inspect from collections.abc import Callable from typing import TYPE_CHECKING, Any, get_origin @@ -48,7 +49,7 @@ def from_function( raise ValueError("You must provide a name for lambda functions") func_doc = description or fn.__doc__ or "" - is_async = inspect.iscoroutinefunction(fn) + is_async = _is_async_callable(fn) if context_kwarg is None: sig = inspect.signature(fn) @@ -92,3 +93,12 @@ async def run( ) except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e + + +def _is_async_callable(obj: Any) -> bool: + while isinstance(obj, functools.partial): + obj = obj.func + + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) + ) diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 8f52e3d85..51e63fa27 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -71,6 +71,39 @@ def create_user(user: UserInput, flag: bool) -> dict: assert "age" in tool.parameters["$defs"]["UserInput"]["properties"] assert "flag" in tool.parameters["properties"] + def test_add_callable_object(self): + """Test registering a callable object.""" + + class MyTool: + def __init__(self): + self.__name__ = "MyTool" + + def __call__(self, x: int) -> int: + return x * 2 + + manager = ToolManager() + tool = manager.add_tool(MyTool()) + assert tool.name == "MyTool" + assert tool.is_async is False + assert tool.parameters["properties"]["x"]["type"] == "integer" + + @pytest.mark.anyio + async def test_add_async_callable_object(self): + """Test registering an async callable object.""" + + class MyAsyncTool: + def __init__(self): + self.__name__ = "MyAsyncTool" + + async def __call__(self, x: int) -> int: + return x * 2 + + manager = ToolManager() + tool = manager.add_tool(MyAsyncTool()) + assert tool.name == "MyAsyncTool" + assert tool.is_async is True + assert tool.parameters["properties"]["x"]["type"] == "integer" + def test_add_invalid_tool(self): manager = ToolManager() with pytest.raises(AttributeError): @@ -137,6 +170,34 @@ async def double(n: int) -> int: result = await manager.call_tool("double", {"n": 5}) assert result == 10 + @pytest.mark.anyio + async def test_call_object_tool(self): + class MyTool: + def __init__(self): + self.__name__ = "MyTool" + + def __call__(self, x: int) -> int: + return x * 2 + + manager = ToolManager() + tool = manager.add_tool(MyTool()) + result = await tool.run({"x": 5}) + assert result == 10 + + @pytest.mark.anyio + async def test_call_async_object_tool(self): + class MyAsyncTool: + def __init__(self): + self.__name__ = "MyAsyncTool" + + async def __call__(self, x: int) -> int: + return x * 2 + + manager = ToolManager() + tool = manager.add_tool(MyAsyncTool()) + result = await tool.run({"x": 5}) + assert result == 10 + @pytest.mark.anyio async def test_call_tool_with_default_args(self): def add(a: int, b: int = 1) -> int: