Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 95 additions & 6 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import copy
import inspect
from dataclasses import dataclass, field
from typing import Any, Generic, cast
from typing import Any, Callable, Generic, cast

from openai.types.responses import ResponseCompletedEvent
from openai.types.responses.response_prompt_param import (
Expand Down Expand Up @@ -56,6 +56,7 @@
from .tracing.span_data import AgentSpanData
from .usage import Usage
from .util import _coro, _error_tracing
from .util._types import MaybeAwaitable

DEFAULT_MAX_TURNS = 10

Expand All @@ -81,6 +82,27 @@ def get_default_agent_runner() -> AgentRunner:
return DEFAULT_AGENT_RUNNER


@dataclass
class ModelInputData:
"""Container for the data that will be sent to the model."""

input: list[TResponseInputItem]
instructions: str | None


@dataclass
class CallModelData(Generic[TContext]):
"""Data passed to `RunConfig.call_model_input_filter` prior to model call."""

model_data: ModelInputData
agent: Agent[TContext]
context: TContext | None


# Type alias for the optional input filter callback
CallModelInputFilter = Callable[[CallModelData[Any]], MaybeAwaitable[ModelInputData]]


@dataclass
class RunConfig:
"""Configures settings for the entire agent run."""
Expand Down Expand Up @@ -139,6 +161,16 @@ class RunConfig:
An optional dictionary of additional metadata to include with the trace.
"""

call_model_input_filter: CallModelInputFilter | None = None
"""
Optional callback that is invoked immediately before calling the model. It receives the current
agent, context and the model input (instructions and input items), and must return a possibly
modified `ModelInputData` to use for the model call.

This allows you to edit the input sent to the model e.g. to stay within a token limit.
For example, you can use this to add a system prompt to the input.
"""


class RunOptions(TypedDict, Generic[TContext]):
"""Arguments for ``AgentRunner`` methods."""
Expand Down Expand Up @@ -593,6 +625,47 @@ def run_streamed(
)
return streamed_result

@classmethod
async def _maybe_filter_model_input(
cls,
*,
agent: Agent[TContext],
run_config: RunConfig,
context_wrapper: RunContextWrapper[TContext],
input_items: list[TResponseInputItem],
system_instructions: str | None,
) -> ModelInputData:
"""Apply optional call_model_input_filter to modify model input.

Returns a `ModelInputData` that will be sent to the model.
"""
effective_instructions = system_instructions
effective_input: list[TResponseInputItem] = input_items

if run_config.call_model_input_filter is None:
return ModelInputData(input=effective_input, instructions=effective_instructions)

try:
model_input = ModelInputData(
input=copy.deepcopy(effective_input),
instructions=effective_instructions,
)
filter_payload: CallModelData[TContext] = CallModelData(
model_data=model_input,
agent=agent,
context=context_wrapper.context,
)
maybe_updated = run_config.call_model_input_filter(filter_payload)
updated = await maybe_updated if inspect.isawaitable(maybe_updated) else maybe_updated
if not isinstance(updated, ModelInputData):
raise UserError("call_model_input_filter must return a ModelInputData instance")
return updated
except Exception as e:
_error_tracing.attach_error_to_current_span(
SpanError(message="Error in call_model_input_filter", data={"error": str(e)})
)
raise

@classmethod
async def _run_input_guardrails_with_queue(
cls,
Expand Down Expand Up @@ -863,10 +936,18 @@ async def _run_single_turn_streamed(
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
input.extend([item.to_input_item() for item in streamed_result.new_items])

filtered = await cls._maybe_filter_model_input(
agent=agent,
run_config=run_config,
context_wrapper=context_wrapper,
input_items=input,
system_instructions=system_prompt,
)

# 1. Stream the output events
async for event in model.stream_response(
system_prompt,
input,
filtered.instructions,
filtered.input,
model_settings,
all_tools,
output_schema,
Expand Down Expand Up @@ -1034,7 +1115,6 @@ async def _get_single_step_result_from_streamed_response(
run_config: RunConfig,
tool_use_tracker: AgentToolUseTracker,
) -> SingleStepResult:

original_input = streamed_result.input
pre_step_items = streamed_result.new_items
event_queue = streamed_result._event_queue
Expand Down Expand Up @@ -1161,13 +1241,22 @@ async def _get_new_response(
previous_response_id: str | None,
prompt_config: ResponsePromptParam | None,
) -> ModelResponse:
# Allow user to modify model input right before the call, if configured
filtered = await cls._maybe_filter_model_input(
agent=agent,
run_config=run_config,
context_wrapper=context_wrapper,
input_items=input,
system_instructions=system_prompt,
)

model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings)
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)

new_response = await model.get_response(
system_instructions=system_prompt,
input=input,
system_instructions=filtered.instructions,
input=filtered.input,
model_settings=model_settings,
tools=all_tools,
output_schema=output_schema,
Expand Down
79 changes: 79 additions & 0 deletions tests/test_call_model_input_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

from typing import Any

import pytest

from agents import Agent, RunConfig, Runner, UserError
from agents.run import CallModelData, ModelInputData

from .fake_model import FakeModel
from .test_responses import get_text_input_item, get_text_message


@pytest.mark.asyncio
async def test_call_model_input_filter_sync_non_streamed() -> None:
model = FakeModel()
agent = Agent(name="test", model=model)

# Prepare model output
model.set_next_output([get_text_message("ok")])

def filter_fn(data: CallModelData[Any]) -> ModelInputData:
mi = data.model_data
new_input = list(mi.input) + [get_text_input_item("added-sync")]
return ModelInputData(input=new_input, instructions="filtered-sync")

await Runner.run(
agent,
input="start",
run_config=RunConfig(call_model_input_filter=filter_fn),
)

assert model.last_turn_args["system_instructions"] == "filtered-sync"
assert isinstance(model.last_turn_args["input"], list)
assert len(model.last_turn_args["input"]) == 2
assert model.last_turn_args["input"][-1]["content"] == "added-sync"


@pytest.mark.asyncio
async def test_call_model_input_filter_async_streamed() -> None:
model = FakeModel()
agent = Agent(name="test", model=model)

# Prepare model output
model.set_next_output([get_text_message("ok")])

async def filter_fn(data: CallModelData[Any]) -> ModelInputData:
mi = data.model_data
new_input = list(mi.input) + [get_text_input_item("added-async")]
return ModelInputData(input=new_input, instructions="filtered-async")

result = Runner.run_streamed(
agent,
input="start",
run_config=RunConfig(call_model_input_filter=filter_fn),
)
async for _ in result.stream_events():
pass

assert model.last_turn_args["system_instructions"] == "filtered-async"
assert isinstance(model.last_turn_args["input"], list)
assert len(model.last_turn_args["input"]) == 2
assert model.last_turn_args["input"][-1]["content"] == "added-async"


@pytest.mark.asyncio
async def test_call_model_input_filter_invalid_return_type_raises() -> None:
model = FakeModel()
agent = Agent(name="test", model=model)

def invalid_filter(_data: CallModelData[Any]):
return "bad"

with pytest.raises(UserError):
await Runner.run(
agent,
input="start",
run_config=RunConfig(call_model_input_filter=invalid_filter),
)
107 changes: 107 additions & 0 deletions tests/test_call_model_input_filter_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from __future__ import annotations

import sys
from pathlib import Path
from typing import Any

import pytest
from openai.types.responses import ResponseOutputMessage, ResponseOutputText

# Make the repository tests helpers importable from this unit test
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "tests"))
from fake_model import FakeModel # type: ignore

# Import directly from submodules to avoid heavy __init__ side effects
from agents.agent import Agent
from agents.exceptions import UserError
from agents.run import CallModelData, ModelInputData, RunConfig, Runner


@pytest.mark.asyncio
async def test_call_model_input_filter_sync_non_streamed_unit() -> None:
model = FakeModel()
agent = Agent(name="test", model=model)

model.set_next_output(
[
ResponseOutputMessage(
id="1",
type="message",
role="assistant",
content=[ResponseOutputText(text="ok", type="output_text", annotations=[])],
status="completed",
)
]
)

def filter_fn(data: CallModelData[Any]) -> ModelInputData:
mi = data.model_data
new_input = list(mi.input) + [
{"content": "added-sync", "role": "user"}
] # pragma: no cover - trivial
return ModelInputData(input=new_input, instructions="filtered-sync")

await Runner.run(
agent,
input="start",
run_config=RunConfig(call_model_input_filter=filter_fn),
)

assert model.last_turn_args["system_instructions"] == "filtered-sync"
assert isinstance(model.last_turn_args["input"], list)
assert len(model.last_turn_args["input"]) == 2
assert model.last_turn_args["input"][-1]["content"] == "added-sync"


@pytest.mark.asyncio
async def test_call_model_input_filter_async_streamed_unit() -> None:
model = FakeModel()
agent = Agent(name="test", model=model)

model.set_next_output(
[
ResponseOutputMessage(
id="1",
type="message",
role="assistant",
content=[ResponseOutputText(text="ok", type="output_text", annotations=[])],
status="completed",
)
]
)

async def filter_fn(data: CallModelData[Any]) -> ModelInputData:
mi = data.model_data
new_input = list(mi.input) + [
{"content": "added-async", "role": "user"}
] # pragma: no cover - trivial
return ModelInputData(input=new_input, instructions="filtered-async")

result = Runner.run_streamed(
agent,
input="start",
run_config=RunConfig(call_model_input_filter=filter_fn),
)
async for _ in result.stream_events():
pass

assert model.last_turn_args["system_instructions"] == "filtered-async"
assert isinstance(model.last_turn_args["input"], list)
assert len(model.last_turn_args["input"]) == 2
assert model.last_turn_args["input"][-1]["content"] == "added-async"


@pytest.mark.asyncio
async def test_call_model_input_filter_invalid_return_type_raises_unit() -> None:
model = FakeModel()
agent = Agent(name="test", model=model)

def invalid_filter(_data: CallModelData[Any]):
return "bad"

with pytest.raises(UserError):
await Runner.run(
agent,
input="start",
run_config=RunConfig(call_model_input_filter=invalid_filter),
)