-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Fix #1238 by enhancing HandoffInputData and enable passing async functions #1302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
import inspect | ||
import json | ||
from collections.abc import Awaitable | ||
from dataclasses import dataclass | ||
from dataclasses import dataclass, replace as dataclasses_replace | ||
from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload | ||
|
||
from pydantic import TypeAdapter | ||
|
@@ -49,8 +49,24 @@ class HandoffInputData: | |
handoff and the tool output message representing the response from the handoff output. | ||
""" | ||
|
||
run_context: RunContextWrapper[Any] | None = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if a different name is preferred, happy to rename this |
||
""" | ||
The run context at the time the handoff was invoked. | ||
Note that, since this property was added later on, it's optional for backwards compatibility. | ||
""" | ||
|
||
def clone(self, **kwargs: Any) -> HandoffInputData: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added clone method for convenience but if we don't want to have this, happy to revert this change |
||
""" | ||
Make a copy of the handoff input data, with the given arguments changed. For example, you | ||
could do: | ||
``` | ||
new_handoff_input_data = handoff_input_data.clone(new_items=()) | ||
``` | ||
""" | ||
return dataclasses_replace(self, **kwargs) | ||
|
||
|
||
HandoffInputFilter: TypeAlias = Callable[[HandoffInputData], HandoffInputData] | ||
HandoffInputFilter: TypeAlias = Callable[[HandoffInputData], MaybeAwaitable[HandoffInputData]] | ||
"""A function that filters the input data passed to the next agent.""" | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -224,6 +224,7 @@ def remove_new_items(handoff_input_data: HandoffInputData) -> HandoffInputData: | |
input_history=handoff_input_data.input_history, | ||
pre_handoff_items=(), | ||
new_items=(), | ||
run_context=handoff_input_data.run_context, | ||
) | ||
|
||
|
||
|
@@ -262,7 +263,7 @@ async def test_handoff_filters(): | |
|
||
|
||
@pytest.mark.asyncio | ||
async def test_async_input_filter_fails(): | ||
async def test_async_input_filter_supported(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we didn't intentionally block async functions, i believe changing this is a good addition |
||
# DO NOT rename this without updating pyproject.toml | ||
|
||
model = FakeModel() | ||
|
@@ -274,7 +275,7 @@ async def test_async_input_filter_fails(): | |
async def on_invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Agent[Any]: | ||
return agent_1 | ||
|
||
async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData: | ||
async def async_input_filter(data: HandoffInputData) -> HandoffInputData: | ||
return data # pragma: no cover | ||
|
||
agent_2 = Agent[None]( | ||
|
@@ -287,8 +288,7 @@ async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData: | |
input_json_schema={}, | ||
on_invoke_handoff=on_invoke_handoff, | ||
agent_name=agent_1.name, | ||
# Purposely ignoring the type error here to simulate invalid input | ||
input_filter=invalid_input_filter, # type: ignore | ||
input_filter=async_input_filter, | ||
) | ||
], | ||
) | ||
|
@@ -300,8 +300,8 @@ async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData: | |
] | ||
) | ||
|
||
with pytest.raises(UserError): | ||
await Runner.run(agent_2, input="user_message") | ||
result = await Runner.run(agent_2, input="user_message") | ||
assert result.final_output == "last" | ||
|
||
|
||
@pytest.mark.asyncio | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -241,6 +241,7 @@ def remove_new_items(handoff_input_data: HandoffInputData) -> HandoffInputData: | |
input_history=handoff_input_data.input_history, | ||
pre_handoff_items=(), | ||
new_items=(), | ||
run_context=handoff_input_data.run_context, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added run_context param in tests but the ones under examples do not have the param. this means backward compatibility is kept |
||
) | ||
|
||
|
||
|
@@ -281,7 +282,7 @@ async def test_handoff_filters(): | |
|
||
|
||
@pytest.mark.asyncio | ||
async def test_async_input_filter_fails(): | ||
async def test_async_input_filter_supported(): | ||
# DO NOT rename this without updating pyproject.toml | ||
|
||
model = FakeModel() | ||
|
@@ -293,7 +294,7 @@ async def test_async_input_filter_fails(): | |
async def on_invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Agent[Any]: | ||
return agent_1 | ||
|
||
async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData: | ||
async def async_input_filter(data: HandoffInputData) -> HandoffInputData: | ||
return data # pragma: no cover | ||
|
||
agent_2 = Agent[None]( | ||
|
@@ -306,8 +307,7 @@ async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData: | |
input_json_schema={}, | ||
on_invoke_handoff=on_invoke_handoff, | ||
agent_name=agent_1.name, | ||
# Purposely ignoring the type error here to simulate invalid input | ||
input_filter=invalid_input_filter, # type: ignore | ||
input_filter=async_input_filter, | ||
) | ||
], | ||
) | ||
|
@@ -319,10 +319,9 @@ async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData: | |
] | ||
) | ||
|
||
with pytest.raises(UserError): | ||
result = Runner.run_streamed(agent_2, input="user_message") | ||
async for _ in result.stream_events(): | ||
pass | ||
result = Runner.run_streamed(agent_2, input="user_message") | ||
async for _ in result.stream_events(): | ||
pass | ||
|
||
|
||
@pytest.mark.asyncio | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to support this use case: #1238 (comment)