diff --git a/src/mistralai/extra/run/tools.py b/src/mistralai/extra/run/tools.py index 81fec665..dac6d948 100644 --- a/src/mistralai/extra/run/tools.py +++ b/src/mistralai/extra/run/tools.py @@ -22,7 +22,8 @@ from mistralai.models import ( FunctionResultEntry, FunctionTool, - Function, + FunctionToolTypedDict, + ToolTypedDict, FunctionCallEntry, ) @@ -34,14 +35,14 @@ class RunFunction: name: str callable: Callable - tool: FunctionTool + tool: Union[FunctionToolTypedDict, ToolTypedDict] @dataclass class RunCoroutine: name: str awaitable: Callable - tool: FunctionTool + tool: Union[FunctionToolTypedDict, ToolTypedDict] @dataclass @@ -140,8 +141,8 @@ def _get_function_parameters( return schema -def create_tool_call(func: Callable) -> FunctionTool: - """Parse a function docstring / type annotations to create a FunctionTool.""" +def create_tool_call(func: Callable) -> Union[FunctionToolTypedDict, ToolTypedDict]: + """Parse a function docstring / type annotations to create a FunctionToolTypedDict or a ToolTypedDict.""" name = func.__name__ # Inspect and parse the docstring of the function @@ -165,19 +166,15 @@ def create_tool_call(func: Callable) -> FunctionTool: params_from_sig = list(sig.parameters.values()) type_hints = get_type_hints(func, include_extras=True, localns=None, globalns=None) - return FunctionTool( - type="function", - function=Function( - name=name, - description=_get_function_description(docstring_sections), - parameters=_get_function_parameters( - docstring_sections=docstring_sections, - params_from_sig=params_from_sig, - type_hints=type_hints, - ), - strict=True, - ), - ) + return { + "type": "function", + "function": { + "name": name, + "description": _get_function_description(docstring_sections), + "parameters": _get_function_parameters(docstring_sections, params_from_sig, type_hints), + "strict": True, + }, + } async def create_function_result(