Skip to content

Commit 283ead8

Browse files
fix: #2008 Fix agent memory leak using weakref (#2014)
Co-authored-by: Kazuhiro Sera <[email protected]>
1 parent a9e63a3 commit 283ead8

File tree

6 files changed

+532
-46
lines changed

6 files changed

+532
-46
lines changed

src/agents/items.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
import abc
4-
from dataclasses import dataclass
4+
import weakref
5+
from dataclasses import dataclass, field
56
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union, cast
67

78
import pydantic
@@ -72,6 +73,9 @@
7273

7374
T = TypeVar("T", bound=Union[TResponseOutputItem, TResponseInputItem])
7475

76+
# Distinguish a missing dict entry from an explicit None value.
77+
_MISSING_ATTR_SENTINEL = object()
78+
7579

7680
@dataclass
7781
class RunItemBase(Generic[T], abc.ABC):
@@ -84,6 +88,49 @@ class RunItemBase(Generic[T], abc.ABC):
8488
(i.e. `openai.types.responses.ResponseInputItemParam`).
8589
"""
8690

91+
_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field(
92+
init=False,
93+
repr=False,
94+
default=None,
95+
)
96+
97+
def __post_init__(self) -> None:
98+
# Store a weak reference so we can release the strong reference later if desired.
99+
self._agent_ref = weakref.ref(self.agent)
100+
101+
def __getattribute__(self, name: str) -> Any:
102+
if name == "agent":
103+
return self._get_agent_via_weakref("agent", "_agent_ref")
104+
return super().__getattribute__(name)
105+
106+
def release_agent(self) -> None:
107+
"""Release the strong reference to the agent while keeping a weak reference."""
108+
if "agent" not in self.__dict__:
109+
return
110+
agent = self.__dict__["agent"]
111+
if agent is None:
112+
return
113+
self._agent_ref = weakref.ref(agent) if agent is not None else None
114+
# Set to None instead of deleting so dataclass repr/asdict keep working.
115+
self.__dict__["agent"] = None
116+
117+
def _get_agent_via_weakref(self, attr_name: str, ref_name: str) -> Any:
118+
# Preserve the dataclass field so repr/asdict still read it, but lazily resolve the weakref
119+
# when the stored value is None (meaning release_agent already dropped the strong ref).
120+
# If the attribute was never overridden we fall back to the default descriptor chain.
121+
data = object.__getattribute__(self, "__dict__")
122+
value = data.get(attr_name, _MISSING_ATTR_SENTINEL)
123+
if value is _MISSING_ATTR_SENTINEL:
124+
return object.__getattribute__(self, attr_name)
125+
if value is not None:
126+
return value
127+
ref = object.__getattribute__(self, ref_name)
128+
if ref is not None:
129+
agent = ref()
130+
if agent is not None:
131+
return agent
132+
return None
133+
87134
def to_input_item(self) -> TResponseInputItem:
88135
"""Converts this item into an input item suitable for passing to the model."""
89136
if isinstance(self.raw_item, dict):
@@ -131,6 +178,48 @@ class HandoffOutputItem(RunItemBase[TResponseInputItem]):
131178

132179
type: Literal["handoff_output_item"] = "handoff_output_item"
133180

181+
_source_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field(
182+
init=False,
183+
repr=False,
184+
default=None,
185+
)
186+
_target_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field(
187+
init=False,
188+
repr=False,
189+
default=None,
190+
)
191+
192+
def __post_init__(self) -> None:
193+
super().__post_init__()
194+
# Maintain weak references so downstream code can release the strong references when safe.
195+
self._source_agent_ref = weakref.ref(self.source_agent)
196+
self._target_agent_ref = weakref.ref(self.target_agent)
197+
198+
def __getattribute__(self, name: str) -> Any:
199+
if name == "source_agent":
200+
# Provide lazy weakref access like the base `agent` field so HandoffOutputItem
201+
# callers keep seeing the original agent until GC occurs.
202+
return self._get_agent_via_weakref("source_agent", "_source_agent_ref")
203+
if name == "target_agent":
204+
# Same as above but for the target of the handoff.
205+
return self._get_agent_via_weakref("target_agent", "_target_agent_ref")
206+
return super().__getattribute__(name)
207+
208+
def release_agent(self) -> None:
209+
super().release_agent()
210+
if "source_agent" in self.__dict__:
211+
source_agent = self.__dict__["source_agent"]
212+
if source_agent is not None:
213+
self._source_agent_ref = weakref.ref(source_agent)
214+
# Preserve dataclass fields for repr/asdict while dropping strong refs.
215+
self.__dict__["source_agent"] = None
216+
if "target_agent" in self.__dict__:
217+
target_agent = self.__dict__["target_agent"]
218+
if target_agent is not None:
219+
self._target_agent_ref = weakref.ref(target_agent)
220+
# Preserve dataclass fields for repr/asdict while dropping strong refs.
221+
self.__dict__["target_agent"] = None
222+
134223

135224
ToolCallItemTypes: TypeAlias = Union[
136225
ResponseFunctionToolCall,

src/agents/result.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import abc
44
import asyncio
5+
import weakref
56
from collections.abc import AsyncIterator
67
from dataclasses import dataclass, field
78
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -74,6 +75,35 @@ class RunResultBase(abc.ABC):
7475
def last_agent(self) -> Agent[Any]:
7576
"""The last agent that was run."""
7677

78+
def release_agents(self, *, release_new_items: bool = True) -> None:
79+
"""
80+
Release strong references to agents held by this result. After calling this method,
81+
accessing `item.agent` or `last_agent` may return `None` if the agent has been garbage
82+
collected. Callers can use this when they are done inspecting the result and want to
83+
eagerly drop any associated agent graph.
84+
"""
85+
if release_new_items:
86+
for item in self.new_items:
87+
release = getattr(item, "release_agent", None)
88+
if callable(release):
89+
release()
90+
self._release_last_agent_reference()
91+
92+
def __del__(self) -> None:
93+
try:
94+
# Fall back to releasing agents automatically in case the caller never invoked
95+
# `release_agents()` explicitly so GC of the RunResult drops the last strong reference.
96+
# We pass `release_new_items=False` so RunItems that the user intentionally keeps
97+
# continue exposing their originating agent until that agent itself is collected.
98+
self.release_agents(release_new_items=False)
99+
except Exception:
100+
# Avoid raising from __del__.
101+
pass
102+
103+
@abc.abstractmethod
104+
def _release_last_agent_reference(self) -> None:
105+
"""Release stored agent reference specific to the concrete result type."""
106+
77107
def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) -> T:
78108
"""A convenience method to cast the final output to a specific type. By default, the cast
79109
is only for the typechecker. If you set `raise_if_incorrect_type` to True, we'll raise a
@@ -111,11 +141,34 @@ def last_response_id(self) -> str | None:
111141
@dataclass
112142
class RunResult(RunResultBase):
113143
_last_agent: Agent[Any]
144+
_last_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field(
145+
init=False,
146+
repr=False,
147+
default=None,
148+
)
149+
150+
def __post_init__(self) -> None:
151+
self._last_agent_ref = weakref.ref(self._last_agent)
114152

