From 6afe843ba42eee9ba6c7dad2b604588347f9eee5 Mon Sep 17 00:00:00 2001 From: Carson Date: Wed, 5 Nov 2025 18:46:06 -0600 Subject: [PATCH] Add basic image generation support; introduce new ToolBuiltIn class --- chatlas/__init__.py | 3 +- chatlas/_chat.py | 59 +++++++++++++++++++------ chatlas/_content.py | 49 +++++++++++++++----- chatlas/_provider.py | 18 ++++---- chatlas/_provider_google.py | 36 ++++++++++++--- chatlas/_provider_openai.py | 45 ++++++++++++++++--- chatlas/_provider_openai_completions.py | 11 ++++- chatlas/_tools.py | 22 +++++++++ 8 files changed, 194 insertions(+), 49 deletions(-) diff --git a/chatlas/__init__.py b/chatlas/__init__.py index a3fa7d12..b2c482b1 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -36,7 +36,7 @@ from ._provider_portkey import ChatPortkey from ._provider_snowflake import ChatSnowflake from ._tokens import token_usage -from ._tools import Tool, ToolRejectError +from ._tools import Tool, ToolBuiltIn, ToolRejectError from ._turn import Turn try: @@ -88,6 +88,7 @@ "Provider", "token_usage", "Tool", + "ToolBuiltIn", "ToolRejectError", "Turn", "types", diff --git a/chatlas/_chat.py b/chatlas/_chat.py index f189572c..013b9ed6 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -48,7 +48,7 @@ from ._mcp_manager import MCPSessionManager from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT from ._tokens import compute_cost, get_token_pricing, tokens_log -from ._tools import Tool, ToolRejectError +from ._tools import Tool, ToolBuiltIn, ToolRejectError from ._turn import Turn, user_turn from ._typing_extensions import TypedDict, TypeGuard from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async @@ -131,7 +131,7 @@ def __init__( self.system_prompt = system_prompt self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {} - self._tools: dict[str, Tool] = {} + self._tools: dict[str, Tool | ToolBuiltIn] = {} self._on_tool_request_callbacks = CallbackManager() self._on_tool_result_callbacks = CallbackManager() self._current_display: Optional[MarkdownDisplay] = None @@ -1866,7 +1866,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None): def register_tool( self, - func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool, + func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | "ToolBuiltIn", *, force: bool = False, name: Optional[str] = None, @@ -1960,7 +1960,11 @@ def add(a: int, b: int) -> int: ValueError If a tool with the same name already exists and `force` is `False`. """ - if isinstance(func, Tool): + if isinstance(func, ToolBuiltIn): + # ToolBuiltIn objects are stored directly without conversion + tool = func + tool_name = tool.name + elif isinstance(func, Tool): name = name or func.name annotations = annotations or func.annotations if model is not None: @@ -1968,23 +1972,27 @@ def add(a: int, b: int) -> int: func.func, name=name, model=model, annotations=annotations ) func = func.func + tool = Tool.from_func(func, name=name, model=model, annotations=annotations) + tool_name = tool.name + else: + tool = Tool.from_func(func, name=name, model=model, annotations=annotations) + tool_name = tool.name - tool = Tool.from_func(func, name=name, model=model, annotations=annotations) - if tool.name in self._tools and not force: + if tool_name in self._tools and not force: raise ValueError( - f"Tool with name '{tool.name}' is already registered. " + f"Tool with name '{tool_name}' is already registered. " "Set `force=True` to overwrite it." ) - self._tools[tool.name] = tool + self._tools[tool_name] = tool - def get_tools(self) -> list[Tool]: + def get_tools(self) -> list[Tool | ToolBuiltIn]: """ Get the list of registered tools. Returns ------- - list[Tool] - A list of `Tool` instances that are currently registered with the chat. + list[Tool | ToolBuiltIn] + A list of `Tool` or `ToolBuiltIn` instances that are currently registered with the chat. """ return list(self._tools.values()) @@ -2508,7 +2516,7 @@ def _submit_turns( data_model: type[BaseModel] | None = None, kwargs: Optional[SubmitInputArgsT] = None, ) -> Generator[str, None, None]: - if any(x._is_async for x in self._tools.values()): + if any(hasattr(x, "_is_async") and x._is_async for x in self._tools.values()): raise ValueError("Cannot use async tools in a synchronous chat") def emit(text: str | Content): @@ -2661,15 +2669,27 @@ def _collect_all_kwargs( def _invoke_tool(self, request: ContentToolRequest): tool = self._tools.get(request.name) - func = tool.func if tool is not None else None - if func is None: + if tool is None: yield self._handle_tool_error_result( request, error=RuntimeError("Unknown tool."), ) return + if isinstance(tool, ToolBuiltIn): + # Built-in tools are handled by the provider, not invoked directly + yield self._handle_tool_error_result( + request, + error=RuntimeError( + f"Built-in tool '{request.name}' cannot be invoked directly. " + "It should be handled by the provider." + ), + ) + return + + func = tool.func + # First, invoke the request callbacks. If a ToolRejectError is raised, # treat it like a tool failure (i.e., gracefully handle it). result: ContentToolResult | None = None @@ -2717,6 +2737,17 @@ async def _invoke_tool_async(self, request: ContentToolRequest): ) return + if isinstance(tool, ToolBuiltIn): + # Built-in tools are handled by the provider, not invoked directly + yield self._handle_tool_error_result( + request, + error=RuntimeError( + f"Built-in tool '{request.name}' cannot be invoked directly. " + "It should be handled by the provider." + ), + ) + return + if tool._is_async: func = tool.func else: diff --git a/chatlas/_content.py b/chatlas/_content.py index be17e775..82273296 100644 --- a/chatlas/_content.py +++ b/chatlas/_content.py @@ -11,7 +11,7 @@ from ._typing_extensions import TypedDict if TYPE_CHECKING: - from ._tools import Tool + from ._tools import Tool, ToolBuiltIn class ToolAnnotations(TypedDict, total=False): @@ -104,15 +104,28 @@ class ToolInfo(BaseModel): annotations: Optional[ToolAnnotations] = None @classmethod - def from_tool(cls, tool: "Tool") -> "ToolInfo": - """Create a ToolInfo from a Tool instance.""" - func_schema = tool.schema["function"] - return cls( - name=tool.name, - description=func_schema.get("description", ""), - parameters=func_schema.get("parameters", {}), - annotations=tool.annotations, - ) + def from_tool(cls, tool: "Tool | ToolBuiltIn") -> "ToolInfo": + """Create a ToolInfo from a Tool or ToolBuiltIn instance.""" + from ._tools import ToolBuiltIn + + if isinstance(tool, ToolBuiltIn): + # For built-in tools, extract info from the definition + defn = tool.definition + return cls( + name=tool.name, + description=defn.get("description", ""), + parameters=defn.get("parameters", {}), + annotations=None, + ) + else: + # For regular tools, extract from schema + func_schema = tool.schema["function"] + return cls( + name=tool.name, + description=func_schema.get("description", ""), + parameters=func_schema.get("parameters", {}), + annotations=tool.annotations, + ) ContentTypeEnum = Literal[ @@ -247,6 +260,22 @@ def __str__(self): def _repr_markdown_(self): return self.__str__() + def _repr_png_(self): + """Display PNG images directly in Jupyter notebooks.""" + if self.image_content_type == "image/png" and self.data: + import base64 + + return base64.b64decode(self.data) + return None + + def _repr_jpeg_(self): + """Display JPEG images directly in Jupyter notebooks.""" + if self.image_content_type == "image/jpeg" and self.data: + import base64 + + return base64.b64decode(self.data) + return None + def __repr__(self, indent: int = 0): n_bytes = len(self.data) if self.data else 0 return ( diff --git a/chatlas/_provider.py b/chatlas/_provider.py index 98347eb4..dfb7fa37 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -16,7 +16,7 @@ from pydantic import BaseModel from ._content import Content -from ._tools import Tool +from ._tools import Tool, ToolBuiltIn from ._turn import Turn from ._typing_extensions import NotRequired, TypedDict @@ -162,7 +162,7 @@ def chat_perform( *, stream: Literal[False], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], kwargs: SubmitInputArgsT, ) -> ChatCompletionT: ... @@ -174,7 +174,7 @@ def chat_perform( *, stream: Literal[True], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], kwargs: SubmitInputArgsT, ) -> Iterable[ChatCompletionChunkT]: ... @@ -185,7 +185,7 @@ def chat_perform( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], kwargs: SubmitInputArgsT, ) -> Iterable[ChatCompletionChunkT] | ChatCompletionT: ... @@ -197,7 +197,7 @@ async def chat_perform_async( *, stream: Literal[False], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], kwargs: SubmitInputArgsT, ) -> ChatCompletionT: ... @@ -209,7 +209,7 @@ async def chat_perform_async( *, stream: Literal[True], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], kwargs: SubmitInputArgsT, ) -> AsyncIterable[ChatCompletionChunkT]: ... @@ -220,7 +220,7 @@ async def chat_perform_async( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], kwargs: SubmitInputArgsT, ) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ... @@ -259,7 +259,7 @@ def value_tokens( def token_count( self, *args: Content | str, - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> int: ... @@ -267,7 +267,7 @@ def token_count( async def token_count_async( self, *args: Content | str, - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> int: ... diff --git a/chatlas/_provider_google.py b/chatlas/_provider_google.py index eea999ee..08be744a 100644 --- a/chatlas/_provider_google.py +++ b/chatlas/_provider_google.py @@ -309,17 +309,25 @@ def _chat_perform_args( config.response_mime_type = "application/json" if tools: - config.tools = [ - GoogleTool( - function_declarations=[ + from ._tools import ToolBuiltIn + + function_declarations = [] + for tool in tools.values(): + if isinstance(tool, ToolBuiltIn): + # For built-in tools, pass the raw definition through + # This allows provider-specific tools like image generation + # Note: Google's API expects these in a specific format + continue # Built-in tools are not yet fully supported for Google + else: + function_declarations.append( FunctionDeclaration.from_callable( client=self._client._api_client, callable=tool.func, ) - for tool in tools.values() - ] - ) - ] + ) + + if function_declarations: + config.tools = [GoogleTool(function_declarations=function_declarations)] kwargs_full["config"] = config @@ -552,6 +560,20 @@ def _as_turn( ), ) ) + inline_data = part.get("inlineData") or part.get("inline_data") + if inline_data: + # Handle image generation responses + mime_type = inline_data.get("mimeType") or inline_data.get("mime_type") + data = inline_data.get("data") + if mime_type and data: + # Ensure data is a string (should be base64 encoded) + data_str = data if isinstance(data, str) else str(data) + contents.append( + ContentImageInline( + image_content_type=mime_type, # type: ignore + data=data_str, + ) + ) if isinstance(finish_reason, FinishReason): finish_reason = finish_reason.name diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index 255aa6bd..30b52ea1 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -205,11 +205,17 @@ def _chat_perform_args( **(kwargs or {}), } - tool_schemas = [tool.schema for tool in tools.values()] - if tool_schemas: - # Convert completion tool format to responses format - responses_tools: list["ToolParam"] = [] - for schema in tool_schemas: + from ._tools import ToolBuiltIn + + # Handle tools - both regular and built-in + responses_tools: list["ToolParam"] = [] + for tool in tools.values(): + if isinstance(tool, ToolBuiltIn): + # For built-in tools, pass the definition through directly + responses_tools.append(tool.definition) # type: ignore + else: + # Convert completion tool format to responses format + schema = tool.schema func = schema["function"] responses_tools.append( { @@ -220,8 +226,9 @@ def _chat_perform_args( "strict": func.get("strict", True), } ) - if responses_tools: - kwargs_full["tools"] = responses_tools + + if responses_tools: + kwargs_full["tools"] = responses_tools # Add structured data extraction if present if data_model is not None: @@ -328,6 +335,30 @@ def _response_as_turn(completion: Response, has_data_model: bool) -> Turn: extra=output.model_dump(), ) ) + + elif output.type == "image_generation_call": + # Handle image generation responses + # The output object should have 'output_format' and 'result' attributes + output_dict = output.model_dump() + output_format = output_dict.get("output_format", "png") + result = output_dict.get("result") + + if result: + # Map output format to MIME type + mime_type_map = { + "png": "image/png", + "jpeg": "image/jpeg", + "webp": "image/webp", + } + mime_type = mime_type_map.get(output_format, "image/png") + + contents.append( + ContentImageInline( + image_content_type=mime_type, # type: ignore + data=result, + ) + ) + else: raise ValueError(f"Unknown output type: {output.type}") diff --git a/chatlas/_provider_openai_completions.py b/chatlas/_provider_openai_completions.py index 99730a33..3e267a66 100644 --- a/chatlas/_provider_openai_completions.py +++ b/chatlas/_provider_openai_completions.py @@ -146,7 +146,16 @@ def _chat_perform_args( data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ) -> "SubmitInputArgs": - tool_schemas = [tool.schema for tool in tools.values()] + from ._tools import ToolBuiltIn + + # Handle tools - both regular and built-in + tool_schemas = [] + for tool in tools.values(): + if isinstance(tool, ToolBuiltIn): + # For built-in tools, pass the definition through directly + tool_schemas.append(tool.definition) + else: + tool_schemas.append(tool.schema) kwargs_full: "SubmitInputArgs" = { "stream": stream, diff --git a/chatlas/_tools.py b/chatlas/_tools.py index b71ddf09..7c55b25a 100644 --- a/chatlas/_tools.py +++ b/chatlas/_tools.py @@ -25,6 +25,7 @@ __all__ = ( "Tool", + "ToolBuiltIn", "ToolRejectError", ) @@ -228,6 +229,27 @@ async def _call(**args: Any) -> AsyncGenerator[ContentToolResult, None]: ) +class ToolBuiltIn: + """ + Define a built-in provider-specific tool + + This class represents tools that are built into specific providers (like image + generation). Unlike regular Tool objects, ToolBuiltIn instances pass raw + provider-specific JSON directly through to the API. + + Parameters + ---------- + name + The name of the tool. + definition + The raw provider-specific tool definition as a dictionary. + """ + + def __init__(self, *, name: str, definition: dict[str, Any]): + self.name = name + self.definition = definition + + class ToolRejectError(Exception): """ Error to represent a tool call being rejected.