From e4b781667fe85eea86437a4a8cc0016152d7e81b Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 26 Aug 2025 21:15:13 +0900 Subject: [PATCH 01/13] Fix #1564 Add conversations API support --- src/agents/extensions/models/litellm_model.py | 6 +++-- src/agents/models/interface.py | 4 +++ src/agents/models/openai_chatcompletions.py | 6 +++-- src/agents/models/openai_responses.py | 19 +++++++++---- src/agents/run.py | 27 +++++++++++++++++++ tests/fake_model.py | 8 ++++-- tests/models/test_kwargs_functionality.py | 1 + .../test_litellm_chatcompletions_stream.py | 4 +++ tests/test_agent_prompt.py | 2 ++ tests/test_extra_headers.py | 1 + tests/test_openai_chatcompletions.py | 4 +++ tests/test_openai_chatcompletions_stream.py | 4 +++ tests/test_reasoning_content.py | 3 +++ tests/test_responses_tracing.py | 18 ++++++++----- tests/voice/test_workflow.py | 2 ++ 15 files changed, 92 insertions(+), 17 deletions(-) diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index fca172fff..b20c673af 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -82,7 +82,8 @@ async def get_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - previous_response_id: str | None, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused prompt: Any | None = None, ) -> ModelResponse: with generation_span( @@ -171,7 +172,8 @@ async def stream_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - previous_response_id: str | None, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused prompt: Any | None = None, ) -> AsyncIterator[TResponseStreamEvent]: with generation_span( diff --git a/src/agents/models/interface.py b/src/agents/models/interface.py index 5a185806c..f25934780 100644 --- a/src/agents/models/interface.py +++ b/src/agents/models/interface.py @@ -48,6 +48,7 @@ async def get_response( tracing: ModelTracing, *, previous_response_id: str | None, + conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> ModelResponse: """Get a response from the model. @@ -62,6 +63,7 @@ async def get_response( tracing: Tracing configuration. previous_response_id: the ID of the previous response. Generally not used by the model, except for the OpenAI Responses API. + conversation_id: The ID of the stored conversation, if any. prompt: The prompt config to use for the model. Returns: @@ -81,6 +83,7 @@ def stream_response( tracing: ModelTracing, *, previous_response_id: str | None, + conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> AsyncIterator[TResponseStreamEvent]: """Stream a response from the model. @@ -95,6 +98,7 @@ def stream_response( tracing: Tracing configuration. previous_response_id: the ID of the previous response. Generally not used by the model, except for the OpenAI Responses API. + conversation_id: The ID of the stored conversation, if any. prompt: The prompt config to use for the model. Returns: diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index c6d1d7d22..f4d75d833 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -55,7 +55,8 @@ async def get_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - previous_response_id: str | None, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused prompt: ResponsePromptParam | None = None, ) -> ModelResponse: with generation_span( @@ -142,7 +143,8 @@ async def stream_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - previous_response_id: str | None, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused prompt: ResponsePromptParam | None = None, ) -> AsyncIterator[TResponseStreamEvent]: """ diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 6405bd586..a30da0fc5 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -74,7 +74,8 @@ async def get_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - previous_response_id: str | None, + previous_response_id: str | None = None, + conversation_id: str | None = None, prompt: ResponsePromptParam | None = None, ) -> ModelResponse: with response_span(disabled=tracing.is_disabled()) as span_response: @@ -86,7 +87,8 @@ async def get_response( tools, output_schema, handoffs, - previous_response_id, + previous_response_id=previous_response_id, + conversation_id=conversation_id, stream=False, prompt=prompt, ) @@ -149,7 +151,8 @@ async def stream_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - previous_response_id: str | None, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused prompt: ResponsePromptParam | None = None, ) -> AsyncIterator[ResponseStreamEvent]: """ @@ -164,7 +167,8 @@ async def stream_response( tools, output_schema, handoffs, - previous_response_id, + previous_response_id=previous_response_id, + conversation_id=conversation_id, stream=True, prompt=prompt, ) @@ -202,6 +206,7 @@ async def _fetch_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None, + conversation_id: str | None, stream: Literal[True], prompt: ResponsePromptParam | None = None, ) -> AsyncStream[ResponseStreamEvent]: ... @@ -216,6 +221,7 @@ async def _fetch_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None, + conversation_id: str | None, stream: Literal[False], prompt: ResponsePromptParam | None = None, ) -> Response: ... @@ -228,7 +234,8 @@ async def _fetch_response( tools: list[Tool], output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], - previous_response_id: str | None, + previous_response_id: str | None = None, + conversation_id: str | None = None, stream: Literal[True] | Literal[False] = False, prompt: ResponsePromptParam | None = None, ) -> Response | AsyncStream[ResponseStreamEvent]: @@ -264,6 +271,7 @@ async def _fetch_response( f"Tool choice: {tool_choice}\n" f"Response format: {response_format}\n" f"Previous response id: {previous_response_id}\n" + f"Conversation id: {conversation_id}\n" ) extra_args = dict(model_settings.extra_args or {}) @@ -277,6 +285,7 @@ async def _fetch_response( return await self._client.responses.create( previous_response_id=self._non_null_or_not_given(previous_response_id), + conversation=self._non_null_or_not_given(conversation_id), instructions=self._non_null_or_not_given(system_instructions), model=self.model, input=list_input, diff --git a/src/agents/run.py b/src/agents/run.py index 727927b08..b34d3fc64 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -208,6 +208,9 @@ class RunOptions(TypedDict, Generic[TContext]): previous_response_id: NotRequired[str | None] """The ID of the previous response, if any.""" + conversation_id: NotRequired[str | None] + """The ID of the stored conversation, if any.""" + session: NotRequired[Session | None] """The session for the run.""" @@ -224,6 +227,7 @@ async def run( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + conversation_id: str | None = None, session: Session | None = None, ) -> RunResult: """Run a workflow starting at the given agent. The agent will run in a loop until a final @@ -248,6 +252,7 @@ async def run( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + conversation_id: The ID of the stored conversation, if any. Returns: A run result containing all the inputs, guardrail results and the output of the last agent. Agents may perform handoffs, so we don't know the specific type of the output. @@ -261,6 +266,7 @@ async def run( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + conversation_id=conversation_id, session=session, ) @@ -275,6 +281,7 @@ def run_sync( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + conversation_id: str | None = None, session: Session | None = None, ) -> RunResult: """Run a workflow synchronously, starting at the given agent. Note that this just wraps the @@ -302,6 +309,7 @@ def run_sync( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + conversation_id: The ID of the stored conversation, if any. Returns: A run result containing all the inputs, guardrail results and the output of the last agent. Agents may perform handoffs, so we don't know the specific type of the output. @@ -315,6 +323,7 @@ def run_sync( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + conversation_id=conversation_id, session=session, ) @@ -328,6 +337,7 @@ def run_streamed( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + conversation_id: str | None = None, session: Session | None = None, ) -> RunResultStreaming: """Run a workflow starting at the given agent in streaming mode. The returned result object @@ -353,6 +363,7 @@ def run_streamed( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. + conversation_id: The ID of the stored conversation, if any. Returns: A result object that contains data about the run, as well as a method to stream events. """ @@ -365,6 +376,7 @@ def run_streamed( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + conversation_id=conversation_id, session=session, ) @@ -386,6 +398,7 @@ async def run( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") if hooks is None: hooks = RunHooks[Any]() @@ -478,6 +491,7 @@ async def run( should_run_agent_start_hooks=should_run_agent_start_hooks, tool_use_tracker=tool_use_tracker, previous_response_id=previous_response_id, + conversation_id=conversation_id, ), ) else: @@ -492,6 +506,7 @@ async def run( should_run_agent_start_hooks=should_run_agent_start_hooks, tool_use_tracker=tool_use_tracker, previous_response_id=previous_response_id, + conversation_id=conversation_id, ) should_run_agent_start_hooks = False @@ -558,6 +573,7 @@ def run_sync( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") return asyncio.get_event_loop().run_until_complete( @@ -570,6 +586,7 @@ def run_sync( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + conversation_id=conversation_id, ) ) @@ -584,6 +601,7 @@ def run_streamed( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") if hooks is None: @@ -638,6 +656,7 @@ def run_streamed( context_wrapper=context_wrapper, run_config=run_config, previous_response_id=previous_response_id, + conversation_id=conversation_id, session=session, ) ) @@ -738,6 +757,7 @@ async def _start_streaming( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, previous_response_id: str | None, + conversation_id: str | None, session: Session | None, ): if streamed_result.trace: @@ -821,6 +841,7 @@ async def _start_streaming( tool_use_tracker, all_tools, previous_response_id, + conversation_id, ) should_run_agent_start_hooks = False @@ -923,6 +944,7 @@ async def _run_single_turn_streamed( tool_use_tracker: AgentToolUseTracker, all_tools: list[Tool], previous_response_id: str | None, + conversation_id: str | None, ) -> SingleStepResult: emitted_tool_call_ids: set[str] = set() @@ -983,6 +1005,7 @@ async def _run_single_turn_streamed( run_config.tracing_disabled, run_config.trace_include_sensitive_data ), previous_response_id=previous_response_id, + conversation_id=conversation_id, prompt=prompt_config, ): if isinstance(event, ResponseCompletedEvent): @@ -1091,6 +1114,7 @@ async def _run_single_turn( should_run_agent_start_hooks: bool, tool_use_tracker: AgentToolUseTracker, previous_response_id: str | None, + conversation_id: str | None, ) -> SingleStepResult: # Ensure we run the hooks before anything else if should_run_agent_start_hooks: @@ -1124,6 +1148,7 @@ async def _run_single_turn( run_config, tool_use_tracker, previous_response_id, + conversation_id, prompt_config, ) @@ -1318,6 +1343,7 @@ async def _get_new_response( run_config: RunConfig, tool_use_tracker: AgentToolUseTracker, previous_response_id: str | None, + conversation_id: str | None, prompt_config: ResponsePromptParam | None, ) -> ModelResponse: # Allow user to modify model input right before the call, if configured @@ -1352,6 +1378,7 @@ async def _get_new_response( run_config.tracing_disabled, run_config.trace_include_sensitive_data ), previous_response_id=previous_response_id, + conversation_id=conversation_id, prompt=prompt_config, ) # If the agent has hooks, we need to call them after the LLM call diff --git a/tests/fake_model.py b/tests/fake_model.py index 6c1377e6d..7de629448 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -61,6 +61,7 @@ async def get_response( tracing: ModelTracing, *, previous_response_id: str | None, + conversation_id: str | None, prompt: Any | None, ) -> ModelResponse: self.last_turn_args = { @@ -70,6 +71,7 @@ async def get_response( "tools": tools, "output_schema": output_schema, "previous_response_id": previous_response_id, + "conversation_id": conversation_id, } with generation_span(disabled=not self.tracing_enabled) as span: @@ -103,8 +105,9 @@ async def stream_response( handoffs: list[Handoff], tracing: ModelTracing, *, - previous_response_id: str | None, - prompt: Any | None, + previous_response_id: str | None = None, + conversation_id: str | None = None, + prompt: Any | None = None, ) -> AsyncIterator[TResponseStreamEvent]: self.last_turn_args = { "system_instructions": system_instructions, @@ -113,6 +116,7 @@ async def stream_response( "tools": tools, "output_schema": output_schema, "previous_response_id": previous_response_id, + "conversation_id": conversation_id, } with generation_span(disabled=not self.tracing_enabled) as span: output = self.get_next_output() diff --git a/tests/models/test_kwargs_functionality.py b/tests/models/test_kwargs_functionality.py index 210610a02..941fdc68d 100644 --- a/tests/models/test_kwargs_functionality.py +++ b/tests/models/test_kwargs_functionality.py @@ -47,6 +47,7 @@ async def fake_acompletion(model, messages=None, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, ) # Verify that all kwargs were passed through diff --git a/tests/models/test_litellm_chatcompletions_stream.py b/tests/models/test_litellm_chatcompletions_stream.py index bd38f8759..d8b79d542 100644 --- a/tests/models/test_litellm_chatcompletions_stream.py +++ b/tests/models/test_litellm_chatcompletions_stream.py @@ -90,6 +90,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -183,6 +184,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -273,6 +275,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -389,6 +392,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) diff --git a/tests/test_agent_prompt.py b/tests/test_agent_prompt.py index 010717d66..3d5ed5a3f 100644 --- a/tests/test_agent_prompt.py +++ b/tests/test_agent_prompt.py @@ -24,6 +24,7 @@ async def get_response( tracing, *, previous_response_id, + conversation_id, prompt, ): # Record the prompt that the agent resolved and passed in. @@ -37,6 +38,7 @@ async def get_response( handoffs, tracing, previous_response_id=previous_response_id, + conversation_id=conversation_id, prompt=prompt, ) diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py index a6af30077..c6672374b 100644 --- a/tests/test_extra_headers.py +++ b/tests/test_extra_headers.py @@ -95,6 +95,7 @@ def __init__(self): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, ) assert "extra_headers" in called_kwargs assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index 6291418f6..d52d89b47 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -77,6 +77,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ) # Should have produced exactly one output message with one text part @@ -129,6 +130,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ) assert len(resp.output) == 1 @@ -182,6 +184,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ) # Expect a message item followed by a function tool call item. @@ -224,6 +227,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ) assert resp.output == [] diff --git a/tests/test_openai_chatcompletions_stream.py b/tests/test_openai_chatcompletions_stream.py index cbb3c5dae..947816f01 100644 --- a/tests/test_openai_chatcompletions_stream.py +++ b/tests/test_openai_chatcompletions_stream.py @@ -90,6 +90,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -183,6 +184,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -273,6 +275,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -390,6 +393,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) diff --git a/tests/test_reasoning_content.py b/tests/test_reasoning_content.py index 69e9a7d0c..a64fdaf15 100644 --- a/tests/test_reasoning_content.py +++ b/tests/test_reasoning_content.py @@ -129,6 +129,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) @@ -216,6 +217,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ) @@ -270,6 +272,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): output_events.append(event) diff --git a/tests/test_responses_tracing.py b/tests/test_responses_tracing.py index fe63e8ecb..a2d9b3c3d 100644 --- a/tests/test_responses_tracing.py +++ b/tests/test_responses_tracing.py @@ -69,7 +69,8 @@ async def dummy_fetch_response( tools, output_schema, handoffs, - prev_response_id, + previous_response_id, + conversation_id, stream, prompt, ): @@ -114,7 +115,8 @@ async def dummy_fetch_response( tools, output_schema, handoffs, - prev_response_id, + previous_response_id, + conversation_id, stream, prompt, ): @@ -157,7 +159,8 @@ async def dummy_fetch_response( tools, output_schema, handoffs, - prev_response_id, + previous_response_id, + conversation_id, stream, prompt, ): @@ -197,7 +200,8 @@ async def dummy_fetch_response( tools, output_schema, handoffs, - prev_response_id, + previous_response_id, + conversation_id, stream, prompt, ): @@ -251,7 +255,8 @@ async def dummy_fetch_response( tools, output_schema, handoffs, - prev_response_id, + previous_response_id, + conversation_id, stream, prompt, ): @@ -304,7 +309,8 @@ async def dummy_fetch_response( tools, output_schema, handoffs, - prev_response_id, + previous_response_id, + conversation_id, stream, prompt, ): diff --git a/tests/voice/test_workflow.py b/tests/voice/test_workflow.py index 611e6f255..94d87b994 100644 --- a/tests/voice/test_workflow.py +++ b/tests/voice/test_workflow.py @@ -55,6 +55,7 @@ async def get_response( tracing: ModelTracing, *, previous_response_id: str | None, + conversation_id: str | None, prompt: Any | None, ) -> ModelResponse: raise NotImplementedError("Not implemented") @@ -70,6 +71,7 @@ async def stream_response( tracing: ModelTracing, *, previous_response_id: str | None, + conversation_id: str | None, prompt: Any | None, ) -> AsyncIterator[TResponseStreamEvent]: output = self.get_next_output() From f54bdb05c185d0728618a3a094f7c42ae1cd3821 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 26 Aug 2025 21:24:27 +0900 Subject: [PATCH 02/13] fix mypy errors --- examples/reasoning_content/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/reasoning_content/main.py b/examples/reasoning_content/main.py index 9da2a5690..e83c0d4d4 100644 --- a/examples/reasoning_content/main.py +++ b/examples/reasoning_content/main.py @@ -47,6 +47,7 @@ async def stream_with_reasoning_content(): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ): if event.type == "response.reasoning_summary_text.delta": @@ -83,6 +84,7 @@ async def get_response_with_reasoning_content(): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, + conversation_id=None, prompt=None, ) From 32e7c5723c5ede3a18087bd2620fc6c2313a1705 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 28 Aug 2025 13:06:40 +0900 Subject: [PATCH 03/13] Update src/agents/run.py Co-authored-by: Rohan Mehta --- src/agents/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agents/run.py b/src/agents/run.py index b34d3fc64..146c0c465 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -252,7 +252,7 @@ async def run( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. - conversation_id: The ID of the stored conversation, if any. + conversation_id: The conversation ID (https://platform.openai.com/docs/guides/conversation-state?api-mode=responses). If provided, the conversation will be used to read and write items. Every agent will have access to the conversation history so far, and it's output items will be written to the conversation. We recommend only using this if you are exclusively using OpenAI models; other model providers don't write to the Conversation object, so you'll end up having partial conversations stored. Returns: A run result containing all the inputs, guardrail results and the output of the last agent. Agents may perform handoffs, so we don't know the specific type of the output. From c919f6e3ed54449d84d94257a0b45a1bd14b5b37 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 28 Aug 2025 13:10:50 +0900 Subject: [PATCH 04/13] fix lint --- src/agents/run.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index 146c0c465..742917b87 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -252,7 +252,13 @@ async def run( run_config: Global settings for the entire agent run. previous_response_id: The ID of the previous response, if using OpenAI models via the Responses API, this allows you to skip passing in input from the previous turn. - conversation_id: The conversation ID (https://platform.openai.com/docs/guides/conversation-state?api-mode=responses). If provided, the conversation will be used to read and write items. Every agent will have access to the conversation history so far, and it's output items will be written to the conversation. We recommend only using this if you are exclusively using OpenAI models; other model providers don't write to the Conversation object, so you'll end up having partial conversations stored. + conversation_id: The conversation ID (https://platform.openai.com/docs/guides/conversation-state?api-mode=responses). + If provided, the conversation will be used to read and write items. + Every agent will have access to the conversation history so far, + and it's output items will be written to the conversation. + We recommend only using this if you are exclusively using OpenAI models; + other model providers don't write to the Conversation object, + so you'll end up having partial conversations stored. Returns: A run result containing all the inputs, guardrail results and the output of the last agent. Agents may perform handoffs, so we don't know the specific type of the output. @@ -1500,4 +1506,3 @@ def _copy_str_or_list(input: str | list[TResponseInputItem]) -> str | list[TResp if isinstance(input, str): return input return input.copy() - From be053ff4e127b0ea87dc6b30b9cf8cc3ec33c4e3 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 28 Aug 2025 19:36:36 +0900 Subject: [PATCH 05/13] Add OpenAISession --- examples/memory/openai_session_example.py | 78 +++++ .../sqlalchemy_session_example.py | 36 ++- .../sqlite_session_example.py} | 0 src/agents/__init__.py | 4 +- src/agents/memory/__init__.py | 11 +- src/agents/memory/openai_session.py | 79 +++++ src/agents/memory/session.py | 270 ----------------- src/agents/memory/sqlite_session.py | 275 ++++++++++++++++++ 8 files changed, 476 insertions(+), 277 deletions(-) create mode 100644 examples/memory/openai_session_example.py rename examples/{basic => memory}/sqlalchemy_session_example.py (50%) rename examples/{basic/session_example.py => memory/sqlite_session_example.py} (100%) create mode 100644 src/agents/memory/openai_session.py create mode 100644 src/agents/memory/sqlite_session.py diff --git a/examples/memory/openai_session_example.py b/examples/memory/openai_session_example.py new file mode 100644 index 000000000..f2a0cfff6 --- /dev/null +++ b/examples/memory/openai_session_example.py @@ -0,0 +1,78 @@ +""" +Example demonstrating session memory functionality. + +This example shows how to use session memory to maintain conversation history +across multiple agent runs without manually handling .to_input_list(). +""" + +import asyncio + +from agents import Agent, OpenAISession, Runner + + +async def main(): + # Create an agent + agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + ) + + # Create a session instance that will persist across runs + session = OpenAISession() + + print("=== Session Example ===") + print("The agent will remember previous messages automatically.\n") + + # First turn + print("First turn:") + print("User: What city is the Golden Gate Bridge in?") + result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print() + + # Second turn - the agent will remember the previous conversation + print("Second turn:") + print("User: What state is it in?") + result = await Runner.run(agent, "What state is it in?", session=session) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - continuing the conversation + print("Third turn:") + print("User: What's the population of that state?") + result = await Runner.run( + agent, + "What's the population of that state?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print() + + print("=== Conversation Complete ===") + print("Notice how the agent remembered the context from previous turns!") + print("Sessions automatically handles conversation history.") + + # Demonstrate the limit parameter - get only the latest 2 items + print("\n=== Latest Items Demo ===") + latest_items = await session.get_items(limit=2) + # print(latest_items) + print("Latest 2 items:") + for i, msg in enumerate(latest_items, 1): + role = msg.get("role", "unknown") + content = msg.get("content", "") + print(f" {i}. {role}: {content}") + + print(f"\nFetched {len(latest_items)} out of total conversation history.") + + # Get all items to show the difference + all_items = await session.get_items() + # print(all_items) + print(f"Total items in session: {len(all_items)}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/basic/sqlalchemy_session_example.py b/examples/memory/sqlalchemy_session_example.py similarity index 50% rename from examples/basic/sqlalchemy_session_example.py rename to examples/memory/sqlalchemy_session_example.py index 2aec270f5..84a6c754f 100644 --- a/examples/basic/sqlalchemy_session_example.py +++ b/examples/memory/sqlalchemy_session_example.py @@ -20,28 +20,56 @@ async def main(): create_tables=True, ) - print("=== SQLAlchemySession Example ===") + print("=== Session Example ===") print("The agent will remember previous messages automatically.\n") # First turn + print("First turn:") print("User: What city is the Golden Gate Bridge in?") result = await Runner.run( agent, "What city is the Golden Gate Bridge in?", session=session, ) - print(f"Assistant: {result.final_output}\n") + print(f"Assistant: {result.final_output}") + print() # Second turn - the agent will remember the previous conversation + print("Second turn:") print("User: What state is it in?") + result = await Runner.run(agent, "What state is it in?", session=session) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - continuing the conversation + print("Third turn:") + print("User: What's the population of that state?") result = await Runner.run( agent, - "What state is it in?", + "What's the population of that state?", session=session, ) - print(f"Assistant: {result.final_output}\n") + print(f"Assistant: {result.final_output}") + print() print("=== Conversation Complete ===") + print("Notice how the agent remembered the context from previous turns!") + print("Sessions automatically handles conversation history.") + + # Demonstrate the limit parameter - get only the latest 2 items + print("\n=== Latest Items Demo ===") + latest_items = await session.get_items(limit=2) + print("Latest 2 items:") + for i, msg in enumerate(latest_items, 1): + role = msg.get("role", "unknown") + content = msg.get("content", "") + print(f" {i}. {role}: {content}") + + print(f"\nFetched {len(latest_items)} out of total conversation history.") + + # Get all items to show the difference + all_items = await session.get_items() + print(f"Total items in session: {len(all_items)}") if __name__ == "__main__": diff --git a/examples/basic/session_example.py b/examples/memory/sqlite_session_example.py similarity index 100% rename from examples/basic/session_example.py rename to examples/memory/sqlite_session_example.py diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 02830bb29..010b00aa2 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -46,7 +46,7 @@ TResponseInputItem, ) from .lifecycle import AgentHooks, RunHooks -from .memory import Session, SQLiteSession +from .memory import OpenAISession, Session, SessionABC, SQLiteSession from .model_settings import ModelSettings from .models.interface import Model, ModelProvider, ModelTracing from .models.multi_provider import MultiProvider @@ -221,7 +221,9 @@ def enable_verbose_stdout_logging(): "RunHooks", "AgentHooks", "Session", + "SessionABC", "SQLiteSession", + "OpenAISession", "RunContextWrapper", "TContext", "RunErrorDetails", diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py index 059ca57ab..95efea4d3 100644 --- a/src/agents/memory/__init__.py +++ b/src/agents/memory/__init__.py @@ -1,3 +1,10 @@ -from .session import Session, SQLiteSession +from .openai_session import OpenAISession +from .session import Session, SessionABC +from .sqlite_session import SQLiteSession -__all__ = ["Session", "SQLiteSession"] +__all__ = [ + "Session", + "SessionABC", + "SQLiteSession", + "OpenAISession", +] diff --git a/src/agents/memory/openai_session.py b/src/agents/memory/openai_session.py new file mode 100644 index 000000000..0cc773fb4 --- /dev/null +++ b/src/agents/memory/openai_session.py @@ -0,0 +1,79 @@ +from openai import AsyncOpenAI + +from agents.models._openai_shared import get_default_openai_client + +from ..items import TResponseInputItem +from .session import SessionABC + + +async def start_openai_session(openai_client: AsyncOpenAI | None = None) -> str: + _openai_client = openai_client + if openai_client is None: + _openai_client = get_default_openai_client() or AsyncOpenAI() + + response = await _openai_client.conversations.create(items=[]) + return response.id + + +class OpenAISession(SessionABC): + def __init__( + self, + session_id: str | None = None, + openai_client: AsyncOpenAI | None = None, + ): + self.session_id = session_id + self.openai_client = openai_client + if self.openai_client is None: + self.openai_client = get_default_openai_client() or AsyncOpenAI() + + async def _ensure_session_id(self) -> None: + if self.session_id is None: + self.session_id = await start_openai_session(self.openai_client) + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + await self._ensure_session_id() + + all_items = [] + if limit is None: + async for item in self.openai_client.conversations.items.list( + conversation_id=self.session_id, + order="asc", + ): + all_items.append(item.model_dump()) + else: + async for item in self.openai_client.conversations.items.list( + conversation_id=self.session_id, + limit=limit, + order="desc", + ): + all_items.append(item.model_dump()) + if limit is not None and len(all_items) >= limit: + break + all_items.reverse() + + return all_items + + async def add_items(self, items: list[TResponseInputItem]) -> None: + await self._ensure_session_id() + await self.openai_client.conversations.items.create( + conversation_id=self.session_id, + items=items, + ) + + async def pop_item(self) -> TResponseInputItem | None: + await self._ensure_session_id() + items = await self.get_items(limit=1) + if not items: + return None + await self.openai_client.conversations.items.delete( + conversation_id=self.session_id, + item_id=items[0].id, + ) + return items[0] + + async def clear_session(self) -> None: + await self._ensure_session_id() + await self.openai_client.conversations.delete( + conversation_id=self.session_id, + ) + self.session_id = None diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 8db0971eb..9c85af6dd 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -1,11 +1,6 @@ from __future__ import annotations -import asyncio -import json -import sqlite3 -import threading from abc import ABC, abstractmethod -from pathlib import Path from typing import TYPE_CHECKING, Protocol, runtime_checkable if TYPE_CHECKING: @@ -102,268 +97,3 @@ async def pop_item(self) -> TResponseInputItem | None: async def clear_session(self) -> None: """Clear all items for this session.""" ... - - -class SQLiteSession(SessionABC): - """SQLite-based implementation of session storage. - - This implementation stores conversation history in a SQLite database. - By default, uses an in-memory database that is lost when the process ends. - For persistent storage, provide a file path. - """ - - def __init__( - self, - session_id: str, - db_path: str | Path = ":memory:", - sessions_table: str = "agent_sessions", - messages_table: str = "agent_messages", - ): - """Initialize the SQLite session. - - Args: - session_id: Unique identifier for the conversation session - db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database) - sessions_table: Name of the table to store session metadata. Defaults to - 'agent_sessions' - messages_table: Name of the table to store message data. Defaults to 'agent_messages' - """ - self.session_id = session_id - self.db_path = db_path - self.sessions_table = sessions_table - self.messages_table = messages_table - self._local = threading.local() - self._lock = threading.Lock() - - # For in-memory databases, we need a shared connection to avoid thread isolation - # For file databases, we use thread-local connections for better concurrency - self._is_memory_db = str(db_path) == ":memory:" - if self._is_memory_db: - self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False) - self._shared_connection.execute("PRAGMA journal_mode=WAL") - self._init_db_for_connection(self._shared_connection) - else: - # For file databases, initialize the schema once since it persists - init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) - init_conn.execute("PRAGMA journal_mode=WAL") - self._init_db_for_connection(init_conn) - init_conn.close() - - def _get_connection(self) -> sqlite3.Connection: - """Get a database connection.""" - if self._is_memory_db: - # Use shared connection for in-memory database to avoid thread isolation - return self._shared_connection - else: - # Use thread-local connections for file databases - if not hasattr(self._local, "connection"): - self._local.connection = sqlite3.connect( - str(self.db_path), - check_same_thread=False, - ) - self._local.connection.execute("PRAGMA journal_mode=WAL") - assert isinstance(self._local.connection, sqlite3.Connection), ( - f"Expected sqlite3.Connection, got {type(self._local.connection)}" - ) - return self._local.connection - - def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: - """Initialize the database schema for a specific connection.""" - conn.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.sessions_table} ( - session_id TEXT PRIMARY KEY, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """ - ) - - conn.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.messages_table} ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - session_id TEXT NOT NULL, - message_data TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id) - ON DELETE CASCADE - ) - """ - ) - - conn.execute( - f""" - CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id - ON {self.messages_table} (session_id, created_at) - """ - ) - - conn.commit() - - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: - """Retrieve the conversation history for this session. - - Args: - limit: Maximum number of items to retrieve. If None, retrieves all items. - When specified, returns the latest N items in chronological order. - - Returns: - List of input items representing the conversation history - """ - - def _get_items_sync(): - conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): - if limit is None: - # Fetch all items in chronological order - cursor = conn.execute( - f""" - SELECT message_data FROM {self.messages_table} - WHERE session_id = ? - ORDER BY created_at ASC - """, - (self.session_id,), - ) - else: - # Fetch the latest N items in chronological order - cursor = conn.execute( - f""" - SELECT message_data FROM {self.messages_table} - WHERE session_id = ? - ORDER BY created_at DESC - LIMIT ? - """, - (self.session_id, limit), - ) - - rows = cursor.fetchall() - - # Reverse to get chronological order when using DESC - if limit is not None: - rows = list(reversed(rows)) - - items = [] - for (message_data,) in rows: - try: - item = json.loads(message_data) - items.append(item) - except json.JSONDecodeError: - # Skip invalid JSON entries - continue - - return items - - return await asyncio.to_thread(_get_items_sync) - - async def add_items(self, items: list[TResponseInputItem]) -> None: - """Add new items to the conversation history. - - Args: - items: List of input items to add to the history - """ - if not items: - return - - def _add_items_sync(): - conn = self._get_connection() - - with self._lock if self._is_memory_db else threading.Lock(): - # Ensure session exists - conn.execute( - f""" - INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) - """, - (self.session_id,), - ) - - # Add items - message_data = [(self.session_id, json.dumps(item)) for item in items] - conn.executemany( - f""" - INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) - """, - message_data, - ) - - # Update session timestamp - conn.execute( - f""" - UPDATE {self.sessions_table} - SET updated_at = CURRENT_TIMESTAMP - WHERE session_id = ? - """, - (self.session_id,), - ) - - conn.commit() - - await asyncio.to_thread(_add_items_sync) - - async def pop_item(self) -> TResponseInputItem | None: - """Remove and return the most recent item from the session. - - Returns: - The most recent item if it exists, None if the session is empty - """ - - def _pop_item_sync(): - conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): - # Use DELETE with RETURNING to atomically delete and return the most recent item - cursor = conn.execute( - f""" - DELETE FROM {self.messages_table} - WHERE id = ( - SELECT id FROM {self.messages_table} - WHERE session_id = ? - ORDER BY created_at DESC - LIMIT 1 - ) - RETURNING message_data - """, - (self.session_id,), - ) - - result = cursor.fetchone() - conn.commit() - - if result: - message_data = result[0] - try: - item = json.loads(message_data) - return item - except json.JSONDecodeError: - # Return None for corrupted JSON entries (already deleted) - return None - - return None - - return await asyncio.to_thread(_pop_item_sync) - - async def clear_session(self) -> None: - """Clear all items for this session.""" - - def _clear_session_sync(): - conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): - conn.execute( - f"DELETE FROM {self.messages_table} WHERE session_id = ?", - (self.session_id,), - ) - conn.execute( - f"DELETE FROM {self.sessions_table} WHERE session_id = ?", - (self.session_id,), - ) - conn.commit() - - await asyncio.to_thread(_clear_session_sync) - - def close(self) -> None: - """Close the database connection.""" - if self._is_memory_db: - if hasattr(self, "_shared_connection"): - self._shared_connection.close() - else: - if hasattr(self._local, "connection"): - self._local.connection.close() diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py new file mode 100644 index 000000000..2c2386ec7 --- /dev/null +++ b/src/agents/memory/sqlite_session.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import asyncio +import json +import sqlite3 +import threading +from pathlib import Path + +from ..items import TResponseInputItem +from .session import SessionABC + + +class SQLiteSession(SessionABC): + """SQLite-based implementation of session storage. + + This implementation stores conversation history in a SQLite database. + By default, uses an in-memory database that is lost when the process ends. + For persistent storage, provide a file path. + """ + + def __init__( + self, + session_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + ): + """Initialize the SQLite session. + + Args: + session_id: Unique identifier for the conversation session + db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database) + sessions_table: Name of the table to store session metadata. Defaults to + 'agent_sessions' + messages_table: Name of the table to store message data. Defaults to 'agent_messages' + """ + self.session_id = session_id + self.db_path = db_path + self.sessions_table = sessions_table + self.messages_table = messages_table + self._local = threading.local() + self._lock = threading.Lock() + + # For in-memory databases, we need a shared connection to avoid thread isolation + # For file databases, we use thread-local connections for better concurrency + self._is_memory_db = str(db_path) == ":memory:" + if self._is_memory_db: + self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False) + self._shared_connection.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(self._shared_connection) + else: + # For file databases, initialize the schema once since it persists + init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + init_conn.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(init_conn) + init_conn.close() + + def _get_connection(self) -> sqlite3.Connection: + """Get a database connection.""" + if self._is_memory_db: + # Use shared connection for in-memory database to avoid thread isolation + return self._shared_connection + else: + # Use thread-local connections for file databases + if not hasattr(self._local, "connection"): + self._local.connection = sqlite3.connect( + str(self.db_path), + check_same_thread=False, + ) + self._local.connection.execute("PRAGMA journal_mode=WAL") + assert isinstance(self._local.connection, sqlite3.Connection), ( + f"Expected sqlite3.Connection, got {type(self._local.connection)}" + ) + return self._local.connection + + def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: + """Initialize the database schema for a specific connection.""" + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.sessions_table} ( + session_id TEXT PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.messages_table} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + message_data TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id) + ON DELETE CASCADE + ) + """ + ) + + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id + ON {self.messages_table} (session_id, created_at) + """ + ) + + conn.commit() + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history + """ + + def _get_items_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + if limit is None: + # Fetch all items in chronological order + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at ASC + """, + (self.session_id,), + ) + else: + # Fetch the latest N items in chronological order + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at DESC + LIMIT ? + """, + (self.session_id, limit), + ) + + rows = cursor.fetchall() + + # Reverse to get chronological order when using DESC + if limit is not None: + rows = list(reversed(rows)) + + items = [] + for (message_data,) in rows: + try: + item = json.loads(message_data) + items.append(item) + except json.JSONDecodeError: + # Skip invalid JSON entries + continue + + return items + + return await asyncio.to_thread(_get_items_sync) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """Add new items to the conversation history. + + Args: + items: List of input items to add to the history + """ + if not items: + return + + def _add_items_sync(): + conn = self._get_connection() + + with self._lock if self._is_memory_db else threading.Lock(): + # Ensure session exists + conn.execute( + f""" + INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + """, + (self.session_id,), + ) + + # Add items + message_data = [(self.session_id, json.dumps(item)) for item in items] + conn.executemany( + f""" + INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) + """, + message_data, + ) + + # Update session timestamp + conn.execute( + f""" + UPDATE {self.sessions_table} + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = ? + """, + (self.session_id,), + ) + + conn.commit() + + await asyncio.to_thread(_add_items_sync) + + async def pop_item(self) -> TResponseInputItem | None: + """Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty + """ + + def _pop_item_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + # Use DELETE with RETURNING to atomically delete and return the most recent item + cursor = conn.execute( + f""" + DELETE FROM {self.messages_table} + WHERE id = ( + SELECT id FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at DESC + LIMIT 1 + ) + RETURNING message_data + """, + (self.session_id,), + ) + + result = cursor.fetchone() + conn.commit() + + if result: + message_data = result[0] + try: + item = json.loads(message_data) + return item + except json.JSONDecodeError: + # Return None for corrupted JSON entries (already deleted) + return None + + return None + + return await asyncio.to_thread(_pop_item_sync) + + async def clear_session(self) -> None: + """Clear all items for this session.""" + + def _clear_session_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + conn.execute( + f"DELETE FROM {self.messages_table} WHERE session_id = ?", + (self.session_id,), + ) + conn.execute( + f"DELETE FROM {self.sessions_table} WHERE session_id = ?", + (self.session_id,), + ) + conn.commit() + + await asyncio.to_thread(_clear_session_sync) + + def close(self) -> None: + """Close the database connection.""" + if self._is_memory_db: + if hasattr(self, "_shared_connection"): + self._shared_connection.close() + else: + if hasattr(self._local, "connection"): + self._local.connection.close() From 6b5277f2dc012cec3859d56e5133d5a8a5250338 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 28 Aug 2025 19:47:25 +0900 Subject: [PATCH 06/13] fix --- src/agents/memory/openai_session.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/agents/memory/openai_session.py b/src/agents/memory/openai_session.py index 0cc773fb4..9d7bb3f14 100644 --- a/src/agents/memory/openai_session.py +++ b/src/agents/memory/openai_session.py @@ -1,3 +1,5 @@ +from typing import Optional + from openai import AsyncOpenAI from agents.models._openai_shared import get_default_openai_client @@ -6,12 +8,12 @@ from .session import SessionABC -async def start_openai_session(openai_client: AsyncOpenAI | None = None) -> str: +async def start_openai_session(openai_client: Optional[AsyncOpenAI] = None) -> str: _openai_client = openai_client if openai_client is None: _openai_client = get_default_openai_client() or AsyncOpenAI() - response = await _openai_client.conversations.create(items=[]) + response = await _openai_client.conversations.create(items=[]) # type: ignore [union-attr] return response.id @@ -19,9 +21,10 @@ class OpenAISession(SessionABC): def __init__( self, session_id: str | None = None, - openai_client: AsyncOpenAI | None = None, + openai_client: Optional[AsyncOpenAI] = None, ): - self.session_id = session_id + # this implementation allows to set this value later + self.session_id = session_id # type: ignore self.openai_client = openai_client if self.openai_client is None: self.openai_client = get_default_openai_client() or AsyncOpenAI() @@ -35,27 +38,29 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: all_items = [] if limit is None: - async for item in self.openai_client.conversations.items.list( + async for item in self.openai_client.conversations.items.list( # type: ignore [union-attr] conversation_id=self.session_id, order="asc", ): + # calling model_dump() to make this serializable all_items.append(item.model_dump()) else: - async for item in self.openai_client.conversations.items.list( + async for item in self.openai_client.conversations.items.list( # type: ignore [union-attr] conversation_id=self.session_id, limit=limit, order="desc", ): + # calling model_dump() to make this serializable all_items.append(item.model_dump()) if limit is not None and len(all_items) >= limit: break all_items.reverse() - return all_items + return all_items # type: ignore async def add_items(self, items: list[TResponseInputItem]) -> None: await self._ensure_session_id() - await self.openai_client.conversations.items.create( + await self.openai_client.conversations.items.create( # type: ignore [union-attr] conversation_id=self.session_id, items=items, ) @@ -65,15 +70,15 @@ async def pop_item(self) -> TResponseInputItem | None: items = await self.get_items(limit=1) if not items: return None - await self.openai_client.conversations.items.delete( + await self.openai_client.conversations.items.delete( # type: ignore [union-attr] conversation_id=self.session_id, - item_id=items[0].id, + item_id=str(items[0]["id"]), # type: ignore ) return items[0] async def clear_session(self) -> None: await self._ensure_session_id() - await self.openai_client.conversations.delete( + await self.openai_client.conversations.delete( # type: ignore [union-attr] conversation_id=self.session_id, ) - self.session_id = None + self.session_id = None # type: ignore From ab847ff64e6ad53da3ca9c8b8eeeeea3e1027c5e Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 28 Aug 2025 19:49:04 +0900 Subject: [PATCH 07/13] fix --- src/agents/memory/openai_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agents/memory/openai_session.py b/src/agents/memory/openai_session.py index 9d7bb3f14..521019aba 100644 --- a/src/agents/memory/openai_session.py +++ b/src/agents/memory/openai_session.py @@ -20,7 +20,7 @@ async def start_openai_session(openai_client: Optional[AsyncOpenAI] = None) -> s class OpenAISession(SessionABC): def __init__( self, - session_id: str | None = None, + session_id: Optional[str] = None, openai_client: Optional[AsyncOpenAI] = None, ): # this implementation allows to set this value later From 8577c17f9e276ddeb98b9db73ba1afd5167e48ce Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 28 Aug 2025 19:51:20 +0900 Subject: [PATCH 08/13] fix --- src/agents/memory/openai_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agents/memory/openai_session.py b/src/agents/memory/openai_session.py index 521019aba..236c6a882 100644 --- a/src/agents/memory/openai_session.py +++ b/src/agents/memory/openai_session.py @@ -33,7 +33,7 @@ async def _ensure_session_id(self) -> None: if self.session_id is None: self.session_id = await start_openai_session(self.openai_client) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items(self, limit: Optional[int] = None) -> list[TResponseInputItem]: await self._ensure_session_id() all_items = [] From 038ea0059bc580373aaf2467549d649876d7e22e Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 28 Aug 2025 19:53:16 +0900 Subject: [PATCH 09/13] fix --- src/agents/memory/openai_session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/agents/memory/openai_session.py b/src/agents/memory/openai_session.py index 236c6a882..8e47c5904 100644 --- a/src/agents/memory/openai_session.py +++ b/src/agents/memory/openai_session.py @@ -1,3 +1,4 @@ +from token import OP from typing import Optional from openai import AsyncOpenAI @@ -65,7 +66,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: items=items, ) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self) -> Optional[TResponseInputItem]: await self._ensure_session_id() items = await self.get_items(limit=1) if not items: From 1f24ff51e9fb200c7bcb9dce8e9089905a512dc5 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 28 Aug 2025 19:55:29 +0900 Subject: [PATCH 10/13] fix --- src/agents/memory/openai_session.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/agents/memory/openai_session.py b/src/agents/memory/openai_session.py index 8e47c5904..d007398f3 100644 --- a/src/agents/memory/openai_session.py +++ b/src/agents/memory/openai_session.py @@ -1,4 +1,3 @@ -from token import OP from typing import Optional from openai import AsyncOpenAI From e9fcdd06a0806fc4ce710ef18843301f7a070bf0 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 29 Aug 2025 09:13:41 +0900 Subject: [PATCH 11/13] rename --- examples/memory/openai_session_example.py | 4 ++-- src/agents/__init__.py | 4 ++-- src/agents/memory/__init__.py | 4 ++-- .../{openai_session.py => openai_conversations_session.py} | 7 ++++--- 4 files changed, 10 insertions(+), 9 deletions(-) rename src/agents/memory/{openai_session.py => openai_conversations_session.py} (92%) diff --git a/examples/memory/openai_session_example.py b/examples/memory/openai_session_example.py index f2a0cfff6..9254195b3 100644 --- a/examples/memory/openai_session_example.py +++ b/examples/memory/openai_session_example.py @@ -7,7 +7,7 @@ import asyncio -from agents import Agent, OpenAISession, Runner +from agents import Agent, OpenAIConversationsSession, Runner async def main(): @@ -18,7 +18,7 @@ async def main(): ) # Create a session instance that will persist across runs - session = OpenAISession() + session = OpenAIConversationsSession() print("=== Session Example ===") print("The agent will remember previous messages automatically.\n") diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 010b00aa2..3a8260f29 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -46,7 +46,7 @@ TResponseInputItem, ) from .lifecycle import AgentHooks, RunHooks -from .memory import OpenAISession, Session, SessionABC, SQLiteSession +from .memory import OpenAIConversationsSession, Session, SessionABC, SQLiteSession from .model_settings import ModelSettings from .models.interface import Model, ModelProvider, ModelTracing from .models.multi_provider import MultiProvider @@ -223,7 +223,7 @@ def enable_verbose_stdout_logging(): "Session", "SessionABC", "SQLiteSession", - "OpenAISession", + "OpenAIConversationsSession", "RunContextWrapper", "TContext", "RunErrorDetails", diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py index 95efea4d3..eeb2ace5d 100644 --- a/src/agents/memory/__init__.py +++ b/src/agents/memory/__init__.py @@ -1,4 +1,4 @@ -from .openai_session import OpenAISession +from .openai_conversations_session import OpenAIConversationsSession from .session import Session, SessionABC from .sqlite_session import SQLiteSession @@ -6,5 +6,5 @@ "Session", "SessionABC", "SQLiteSession", - "OpenAISession", + "OpenAIConversationsSession", ] diff --git a/src/agents/memory/openai_session.py b/src/agents/memory/openai_conversations_session.py similarity index 92% rename from src/agents/memory/openai_session.py rename to src/agents/memory/openai_conversations_session.py index d007398f3..5500d59c1 100644 --- a/src/agents/memory/openai_session.py +++ b/src/agents/memory/openai_conversations_session.py @@ -8,7 +8,7 @@ from .session import SessionABC -async def start_openai_session(openai_client: Optional[AsyncOpenAI] = None) -> str: +async def start_openai_conversations_session(openai_client: Optional[AsyncOpenAI] = None) -> str: _openai_client = openai_client if openai_client is None: _openai_client = get_default_openai_client() or AsyncOpenAI() @@ -17,9 +17,10 @@ async def start_openai_session(openai_client: Optional[AsyncOpenAI] = None) -> s return response.id -class OpenAISession(SessionABC): +class OpenAIConversationsSession(SessionABC): def __init__( self, + *, session_id: Optional[str] = None, openai_client: Optional[AsyncOpenAI] = None, ): @@ -31,7 +32,7 @@ def __init__( async def _ensure_session_id(self) -> None: if self.session_id is None: - self.session_id = await start_openai_session(self.openai_client) + self.session_id = await start_openai_conversations_session(self.openai_client) async def get_items(self, limit: Optional[int] = None) -> list[TResponseInputItem]: await self._ensure_session_id() From 0f1c99da62803d600286b27b5c38d9ea15e004a9 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 29 Aug 2025 09:24:40 +0900 Subject: [PATCH 12/13] fix type ignores --- .../memory/openai_conversations_session.py | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/agents/memory/openai_conversations_session.py b/src/agents/memory/openai_conversations_session.py index 5500d59c1..3786b9e57 100644 --- a/src/agents/memory/openai_conversations_session.py +++ b/src/agents/memory/openai_conversations_session.py @@ -9,14 +9,19 @@ async def start_openai_conversations_session(openai_client: Optional[AsyncOpenAI] = None) -> str: - _openai_client = openai_client + _maybe_openai_client = openai_client if openai_client is None: - _openai_client = get_default_openai_client() or AsyncOpenAI() + _maybe_openai_client = get_default_openai_client() or AsyncOpenAI() + # this never be None here + _openai_client: AsyncOpenAI = _maybe_openai_client # type: ignore [assignment] - response = await _openai_client.conversations.create(items=[]) # type: ignore [union-attr] + response = await _openai_client.conversations.create(items=[]) return response.id +_EMPTY_SESSION_ID = "" + + class OpenAIConversationsSession(SessionABC): def __init__( self, @@ -25,13 +30,15 @@ def __init__( openai_client: Optional[AsyncOpenAI] = None, ): # this implementation allows to set this value later - self.session_id = session_id # type: ignore - self.openai_client = openai_client - if self.openai_client is None: - self.openai_client = get_default_openai_client() or AsyncOpenAI() + self.session_id = session_id or _EMPTY_SESSION_ID + _openai_client = openai_client + if _openai_client is None: + _openai_client = get_default_openai_client() or AsyncOpenAI() + # this never be None here + self.openai_client: AsyncOpenAI = _openai_client async def _ensure_session_id(self) -> None: - if self.session_id is None: + if self.session_id == _EMPTY_SESSION_ID: self.session_id = await start_openai_conversations_session(self.openai_client) async def get_items(self, limit: Optional[int] = None) -> list[TResponseInputItem]: @@ -39,14 +46,14 @@ async def get_items(self, limit: Optional[int] = None) -> list[TResponseInputIte all_items = [] if limit is None: - async for item in self.openai_client.conversations.items.list( # type: ignore [union-attr] + async for item in self.openai_client.conversations.items.list( conversation_id=self.session_id, order="asc", ): # calling model_dump() to make this serializable all_items.append(item.model_dump()) else: - async for item in self.openai_client.conversations.items.list( # type: ignore [union-attr] + async for item in self.openai_client.conversations.items.list( conversation_id=self.session_id, limit=limit, order="desc", @@ -61,7 +68,7 @@ async def get_items(self, limit: Optional[int] = None) -> list[TResponseInputIte async def add_items(self, items: list[TResponseInputItem]) -> None: await self._ensure_session_id() - await self.openai_client.conversations.items.create( # type: ignore [union-attr] + await self.openai_client.conversations.items.create( conversation_id=self.session_id, items=items, ) @@ -71,15 +78,15 @@ async def pop_item(self) -> Optional[TResponseInputItem]: items = await self.get_items(limit=1) if not items: return None - await self.openai_client.conversations.items.delete( # type: ignore [union-attr] - conversation_id=self.session_id, - item_id=str(items[0]["id"]), # type: ignore + item_id: str = str(items[0]["id"]) # type: ignore [typeddict-item] + await self.openai_client.conversations.items.delete( + conversation_id=self.session_id, item_id=item_id ) return items[0] async def clear_session(self) -> None: await self._ensure_session_id() - await self.openai_client.conversations.delete( # type: ignore [union-attr] + await self.openai_client.conversations.delete( conversation_id=self.session_id, ) - self.session_id = None # type: ignore + self.session_id = _EMPTY_SESSION_ID From d9cd39b622438258a4c1877c6c7cbfc63adf6eb3 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 29 Aug 2025 10:36:59 +0900 Subject: [PATCH 13/13] fix --- src/agents/memory/openai_conversations_session.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/agents/memory/openai_conversations_session.py b/src/agents/memory/openai_conversations_session.py index 3786b9e57..a4a373b58 100644 --- a/src/agents/memory/openai_conversations_session.py +++ b/src/agents/memory/openai_conversations_session.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations from openai import AsyncOpenAI @@ -8,7 +8,7 @@ from .session import SessionABC -async def start_openai_conversations_session(openai_client: Optional[AsyncOpenAI] = None) -> str: +async def start_openai_conversations_session(openai_client: AsyncOpenAI | None = None) -> str: _maybe_openai_client = openai_client if openai_client is None: _maybe_openai_client = get_default_openai_client() or AsyncOpenAI() @@ -26,8 +26,8 @@ class OpenAIConversationsSession(SessionABC): def __init__( self, *, - session_id: Optional[str] = None, - openai_client: Optional[AsyncOpenAI] = None, + session_id: str | None = None, + openai_client: AsyncOpenAI | None = None, ): # this implementation allows to set this value later self.session_id = session_id or _EMPTY_SESSION_ID @@ -41,7 +41,7 @@ async def _ensure_session_id(self) -> None: if self.session_id == _EMPTY_SESSION_ID: self.session_id = await start_openai_conversations_session(self.openai_client) - async def get_items(self, limit: Optional[int] = None) -> list[TResponseInputItem]: + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: await self._ensure_session_id() all_items = [] @@ -73,7 +73,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: items=items, ) - async def pop_item(self) -> Optional[TResponseInputItem]: + async def pop_item(self) -> TResponseInputItem | None: await self._ensure_session_id() items = await self.get_items(limit=1) if not items: