diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6f775057..252d5e67 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -13,6 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Added a `ChatSnowflake()` class to interact with [Snowflake Cortex LLM](https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions). (#54)
* Added a `ChatAuto()` class, allowing for configuration of chat providers and models via environment variables. (#38, thanks @mconflitti-pbc)
+* Added a `ToolResult()` class which allow for: (1) control how results get formatted when sent to the model and (2) yield additional content to the user (i.e., the downstream consumer of a `.stream()` or `.chat()`) for display when the tool is called. (#69)
+* Added a `on_request` parameter to `.register_tool()`. When tool is requested, this callback executes, and the result is yielded to the user. (#69)
+* Added a `Chat.on_tool_request()` method for registering a default tool request handler. (#69)
+
+
+### Changes
+
+* By default, tool results are formatted as a JSON string when sent to the model. (#69)
### Improvements
diff --git a/chatlas/__init__.py b/chatlas/__init__.py
index 29f70974..5ce20b58 100644
--- a/chatlas/__init__.py
+++ b/chatlas/__init__.py
@@ -13,7 +13,7 @@
from ._provider import Provider
from ._snowflake import ChatSnowflake
from ._tokens import token_usage
-from ._tools import Tool
+from ._tools import Tool, ToolResult
from ._turn import Turn
try:
@@ -43,6 +43,7 @@
"Provider",
"token_usage",
"Tool",
+ "ToolResult",
"Turn",
"types",
)
diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py
index 93b4cbf8..ada44a2f 100644
--- a/chatlas/_anthropic.py
+++ b/chatlas/_anthropic.py
@@ -472,12 +472,15 @@ def _as_content_block(content: Content) -> "ContentBlockParam":
"input": content.arguments,
}
elif isinstance(content, ContentToolResult):
- return {
+ res: ToolResultBlockParam = {
"type": "tool_result",
"tool_use_id": content.id,
- "content": content.get_final_value(),
"is_error": content.error is not None,
}
+ # Anthropic supports non-text contents like ImageBlockParam
+ res["content"] = content.get_final_value() # type: ignore
+ return res
+
raise ValueError(f"Unknown content type: {type(content)}")
@staticmethod
diff --git a/chatlas/_chat.py b/chatlas/_chat.py
index 6e1f493f..a5850e13 100644
--- a/chatlas/_chat.py
+++ b/chatlas/_chat.py
@@ -39,7 +39,7 @@
)
from ._logging import log_tool_error
from ._provider import Provider
-from ._tools import Tool
+from ._tools import Stringable, Tool, ToolResult
from ._turn import Turn, user_turn
from ._typing_extensions import TypedDict
from ._utils import html_escape, wrap_async
@@ -96,6 +96,9 @@ def __init__(
"rich_console": {},
"css_styles": {},
}
+ self._on_tool_request_default: Optional[
+ Callable[[ContentToolRequest], Stringable]
+ ] = None
def get_turns(
self,
@@ -625,7 +628,7 @@ def stream(
*args: Content | str,
echo: Literal["text", "all", "none"] = "none",
kwargs: Optional[SubmitInputArgsT] = None,
- ) -> ChatResponse:
+ ) -> Generator[Stringable, None, None]:
"""
Generate a response from the chat in a streaming fashion.
@@ -658,19 +661,19 @@ def stream(
kwargs=kwargs,
)
- def wrapper() -> Generator[str, None, None]:
+ def wrapper() -> Generator[Stringable, None, None]:
with display:
for chunk in generator:
yield chunk
- return ChatResponse(wrapper())
+ return wrapper()
async def stream_async(
self,
*args: Content | str,
echo: Literal["text", "all", "none"] = "none",
kwargs: Optional[SubmitInputArgsT] = None,
- ) -> ChatResponseAsync:
+ ) -> AsyncGenerator[Stringable, None]:
"""
Generate a response from the chat in a streaming fashion asynchronously.
@@ -695,7 +698,7 @@ async def stream_async(
display = self._markdown_display(echo=echo)
- async def wrapper() -> AsyncGenerator[str, None]:
+ async def wrapper() -> AsyncGenerator[Stringable, None]:
with display:
async for chunk in self._chat_impl_async(
turn,
@@ -706,7 +709,7 @@ async def wrapper() -> AsyncGenerator[str, None]:
):
yield chunk
- return ChatResponseAsync(wrapper())
+ return wrapper()
def extract_data(
self,
@@ -831,6 +834,7 @@ def register_tool(
self,
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
*,
+ on_request: Optional[Callable[[ContentToolRequest], Stringable]] = None,
model: Optional[type[BaseModel]] = None,
):
"""
@@ -900,6 +904,11 @@ def add(a: int, b: int) -> int:
----------
func
The function to be invoked when the tool is called.
+ on_request
+ A callable that will be passed a :class:`~chatlas.ContentToolRequest`
+ when the tool is requested. If defined, and the callable returns a
+ stringable object, that value will be yielded to the chat as a part
+ of the response.
model
A Pydantic model that describes the input parameters for the function.
If not provided, the model will be inferred from the function's type hints.
@@ -907,9 +916,37 @@ def add(a: int, b: int) -> int:
Note that the name and docstring of the model takes precedence over the
name and docstring of the function.
"""
- tool = Tool(func, model=model)
+ tool = Tool(func, on_request=on_request, model=model)
self._tools[tool.name] = tool
+ def on_tool_request(
+ self,
+ func: Callable[[ContentToolRequest], Stringable],
+ ):
+ """
+ Register a default function to be invoked when a tool is requested.
+
+ This function will be invoked if a tool is requested that does not have
+ a specific `on_request` function defined.
+
+ Parameters
+ ----------
+ func
+ A callable that will be passed a :class:`~chatlas.ContentToolRequest`
+ when the tool is requested. If defined, and the callable returns a
+ stringable object, that value will be yielded to the chat as a part
+ of the response.
+ """
+ self._on_tool_request_default = func
+
+ def _on_tool_request(self, req: ContentToolRequest) -> Stringable | None:
+ tool_def = self._tools.get(req.name, None)
+ if tool_def and tool_def.on_request:
+ return tool_def.on_request(req)
+ if self._on_tool_request_default:
+ return self._on_tool_request_default(req)
+ return None
+
def export(
self,
filename: str | Path,
@@ -1040,7 +1077,7 @@ def _chat_impl(
display: MarkdownDisplay,
stream: bool,
kwargs: Optional[SubmitInputArgsT] = None,
- ) -> Generator[str, None, None]:
+ ) -> Generator[Stringable, None, None]:
user_turn_result: Turn | None = user_turn
while user_turn_result is not None:
for chunk in self._submit_turns(
@@ -1051,7 +1088,24 @@ def _chat_impl(
kwargs=kwargs,
):
yield chunk
- user_turn_result = self._invoke_tools()
+
+ turn = self.get_last_turn(role="assistant")
+ assert turn is not None
+ user_turn_result = None
+
+ results: list[ContentToolResult] = []
+ for x in turn.contents:
+ if isinstance(x, ContentToolRequest):
+ req = self._on_tool_request(x)
+ if req is not None:
+ yield req
+ res = self._invoke_tool_request(x)
+ if res.result and res.result.user is not None:
+ yield res.result.user
+ results.append(res)
+
+ if results:
+ user_turn_result = Turn("user", results)
async def _chat_impl_async(
self,
@@ -1060,7 +1114,7 @@ async def _chat_impl_async(
display: MarkdownDisplay,
stream: bool,
kwargs: Optional[SubmitInputArgsT] = None,
- ) -> AsyncGenerator[str, None]:
+ ) -> AsyncGenerator[Stringable, None]:
user_turn_result: Turn | None = user_turn
while user_turn_result is not None:
async for chunk in self._submit_turns_async(
@@ -1071,7 +1125,24 @@ async def _chat_impl_async(
kwargs=kwargs,
):
yield chunk
- user_turn_result = await self._invoke_tools_async()
+
+ turn = self.get_last_turn(role="assistant")
+ assert turn is not None
+ user_turn_result = None
+
+ results: list[ContentToolResult] = []
+ for x in turn.contents:
+ if isinstance(x, ContentToolRequest):
+ req = self._on_tool_request(x)
+ if req is not None:
+ yield req
+ res = await self._invoke_tool_request_async(x)
+ if res.result and res.result.user is not None:
+ yield res.result.user
+ results.append(res)
+
+ if results:
+ user_turn_result = Turn("user", results)
def _submit_turns(
self,
@@ -1085,7 +1156,7 @@ def _submit_turns(
if any(x._is_async for x in self._tools.values()):
raise ValueError("Cannot use async tools in a synchronous chat")
- def emit(text: str | Content):
+ def emit(text: Stringable):
display.update(str(text))
emit("
\n\n")
@@ -1148,7 +1219,7 @@ async def _submit_turns_async(
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> AsyncGenerator[str, None]:
- def emit(text: str | Content):
+ def emit(text: Stringable):
display.update(str(text))
emit("
\n\n")
@@ -1202,88 +1273,58 @@ def emit(text: str | Content):
self._turns.extend([user_turn, turn])
- def _invoke_tools(self) -> Turn | None:
- turn = self.get_last_turn()
- if turn is None:
- return None
-
- results: list[ContentToolResult] = []
- for x in turn.contents:
- if isinstance(x, ContentToolRequest):
- tool_def = self._tools.get(x.name, None)
- func = tool_def.func if tool_def is not None else None
- results.append(self._invoke_tool(func, x.arguments, x.id))
-
- if not results:
- return None
-
- return Turn("user", results)
-
- async def _invoke_tools_async(self) -> Turn | None:
- turn = self.get_last_turn()
- if turn is None:
- return None
-
- results: list[ContentToolResult] = []
- for x in turn.contents:
- if isinstance(x, ContentToolRequest):
- tool_def = self._tools.get(x.name, None)
- func = None
- if tool_def:
- if tool_def._is_async:
- func = tool_def.func
- else:
- func = wrap_async(tool_def.func)
- results.append(await self._invoke_tool_async(func, x.arguments, x.id))
+ def _invoke_tool_request(self, x: ContentToolRequest) -> ContentToolResult:
+ tool_def = self._tools.get(x.name, None)
+ func = tool_def.func if tool_def is not None else None
- if not results:
- return None
-
- return Turn("user", results)
-
- @staticmethod
- def _invoke_tool(
- func: Callable[..., Any] | None,
- arguments: object,
- id_: str,
- ) -> ContentToolResult:
if func is None:
- return ContentToolResult(id_, value=None, error="Unknown tool")
+ return ContentToolResult(x.id, result=None, error="Unknown tool")
name = func.__name__
try:
- if isinstance(arguments, dict):
- result = func(**arguments)
+ if isinstance(x.arguments, dict):
+ result = func(**x.arguments)
else:
- result = func(arguments)
+ result = func(x.arguments)
+
+ if not isinstance(result, ToolResult):
+ result = ToolResult(result)
- return ContentToolResult(id_, value=result, error=None, name=name)
+ return ContentToolResult(x.id, result=result, error=None, name=name)
except Exception as e:
- log_tool_error(name, str(arguments), e)
- return ContentToolResult(id_, value=None, error=str(e), name=name)
+ log_tool_error(name, str(x.arguments), e)
+ return ContentToolResult(x.id, result=None, error=str(e), name=name)
- @staticmethod
- async def _invoke_tool_async(
- func: Callable[..., Awaitable[Any]] | None,
- arguments: object,
- id_: str,
+ async def _invoke_tool_request_async(
+ self, x: ContentToolRequest
) -> ContentToolResult:
+ tool_def = self._tools.get(x.name, None)
+ func = None
+ if tool_def:
+ if tool_def._is_async:
+ func = tool_def.func
+ else:
+ func = wrap_async(tool_def.func)
+
if func is None:
- return ContentToolResult(id_, value=None, error="Unknown tool")
+ return ContentToolResult(x.id, result=None, error="Unknown tool")
name = func.__name__
try:
- if isinstance(arguments, dict):
- result = await func(**arguments)
+ if isinstance(x.arguments, dict):
+ result = await func(**x.arguments)
else:
- result = await func(arguments)
+ result = await func(x.arguments)
+
+ if not isinstance(result, ToolResult):
+ result = ToolResult(result)
- return ContentToolResult(id_, value=result, error=None, name=name)
+ return ContentToolResult(x.id, result=result, error=None, name=name)
except Exception as e:
- log_tool_error(func.__name__, str(arguments), e)
- return ContentToolResult(id_, value=None, error=str(e), name=name)
+ log_tool_error(func.__name__, str(x.arguments), e)
+ return ContentToolResult(x.id, result=None, error=str(e), name=name)
def _markdown_display(
self, echo: Literal["text", "all", "none"]
@@ -1378,16 +1419,16 @@ class ChatResponse:
still be retrieved (via the `content` attribute).
"""
- def __init__(self, generator: Generator[str, None]):
+ def __init__(self, generator: Generator[Stringable, None]):
self._generator = generator
- self.content: str = ""
+ self.contents: list[Stringable] = []
- def __iter__(self) -> Iterator[str]:
+ def __iter__(self) -> Iterator[Stringable]:
return self
- def __next__(self) -> str:
+ def __next__(self) -> Stringable:
chunk = next(self._generator)
- self.content += chunk # Keep track of accumulated content
+ self.contents.append(chunk)
return chunk
def get_content(self) -> str:
@@ -1396,7 +1437,7 @@ def get_content(self) -> str:
"""
for _ in self:
pass
- return self.content
+ return "".join(str(x) for x in self.contents)
@property
def consumed(self) -> bool:
@@ -1430,23 +1471,23 @@ class ChatResponseAsync:
still be retrieved (via the `content` attribute).
"""
- def __init__(self, generator: AsyncGenerator[str, None]):
+ def __init__(self, generator: AsyncGenerator[Stringable, None]):
self._generator = generator
- self.content: str = ""
+ self.contents: list[Stringable] = []
- def __aiter__(self) -> AsyncIterator[str]:
+ def __aiter__(self) -> AsyncIterator[Stringable]:
return self
- async def __anext__(self) -> str:
+ async def __anext__(self) -> Stringable:
chunk = await self._generator.__anext__()
- self.content += chunk # Keep track of accumulated content
+ self.contents.append(chunk)
return chunk
async def get_content(self) -> str:
"Get the chat response content as a string."
async for _ in self:
pass
- return self.content
+ return "".join(str(x) for x in self.contents)
@property
def consumed(self) -> bool:
diff --git a/chatlas/_content.py b/chatlas/_content.py
index e453a91b..620e2fbc 100644
--- a/chatlas/_content.py
+++ b/chatlas/_content.py
@@ -3,7 +3,15 @@
import json
from dataclasses import dataclass
from pprint import pformat
-from typing import Any, Literal, Optional
+from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, cast
+
+if TYPE_CHECKING:
+ from ._tools import ToolResult
+
+
+class Stringable(Protocol):
+ def __str__(self) -> str: ...
+
ImageContentTypes = Literal[
"image/png",
@@ -195,20 +203,21 @@ class ContentToolResult(Content):
"""
id: str
- value: Any = None
+ result: Optional[ToolResult] = None
name: Optional[str] = None
error: Optional[str] = None
- def _get_value(self, pretty: bool = False) -> str:
+ def _get_value(self, pretty: bool = False) -> Stringable:
if self.error:
return f"Tool calling failed with error: '{self.error}'"
+ result = cast("ToolResult", self.result)
if not pretty:
- return str(self.value)
+ return result.assistant
try:
- json_val = json.loads(self.value) # type: ignore
+ json_val = json.loads(result.assistant) # type: ignore
return pformat(json_val, indent=2, sort_dicts=False)
except: # noqa: E722
- return str(self.value)
+ return result.assistant
# Primarily used for `echo="all"`...
def __str__(self):
@@ -222,13 +231,14 @@ def _repr_markdown_(self):
def __repr__(self, indent: int = 0):
res = " " * indent
- res += f""
# The actual value to send to the model
- def get_final_value(self) -> str:
+ def get_final_value(self) -> Stringable:
return self._get_value()
diff --git a/chatlas/_google.py b/chatlas/_google.py
index 0113248f..c994e943 100644
--- a/chatlas/_google.py
+++ b/chatlas/_google.py
@@ -20,7 +20,7 @@
from ._merge import merge_dicts
from ._provider import Provider
from ._tokens import tokens_log
-from ._tools import Tool
+from ._tools import Tool, ToolResult
from ._turn import Turn, normalize_turns, user_turn
if TYPE_CHECKING:
@@ -422,7 +422,7 @@ def _as_part_type(self, content: Content) -> "Part":
if content.error:
resp = {"error": content.error}
else:
- resp = {"result": str(content.value)}
+ resp = {"result": str(content.result)}
return Part(
# TODO: seems function response parts might need role='tool'???
# https://github.com/googleapis/python-genai/blame/c8cfef85c/README.md#L344
@@ -483,7 +483,7 @@ def _as_turn(
contents.append(
ContentToolResult(
id=function_response.get("id") or name,
- value=function_response.get("response"),
+ result=ToolResult(function_response.get("response")),
name=name,
)
)
diff --git a/chatlas/_openai.py b/chatlas/_openai.py
index d7ea03dd..598ecdcb 100644
--- a/chatlas/_openai.py
+++ b/chatlas/_openai.py
@@ -483,8 +483,8 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
elif isinstance(x, ContentToolResult):
tool_results.append(
ChatCompletionToolMessageParam(
- # TODO: a tool could return an image!?!
- content=x.get_final_value(),
+ # Currently, OpenAI only allows for text content in tool results
+ content=cast(str, x.get_final_value()),
tool_call_id=x.id,
role="tool",
)
diff --git a/chatlas/_tools.py b/chatlas/_tools.py
index 359b985f..ccd5b381 100644
--- a/chatlas/_tools.py
+++ b/chatlas/_tools.py
@@ -1,18 +1,28 @@
from __future__ import annotations
import inspect
+import json
import warnings
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Literal, Optional, Protocol
from pydantic import BaseModel, Field, create_model
from . import _utils
-__all__ = ("Tool",)
+__all__ = (
+ "Tool",
+ "ToolResult",
+)
if TYPE_CHECKING:
from openai.types.chat import ChatCompletionToolParam
+ from ._content import ContentToolRequest
+
+
+class Stringable(Protocol):
+ def __str__(self) -> str: ...
+
class Tool:
"""
@@ -40,11 +50,81 @@ def __init__(
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
*,
model: Optional[type[BaseModel]] = None,
+ on_request: Optional[Callable[[ContentToolRequest], Stringable]] = None,
):
self.func = func
self._is_async = _utils.is_async_callable(func)
self.schema = func_to_schema(func, model)
self.name = self.schema["function"]["name"]
+ self.on_request = on_request
+
+
+class ToolResult:
+ """
+ A result from a tool invocation
+
+ Return an instance of this class from a tool function in order to:
+
+ 1. Yield content for the user (i.e., the downstream consumer of a `.stream()` or `.chat()`)
+ to display.
+ 2. Control how the tool result gets formatted for the model (i.e., the assistant).
+
+ Parameters
+ ----------
+ assistant
+ The tool result to send to the llm (i.e., assistant). If the result is
+ not a string, `format_as` determines how to the value is formatted
+ before sending it to the model.
+ user
+ A value to yield to the user (i.e., the consumer of a `.stream()`) when
+ the tool is called. If `None`, no value is yielded. This is primarily
+ useful for producing custom UI in the response output to indicate to the
+ user that a tool call has completed (for example, return shiny UI here
+ when `.stream()`-ing inside a shiny app).
+ format_as
+ How to format the `assistant` value for the model. The default,
+ `"auto"`, first attempts to format the value as a JSON string. If that
+ fails, it gets converted to a string via `str()`. To force
+ `json.dumps()` or `str()`, set to `"json"` or `"str"`. Finally,
+ `"as_is"` is useful for doing your own formatting and/or passing a
+ non-string value (e.g., a list or dict) straight to the model.
+ Non-string values are useful for tools that return images or other
+ 'known' non-text content types.
+ """
+
+ def __init__(
+ self,
+ assistant: Stringable,
+ *,
+ user: Optional[Stringable] = None,
+ format_as: Literal["auto", "json", "str", "as_is"] = "auto",
+ ):
+ # TODO: if called when an active user session, perhaps we could
+ # provide a smart default here
+ self.user = user
+ self.assistant = self._format_value(assistant, format_as)
+ # TODO: we could consider adding an "emit value" -- that is, the thing to
+ # display when `echo="all"` is used. I imagine that might be useful for
+ # advanced users, but let's not worry about it until someone asks for it.
+ # self.emit = emit
+
+ def _format_value(self, value: Stringable, mode: str) -> Stringable:
+ if isinstance(value, str):
+ return value
+
+ if mode == "auto":
+ try:
+ return json.dumps(value)
+ except Exception:
+ return str(value)
+ elif mode == "json":
+ return json.dumps(value)
+ elif mode == "str":
+ return str(value)
+ elif mode == "as_is":
+ return value
+ else:
+ raise ValueError(f"Unknown format mode: {mode}")
def func_to_schema(
diff --git a/tests/conftest.py b/tests/conftest.py
index 7c91f414..c3b0e0a8 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,10 +3,11 @@
from typing import Awaitable, Callable
import pytest
-from chatlas import Chat, Turn, content_image_file, content_image_url
from PIL import Image
from pydantic import BaseModel
+from chatlas import Chat, Turn, content_image_file, content_image_url
+
ChatFun = Callable[..., Chat]
@@ -223,3 +224,8 @@ def assert_images_remote_error(chat_fun: ChatFun):
chat.chat("What's in this image?", image_remote)
assert len(chat.get_turns()) == 0
+
+
+@pytest.fixture
+def test_images_dir():
+ return Path(__file__).parent / "images"
diff --git a/tests/images/dice.png b/tests/images/dice.png
new file mode 100644
index 00000000..11eb8185
Binary files /dev/null and b/tests/images/dice.png differ
diff --git a/tests/test_chat.py b/tests/test_chat.py
index 4494db38..7b76b6d8 100644
--- a/tests/test_chat.py
+++ b/tests/test_chat.py
@@ -2,9 +2,10 @@
import tempfile
import pytest
-from chatlas import ChatOpenAI, Turn
from pydantic import BaseModel
+from chatlas import ChatOpenAI, ToolResult, Turn
+
def test_simple_batch_chat():
chat = ChatOpenAI()
@@ -91,6 +92,76 @@ def test_basic_export(snapshot):
assert snapshot == f.read()
+def test_tool_results():
+ chat = ChatOpenAI(system_prompt="Be very terse, not even punctuation.")
+
+ def get_date():
+ """Gets the current date"""
+ return ToolResult("2024-01-01", user=["Tool result..."])
+
+ chat.register_tool(get_date)
+ chat.on_tool_request(lambda req: [f"Requesting tool {req.name}..."])
+
+ results = []
+ for chunk in chat.stream("What's the date?"):
+ results.append(chunk)
+
+ # Make sure values haven't been str()'d yet
+ assert ["Requesting tool get_date..."] in results
+ assert ["Tool result..."] in results
+
+ response_str = "".join(str(chunk) for chunk in results)
+
+ assert "Requesting tool get_date..." in response_str
+ assert "Tool result..." in response_str
+ assert "2024-01-01" in response_str
+
+ chat.register_tool(get_date, on_request=lambda req: f"Calling {req.name}...")
+
+ response = chat.chat("What's the date?")
+ assert "Calling get_date..." in str(response)
+ assert "Requesting tool get_date..." not in str(response)
+ assert "Tool result..." in str(response)
+ assert "2024-01-01" in str(response)
+
+
+@pytest.mark.asyncio
+async def test_tool_results_async():
+ chat = ChatOpenAI(system_prompt="Be very terse, not even punctuation.")
+
+ async def get_date():
+ """Gets the current date"""
+ import asyncio
+
+ await asyncio.sleep(0.1)
+ return ToolResult("2024-01-01", user=["Tool result..."])
+
+ chat.register_tool(get_date)
+ chat.on_tool_request(lambda req: [f"Requesting tool {req.name}..."])
+
+ results = []
+ async for chunk in await chat.stream_async("What's the date?"):
+ results.append(chunk)
+
+ # Make sure values haven't been str()'d yet
+ assert ["Requesting tool get_date..."] in results
+ assert ["Tool result..."] in results
+
+ response_str = "".join(str(chunk) for chunk in results)
+
+ assert "Requesting tool get_date..." in response_str
+ assert "Tool result..." in response_str
+ assert "2024-01-01" in response_str
+
+ chat.register_tool(get_date, on_request=lambda req: [f"Calling {req.name}..."])
+
+ response = await chat.chat_async("What's the date?")
+ assert "Calling get_date..." in await response.get_content()
+ assert "Requesting tool get_date..." not in await response.get_content()
+ assert "Tool result..." in await response.get_content()
+ assert "2024-01-01" in await response.get_content()
+
+
def test_extract_data():
chat = ChatOpenAI()
diff --git a/tests/test_content_tools.py b/tests/test_content_tools.py
index 3197b463..05dfdae4 100644
--- a/tests/test_content_tools.py
+++ b/tests/test_content_tools.py
@@ -3,7 +3,7 @@
import pytest
from chatlas import ChatOpenAI
-from chatlas.types import ContentToolResult
+from chatlas.types import ContentToolRequest, ContentToolResult
def test_register_tool():
@@ -106,24 +106,32 @@ def test_invoke_tool_returns_tool_result():
def tool():
return 1
- res = chat._invoke_tool(tool, {}, id_="x")
+ chat.register_tool(tool)
+
+ res = chat._invoke_tool_request(
+ ContentToolRequest(id="x", name="tool", arguments={})
+ )
assert isinstance(res, ContentToolResult)
assert res.id == "x"
assert res.error is None
- assert res.value == 1
+ assert res.result.assistant == "1"
- res = chat._invoke_tool(tool, {"x": 1}, id_="x")
+ res = chat._invoke_tool_request(
+ ContentToolRequest(id="x", name="tool", arguments={"x": 1})
+ )
assert isinstance(res, ContentToolResult)
assert res.id == "x"
assert res.error is not None
assert "got an unexpected keyword argument" in res.error
- assert res.value is None
+ assert res.result is None
- res = chat._invoke_tool(None, {"x": 1}, id_="x")
+ res = chat._invoke_tool_request(
+ ContentToolRequest(id="x", name="foo", arguments={"x": 1})
+ )
assert isinstance(res, ContentToolResult)
assert res.id == "x"
assert res.error == "Unknown tool"
- assert res.value is None
+ assert res.result is None
@pytest.mark.asyncio
@@ -133,21 +141,29 @@ async def test_invoke_tool_returns_tool_result_async():
async def tool():
return 1
- res = await chat._invoke_tool_async(tool, {}, id_="x")
+ chat.register_tool(tool)
+
+ res = await chat._invoke_tool_request_async(
+ ContentToolRequest(id="x", name="tool", arguments={})
+ )
assert isinstance(res, ContentToolResult)
assert res.id == "x"
assert res.error is None
- assert res.value == 1
+ assert res.result.assistant == "1"
- res = await chat._invoke_tool_async(tool, {"x": 1}, id_="x")
+ res = await chat._invoke_tool_request_async(
+ ContentToolRequest(id="x", name="tool", arguments={"x": 1})
+ )
assert isinstance(res, ContentToolResult)
assert res.id == "x"
assert res.error is not None
assert "got an unexpected keyword argument" in res.error
- assert res.value is None
+ assert res.result is None
- res = await chat._invoke_tool_async(None, {"x": 1}, id_="x")
+ res = await chat._invoke_tool_request_async(
+ ContentToolRequest(id="x", name="foo", arguments={"x": 1})
+ )
assert isinstance(res, ContentToolResult)
assert res.id == "x"
assert res.error == "Unknown tool"
- assert res.value is None
+ assert res.result is None
diff --git a/tests/test_provider_anthropic.py b/tests/test_provider_anthropic.py
index 8bf8728c..7bc3cb08 100644
--- a/tests/test_provider_anthropic.py
+++ b/tests/test_provider_anthropic.py
@@ -1,4 +1,7 @@
+import base64
+
import pytest
+
from chatlas import ChatAnthropic
from .conftest import (
@@ -94,3 +97,35 @@ def run_inlineassert():
retryassert(run_inlineassert, retries=3)
assert_images_remote_error(chat_fun)
+
+
+def test_anthropic_image_tool(test_images_dir):
+ from chatlas import ToolResult
+
+ def get_picture():
+ "Returns an image"
+ # Local copy of https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png
+ with open(test_images_dir / "dice.png", "rb") as image:
+ bytez = image.read()
+ res = [
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/png",
+ "data": base64.b64encode(bytez).decode("utf-8"),
+ },
+ }
+ ]
+ return ToolResult(res, format_as="as_is")
+
+ chat = ChatAnthropic()
+ chat.register_tool(get_picture)
+
+ res = chat.chat(
+ "You have a tool called 'get_picture' available to you. "
+ "When called, it returns an image. "
+ "Tell me what you see in the image."
+ )
+
+ assert "dice" in res.get_content()