Skip to content

Commit 66df149

Browse files
committed
fix: addressing edge cases when resuming (continued)
1 parent adc910c commit 66df149

File tree

3 files changed

+227
-34
lines changed

3 files changed

+227
-34
lines changed

src/agents/run.py

Lines changed: 102 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,13 @@ def prepare_input(
161161

162162
# On first call (when there are no generated items yet), include the original input
163163
if not generated_items:
164-
input_items.extend(ItemHelpers.input_to_new_input_list(original_input))
164+
# Normalize original_input items to ensure field names are in snake_case
165+
# (items from RunState deserialization may have camelCase)
166+
raw_input_list = ItemHelpers.input_to_new_input_list(original_input)
167+
# Filter out function_call items that don't have corresponding function_call_output
168+
# (API requires every function_call to have a function_call_output)
169+
filtered_input_list = AgentRunner._filter_incomplete_function_calls(raw_input_list)
170+
input_items.extend(AgentRunner._normalize_input_items(filtered_input_list))
165171

166172
# First, collect call_ids from tool_call_output_item items
167173
# (completed tool calls with outputs) and build a map of
@@ -724,8 +730,8 @@ async def run(
724730
original_user_input = run_state._original_input
725731
# Normalize items to remove top-level providerData (API doesn't accept it there)
726732
if isinstance(original_user_input, list):
727-
prepared_input: str | list[TResponseInputItem] = (
728-
AgentRunner._normalize_input_items(original_user_input)
733+
prepared_input: str | list[TResponseInputItem] = AgentRunner._normalize_input_items(
734+
original_user_input
729735
)
730736
else:
731737
prepared_input = original_user_input
@@ -820,8 +826,7 @@ async def run(
820826
if session is not None and generated_items:
821827
# Save tool_call_output_item items (the outputs)
822828
tool_output_items: list[RunItem] = [
823-
item for item in generated_items
824-
if item.type == "tool_call_output_item"
829+
item for item in generated_items if item.type == "tool_call_output_item"
825830
]
826831
# Also find and save the corresponding function_call items
827832
# (they might not be in session if the run was interrupted before saving)
@@ -982,7 +987,7 @@ async def run(
982987
)
983988
if call_id in output_call_ids and item not in items_to_save:
984989
items_to_save.append(item)
985-
990+
986991
# Don't save original_user_input again - it was already saved at the start
987992
await self._save_result_to_session(session, [], items_to_save)
988993

@@ -1356,9 +1361,12 @@ async def _start_streaming(
13561361
# state's input, causing duplicate items.
13571362
if run_state is not None:
13581363
# Resuming from state - normalize items to remove top-level providerData
1364+
# and filter incomplete function_call pairs
13591365
if isinstance(starting_input, list):
1366+
# Filter incomplete function_call pairs before normalizing
1367+
filtered = AgentRunner._filter_incomplete_function_calls(starting_input)
13601368
prepared_input: str | list[TResponseInputItem] = (
1361-
AgentRunner._normalize_input_items(starting_input)
1369+
AgentRunner._normalize_input_items(filtered)
13621370
)
13631371
else:
13641372
prepared_input = starting_input
@@ -2332,20 +2340,82 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
23322340

23332341
return run_config.model_provider.get_model(agent.model)
23342342

2343+
@staticmethod
2344+
def _filter_incomplete_function_calls(
2345+
items: list[TResponseInputItem],
2346+
) -> list[TResponseInputItem]:
2347+
"""Filter out function_call items that don't have corresponding function_call_output.
2348+
2349+
The OpenAI API requires every function_call in an assistant message to have a
2350+
corresponding function_call_output (tool message). This function ensures only
2351+
complete pairs are included to prevent API errors.
2352+
2353+
IMPORTANT: This only filters incomplete function_call items. All other items
2354+
(messages, complete function_call pairs, etc.) are preserved to maintain
2355+
conversation history integrity.
2356+
2357+
Args:
2358+
items: List of input items to filter
2359+
2360+
Returns:
2361+
Filtered list with only complete function_call pairs. All non-function_call
2362+
items and complete function_call pairs are preserved.
2363+
"""
2364+
# First pass: collect call_ids from function_call_output/function_call_result items
2365+
completed_call_ids: set[str] = set()
2366+
for item in items:
2367+
if isinstance(item, dict):
2368+
item_type = item.get("type")
2369+
# Handle both API format (function_call_output) and
2370+
# protocol format (function_call_result)
2371+
if item_type in ("function_call_output", "function_call_result"):
2372+
call_id = item.get("call_id") or item.get("callId")
2373+
if call_id and isinstance(call_id, str):
2374+
completed_call_ids.add(call_id)
2375+
2376+
# Second pass: only include function_call items that have corresponding outputs
2377+
filtered: list[TResponseInputItem] = []
2378+
for item in items:
2379+
if isinstance(item, dict):
2380+
item_type = item.get("type")
2381+
if item_type == "function_call":
2382+
call_id = item.get("call_id") or item.get("callId")
2383+
# Only include if there's a corresponding
2384+
# function_call_output/function_call_result
2385+
if call_id and call_id in completed_call_ids:
2386+
filtered.append(item)
2387+
else:
2388+
# Include all non-function_call items
2389+
filtered.append(item)
2390+
else:
2391+
# Include non-dict items as-is
2392+
filtered.append(item)
2393+
2394+
return filtered
2395+
23352396
@staticmethod
23362397
def _normalize_input_items(items: list[TResponseInputItem]) -> list[TResponseInputItem]:
2337-
"""Normalize input items by removing top-level providerData/provider_data.
2338-
2398+
"""Normalize input items by removing top-level providerData/provider_data
2399+
and normalizing field names (callId -> call_id).
2400+
23392401
The OpenAI API doesn't accept providerData at the top level of input items.
23402402
providerData should only be in content where it belongs. This function removes
23412403
top-level providerData while preserving it in content.
2342-
2404+
2405+
Also normalizes field names from camelCase (callId) to snake_case (call_id)
2406+
to match API expectations.
2407+
2408+
Normalizes item types: converts 'function_call_result' to 'function_call_output'
2409+
to match API expectations.
2410+
23432411
Args:
23442412
items: List of input items to normalize
2345-
2413+
23462414
Returns:
23472415
Normalized list of input items
23482416
"""
2417+
from .run_state import _normalize_field_names
2418+
23492419
normalized: list[TResponseInputItem] = []
23502420
for item in items:
23512421
if isinstance(item, dict):
@@ -2355,6 +2425,18 @@ def _normalize_input_items(items: list[TResponseInputItem]) -> list[TResponseInp
23552425
# The API doesn't accept providerData at the top level of input items
23562426
normalized_item.pop("providerData", None)
23572427
normalized_item.pop("provider_data", None)
2428+
# Normalize item type: API expects 'function_call_output',
2429+
# not 'function_call_result'
2430+
item_type = normalized_item.get("type")
2431+
if item_type == "function_call_result":
2432+
normalized_item["type"] = "function_call_output"
2433+
item_type = "function_call_output"
2434+
# Remove invalid fields based on item type
2435+
# function_call_output items should not have 'name' field
2436+
if item_type == "function_call_output":
2437+
normalized_item.pop("name", None)
2438+
# Normalize field names (callId -> call_id, responseId -> response_id)
2439+
normalized_item = _normalize_field_names(normalized_item)
23582440
normalized.append(cast(TResponseInputItem, normalized_item))
23592441
else:
23602442
# For non-dict items, keep as-is (they should already be in correct format)
@@ -2401,10 +2483,14 @@ async def _prepare_input_with_session(
24012483
f"Invalid `session_input_callback` value: {session_input_callback}. "
24022484
"Choose between `None` or a custom callable function."
24032485
)
2404-
2486+
2487+
# Filter incomplete function_call pairs before normalizing
2488+
# (API requires every function_call to have a function_call_output)
2489+
filtered = cls._filter_incomplete_function_calls(merged)
2490+
24052491
# Normalize items to remove top-level providerData and deduplicate by ID
2406-
normalized = cls._normalize_input_items(merged)
2407-
2492+
normalized = cls._normalize_input_items(filtered)
2493+
24082494
# Deduplicate items by ID to prevent sending duplicate items to the API
24092495
# This can happen when resuming from state and items are already in the session
24102496
seen_ids: set[str] = set()
@@ -2416,13 +2502,13 @@ async def _prepare_input_with_session(
24162502
item_id = cast(str | None, item.get("id"))
24172503
elif hasattr(item, "id"):
24182504
item_id = cast(str | None, getattr(item, "id", None))
2419-
2505+
24202506
# Only add items we haven't seen before (or items without IDs)
24212507
if item_id is None or item_id not in seen_ids:
24222508
deduplicated.append(item)
24232509
if item_id:
24242510
seen_ids.add(item_id)
2425-
2511+
24262512
return deduplicated
24272513

24282514
@classmethod

src/agents/run_state.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,6 @@ class RunState(Generic[TContext, TAgent]):
4848
_current_turn: int = 0
4949
"""Current turn number in the conversation."""
5050

51-
_current_turn_persisted_item_count: int = 0
52-
"""Tracks how many generated run items from this turn were already persisted to session.
53-
54-
When saving to session, we slice off only new entries. When a turn is interrupted
55-
(e.g., awaiting tool approval) and later resumed, we rewind this counter before
56-
continuing so pending tool outputs still get stored.
57-
"""
58-
5951
_current_agent: TAgent | None = None
6052
"""The agent currently handling the conversation."""
6153

@@ -250,13 +242,63 @@ def to_json(self) -> dict[str, Any]:
250242
}
251243
model_responses.append(response_dict)
252244

245+
# Normalize and camelize originalInput if it's a list of items
246+
# Convert API format to protocol format to match TypeScript schema
247+
# Protocol expects function_call_result (not function_call_output)
248+
original_input_serialized = self._original_input
249+
if isinstance(original_input_serialized, list):
250+
# First pass: build a map of call_id -> function_call name
251+
# to help convert function_call_output to function_call_result
252+
call_id_to_name: dict[str, str] = {}
253+
for item in original_input_serialized:
254+
if isinstance(item, dict):
255+
item_type = item.get("type")
256+
call_id = item.get("call_id") or item.get("callId")
257+
name = item.get("name")
258+
if item_type == "function_call" and call_id and name:
259+
call_id_to_name[call_id] = name
260+
261+
normalized_items = []
262+
for item in original_input_serialized:
263+
if isinstance(item, dict):
264+
# Create a copy to avoid modifying the original
265+
normalized_item = dict(item)
266+
# Remove session/conversation metadata fields that shouldn't be in originalInput
267+
# These are not part of the input protocol schema
268+
normalized_item.pop("id", None)
269+
normalized_item.pop("created_at", None)
270+
# Remove top-level providerData/provider_data (protocol allows it but
271+
# we remove it for cleaner serialization)
272+
normalized_item.pop("providerData", None)
273+
normalized_item.pop("provider_data", None)
274+
# Convert API format to protocol format
275+
# API uses function_call_output, protocol uses function_call_result
276+
item_type = normalized_item.get("type")
277+
call_id = normalized_item.get("call_id") or normalized_item.get("callId")
278+
if item_type == "function_call_output":
279+
# Convert to protocol format: function_call_result
280+
normalized_item["type"] = "function_call_result"
281+
# Protocol format requires status field (default to 'completed')
282+
if "status" not in normalized_item:
283+
normalized_item["status"] = "completed"
284+
# Protocol format requires name field
285+
# Look it up from the corresponding function_call if missing
286+
if "name" not in normalized_item and call_id:
287+
normalized_item["name"] = call_id_to_name.get(call_id, "")
288+
# Normalize field names to camelCase for JSON (call_id -> callId)
289+
normalized_item = self._camelize_field_names(normalized_item)
290+
normalized_items.append(normalized_item)
291+
else:
292+
normalized_items.append(item)
293+
original_input_serialized = normalized_items
294+
253295
result = {
254296
"$schemaVersion": CURRENT_SCHEMA_VERSION,
255297
"currentTurn": self._current_turn,
256298
"currentAgent": {
257299
"name": self._current_agent.name,
258300
},
259-
"originalInput": self._original_input,
301+
"originalInput": original_input_serialized,
260302
"modelResponses": model_responses,
261303
"context": {
262304
"usage": {
@@ -345,7 +387,6 @@ def to_json(self) -> dict[str, Any]:
345387
if self._last_processed_response
346388
else None
347389
)
348-
result["currentTurnPersistedItemCount"] = self._current_turn_persisted_item_count
349390
result["trace"] = None
350391

351392
return result
@@ -571,18 +612,29 @@ async def from_string(
571612
context.usage = usage
572613
context._rebuild_approvals(context_data.get("approvals", {}))
573614

615+
# Normalize originalInput to remove providerData fields that may have been
616+
# included by TypeScript serialization. These fields are metadata and should
617+
# not be sent to the API.
618+
original_input_raw = state_json["originalInput"]
619+
if isinstance(original_input_raw, list):
620+
# Normalize each item in the list to remove providerData fields
621+
normalized_original_input = [
622+
_normalize_field_names(item) if isinstance(item, dict) else item
623+
for item in original_input_raw
624+
]
625+
else:
626+
# If it's a string, use it as-is
627+
normalized_original_input = original_input_raw
628+
574629
# Create the RunState instance
575630
state = RunState(
576631
context=context,
577-
original_input=state_json["originalInput"],
632+
original_input=normalized_original_input,
578633
starting_agent=current_agent,
579634
max_turns=state_json["maxTurns"],
580635
)
581636

582637
state._current_turn = state_json["currentTurn"]
583-
state._current_turn_persisted_item_count = state_json.get(
584-
"currentTurnPersistedItemCount", 0
585-
)
586638

587639
# Reconstruct model responses
588640
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))
@@ -679,18 +731,29 @@ async def from_json(
679731
context.usage = usage
680732
context._rebuild_approvals(context_data.get("approvals", {}))
681733

734+
# Normalize originalInput to remove providerData fields that may have been
735+
# included by TypeScript serialization. These fields are metadata and should
736+
# not be sent to the API.
737+
original_input_raw = state_json["originalInput"]
738+
if isinstance(original_input_raw, list):
739+
# Normalize each item in the list to remove providerData fields
740+
normalized_original_input = [
741+
_normalize_field_names(item) if isinstance(item, dict) else item
742+
for item in original_input_raw
743+
]
744+
else:
745+
# If it's a string, use it as-is
746+
normalized_original_input = original_input_raw
747+
682748
# Create the RunState instance
683749
state = RunState(
684750
context=context,
685-
original_input=state_json["originalInput"],
751+
original_input=normalized_original_input,
686752
starting_agent=current_agent,
687753
max_turns=state_json["maxTurns"],
688754
)
689755

690756
state._current_turn = state_json["currentTurn"]
691-
state._current_turn_persisted_item_count = state_json.get(
692-
"currentTurnPersistedItemCount", 0
693-
)
694757

695758
# Reconstruct model responses
696759
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))

0 commit comments

Comments
 (0)