Skip to content

Commit 6afe843

Browse files
committed
Add basic image generation support; introduce new ToolBuiltIn class
1 parent edb5615 commit 6afe843

File tree

8 files changed

+194
-49
lines changed

8 files changed

+194
-49
lines changed

chatlas/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from ._provider_portkey import ChatPortkey
3737
from ._provider_snowflake import ChatSnowflake
3838
from ._tokens import token_usage
39-
from ._tools import Tool, ToolRejectError
39+
from ._tools import Tool, ToolBuiltIn, ToolRejectError
4040
from ._turn import Turn
4141

4242
try:
@@ -88,6 +88,7 @@
8888
"Provider",
8989
"token_usage",
9090
"Tool",
91+
"ToolBuiltIn",
9192
"ToolRejectError",
9293
"Turn",
9394
"types",

chatlas/_chat.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from ._mcp_manager import MCPSessionManager
4949
from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT
5050
from ._tokens import compute_cost, get_token_pricing, tokens_log
51-
from ._tools import Tool, ToolRejectError
51+
from ._tools import Tool, ToolBuiltIn, ToolRejectError
5252
from ._turn import Turn, user_turn
5353
from ._typing_extensions import TypedDict, TypeGuard
5454
from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async
@@ -131,7 +131,7 @@ def __init__(
131131
self.system_prompt = system_prompt
132132
self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {}
133133

134-
self._tools: dict[str, Tool] = {}
134+
self._tools: dict[str, Tool | ToolBuiltIn] = {}
135135
self._on_tool_request_callbacks = CallbackManager()
136136
self._on_tool_result_callbacks = CallbackManager()
137137
self._current_display: Optional[MarkdownDisplay] = None
@@ -1866,7 +1866,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None):
18661866