115153
@property
116154
def last_agent(self) -> Agent[Any]:
117155
"""The last agent that was run."""
118-
return self._last_agent
156+
agent = cast("Agent[Any] | None", self.__dict__.get("_last_agent"))
157+
if agent is not None:
158+
return agent
159+
if self._last_agent_ref:
160+
agent = self._last_agent_ref()
161+
if agent is not None:
162+
return agent
163+
raise AgentsException("Last agent reference is no longer available.")
164+
165+
def _release_last_agent_reference(self) -> None:
166+
agent = cast("Agent[Any] | None", self.__dict__.get("_last_agent"))
167+
if agent is None:
168+
return
169+
self._last_agent_ref = weakref.ref(agent)
170+
# Preserve dataclass field so repr/asdict continue to succeed.
171+
self.__dict__["_last_agent"] = None
119172

120173
def __str__(self) -> str:
121174
return pretty_print_result(self)
@@ -150,6 +203,12 @@ class RunResultStreaming(RunResultBase):
150203
is_complete: bool = False
151204
"""Whether the agent has finished running."""
152205

206+
_current_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field(
207+
init=False,
208+
repr=False,
209+
default=None,
210+
)
211+
153212
# Queues that the background run_loop writes to
154213
_event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field(
155214
default_factory=asyncio.Queue, repr=False
@@ -167,12 +226,30 @@ class RunResultStreaming(RunResultBase):
167226
# Soft cancel state
168227
_cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False)
169228

229+
def __post_init__(self) -> None:
230+
self._current_agent_ref = weakref.ref(self.current_agent)
231+
170232
@property
171233
def last_agent(self) -> Agent[Any]:
172234
"""The last agent that was run. Updates as the agent run progresses, so the true last agent
173235
is only available after the agent run is complete.
174236
"""
175-
return self.current_agent
237+
agent = cast("Agent[Any] | None", self.__dict__.get("current_agent"))
238+
if agent is not None:
239+
return agent
240+
if self._current_agent_ref:
241+
agent = self._current_agent_ref()
242+
if agent is not None:
243+
return agent
244+
raise AgentsException("Last agent reference is no longer available.")
245+
246+
def _release_last_agent_reference(self) -> None:
247+
agent = cast("Agent[Any] | None", self.__dict__.get("current_agent"))
248+
if agent is None:
249+
return
250+
self._current_agent_ref = weakref.ref(agent)
251+
# Preserve dataclass field so repr/asdict continue to succeed.
252+
self.__dict__["current_agent"] = None
176253

