7
7
8
8
from agents import Agent , FunctionTool , ModelBehaviorError , RunContextWrapper , function_tool
9
9
from agents .tool import default_tool_error_function
10
+ from agents .tool_context import ToolContext
10
11
11
12
12
13
def argless_function () -> str :
@@ -18,11 +19,11 @@ async def test_argless_function():
18
19
tool = function_tool (argless_function )
19
20
assert tool .name == "argless_function"
20
21
21
- result = await tool .on_invoke_tool (RunContextWrapper ( None ), "" )
22
+ result = await tool .on_invoke_tool (ToolContext ( context = None , tool_call_id = "1" ), "" )
22
23
assert result == "ok"
23
24
24
25
25
- def argless_with_context (ctx : RunContextWrapper [str ]) -> str :
26
+ def argless_with_context (ctx : ToolContext [str ]) -> str :
26
27
return "ok"
27
28
28
29
@@ -31,11 +32,11 @@ async def test_argless_with_context():
31
32
tool = function_tool (argless_with_context )
32
33
assert tool .name == "argless_with_context"
33
34
34
- result = await tool .on_invoke_tool (RunContextWrapper (None ), "" )
35
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), "" )
35
36
assert result == "ok"
36
37
37
38
# Extra JSON should not raise an error
38
- result = await tool .on_invoke_tool (RunContextWrapper (None ), '{"a": 1}' )
39
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), '{"a": 1}' )
39
40
assert result == "ok"
40
41
41
42
@@ -48,15 +49,15 @@ async def test_simple_function():
48
49
tool = function_tool (simple_function , failure_error_function = None )
49
50
assert tool .name == "simple_function"
50
51
51
- result = await tool .on_invoke_tool (RunContextWrapper (None ), '{"a": 1}' )
52
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), '{"a": 1}' )
52
53
assert result == 6
53
54
54
- result = await tool .on_invoke_tool (RunContextWrapper (None ), '{"a": 1, "b": 2}' )
55
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), '{"a": 1, "b": 2}' )
55
56
assert result == 3
56
57
57
58
# Missing required argument should raise an error
58
59
with pytest .raises (ModelBehaviorError ):
59
- await tool .on_invoke_tool (RunContextWrapper (None ), "" )
60
+ await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), "" )
60
61
61
62
62
63
class Foo (BaseModel ):
@@ -84,7 +85,7 @@ async def test_complex_args_function():
84
85
"bar" : Bar (x = "hello" , y = 10 ),
85
86
}
86
87
)
87
- result = await tool .on_invoke_tool (RunContextWrapper (None ), valid_json )
88
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), valid_json )
88
89
assert result == "6 hello10 hello"
89
90
90
91
valid_json = json .dumps (
@@ -93,7 +94,7 @@ async def test_complex_args_function():
93
94
"bar" : Bar (x = "hello" , y = 10 ),
94
95
}
95
96
)
96
- result = await tool .on_invoke_tool (RunContextWrapper (None ), valid_json )
97
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), valid_json )
97
98
assert result == "3 hello10 hello"
98
99
99
100
valid_json = json .dumps (
@@ -103,12 +104,12 @@ async def test_complex_args_function():
103
104
"baz" : "world" ,
104
105
}
105
106
)
106
- result = await tool .on_invoke_tool (RunContextWrapper (None ), valid_json )
107
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), valid_json )
107
108
assert result == "3 hello10 world"
108
109
109
110
# Missing required argument should raise an error
110
111
with pytest .raises (ModelBehaviorError ):
111
- await tool .on_invoke_tool (RunContextWrapper (None ), '{"foo": {"a": 1}}' )
112
+ await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), '{"foo": {"a": 1}}' )
112
113
113
114
114
115
def test_function_config_overrides ():
@@ -168,7 +169,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
168
169
assert tool .params_json_schema [key ] == value
169
170
assert tool .strict_json_schema
170
171
171
- result = await tool .on_invoke_tool (RunContextWrapper (None ), '{"data": "hello"}' )
172
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), '{"data": "hello"}' )
172
173
assert result == "hello_done"
173
174
174
175
tool_not_strict = FunctionTool (
@@ -183,7 +184,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
183
184
assert "additionalProperties" not in tool_not_strict .params_json_schema
184
185
185
186
result = await tool_not_strict .on_invoke_tool (
186
- RunContextWrapper (None ), '{"data": "hello", "bar": "baz"}'
187
+ ToolContext (None , tool_call_id = "1" ), '{"data": "hello", "bar": "baz"}'
187
188
)
188
189
assert result == "hello_done"
189
190
@@ -194,7 +195,7 @@ def my_func(a: int, b: int = 5):
194
195
raise ValueError ("test" )
195
196
196
197
tool = function_tool (my_func )
197
- ctx = RunContextWrapper (None )
198
+ ctx = ToolContext (None , tool_call_id = "1" )
198
199
199
200
result = await tool .on_invoke_tool (ctx , "" )
200
201
assert "Invalid JSON" in str (result )
@@ -218,7 +219,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
218
219
return f"error_{ error .__class__ .__name__ } "
219
220
220
221
tool = function_tool (my_func , failure_error_function = custom_sync_error_function )
221
- ctx = RunContextWrapper (None )
222
+ ctx = ToolContext (None , tool_call_id = "1" )
222
223
223
224
result = await tool .on_invoke_tool (ctx , "" )
224
225
assert result == "error_ModelBehaviorError"
@@ -242,7 +243,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
242
243
return f"error_{ error .__class__ .__name__ } "
243
244
244
245
tool = function_tool (my_func , failure_error_function = custom_sync_error_function )
245
- ctx = RunContextWrapper (None )
246
+ ctx = ToolContext (None , tool_call_id = "1" )
246
247
247
248
result = await tool .on_invoke_tool (ctx , "" )
248
249
assert result == "error_ModelBehaviorError"
0 commit comments