Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/handoffs/message_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def spanish_handoff_message_filter(handoff_message_data: HandoffInputData) -> Ha
else handoff_message_data.input_history
)

# or, you can use the HandoffInputData.clone(kwargs) method
return HandoffInputData(
input_history=history,
pre_handoff_items=tuple(handoff_message_data.pre_handoff_items),
Expand Down
1 change: 1 addition & 0 deletions examples/handoffs/message_filter_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def spanish_handoff_message_filter(handoff_message_data: HandoffInputData) -> Ha
else handoff_message_data.input_history
)

# or, you can use the HandoffInputData.clone(kwargs) method
return HandoffInputData(
input_history=history,
pre_handoff_items=tuple(handoff_message_data.pre_handoff_items),
Expand Down
3 changes: 3 additions & 0 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ async def execute_handoffs(
else original_input,
pre_handoff_items=tuple(pre_step_items),
new_items=tuple(new_step_items),
run_context=context_wrapper,
)
if not callable(input_filter):
_error_tracing.attach_error_to_span(
Expand All @@ -785,6 +786,8 @@ async def execute_handoffs(
)
raise UserError(f"Invalid input filter: {input_filter}")
filtered = input_filter(handoff_input_data)
if inspect.isawaitable(filtered):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to support this use case: #1238 (comment)

filtered = await filtered
if not isinstance(filtered, HandoffInputData):
_error_tracing.attach_error_to_span(
span_handoff,
Expand Down
1 change: 1 addition & 0 deletions src/agents/extensions/handoff_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def remove_all_tools(handoff_input_data: HandoffInputData) -> HandoffInputData:
input_history=filtered_history,
pre_handoff_items=filtered_pre_handoff_items,
new_items=filtered_new_items,
run_context=handoff_input_data.run_context,
)


Expand Down
20 changes: 18 additions & 2 deletions src/agents/handoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import json
from collections.abc import Awaitable
from dataclasses import dataclass
from dataclasses import dataclass, replace as dataclasses_replace
from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload

from pydantic import TypeAdapter
Expand Down Expand Up @@ -49,8 +49,24 @@ class HandoffInputData:
handoff and the tool output message representing the response from the handoff output.
"""

run_context: RunContextWrapper[Any] | None = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a different name is preferred, happy to rename this

"""
The run context at the time the handoff was invoked.
Note that, since this property was added later on, it's optional for backwards compatibility.
"""

def clone(self, **kwargs: Any) -> HandoffInputData:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added clone method for convenience but if we don't want to have this, happy to revert this change

"""
Make a copy of the handoff input data, with the given arguments changed. For example, you
could do:
```
new_handoff_input_data = handoff_input_data.clone(new_items=())
```
"""
return dataclasses_replace(self, **kwargs)


HandoffInputFilter: TypeAlias = Callable[[HandoffInputData], HandoffInputData]
HandoffInputFilter: TypeAlias = Callable[[HandoffInputData], MaybeAwaitable[HandoffInputData]]
"""A function that filters the input data passed to the next agent."""


Expand Down
12 changes: 6 additions & 6 deletions tests/test_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def remove_new_items(handoff_input_data: HandoffInputData) -> HandoffInputData:
input_history=handoff_input_data.input_history,
pre_handoff_items=(),
new_items=(),
run_context=handoff_input_data.run_context,
)


Expand Down Expand Up @@ -262,7 +263,7 @@ async def test_handoff_filters():


@pytest.mark.asyncio
async def test_async_input_filter_fails():
async def test_async_input_filter_supported():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we didn't intentionally block async functions, i believe changing this is a good addition

# DO NOT rename this without updating pyproject.toml

model = FakeModel()
Expand All @@ -274,7 +275,7 @@ async def test_async_input_filter_fails():
async def on_invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Agent[Any]:
return agent_1

async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData:
async def async_input_filter(data: HandoffInputData) -> HandoffInputData:
return data # pragma: no cover

agent_2 = Agent[None](
Expand All @@ -287,8 +288,7 @@ async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData:
input_json_schema={},
on_invoke_handoff=on_invoke_handoff,
agent_name=agent_1.name,
# Purposely ignoring the type error here to simulate invalid input
input_filter=invalid_input_filter, # type: ignore
input_filter=async_input_filter,
)
],
)
Expand All @@ -300,8 +300,8 @@ async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData:
]
)

with pytest.raises(UserError):
await Runner.run(agent_2, input="user_message")
result = await Runner.run(agent_2, input="user_message")
assert result.final_output == "last"


@pytest.mark.asyncio
Expand Down
15 changes: 7 additions & 8 deletions tests/test_agent_runner_streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def remove_new_items(handoff_input_data: HandoffInputData) -> HandoffInputData:
input_history=handoff_input_data.input_history,
pre_handoff_items=(),
new_items=(),
run_context=handoff_input_data.run_context,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added run_context param in tests but the ones under examples do not have the param. this means backward compatibility is kept

)


Expand Down Expand Up @@ -281,7 +282,7 @@ async def test_handoff_filters():


@pytest.mark.asyncio
async def test_async_input_filter_fails():
async def test_async_input_filter_supported():
# DO NOT rename this without updating pyproject.toml

model = FakeModel()
Expand All @@ -293,7 +294,7 @@ async def test_async_input_filter_fails():
async def on_invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Agent[Any]:
return agent_1

async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData:
async def async_input_filter(data: HandoffInputData) -> HandoffInputData:
return data # pragma: no cover

agent_2 = Agent[None](
Expand All @@ -306,8 +307,7 @@ async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData:
input_json_schema={},
on_invoke_handoff=on_invoke_handoff,
agent_name=agent_1.name,
# Purposely ignoring the type error here to simulate invalid input
input_filter=invalid_input_filter, # type: ignore
input_filter=async_input_filter,
)
],
)
Expand All @@ -319,10 +319,9 @@ async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData:
]
)

with pytest.raises(UserError):
result = Runner.run_streamed(agent_2, input="user_message")
async for _ in result.stream_events():
pass
result = Runner.run_streamed(agent_2, input="user_message")
async for _ in result.stream_events():
pass


@pytest.mark.asyncio
Expand Down
22 changes: 19 additions & 3 deletions tests/test_extension_filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from openai.types.responses import ResponseOutputMessage, ResponseOutputText

from agents import Agent, HandoffInputData
from agents import Agent, HandoffInputData, RunContextWrapper
from agents.extensions.handoff_filters import remove_all_tools
from agents.items import (
HandoffOutputItem,
Expand Down Expand Up @@ -78,13 +78,23 @@ def _get_handoff_output_run_item(content: str) -> HandoffOutputItem:


def test_empty_data():
handoff_input_data = HandoffInputData(input_history=(), pre_handoff_items=(), new_items=())
handoff_input_data = HandoffInputData(
input_history=(),
pre_handoff_items=(),
new_items=(),
run_context=RunContextWrapper(context=()),
)
filtered_data = remove_all_tools(handoff_input_data)
assert filtered_data == handoff_input_data


def test_str_historyonly():
handoff_input_data = HandoffInputData(input_history="Hello", pre_handoff_items=(), new_items=())
handoff_input_data = HandoffInputData(
input_history="Hello",
pre_handoff_items=(),
new_items=(),
run_context=RunContextWrapper(context=()),
)
filtered_data = remove_all_tools(handoff_input_data)
assert filtered_data == handoff_input_data

Expand All @@ -94,6 +104,7 @@ def test_str_history_and_list():
input_history="Hello",
pre_handoff_items=(),
new_items=(_get_message_output_run_item("Hello"),),
run_context=RunContextWrapper(context=()),
)
filtered_data = remove_all_tools(handoff_input_data)
assert filtered_data == handoff_input_data
Expand All @@ -104,6 +115,7 @@ def test_list_history_and_list():
input_history=(_get_message_input_item("Hello"),),
pre_handoff_items=(_get_message_output_run_item("123"),),
new_items=(_get_message_output_run_item("World"),),
run_context=RunContextWrapper(context=()),
)
filtered_data = remove_all_tools(handoff_input_data)
assert filtered_data == handoff_input_data
Expand All @@ -121,6 +133,7 @@ def test_removes_tools_from_history():
_get_message_output_run_item("123"),
),
new_items=(_get_message_output_run_item("World"),),
run_context=RunContextWrapper(context=()),
)
filtered_data = remove_all_tools(handoff_input_data)
assert len(filtered_data.input_history) == 2
Expand All @@ -136,6 +149,7 @@ def test_removes_tools_from_new_items():
_get_message_output_run_item("Hello"),
_get_tool_output_run_item("World"),
),
run_context=RunContextWrapper(context=()),
)
filtered_data = remove_all_tools(handoff_input_data)
assert len(filtered_data.input_history) == 0
Expand All @@ -158,6 +172,7 @@ def test_removes_tools_from_new_items_and_history():
_get_message_output_run_item("Hello"),
_get_tool_output_run_item("World"),
),
run_context=RunContextWrapper(context=()),
)
filtered_data = remove_all_tools(handoff_input_data)
assert len(filtered_data.input_history) == 2
Expand All @@ -181,6 +196,7 @@ def test_removes_handoffs_from_history():
_get_tool_output_run_item("World"),
_get_handoff_output_run_item("World"),
),
run_context=RunContextWrapper(context=()),
)
filtered_data = remove_all_tools(handoff_input_data)
assert len(filtered_data.input_history) == 1
Expand Down
5 changes: 5 additions & 0 deletions tests/test_handoff_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,15 @@ def test_handoff_input_data():
input_history="",
pre_handoff_items=(),
new_items=(),
run_context=RunContextWrapper(context=()),
)
assert get_len(data) == 1

data = HandoffInputData(
input_history=({"role": "user", "content": "foo"},),
pre_handoff_items=(),
new_items=(),
run_context=RunContextWrapper(context=()),
)
assert get_len(data) == 1

Expand All @@ -238,6 +240,7 @@ def test_handoff_input_data():
),
pre_handoff_items=(),
new_items=(),
run_context=RunContextWrapper(context=()),
)
assert get_len(data) == 2

Expand All @@ -251,6 +254,7 @@ def test_handoff_input_data():
message_item("bar", agent),
message_item("baz", agent),
),
run_context=RunContextWrapper(context=()),
)
assert get_len(data) == 5

Expand All @@ -264,6 +268,7 @@ def test_handoff_input_data():
message_item("baz", agent),
message_item("qux", agent),
),
run_context=RunContextWrapper(context=()),
)

assert get_len(data) == 5
Expand Down