177254
def cancel(self, mode: Literal["immediate", "after_turn"] = "immediate") -> None:
178255
"""Cancel the streaming run.

src/agents/run.py

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -684,51 +684,60 @@ async def run(
684684
tool_input_guardrail_results.extend(turn_result.tool_input_guardrail_results)
685685
tool_output_guardrail_results.extend(turn_result.tool_output_guardrail_results)
686686

687-
if isinstance(turn_result.next_step, NextStepFinalOutput):
688-
output_guardrail_results = await self._run_output_guardrails(
689-
current_agent.output_guardrails + (run_config.output_guardrails or []),
690-
current_agent,
691-
turn_result.next_step.output,
692-
context_wrapper,
693-
)
694-
result = RunResult(
695-
input=original_input,
696-
new_items=generated_items,
697-
raw_responses=model_responses,
698-
final_output=turn_result.next_step.output,
699-
_last_agent=current_agent,
700-
input_guardrail_results=input_guardrail_results,
701-
output_guardrail_results=output_guardrail_results,
702-
tool_input_guardrail_results=tool_input_guardrail_results,
703-
tool_output_guardrail_results=tool_output_guardrail_results,
704-
context_wrapper=context_wrapper,
705-
)
706-
if not any(
707-
guardrail_result.output.tripwire_triggered
708-
for guardrail_result in input_guardrail_results
709-
):
710-
await self._save_result_to_session(
711-
session, [], turn_result.new_step_items
687+
try:
688+
if isinstance(turn_result.next_step, NextStepFinalOutput):
689+
output_guardrail_results = await self._run_output_guardrails(
690+
current_agent.output_guardrails
691+
+ (run_config.output_guardrails or []),
692+
current_agent,
693+
turn_result.next_step.output,
694+
context_wrapper,
695+
)
696+
result = RunResult(
697+
input=original_input,
698+
new_items=generated_items,
699+
raw_responses=model_responses,
700+
final_output=turn_result.next_step.output,
701+
_last_agent=current_agent,
702+
input_guardrail_results=input_guardrail_results,
703+
output_guardrail_results=output_guardrail_results,
704+
tool_input_guardrail_results=tool_input_guardrail_results,
705+
tool_output_guardrail_results=tool_output_guardrail_results,
706+
context_wrapper=context_wrapper,
712707
)
708+
if not any(
709+
guardrail_result.output.tripwire_triggered
710+
for guardrail_result in input_guardrail_results
711+
):
712+
await self._save_result_to_session(
713+
session, [], turn_result.new_step_items
714+
)
713715

714-
return result
715-
elif isinstance(turn_result.next_step, NextStepHandoff):
716-
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
717-
current_span.finish(reset_current=True)
718-
current_span = None
719-
should_run_agent_start_hooks = True
720-
elif isinstance(turn_result.next_step, NextStepRunAgain):
721-
if not any(
722-
guardrail_result.output.tripwire_triggered
723-
for guardrail_result in input_guardrail_results
724-
):
725-
await self._save_result_to_session(
726-
session, [], turn_result.new_step_items
716+
return result
717+
elif isinstance(turn_result.next_step, NextStepHandoff):
718+
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
719+
current_span.finish(reset_current=True)
720+
current_span = None
721+
should_run_agent_start_hooks = True
722+
elif isinstance(turn_result.next_step, NextStepRunAgain):
723+
if not any(
724+
guardrail_result.output.tripwire_triggered
725+
for guardrail_result in input_guardrail_results
726+
):
727+
await self._save_result_to_session(
728+
session, [], turn_result.new_step_items
729+
)
730+
else:
731+
raise AgentsException(
732+
f"Unknown next step type: {type(turn_result.next_step)}"
727733
)
728-
else:
729-
raise AgentsException(
730-
f"Unknown next step type: {type(turn_result.next_step)}"
731-
)
734+
finally:
735+
# RunImpl.execute_tools_and_side_effects returns a SingleStepResult that
736+
# stores direct references to the `pre_step_items` and `new_step_items`
737+
# lists it manages internally. Clear them here so the next turn does not
738+
# hold on to items from previous turns and to avoid leaking agent refs.
739+
turn_result.pre_step_items.clear()
740+
turn_result.new_step_items.clear()
732741
except AgentsException as exc:
733742
exc.run_data = RunErrorDetails(
734743
input=original_input,

tests/test_agent_memory_leak.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
import gc
4+
import weakref
5+
6+
import pytest
7+
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
8+
9+
from agents import Agent, Runner
10+
from tests.fake_model import FakeModel
11+
12+
13+
def _make_message(text: str) -> ResponseOutputMessage:
14+
return ResponseOutputMessage(
15+
id="msg-1",
16+
content=[ResponseOutputText(annotations=[], text=text, type="output_text")],
17+
role="assistant",
18+
status="completed",
19+
type="message",
20+
)
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_agent_is_released_after_run() -> None:
25+
fake_model = FakeModel(initial_output=[_make_message("Paris")])
26+
agent = Agent(name="leak-test-agent", instructions="Answer questions.", model=fake_model)
27+
agent_ref = weakref.ref(agent)
28+
29+
# Running the agent should not leave behind strong references once the result goes out of scope.
30+
await Runner.run(agent, "What is the capital of France?")
31+
32+
del agent
33+
gc.collect()
34+
35+
assert agent_ref() is None

0 commit comments

Comments
 (0)