Skip to content

Commit 9265d6a

Browse files
committed
Allow modifying the input sent to the model
1 parent a17625e commit 9265d6a

File tree

1 file changed

+95
-6
lines changed

1 file changed

+95
-6
lines changed

src/agents/run.py

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copy
55
import inspect
66
from dataclasses import dataclass, field
7-
from typing import Any, Generic, cast
7+
from typing import Any, Callable, Generic, cast
88

99
from openai.types.responses import ResponseCompletedEvent
1010
from openai.types.responses.response_prompt_param import (
@@ -56,6 +56,7 @@
5656
from .tracing.span_data import AgentSpanData
5757
from .usage import Usage
5858
from .util import _coro, _error_tracing
59+
from .util._types import MaybeAwaitable
5960

6061
DEFAULT_MAX_TURNS = 10
6162

@@ -81,6 +82,27 @@ def get_default_agent_runner() -> AgentRunner:
8182
return DEFAULT_AGENT_RUNNER
8283

8384

85+
@dataclass
86+
class ModelInputData:
87+
"""Container for the data that will be sent to the model."""
88+
89+
input: list[TResponseInputItem]
90+
instructions: str | None
91+
92+
93+
@dataclass
94+
class CallModelData(Generic[TContext]):
95+
"""Data passed to `RunConfig.call_model_input_filter` prior to model call."""
96+
97+
model_data: ModelInputData
98+
agent: Agent[TContext]
99+
context: TContext | None
100+
101+
102+
# Type alias for the optional input filter callback
103+
CallModelInputFilter = Callable[[CallModelData[Any]], MaybeAwaitable[ModelInputData]]
104+
105+
84106
@dataclass
85107
class RunConfig:
86108
"""Configures settings for the entire agent run."""
@@ -139,6 +161,16 @@ class RunConfig:
139161
An optional dictionary of additional metadata to include with the trace.
140162
"""
141163

164+
call_model_input_filter: CallModelInputFilter | None = None
165+
"""
166+
Optional callback that is invoked immediately before calling the model. It receives the current
167+
agent, context and the model input (instructions and input items), and must return a possibly
168+
modified `ModelInputData` to use for the model call.
169+
170+
This allows you to edit the input sent to the model e.g. to stay within a token limit.
171+
For example, you can use this to add a system prompt to the input.
172+
"""
173+
142174

143175
class RunOptions(TypedDict, Generic[TContext]):
144176
"""Arguments for ``AgentRunner`` methods."""
@@ -593,6 +625,47 @@ def run_streamed(
593625
)
594626
return streamed_result
595627

628+
@classmethod
629+
async def _maybe_filter_model_input(
630+
cls,
631+
*,
632+
agent: Agent[TContext],
633+
run_config: RunConfig,
634+
context_wrapper: RunContextWrapper[TContext],
635+
input_items: list[TResponseInputItem],
636+
system_instructions: str | None,
637+
) -> ModelInputData:
638+
"""Apply optional call_model_input_filter to modify model input.
639+
640+
Returns a `ModelInputData` that will be sent to the model.
641+
"""
642+
effective_instructions = system_instructions
643+
effective_input: list[TResponseInputItem] = input_items
644+
645+
if run_config.call_model_input_filter is None:
646+
return ModelInputData(input=effective_input, instructions=effective_instructions)
647+
648+
try:
649+
model_input = ModelInputData(
650+
input=copy.deepcopy(effective_input),
651+
instructions=effective_instructions,
652+
)
653+
filter_payload: CallModelData[TContext] = CallModelData(
654+
model_data=model_input,
655+
agent=agent,
656+
context=context_wrapper.context,
657+
)
658+
maybe_updated = run_config.call_model_input_filter(filter_payload)
659+
updated = await maybe_updated if inspect.isawaitable(maybe_updated) else maybe_updated
660+
if not isinstance(updated, ModelInputData):
661+
raise UserError("call_model_input_filter must return a ModelInputData instance")
662+
return updated
663+
except Exception as e:
664+
_error_tracing.attach_error_to_current_span(
665+
SpanError(message="Error in call_model_input_filter", data={"error": str(e)})
666+
)
667+
raise
668+
596669
@classmethod
597670
async def _run_input_guardrails_with_queue(
598671
cls,
@@ -863,10 +936,18 @@ async def _run_single_turn_streamed(
863936
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
864937
input.extend([item.to_input_item() for item in streamed_result.new_items])
865938

939+
filtered = await cls._maybe_filter_model_input(
940+
agent=agent,
941+
run_config=run_config,
942+
context_wrapper=context_wrapper,
943+
input_items=input,
944+
system_instructions=system_prompt,
945+
)
946+
866947
# 1. Stream the output events
867948
async for event in model.stream_response(
868-
system_prompt,
869-
input,
949+
filtered.instructions,
950+
filtered.input,
870951
model_settings,
871952
all_tools,
872953
output_schema,
@@ -1034,7 +1115,6 @@ async def _get_single_step_result_from_streamed_response(
10341115
run_config: RunConfig,
10351116
tool_use_tracker: AgentToolUseTracker,
10361117
) -> SingleStepResult:
1037-
10381118
original_input = streamed_result.input
10391119
pre_step_items = streamed_result.new_items
10401120
event_queue = streamed_result._event_queue
@@ -1161,13 +1241,22 @@ async def _get_new_response(
11611241
previous_response_id: str | None,
11621242
prompt_config: ResponsePromptParam | None,
11631243
) -> ModelResponse:
1244+
# Allow user to modify model input right before the call, if configured
1245+
filtered = await cls._maybe_filter_model_input(
1246+
agent=agent,
1247+
run_config=run_config,
1248+
context_wrapper=context_wrapper,
1249+
input_items=input,
1250+
system_instructions=system_prompt,
1251+
)
1252+
11641253
model = cls._get_model(agent, run_config)
11651254
model_settings = agent.model_settings.resolve(run_config.model_settings)
11661255
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
11671256

11681257
new_response = await model.get_response(
1169-
system_instructions=system_prompt,
1170-
input=input,
1258+
system_instructions=filtered.instructions,
1259+
input=filtered.input,
11711260
model_settings=model_settings,
11721261
tools=all_tools,
11731262
output_schema=output_schema,

0 commit comments

Comments
 (0)