3
3
import asyncio
4
4
import inspect
5
5
from dataclasses import dataclass , field
6
- from typing import Any , Callable , Generic , cast
6
+ from typing import Any , Callable , Generic , cast , get_args
7
7
8
- from openai .types .responses import ResponseCompletedEvent
8
+ from openai .types .responses import (
9
+ ResponseCompletedEvent ,
10
+ ResponseOutputItemAddedEvent ,
11
+ )
9
12
from openai .types .responses .response_prompt_param import (
10
13
ResponsePromptParam ,
11
14
)
40
43
OutputGuardrailResult ,
41
44
)
42
45
from .handoffs import Handoff , HandoffInputFilter , handoff
43
- from .items import ItemHelpers , ModelResponse , RunItem , TResponseInputItem
46
+ from .items import (
47
+ ItemHelpers ,
48
+ ModelResponse ,
49
+ RunItem ,
50
+ ToolCallItem ,
51
+ ToolCallItemTypes ,
52
+ TResponseInputItem ,
53
+ )
44
54
from .lifecycle import RunHooks
45
55
from .logger import logger
46
56
from .memory import Session
49
59
from .models .multi_provider import MultiProvider
50
60
from .result import RunResult , RunResultStreaming
51
61
from .run_context import RunContextWrapper , TContext
52
- from .stream_events import AgentUpdatedStreamEvent , RawResponsesStreamEvent
62
+ from .stream_events import AgentUpdatedStreamEvent , RawResponsesStreamEvent , RunItemStreamEvent
53
63
from .tool import Tool
54
64
from .tracing import Span , SpanError , agent_span , get_current_trace , trace
55
65
from .tracing .span_data import AgentSpanData
@@ -905,6 +915,8 @@ async def _run_single_turn_streamed(
905
915
all_tools : list [Tool ],
906
916
previous_response_id : str | None ,
907
917
) -> SingleStepResult :
918
+ emitted_tool_call_ids : set [str ] = set ()
919
+
908
920
if should_run_agent_start_hooks :
909
921
await asyncio .gather (
910
922
hooks .on_agent_start (context_wrapper , agent ),
@@ -984,6 +996,25 @@ async def _run_single_turn_streamed(
984
996
)
985
997
context_wrapper .usage .add (usage )
986
998
999
+ if isinstance (event , ResponseOutputItemAddedEvent ):
1000
+ output_item = event .item
1001
+
1002
+ if isinstance (output_item , _TOOL_CALL_TYPES ):
1003
+ call_id : str | None = getattr (
1004
+ output_item , "call_id" , getattr (output_item , "id" , None )
1005
+ )
1006
+
1007
+ if call_id and call_id not in emitted_tool_call_ids :
1008
+ emitted_tool_call_ids .add (call_id )
1009
+
1010
+ tool_item = ToolCallItem (
1011
+ raw_item = cast (ToolCallItemTypes , output_item ),
1012
+ agent = agent ,
1013
+ )
1014
+ streamed_result ._event_queue .put_nowait (
1015
+ RunItemStreamEvent (item = tool_item , name = "tool_called" )
1016
+ )
1017
+
987
1018
streamed_result ._event_queue .put_nowait (RawResponsesStreamEvent (data = event ))
988
1019
989
1020
# Call hook just after the model response is finalized.
@@ -995,9 +1026,10 @@ async def _run_single_turn_streamed(
995
1026
raise ModelBehaviorError ("Model did not produce a final response!" )
996
1027
997
1028
# 3. Now, we can process the turn as we do in the non-streaming case
998
- return await cls ._get_single_step_result_from_streamed_response (
1029
+ single_step_result = await cls ._get_single_step_result_from_response (
999
1030
agent = agent ,
1000
- streamed_result = streamed_result ,
1031
+ original_input = streamed_result .input ,
1032
+ pre_step_items = streamed_result .new_items ,
1001
1033
new_response = final_response ,
1002
1034
output_schema = output_schema ,
1003
1035
all_tools = all_tools ,
@@ -1008,6 +1040,34 @@ async def _run_single_turn_streamed(
1008
1040
tool_use_tracker = tool_use_tracker ,
1009
1041
)
1010
1042
1043
+ if emitted_tool_call_ids :
1044
+ import dataclasses as _dc
1045
+
1046
+ filtered_items = [
1047
+ item
1048
+ for item in single_step_result .new_step_items
1049
+ if not (
1050
+ isinstance (item , ToolCallItem )
1051
+ and (
1052
+ call_id := getattr (
1053
+ item .raw_item , "call_id" , getattr (item .raw_item , "id" , None )
1054
+ )
1055
+ )
1056
+ and call_id in emitted_tool_call_ids
1057
+ )
1058
+ ]
1059
+
1060
+ single_step_result_filtered = _dc .replace (
1061
+ single_step_result , new_step_items = filtered_items
1062
+ )
1063
+
1064
+ RunImpl .stream_step_result_to_queue (
1065
+ single_step_result_filtered , streamed_result ._event_queue
1066
+ )
1067
+ else :
1068
+ RunImpl .stream_step_result_to_queue (single_step_result , streamed_result ._event_queue )
1069
+ return single_step_result
1070
+
1011
1071
@classmethod
1012
1072
async def _run_single_turn (
1013
1073
cls ,
@@ -1397,9 +1457,11 @@ async def _save_result_to_session(
1397
1457
1398
1458
1399
1459
DEFAULT_AGENT_RUNNER = AgentRunner ()
1460
+ _TOOL_CALL_TYPES : tuple [type , ...] = get_args (ToolCallItemTypes )
1400
1461
1401
1462
1402
1463
def _copy_str_or_list (input : str | list [TResponseInputItem ]) -> str | list [TResponseInputItem ]:
1403
1464
if isinstance (input , str ):
1404
1465
return input
1405
1466
return input .copy ()
1467
+
0 commit comments