diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f775057..252d5e67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added a `ChatSnowflake()` class to interact with [Snowflake Cortex LLM](https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions). (#54) * Added a `ChatAuto()` class, allowing for configuration of chat providers and models via environment variables. (#38, thanks @mconflitti-pbc) +* Added a `ToolResult()` class which allow for: (1) control how results get formatted when sent to the model and (2) yield additional content to the user (i.e., the downstream consumer of a `.stream()` or `.chat()`) for display when the tool is called. (#69) +* Added a `on_request` parameter to `.register_tool()`. When tool is requested, this callback executes, and the result is yielded to the user. (#69) +* Added a `Chat.on_tool_request()` method for registering a default tool request handler. (#69) + + +### Changes + +* By default, tool results are formatted as a JSON string when sent to the model. (#69) ### Improvements diff --git a/chatlas/__init__.py b/chatlas/__init__.py index 29f70974..5ce20b58 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -13,7 +13,7 @@ from ._provider import Provider from ._snowflake import ChatSnowflake from ._tokens import token_usage -from ._tools import Tool +from ._tools import Tool, ToolResult from ._turn import Turn try: @@ -43,6 +43,7 @@ "Provider", "token_usage", "Tool", + "ToolResult", "Turn", "types", ) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index 93b4cbf8..ada44a2f 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -472,12 +472,15 @@ def _as_content_block(content: Content) -> "ContentBlockParam": "input": content.arguments, } elif isinstance(content, ContentToolResult): - return { + res: ToolResultBlockParam = { "type": "tool_result", "tool_use_id": content.id, - "content": content.get_final_value(), "is_error": content.error is not None, } + # Anthropic supports non-text contents like ImageBlockParam + res["content"] = content.get_final_value() # type: ignore + return res + raise ValueError(f"Unknown content type: {type(content)}") @staticmethod diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 6e1f493f..a5850e13 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -39,7 +39,7 @@ ) from ._logging import log_tool_error from ._provider import Provider -from ._tools import Tool +from ._tools import Stringable, Tool, ToolResult from ._turn import Turn, user_turn from ._typing_extensions import TypedDict from ._utils import html_escape, wrap_async @@ -96,6 +96,9 @@ def __init__( "rich_console": {}, "css_styles": {}, } + self._on_tool_request_default: Optional[ + Callable[[ContentToolRequest], Stringable] + ] = None def get_turns( self, @@ -625,7 +628,7 @@ def stream( *args: Content | str, echo: Literal["text", "all", "none"] = "none", kwargs: Optional[SubmitInputArgsT] = None, - ) -> ChatResponse: + ) -> Generator[Stringable, None, None]: """ Generate a response from the chat in a streaming fashion. @@ -658,19 +661,19 @@ def stream( kwargs=kwargs, ) - def wrapper() -> Generator[str, None, None]: + def wrapper() -> Generator[Stringable, None, None]: with display: for chunk in generator: yield chunk - return ChatResponse(wrapper()) + return wrapper() async def stream_async( self, *args: Content | str, echo: Literal["text", "all", "none"] = "none", kwargs: Optional[SubmitInputArgsT] = None, - ) -> ChatResponseAsync: + ) -> AsyncGenerator[Stringable, None]: """ Generate a response from the chat in a streaming fashion asynchronously. @@ -695,7 +698,7 @@ async def stream_async( display = self._markdown_display(echo=echo) - async def wrapper() -> AsyncGenerator[str, None]: + async def wrapper() -> AsyncGenerator[Stringable, None]: with display: async for chunk in self._chat_impl_async( turn, @@ -706,7 +709,7 @@ async def wrapper() -> AsyncGenerator[str, None]: ): yield chunk - return ChatResponseAsync(wrapper()) + return wrapper() def extract_data( self, @@ -831,6 +834,7 @@ def register_tool( self, func: Callable[..., Any] | Callable[..., Awaitable[Any]], *, + on_request: Optional[Callable[[ContentToolRequest], Stringable]] = None, model: Optional[type[BaseModel]] = None, ): """ @@ -900,6 +904,11 @@ def add(a: int, b: int) -> int: ---------- func The function to be invoked when the tool is called. + on_request + A callable that will be passed a :class:`~chatlas.ContentToolRequest` + when the tool is requested. If defined, and the callable returns a + stringable object, that value will be yielded to the chat as a part + of the response. model A Pydantic model that describes the input parameters for the function. If not provided, the model will be inferred from the function's type hints. @@ -907,9 +916,37 @@ def add(a: int, b: int) -> int: Note that the name and docstring of the model takes precedence over the name and docstring of the function. """ - tool = Tool(func, model=model) + tool = Tool(func, on_request=on_request, model=model) self._tools[tool.name] = tool + def on_tool_request( + self, + func: Callable[[ContentToolRequest], Stringable], + ): + """ + Register a default function to be invoked when a tool is requested. + + This function will be invoked if a tool is requested that does not have + a specific `on_request` function defined. + + Parameters + ---------- + func + A callable that will be passed a :class:`~chatlas.ContentToolRequest` + when the tool is requested. If defined, and the callable returns a + stringable object, that value will be yielded to the chat as a part + of the response. + """ + self._on_tool_request_default = func + + def _on_tool_request(self, req: ContentToolRequest) -> Stringable | None: + tool_def = self._tools.get(req.name, None) + if tool_def and tool_def.on_request: + return tool_def.on_request(req) + if self._on_tool_request_default: + return self._on_tool_request_default(req) + return None + def export( self, filename: str | Path, @@ -1040,7 +1077,7 @@ def _chat_impl( display: MarkdownDisplay, stream: bool, kwargs: Optional[SubmitInputArgsT] = None, - ) -> Generator[str, None, None]: + ) -> Generator[Stringable, None, None]: user_turn_result: Turn | None = user_turn while user_turn_result is not None: for chunk in self._submit_turns( @@ -1051,7 +1088,24 @@ def _chat_impl( kwargs=kwargs, ): yield chunk - user_turn_result = self._invoke_tools() + + turn = self.get_last_turn(role="assistant") + assert turn is not None + user_turn_result = None + + results: list[ContentToolResult] = [] + for x in turn.contents: + if isinstance(x, ContentToolRequest): + req = self._on_tool_request(x) + if req is not None: + yield req + res = self._invoke_tool_request(x) + if res.result and res.result.user is not None: + yield res.result.user + results.append(res) + + if results: + user_turn_result = Turn("user", results) async def _chat_impl_async( self, @@ -1060,7 +1114,7 @@ async def _chat_impl_async( display: MarkdownDisplay, stream: bool, kwargs: Optional[SubmitInputArgsT] = None, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[Stringable, None]: user_turn_result: Turn | None = user_turn while user_turn_result is not None: async for chunk in self._submit_turns_async( @@ -1071,7 +1125,24 @@ async def _chat_impl_async( kwargs=kwargs, ): yield chunk - user_turn_result = await self._invoke_tools_async() + + turn = self.get_last_turn(role="assistant") + assert turn is not None + user_turn_result = None + + results: list[ContentToolResult] = [] + for x in turn.contents: + if isinstance(x, ContentToolRequest): + req = self._on_tool_request(x) + if req is not None: + yield req + res = await self._invoke_tool_request_async(x) + if res.result and res.result.user is not None: + yield res.result.user + results.append(res) + + if results: + user_turn_result = Turn("user", results) def _submit_turns( self, @@ -1085,7 +1156,7 @@ def _submit_turns( if any(x._is_async for x in self._tools.values()): raise ValueError("Cannot use async tools in a synchronous chat") - def emit(text: str | Content): + def emit(text: Stringable): display.update(str(text)) emit("
\n\n") @@ -1148,7 +1219,7 @@ async def _submit_turns_async( data_model: type[BaseModel] | None = None, kwargs: Optional[SubmitInputArgsT] = None, ) -> AsyncGenerator[str, None]: - def emit(text: str | Content): + def emit(text: Stringable): display.update(str(text)) emit("
\n\n") @@ -1202,88 +1273,58 @@ def emit(text: str | Content): self._turns.extend([user_turn, turn]) - def _invoke_tools(self) -> Turn | None: - turn = self.get_last_turn() - if turn is None: - return None - - results: list[ContentToolResult] = [] - for x in turn.contents: - if isinstance(x, ContentToolRequest): - tool_def = self._tools.get(x.name, None) - func = tool_def.func if tool_def is not None else None - results.append(self._invoke_tool(func, x.arguments, x.id)) - - if not results: - return None - - return Turn("user", results) - - async def _invoke_tools_async(self) -> Turn | None: - turn = self.get_last_turn() - if turn is None: - return None - - results: list[ContentToolResult] = [] - for x in turn.contents: - if isinstance(x, ContentToolRequest): - tool_def = self._tools.get(x.name, None) - func = None - if tool_def: - if tool_def._is_async: - func = tool_def.func - else: - func = wrap_async(tool_def.func) - results.append(await self._invoke_tool_async(func, x.arguments, x.id)) + def _invoke_tool_request(self, x: ContentToolRequest) -> ContentToolResult: + tool_def = self._tools.get(x.name, None) + func = tool_def.func if tool_def is not None else None - if not results: - return None - - return Turn("user", results) - - @staticmethod - def _invoke_tool( - func: Callable[..., Any] | None, - arguments: object, - id_: str, - ) -> ContentToolResult: if func is None: - return ContentToolResult(id_, value=None, error="Unknown tool") + return ContentToolResult(x.id, result=None, error="Unknown tool") name = func.__name__ try: - if isinstance(arguments, dict): - result = func(**arguments) + if isinstance(x.arguments, dict): + result = func(**x.arguments) else: - result = func(arguments) + result = func(x.arguments) + + if not isinstance(result, ToolResult): + result = ToolResult(result) - return ContentToolResult(id_, value=result, error=None, name=name) + return ContentToolResult(x.id, result=result, error=None, name=name) except Exception as e: - log_tool_error(name, str(arguments), e) - return ContentToolResult(id_, value=None, error=str(e), name=name) + log_tool_error(name, str(x.arguments), e) + return ContentToolResult(x.id, result=None, error=str(e), name=name) - @staticmethod - async def _invoke_tool_async( - func: Callable[..., Awaitable[Any]] | None, - arguments: object, - id_: str, + async def _invoke_tool_request_async( + self, x: ContentToolRequest ) -> ContentToolResult: + tool_def = self._tools.get(x.name, None) + func = None + if tool_def: + if tool_def._is_async: + func = tool_def.func + else: + func = wrap_async(tool_def.func) + if func is None: - return ContentToolResult(id_, value=None, error="Unknown tool") + return ContentToolResult(x.id, result=None, error="Unknown tool") name = func.__name__ try: - if isinstance(arguments, dict): - result = await func(**arguments) + if isinstance(x.arguments, dict): + result = await func(**x.arguments) else: - result = await func(arguments) + result = await func(x.arguments) + + if not isinstance(result, ToolResult): + result = ToolResult(result) - return ContentToolResult(id_, value=result, error=None, name=name) + return ContentToolResult(x.id, result=result, error=None, name=name) except Exception as e: - log_tool_error(func.__name__, str(arguments), e) - return ContentToolResult(id_, value=None, error=str(e), name=name) + log_tool_error(func.__name__, str(x.arguments), e) + return ContentToolResult(x.id, result=None, error=str(e), name=name) def _markdown_display( self, echo: Literal["text", "all", "none"] @@ -1378,16 +1419,16 @@ class ChatResponse: still be retrieved (via the `content` attribute). """ - def __init__(self, generator: Generator[str, None]): + def __init__(self, generator: Generator[Stringable, None]): self._generator = generator - self.content: str = "" + self.contents: list[Stringable] = [] - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[Stringable]: return self - def __next__(self) -> str: + def __next__(self) -> Stringable: chunk = next(self._generator) - self.content += chunk # Keep track of accumulated content + self.contents.append(chunk) return chunk def get_content(self) -> str: @@ -1396,7 +1437,7 @@ def get_content(self) -> str: """ for _ in self: pass - return self.content + return "".join(str(x) for x in self.contents) @property def consumed(self) -> bool: @@ -1430,23 +1471,23 @@ class ChatResponseAsync: still be retrieved (via the `content` attribute). """ - def __init__(self, generator: AsyncGenerator[str, None]): + def __init__(self, generator: AsyncGenerator[Stringable, None]): self._generator = generator - self.content: str = "" + self.contents: list[Stringable] = [] - def __aiter__(self) -> AsyncIterator[str]: + def __aiter__(self) -> AsyncIterator[Stringable]: return self - async def __anext__(self) -> str: + async def __anext__(self) -> Stringable: chunk = await self._generator.__anext__() - self.content += chunk # Keep track of accumulated content + self.contents.append(chunk) return chunk async def get_content(self) -> str: "Get the chat response content as a string." async for _ in self: pass - return self.content + return "".join(str(x) for x in self.contents) @property def consumed(self) -> bool: diff --git a/chatlas/_content.py b/chatlas/_content.py index e453a91b..620e2fbc 100644 --- a/chatlas/_content.py +++ b/chatlas/_content.py @@ -3,7 +3,15 @@ import json from dataclasses import dataclass from pprint import pformat -from typing import Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, cast + +if TYPE_CHECKING: + from ._tools import ToolResult + + +class Stringable(Protocol): + def __str__(self) -> str: ... + ImageContentTypes = Literal[ "image/png", @@ -195,20 +203,21 @@ class ContentToolResult(Content): """ id: str - value: Any = None + result: Optional[ToolResult] = None name: Optional[str] = None error: Optional[str] = None - def _get_value(self, pretty: bool = False) -> str: + def _get_value(self, pretty: bool = False) -> Stringable: if self.error: return f"Tool calling failed with error: '{self.error}'" + result = cast("ToolResult", self.result) if not pretty: - return str(self.value) + return result.assistant try: - json_val = json.loads(self.value) # type: ignore + json_val = json.loads(result.assistant) # type: ignore return pformat(json_val, indent=2, sort_dicts=False) except: # noqa: E722 - return str(self.value) + return result.assistant # Primarily used for `echo="all"`... def __str__(self): @@ -222,13 +231,14 @@ def _repr_markdown_(self): def __repr__(self, indent: int = 0): res = " " * indent - res += f"" # The actual value to send to the model - def get_final_value(self) -> str: + def get_final_value(self) -> Stringable: return self._get_value() diff --git a/chatlas/_google.py b/chatlas/_google.py index 0113248f..c994e943 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -20,7 +20,7 @@ from ._merge import merge_dicts from ._provider import Provider from ._tokens import tokens_log -from ._tools import Tool +from ._tools import Tool, ToolResult from ._turn import Turn, normalize_turns, user_turn if TYPE_CHECKING: @@ -422,7 +422,7 @@ def _as_part_type(self, content: Content) -> "Part": if content.error: resp = {"error": content.error} else: - resp = {"result": str(content.value)} + resp = {"result": str(content.result)} return Part( # TODO: seems function response parts might need role='tool'??? # https://github.com/googleapis/python-genai/blame/c8cfef85c/README.md#L344 @@ -483,7 +483,7 @@ def _as_turn( contents.append( ContentToolResult( id=function_response.get("id") or name, - value=function_response.get("response"), + result=ToolResult(function_response.get("response")), name=name, ) ) diff --git a/chatlas/_openai.py b/chatlas/_openai.py index d7ea03dd..598ecdcb 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -483,8 +483,8 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]: elif isinstance(x, ContentToolResult): tool_results.append( ChatCompletionToolMessageParam( - # TODO: a tool could return an image!?! - content=x.get_final_value(), + # Currently, OpenAI only allows for text content in tool results + content=cast(str, x.get_final_value()), tool_call_id=x.id, role="tool", ) diff --git a/chatlas/_tools.py b/chatlas/_tools.py index 359b985f..ccd5b381 100644 --- a/chatlas/_tools.py +++ b/chatlas/_tools.py @@ -1,18 +1,28 @@ from __future__ import annotations import inspect +import json import warnings -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Literal, Optional, Protocol from pydantic import BaseModel, Field, create_model from . import _utils -__all__ = ("Tool",) +__all__ = ( + "Tool", + "ToolResult", +) if TYPE_CHECKING: from openai.types.chat import ChatCompletionToolParam + from ._content import ContentToolRequest + + +class Stringable(Protocol): + def __str__(self) -> str: ... + class Tool: """ @@ -40,11 +50,81 @@ def __init__( func: Callable[..., Any] | Callable[..., Awaitable[Any]], *, model: Optional[type[BaseModel]] = None, + on_request: Optional[Callable[[ContentToolRequest], Stringable]] = None, ): self.func = func self._is_async = _utils.is_async_callable(func) self.schema = func_to_schema(func, model) self.name = self.schema["function"]["name"] + self.on_request = on_request + + +class ToolResult: + """ + A result from a tool invocation + + Return an instance of this class from a tool function in order to: + + 1. Yield content for the user (i.e., the downstream consumer of a `.stream()` or `.chat()`) + to display. + 2. Control how the tool result gets formatted for the model (i.e., the assistant). + + Parameters + ---------- + assistant + The tool result to send to the llm (i.e., assistant). If the result is + not a string, `format_as` determines how to the value is formatted + before sending it to the model. + user + A value to yield to the user (i.e., the consumer of a `.stream()`) when + the tool is called. If `None`, no value is yielded. This is primarily + useful for producing custom UI in the response output to indicate to the + user that a tool call has completed (for example, return shiny UI here + when `.stream()`-ing inside a shiny app). + format_as + How to format the `assistant` value for the model. The default, + `"auto"`, first attempts to format the value as a JSON string. If that + fails, it gets converted to a string via `str()`. To force + `json.dumps()` or `str()`, set to `"json"` or `"str"`. Finally, + `"as_is"` is useful for doing your own formatting and/or passing a + non-string value (e.g., a list or dict) straight to the model. + Non-string values are useful for tools that return images or other + 'known' non-text content types. + """ + + def __init__( + self, + assistant: Stringable, + *, + user: Optional[Stringable] = None, + format_as: Literal["auto", "json", "str", "as_is"] = "auto", + ): + # TODO: if called when an active user session, perhaps we could + # provide a smart default here + self.user = user + self.assistant = self._format_value(assistant, format_as) + # TODO: we could consider adding an "emit value" -- that is, the thing to + # display when `echo="all"` is used. I imagine that might be useful for + # advanced users, but let's not worry about it until someone asks for it. + # self.emit = emit + + def _format_value(self, value: Stringable, mode: str) -> Stringable: + if isinstance(value, str): + return value + + if mode == "auto": + try: + return json.dumps(value) + except Exception: + return str(value) + elif mode == "json": + return json.dumps(value) + elif mode == "str": + return str(value) + elif mode == "as_is": + return value + else: + raise ValueError(f"Unknown format mode: {mode}") def func_to_schema( diff --git a/tests/conftest.py b/tests/conftest.py index 7c91f414..c3b0e0a8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,10 +3,11 @@ from typing import Awaitable, Callable import pytest -from chatlas import Chat, Turn, content_image_file, content_image_url from PIL import Image from pydantic import BaseModel +from chatlas import Chat, Turn, content_image_file, content_image_url + ChatFun = Callable[..., Chat] @@ -223,3 +224,8 @@ def assert_images_remote_error(chat_fun: ChatFun): chat.chat("What's in this image?", image_remote) assert len(chat.get_turns()) == 0 + + +@pytest.fixture +def test_images_dir(): + return Path(__file__).parent / "images" diff --git a/tests/images/dice.png b/tests/images/dice.png new file mode 100644 index 00000000..11eb8185 Binary files /dev/null and b/tests/images/dice.png differ diff --git a/tests/test_chat.py b/tests/test_chat.py index 4494db38..7b76b6d8 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -2,9 +2,10 @@ import tempfile import pytest -from chatlas import ChatOpenAI, Turn from pydantic import BaseModel +from chatlas import ChatOpenAI, ToolResult, Turn + def test_simple_batch_chat(): chat = ChatOpenAI() @@ -91,6 +92,76 @@ def test_basic_export(snapshot): assert snapshot == f.read() +def test_tool_results(): + chat = ChatOpenAI(system_prompt="Be very terse, not even punctuation.") + + def get_date(): + """Gets the current date""" + return ToolResult("2024-01-01", user=["Tool result..."]) + + chat.register_tool(get_date) + chat.on_tool_request(lambda req: [f"Requesting tool {req.name}..."]) + + results = [] + for chunk in chat.stream("What's the date?"): + results.append(chunk) + + # Make sure values haven't been str()'d yet + assert ["Requesting tool get_date..."] in results + assert ["Tool result..."] in results + + response_str = "".join(str(chunk) for chunk in results) + + assert "Requesting tool get_date..." in response_str + assert "Tool result..." in response_str + assert "2024-01-01" in response_str + + chat.register_tool(get_date, on_request=lambda req: f"Calling {req.name}...") + + response = chat.chat("What's the date?") + assert "Calling get_date..." in str(response) + assert "Requesting tool get_date..." not in str(response) + assert "Tool result..." in str(response) + assert "2024-01-01" in str(response) + + +@pytest.mark.asyncio +async def test_tool_results_async(): + chat = ChatOpenAI(system_prompt="Be very terse, not even punctuation.") + + async def get_date(): + """Gets the current date""" + import asyncio + + await asyncio.sleep(0.1) + return ToolResult("2024-01-01", user=["Tool result..."]) + + chat.register_tool(get_date) + chat.on_tool_request(lambda req: [f"Requesting tool {req.name}..."]) + + results = [] + async for chunk in await chat.stream_async("What's the date?"): + results.append(chunk) + + # Make sure values haven't been str()'d yet + assert ["Requesting tool get_date..."] in results + assert ["Tool result..."] in results + + response_str = "".join(str(chunk) for chunk in results) + + assert "Requesting tool get_date..." in response_str + assert "Tool result..." in response_str + assert "2024-01-01" in response_str + + chat.register_tool(get_date, on_request=lambda req: [f"Calling {req.name}..."]) + + response = await chat.chat_async("What's the date?") + assert "Calling get_date..." in await response.get_content() + assert "Requesting tool get_date..." not in await response.get_content() + assert "Tool result..." in await response.get_content() + assert "2024-01-01" in await response.get_content() + + def test_extract_data(): chat = ChatOpenAI() diff --git a/tests/test_content_tools.py b/tests/test_content_tools.py index 3197b463..05dfdae4 100644 --- a/tests/test_content_tools.py +++ b/tests/test_content_tools.py @@ -3,7 +3,7 @@ import pytest from chatlas import ChatOpenAI -from chatlas.types import ContentToolResult +from chatlas.types import ContentToolRequest, ContentToolResult def test_register_tool(): @@ -106,24 +106,32 @@ def test_invoke_tool_returns_tool_result(): def tool(): return 1 - res = chat._invoke_tool(tool, {}, id_="x") + chat.register_tool(tool) + + res = chat._invoke_tool_request( + ContentToolRequest(id="x", name="tool", arguments={}) + ) assert isinstance(res, ContentToolResult) assert res.id == "x" assert res.error is None - assert res.value == 1 + assert res.result.assistant == "1" - res = chat._invoke_tool(tool, {"x": 1}, id_="x") + res = chat._invoke_tool_request( + ContentToolRequest(id="x", name="tool", arguments={"x": 1}) + ) assert isinstance(res, ContentToolResult) assert res.id == "x" assert res.error is not None assert "got an unexpected keyword argument" in res.error - assert res.value is None + assert res.result is None - res = chat._invoke_tool(None, {"x": 1}, id_="x") + res = chat._invoke_tool_request( + ContentToolRequest(id="x", name="foo", arguments={"x": 1}) + ) assert isinstance(res, ContentToolResult) assert res.id == "x" assert res.error == "Unknown tool" - assert res.value is None + assert res.result is None @pytest.mark.asyncio @@ -133,21 +141,29 @@ async def test_invoke_tool_returns_tool_result_async(): async def tool(): return 1 - res = await chat._invoke_tool_async(tool, {}, id_="x") + chat.register_tool(tool) + + res = await chat._invoke_tool_request_async( + ContentToolRequest(id="x", name="tool", arguments={}) + ) assert isinstance(res, ContentToolResult) assert res.id == "x" assert res.error is None - assert res.value == 1 + assert res.result.assistant == "1" - res = await chat._invoke_tool_async(tool, {"x": 1}, id_="x") + res = await chat._invoke_tool_request_async( + ContentToolRequest(id="x", name="tool", arguments={"x": 1}) + ) assert isinstance(res, ContentToolResult) assert res.id == "x" assert res.error is not None assert "got an unexpected keyword argument" in res.error - assert res.value is None + assert res.result is None - res = await chat._invoke_tool_async(None, {"x": 1}, id_="x") + res = await chat._invoke_tool_request_async( + ContentToolRequest(id="x", name="foo", arguments={"x": 1}) + ) assert isinstance(res, ContentToolResult) assert res.id == "x" assert res.error == "Unknown tool" - assert res.value is None + assert res.result is None diff --git a/tests/test_provider_anthropic.py b/tests/test_provider_anthropic.py index 8bf8728c..7bc3cb08 100644 --- a/tests/test_provider_anthropic.py +++ b/tests/test_provider_anthropic.py @@ -1,4 +1,7 @@ +import base64 + import pytest + from chatlas import ChatAnthropic from .conftest import ( @@ -94,3 +97,35 @@ def run_inlineassert(): retryassert(run_inlineassert, retries=3) assert_images_remote_error(chat_fun) + + +def test_anthropic_image_tool(test_images_dir): + from chatlas import ToolResult + + def get_picture(): + "Returns an image" + # Local copy of https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png + with open(test_images_dir / "dice.png", "rb") as image: + bytez = image.read() + res = [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": base64.b64encode(bytez).decode("utf-8"), + }, + } + ] + return ToolResult(res, format_as="as_is") + + chat = ChatAnthropic() + chat.register_tool(get_picture) + + res = chat.chat( + "You have a tool called 'get_picture' available to you. " + "When called, it returns an image. " + "Tell me what you see in the image." + ) + + assert "dice" in res.get_content()