Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ async def execute_function_tool_calls(
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> list[FunctionToolResult]:
async def run_single_tool(
async def run_single_tool(
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
) -> Any:
with function_span(func_tool.name) as span_fn:
Expand Down
26 changes: 16 additions & 10 deletions src/agents/tool_context.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from dataclasses import dataclass, field, fields
from typing import Any, Optional

from openai.types.responses import ResponseFunctionToolCall

from typing import Any, Optional
from .run_context import RunContextWrapper, TContext


def _assert_must_pass_tool_call_id() -> str:
raise ValueError("tool_call_id must be passed to ToolContext")
def _assert_must_pass_tool_name() -> str:
raise ValueError("Tool name must be passed")


def _assert_must_pass_tool_name() -> str:
raise ValueError("tool_name must be passed to ToolContext")
def _assert_must_pass_tool_call_id() -> str:
raise ValueError("Tool call ID must be passed")


@dataclass
Expand All @@ -24,6 +22,9 @@ class ToolContext(RunContextWrapper[TContext]):
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
"""The ID of the tool call."""

arguments: Optional[str] = None
"""The raw JSON arguments string sent by the model for this tool call, if available."""

@classmethod
def from_agent_context(
cls,
Expand All @@ -34,9 +35,14 @@ def from_agent_context(
"""
Create a ToolContext from a RunContextWrapper.
"""
# Grab the names of the RunContextWrapper's init=True fields
base_values: dict[str, Any] = {
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
}
tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name()
return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values)
tool_name = tool_call.function.name if tool_call is not None else _assert_must_pass_tool_name()
args = tool_call.function.arguments if tool_call is not None else None
return cls(
tool_name=tool_name,
tool_call_id=tool_call_id,
arguments=args,
**base_values,
)
77 changes: 77 additions & 0 deletions tests/test_tool_context_arg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import json
from dataclasses import fields
from types import SimpleNamespace
from typing import Optional

import pytest

from agents import function_tool
from agents.run_context import RunContextWrapper
from agents.tool_context import ToolContext


class FakeToolCall:
def __init__(self, name: str, arguments: Optional[str] = None):
self.name = name
self.arguments = arguments


def make_minimal_context_like_runcontext():
ctx = SimpleNamespace()
for f in fields(RunContextWrapper):
setattr(ctx, f.name, None)
return ctx


def test_from_agent_context_populates_arguments_and_names():
context_like = make_minimal_context_like_runcontext()
fake_call = FakeToolCall(name="my_tool", arguments='{"x": 1, "y": 2}')

tc: ToolContext = ToolContext.from_agent_context(
context_like, tool_call_id="c-1", tool_call=fake_call
)

assert tc.tool_name == "my_tool"
assert tc.tool_call_id == "c-1"
assert tc.arguments == '{"x": 1, "y": 2}'


def test_from_agent_context_raises_if_tool_name_missing():
context_like = make_minimal_context_like_runcontext()

with pytest.raises(ValueError, match="Tool name must"):
ToolContext.from_agent_context(context_like, tool_call_id="c-2", tool_call=None)


@pytest.mark.asyncio
async def test_function_tool_accepts_toolcontext_generic_argless():
def argless_with_context(ctx: ToolContext[str]) -> str:
return "ok"

tool = function_tool(argless_with_context)
assert tool.name == "argless_with_context"

ctx = ToolContext(context=None, tool_name="argless_with_context", tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert result == "ok"

result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}')
assert result == "ok"


@pytest.mark.asyncio
async def test_function_tool_with_context_and_args_parsed():
class DummyCtx:
def __init__(self):
self.data = "xyz"

def with_ctx_and_name(ctx: ToolContext[DummyCtx], name: str) -> str:
return f"{name}_{ctx.context.data}"

tool = function_tool(with_ctx_and_name)
ctx = ToolContext(context=DummyCtx(), tool_name="with_ctx_and_name", tool_call_id="1")
payload = json.dumps({"name": "uzair"})
result = await tool.on_invoke_tool(ctx, payload)

assert result == "uzair_xyz"
Loading