18671867
def register_tool(
18681868
self,
1869-
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool,
1869+
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | "ToolBuiltIn",
18701870
*,
18711871
force: bool = False,
18721872
name: Optional[str] = None,
@@ -1960,31 +1960,39 @@ def add(a: int, b: int) -> int:
19601960
ValueError
19611961
If a tool with the same name already exists and `force` is `False`.
19621962
"""
1963-
if isinstance(func, Tool):
1963+
if isinstance(func, ToolBuiltIn):
1964+
# ToolBuiltIn objects are stored directly without conversion
1965+
tool = func
1966+
tool_name = tool.name
1967+
elif isinstance(func, Tool):
19641968
name = name or func.name
19651969
annotations = annotations or func.annotations
19661970
if model is not None:
19671971
func = Tool.from_func(
19681972
func.func, name=name, model=model, annotations=annotations
19691973
)
19701974
func = func.func
1975+
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
1976+
tool_name = tool.name
1977+
else:
1978+
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
1979+
tool_name = tool.name
19711980

1972-
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
1973-
if tool.name in self._tools and not force:
1981+
if tool_name in self._tools and not force:
19741982
raise ValueError(
1975-
f"Tool with name '{tool.name}' is already registered. "
1983+
f"Tool with name '{tool_name}' is already registered. "
19761984
"Set `force=True` to overwrite it."
19771985
)
1978-
self._tools[tool.name] = tool
1986+
self._tools[tool_name] = tool
19791987

1980-
def get_tools(self) -> list[Tool]:
1988+
def get_tools(self) -> list[Tool | ToolBuiltIn]:
19811989
"""
19821990
Get the list of registered tools.
19831991
19841992
Returns
19851993
-------
1986-
list[Tool]
1987-
A list of `Tool` instances that are currently registered with the chat.
1994+
list[Tool | ToolBuiltIn]
1995+
A list of `Tool` or `ToolBuiltIn` instances that are currently registered with the chat.
19881996
"""
19891997
return list(self._tools.values())
19901998

@@ -2508,7 +2516,7 @@ def _submit_turns(
25082516
data_model: type[BaseModel] | None = None,
25092517
kwargs: Optional[SubmitInputArgsT] = None,
25102518
) -> Generator[str, None, None]:
2511-
if any(x._is_async for x in self._tools.values()):
2519+
if any(hasattr(x, "_is_async") and x._is_async for x in self._tools.values()):
25122520
raise ValueError("Cannot use async tools in a synchronous chat")
25132521

25142522
def emit(text: str | Content):
@@ -2661,15 +2669,27 @@ def _collect_all_kwargs(
26612669

26622670
def _invoke_tool(self, request: ContentToolRequest):
26632671
tool = self._tools.get(request.name)
2664-
func = tool.func if tool is not None else None
26652672

2666-
if func is None:
2673+
if tool is None:
26672674
yield self._handle_tool_error_result(
26682675
request,
26692676
error=RuntimeError("Unknown tool."),
26702677
)
26712678
return
26722679

2680+
if isinstance(tool, ToolBuiltIn):
2681+
# Built-in tools are handled by the provider, not invoked directly
2682+
yield self._handle_tool_error_result(
2683+
request,
2684+
error=RuntimeError(
2685+
f"Built-in tool '{request.name}' cannot be invoked directly. "
2686+
"It should be handled by the provider."
2687+
),
2688+
)
2689+
return
2690+
2691+
func = tool.func
2692+
26732693
# First, invoke the request callbacks. If a ToolRejectError is raised,
26742694
# treat it like a tool failure (i.e., gracefully handle it).
26752695
result: ContentToolResult | None = None
@@ -2717,6 +2737,17 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
27172737
)
27182738
return
27192739

2740+
if isinstance(tool, ToolBuiltIn):
2741+
# Built-in tools are handled by the provider, not invoked directly
2742+
yield self._handle_tool_error_result(
2743+
request,
2744+
error=RuntimeError(
2745+
f"Built-in tool '{request.name}' cannot be invoked directly. "
2746+
"It should be handled by the provider."
2747+
),
2748+
)
2749+
return
2750+
27202751
if tool._is_async:
27212752
func = tool.func
27222753
else:

chatlas/_content.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ._typing_extensions import TypedDict
1212

1313
if TYPE_CHECKING:
14-
from ._tools import Tool
14+
from ._tools import Tool, ToolBuiltIn
1515

1616

1717
class ToolAnnotations(TypedDict, total=False):
@@ -104,15 +104,28 @@ class ToolInfo(BaseModel):
104104
annotations: Optional[ToolAnnotations] = None
105105

106106
@classmethod
107-
def from_tool(cls, tool: "Tool") -> "ToolInfo":
108-
"""Create a ToolInfo from a Tool instance."""
109-
func_schema = tool.schema["function"]
110-
return cls(
111-
name=tool.name,
112-
description=func_schema.get("description", ""),
113-
parameters=func_schema.get("parameters", {}),
114-
annotations=tool.annotations,
115-
)
107+
def from_tool(cls, tool: "Tool | ToolBuiltIn") -> "ToolInfo":
108+
"""Create a ToolInfo from a Tool or ToolBuiltIn instance."""
109+
from ._tools import ToolBuiltIn
110+
111+
if isinstance(tool, ToolBuiltIn):
112+
# For built-in tools, extract info from the definition
113+
defn = tool.definition
114+
return cls(
115+
name=tool.name,
116+
description=defn.get("description", ""),
117+
parameters=defn.get("parameters", {}),
118+
annotations=None,
119+
)
120+
else:
121+
# For regular tools, extract from schema
122+
func_schema = tool.schema["function"]
123+
return cls(
124+
name=tool.name,
125+
description=func_schema.get("description", ""),
126+
parameters=func_schema.get("parameters", {}),
127+
annotations=tool.annotations,
128+
)
116129

117130

118131
ContentTypeEnum = Literal[
@@ -247,6 +260,22 @@ def __str__(self):
247260
def _repr_markdown_(self):
248261
return self.__str__()
249262

263+
def _repr_png_(self):
264+
"""Display PNG images directly in Jupyter notebooks."""
265+
if self.image_content_type == "image/png" and self.data:
266+
import base64
267+
268+
return base64.b64decode(self.data)
269+
return None
270+
271+
def _repr_jpeg_(self):
272+
"""Display JPEG images directly in Jupyter notebooks."""
273+
if self.image_content_type == "image/jpeg" and self.data:
274+
import base64
275+
276+
return base64.b64decode(self.data)
277+
return None
278+
250279
def __repr__(self, indent: int = 0):
251280
n_bytes = len(self.data) if self.data else 0
252281
return (

chatlas/_provider.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pydantic import BaseModel
1717

1818
from ._content import Content
19-
from ._tools import Tool
19+
from ._tools import Tool, ToolBuiltIn
2020
from ._turn import Turn
2121
from ._typing_extensions import NotRequired, TypedDict
2222

@@ -162,7 +162,7 @@ def chat_perform(
162162
*,
163163
stream: Literal[False],
164164
turns: list[Turn],
165-
tools: dict[str, Tool],
165+
tools: dict[str, Tool | ToolBuiltIn],
166166
data_model: Optional[type[BaseModel]],
167167
kwargs: SubmitInputArgsT,
168168
) -> ChatCompletionT: ...
@@ -174,7 +174,7 @@ def chat_perform(
174174
*,
175175
stream: Literal[True],
176176
turns: list[Turn],
177-
tools: dict[str, Tool],
177+
tools: dict[str, Tool | ToolBuiltIn],
178178
data_model: Optional[type[BaseModel]],
179179
kwargs: SubmitInputArgsT,
180180
) -> Iterable[ChatCompletionChunkT]: ...
@@ -185,7 +185,7 @@ def chat_perform(
185185
*,
186186
stream: bool,
187187
turns: list[Turn],
188-
tools: dict[str, Tool],
188+
tools: dict[str, Tool | ToolBuiltIn],
189189
data_model: Optional[type[BaseModel]],
190190
kwargs: SubmitInputArgsT,
191191
) -> Iterable[ChatCompletionChunkT] | ChatCompletionT: ...
@@ -197,7 +197,7 @@ async def chat_perform_async(
197197
*,
198198
stream: Literal[False],
199199
turns: list[Turn],
200-
tools: dict[str, Tool],
200+
tools: dict[str, Tool | ToolBuiltIn],
201201
data_model: Optional[type[BaseModel]],
202202
kwargs: SubmitInputArgsT,
203203
) -> ChatCompletionT: ...
@@ -209,7 +209,7 @@ async def chat_perform_async(
209209
*,
210210
stream: Literal[True],
211211
turns: list[Turn],
212-
tools: dict[str, Tool],
212+
tools: dict[str, Tool | ToolBuiltIn],
213213
data_model: Optional[type[BaseModel]],
214214
kwargs: SubmitInputArgsT,
215215
) -> AsyncIterable[ChatCompletionChunkT]: ...
@@ -220,7 +220,7 @@ async def chat_perform_async(
220220
*,
221221
stream: bool,
222222
turns: list[Turn],
223-
tools: dict[str, Tool],
223+
tools: dict[str, Tool | ToolBuiltIn],
224224
data_model: Optional[type[BaseModel]],
225225
kwargs: SubmitInputArgsT,
226226
) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ...
@@ -259,15 +259,15 @@ def value_tokens(
259259
def token_count(
260260
self,
261261
*args: Content | str,
262-
tools: dict[str, Tool],
262+
tools: dict[str, Tool | ToolBuiltIn],
263263
data_model: Optional[type[BaseModel]],
264264
) -> int: ...
265265

266266
@abstractmethod
267267
async def token_count_async(
268268
self,
269269
*args: Content | str,
270-
tools: dict[str, Tool],
270+
tools: dict[str, Tool | ToolBuiltIn],
271271
data_model: Optional[type[BaseModel]],
272272
) -> int: ...
273273

chatlas/_provider_google.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -309,17 +309,25 @@ def _chat_perform_args(
309309
config.response_mime_type = "application/json"
310310

311311
if tools:
312-
config.tools = [
313-
GoogleTool(
314-
function_declarations=[
312+
from ._tools import ToolBuiltIn
313+
314+
function_declarations = []
315+
for tool in tools.values():
316+
if isinstance(tool, ToolBuiltIn):
317+
# For built-in tools, pass the raw definition through
318+
# This allows provider-specific tools like image generation
319+
# Note: Google's API expects these in a specific format
320+
continue # Built-in tools are not yet fully supported for Google
321+
else:
322+
function_declarations.append(
315323
FunctionDeclaration.from_callable(
316324
client=self._client._api_client,
317325
callable=tool.func,
318326
)
319-
for tool in tools.values()
320-
]
321-
)
322-
]
327+
)
328+
329+
if function_declarations:
330+
config.tools = [GoogleTool(function_declarations=function_declarations)]
323331

324332
kwargs_full["config"] = config
325333

@@ -552,6 +560,20 @@ def _as_turn(
552560
),
553561
)
554562
)
563+
inline_data = part.get("inlineData") or part.get("inline_data")
564+
if inline_data:
565+
# Handle image generation responses
566+
mime_type = inline_data.get("mimeType") or inline_data.get("mime_type")
567+
data = inline_data.get("data")
568+
if mime_type and data:
569+
# Ensure data is a string (should be base64 encoded)
570+
data_str = data if isinstance(data, str) else str(data)
571+
contents.append(
572+
ContentImageInline(
573+
image_content_type=mime_type, # type: ignore
574+
data=data_str,
575+
)
576+
)
555577

556578
if isinstance(finish_reason, FinishReason):
557579
finish_reason = finish_reason.name

0 commit comments

Comments
 (0)