Skip to content

Commit e78336f

Browse files
committed
fix: address issues around resuming run state with conversation history
1 parent def65ac commit e78336f

File tree

1 file changed

+72
-3
lines changed

1 file changed

+72
-3
lines changed

src/agents/run.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -715,12 +715,17 @@ async def run(
715715
# Check if we're resuming from a RunState
716716
is_resumed_state = isinstance(input, RunState)
717717
run_state: RunState[TContext] | None = None
718+
prepared_input: str | list[TResponseInputItem]
718719

719720
if is_resumed_state:
720721
# Resuming from a saved state
721722
run_state = cast(RunState[TContext], input)
722723
original_user_input = run_state._original_input
723-
prepared_input = run_state._original_input
724+
725+
if isinstance(run_state._original_input, list):
726+
prepared_input = self._merge_provider_data_in_items(run_state._original_input)
727+
else:
728+
prepared_input = run_state._original_input
724729

725730
# Override context with the state's context if not provided
726731
if context is None and run_state._context is not None:
@@ -790,6 +795,9 @@ async def run(
790795
# If resuming from an interrupted state, execute approved tools first
791796
if is_resumed_state and run_state is not None and run_state._current_step is not None:
792797
if isinstance(run_state._current_step, NextStepInterruption):
798+
# Track items before executing approved tools
799+
items_before_execution = len(generated_items)
800+
793801
# We're resuming from an interruption - execute approved tools
794802
await self._execute_approved_tools(
795803
agent=current_agent,
@@ -799,6 +807,16 @@ async def run(
799807
run_config=run_config,
800808
hooks=hooks,
801809
)
810+
811+
# Save the newly executed tool outputs to the session
812+
new_tool_outputs: list[RunItem] = [
813+
item
814+
for item in generated_items[items_before_execution:]
815+
if item.type == "tool_call_output_item"
816+
]
817+
if new_tool_outputs and session is not None:
818+
await self._save_result_to_session(session, [], new_tool_outputs)
819+
802820
# Clear the current step since we've handled it
803821
run_state._current_step = None
804822

@@ -1089,7 +1107,14 @@ def run_streamed(
10891107

10901108
if is_resumed_state:
10911109
run_state = cast(RunState[TContext], input)
1092-
input_for_result = run_state._original_input
1110+
1111+
if isinstance(run_state._original_input, list):
1112+
input_for_result = AgentRunner._merge_provider_data_in_items(
1113+
run_state._original_input
1114+
)
1115+
else:
1116+
input_for_result = run_state._original_input
1117+
10931118
# Use context from RunState if not provided
10941119
if context is None and run_state._context is not None:
10951120
context = run_state._context.context
@@ -1289,6 +1314,9 @@ async def _start_streaming(
12891314
# If resuming from an interrupted state, execute approved tools first
12901315
if run_state is not None and run_state._current_step is not None:
12911316
if isinstance(run_state._current_step, NextStepInterruption):
1317+
# Track items before executing approved tools
1318+
items_before_execution = len(streamed_result.new_items)
1319+
12921320
# We're resuming from an interruption - execute approved tools
12931321
await cls._execute_approved_tools_static(
12941322
agent=current_agent,
@@ -1298,6 +1326,16 @@ async def _start_streaming(
12981326
run_config=run_config,
12991327
hooks=hooks,
13001328
)
1329+
1330+
# Save the newly executed tool outputs to the session
1331+
new_tool_outputs: list[RunItem] = [
1332+
item
1333+
for item in streamed_result.new_items[items_before_execution:]
1334+
if item.type == "tool_call_output_item"
1335+
]
1336+
if new_tool_outputs and session is not None:
1337+
await cls._save_result_to_session(session, [], new_tool_outputs)
1338+
13011339
# Clear the current step since we've handled it
13021340
run_state._current_step = None
13031341

@@ -1568,6 +1606,8 @@ async def _run_single_turn_streamed(
15681606
input_item = item.to_input_item()
15691607
input.append(input_item)
15701608

1609+
input = cls._merge_provider_data_in_items(input)
1610+
15711611
# THIS IS THE RESOLVED CONFLICT BLOCK
15721612
filtered = await cls._maybe_filter_model_input(
15731613
agent=agent,
@@ -1907,6 +1947,8 @@ async def _run_single_turn(
19071947
input_item = generated_item.to_input_item()
19081948
input.append(input_item)
19091949

1950+
input = cls._merge_provider_data_in_items(input)
1951+
19101952
new_response = await cls._get_new_response(
19111953
agent,
19121954
system_prompt,
@@ -2241,6 +2283,30 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
22412283

22422284
return run_config.model_provider.get_model(agent.model)
22432285

2286+
@classmethod
2287+
def _merge_provider_data_in_items(
2288+
cls, items: list[TResponseInputItem]
2289+
) -> list[TResponseInputItem]:
2290+
"""Remove providerData fields from items."""
2291+
result = []
2292+
for item in items:
2293+
if isinstance(item, dict):
2294+
merged_item = dict(item)
2295+
# Pop both possible keys (providerData and provider_data)
2296+
provider_data = merged_item.pop("providerData", None)
2297+
if provider_data is None:
2298+
provider_data = merged_item.pop("provider_data", None)
2299+
# Merge contents if providerData exists and is a dict
2300+
if isinstance(provider_data, dict):
2301+
# Merge provider_data contents, with existing fields taking precedence
2302+
for key, value in provider_data.items():
2303+
if key not in merged_item:
2304+
merged_item[key] = value
2305+
result.append(cast(TResponseInputItem, merged_item))
2306+
else:
2307+
result.append(item)
2308+
return result
2309+
22442310
@classmethod
22452311
async def _prepare_input_with_session(
22462312
cls,
@@ -2264,6 +2330,7 @@ async def _prepare_input_with_session(
22642330

22652331
# Get previous conversation history
22662332
history = await session.get_items()
2333+
history = cls._merge_provider_data_in_items(history)
22672334

22682335
# Convert input to list format
22692336
new_input_list = ItemHelpers.input_to_new_input_list(input)
@@ -2273,7 +2340,9 @@ async def _prepare_input_with_session(
22732340
elif callable(session_input_callback):
22742341
res = session_input_callback(history, new_input_list)
22752342
if inspect.isawaitable(res):
2276-
return await res
2343+
res = await res
2344+
if isinstance(res, list):
2345+
res = cls._merge_provider_data_in_items(res)
22772346
return res
22782347
else:
22792348
raise UserError(

0 commit comments

Comments
 (0)