|
48 | 48 | from ._mcp_manager import MCPSessionManager |
49 | 49 | from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT |
50 | 50 | from ._tokens import compute_cost, get_token_pricing, tokens_log |
51 | | -from ._tools import Tool, ToolRejectError |
| 51 | +from ._tools import Tool, ToolBuiltIn, ToolRejectError |
52 | 52 | from ._turn import Turn, user_turn |
53 | 53 | from ._typing_extensions import TypedDict, TypeGuard |
54 | 54 | from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async |
@@ -131,7 +131,7 @@ def __init__( |
131 | 131 | self.system_prompt = system_prompt |
132 | 132 | self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {} |
133 | 133 |
|
134 | | - self._tools: dict[str, Tool] = {} |
| 134 | + self._tools: dict[str, Tool | ToolBuiltIn] = {} |
135 | 135 | self._on_tool_request_callbacks = CallbackManager() |
136 | 136 | self._on_tool_result_callbacks = CallbackManager() |
137 | 137 | self._current_display: Optional[MarkdownDisplay] = None |
@@ -1866,7 +1866,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None): |
1866 | 1866 |
|
1867 | 1867 | def register_tool( |
1868 | 1868 | self, |
1869 | | - func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool, |
| 1869 | + func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | "ToolBuiltIn", |
1870 | 1870 | *, |
1871 | 1871 | force: bool = False, |
1872 | 1872 | name: Optional[str] = None, |
@@ -1960,31 +1960,39 @@ def add(a: int, b: int) -> int: |
1960 | 1960 | ValueError |
1961 | 1961 | If a tool with the same name already exists and `force` is `False`. |
1962 | 1962 | """ |
1963 | | - if isinstance(func, Tool): |
| 1963 | + if isinstance(func, ToolBuiltIn): |
| 1964 | + # ToolBuiltIn objects are stored directly without conversion |
| 1965 | + tool = func |
| 1966 | + tool_name = tool.name |
| 1967 | + elif isinstance(func, Tool): |
1964 | 1968 | name = name or func.name |
1965 | 1969 | annotations = annotations or func.annotations |
1966 | 1970 | if model is not None: |
1967 | 1971 | func = Tool.from_func( |
1968 | 1972 | func.func, name=name, model=model, annotations=annotations |
1969 | 1973 | ) |
1970 | 1974 | func = func.func |
| 1975 | + tool = Tool.from_func(func, name=name, model=model, annotations=annotations) |
| 1976 | + tool_name = tool.name |
| 1977 | + else: |
| 1978 | + tool = Tool.from_func(func, name=name, model=model, annotations=annotations) |
| 1979 | + tool_name = tool.name |
1971 | 1980 |
|
1972 | | - tool = Tool.from_func(func, name=name, model=model, annotations=annotations) |
1973 | | - if tool.name in self._tools and not force: |
| 1981 | + if tool_name in self._tools and not force: |
1974 | 1982 | raise ValueError( |
1975 | | - f"Tool with name '{tool.name}' is already registered. " |
| 1983 | + f"Tool with name '{tool_name}' is already registered. " |
1976 | 1984 | "Set `force=True` to overwrite it." |
1977 | 1985 | ) |
1978 | | - self._tools[tool.name] = tool |
| 1986 | + self._tools[tool_name] = tool |
1979 | 1987 |
|
1980 | | - def get_tools(self) -> list[Tool]: |
| 1988 | + def get_tools(self) -> list[Tool | ToolBuiltIn]: |
1981 | 1989 | """ |
1982 | 1990 | Get the list of registered tools. |
1983 | 1991 |
|
1984 | 1992 | Returns |
1985 | 1993 | ------- |
1986 | | - list[Tool] |
1987 | | - A list of `Tool` instances that are currently registered with the chat. |
| 1994 | + list[Tool | ToolBuiltIn] |
| 1995 | + A list of `Tool` or `ToolBuiltIn` instances that are currently registered with the chat. |
1988 | 1996 | """ |
1989 | 1997 | return list(self._tools.values()) |
1990 | 1998 |
|
@@ -2508,7 +2516,7 @@ def _submit_turns( |
2508 | 2516 | data_model: type[BaseModel] | None = None, |
2509 | 2517 | kwargs: Optional[SubmitInputArgsT] = None, |
2510 | 2518 | ) -> Generator[str, None, None]: |
2511 | | - if any(x._is_async for x in self._tools.values()): |
| 2519 | + if any(hasattr(x, "_is_async") and x._is_async for x in self._tools.values()): |
2512 | 2520 | raise ValueError("Cannot use async tools in a synchronous chat") |
2513 | 2521 |
|
2514 | 2522 | def emit(text: str | Content): |
@@ -2661,15 +2669,27 @@ def _collect_all_kwargs( |
2661 | 2669 |
|
2662 | 2670 | def _invoke_tool(self, request: ContentToolRequest): |
2663 | 2671 | tool = self._tools.get(request.name) |
2664 | | - func = tool.func if tool is not None else None |
2665 | 2672 |
|
2666 | | - if func is None: |
| 2673 | + if tool is None: |
2667 | 2674 | yield self._handle_tool_error_result( |
2668 | 2675 | request, |
2669 | 2676 | error=RuntimeError("Unknown tool."), |
2670 | 2677 | ) |
2671 | 2678 | return |
2672 | 2679 |
|
| 2680 | + if isinstance(tool, ToolBuiltIn): |
| 2681 | + # Built-in tools are handled by the provider, not invoked directly |
| 2682 | + yield self._handle_tool_error_result( |
| 2683 | + request, |
| 2684 | + error=RuntimeError( |
| 2685 | + f"Built-in tool '{request.name}' cannot be invoked directly. " |
| 2686 | + "It should be handled by the provider." |
| 2687 | + ), |
| 2688 | + ) |
| 2689 | + return |
| 2690 | + |
| 2691 | + func = tool.func |
| 2692 | + |
2673 | 2693 | # First, invoke the request callbacks. If a ToolRejectError is raised, |
2674 | 2694 | # treat it like a tool failure (i.e., gracefully handle it). |
2675 | 2695 | result: ContentToolResult | None = None |
@@ -2717,6 +2737,17 @@ async def _invoke_tool_async(self, request: ContentToolRequest): |
2717 | 2737 | ) |
2718 | 2738 | return |
2719 | 2739 |
|
| 2740 | + if isinstance(tool, ToolBuiltIn): |
| 2741 | + # Built-in tools are handled by the provider, not invoked directly |
| 2742 | + yield self._handle_tool_error_result( |
| 2743 | + request, |
| 2744 | + error=RuntimeError( |
| 2745 | + f"Built-in tool '{request.name}' cannot be invoked directly. " |
| 2746 | + "It should be handled by the provider." |
| 2747 | + ), |
| 2748 | + ) |
| 2749 | + return |
| 2750 | + |
2720 | 2751 | if tool._is_async: |
2721 | 2752 | func = tool.func |
2722 | 2753 | else: |
|
0 commit comments