11"""Initialize Temporal OpenAI Agents overrides."""
22
3+ import dataclasses
34from contextlib import asynccontextmanager , contextmanager
45from datetime import timedelta
56from typing import AsyncIterator , Callable , Optional , Union
4344)
4445from temporalio .converter import (
4546 DataConverter ,
47+ DefaultPayloadConverter ,
4648)
4749from temporalio .worker import (
4850 Replayer ,
@@ -148,8 +150,11 @@ def stream_response(
148150 raise NotImplementedError ()
149151
150152
151- class _OpenAIPayloadConverter (PydanticPayloadConverter ):
153+ class OpenAIPayloadConverter (PydanticPayloadConverter ):
154+ """PayloadConverter for OpenAI agents."""
155+
152156 def __init__ (self ) -> None :
157+ """Initialize a payload converter."""
153158 super ().__init__ (ToJsonOptions (exclude_unset = True ))
154159
155160
@@ -250,6 +255,20 @@ def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None:
250255 """Set the next worker plugin"""
251256 self .next_worker_plugin = next
252257
258+ @staticmethod
259+ def _data_converter (converter : Optional [DataConverter ]) -> DataConverter :
260+ if converter is None :
261+ return DataConverter (payload_converter_class = OpenAIPayloadConverter )
262+ elif converter .payload_converter_class is DefaultPayloadConverter :
263+ return dataclasses .replace (
264+ converter , payload_converter_class = OpenAIPayloadConverter
265+ )
266+ elif not isinstance (converter .payload_converter , OpenAIPayloadConverter ):
267+ raise ValueError (
268+ "The payload converter must be of type OpenAIPayloadConverter."
269+ )
270+ return converter
271+
253272 def configure_client (self , config : ClientConfig ) -> ClientConfig :
254273 """Configure the Temporal client for OpenAI agents integration.
255274
@@ -262,9 +281,7 @@ def configure_client(self, config: ClientConfig) -> ClientConfig:
262281 Returns:
263282 The modified client configuration.
264283 """
265- config ["data_converter" ] = DataConverter (
266- payload_converter_class = _OpenAIPayloadConverter
267- )
284+ config ["data_converter" ] = self ._data_converter (config ["data_converter" ])
268285 return self .next_client_plugin .configure_client (config )
269286
270287 def configure_worker (self , config : WorkerConfig ) -> WorkerConfig :
@@ -310,9 +327,7 @@ def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
310327 config ["interceptors" ] = list (config .get ("interceptors" ) or []) + [
311328 OpenAIAgentsTracingInterceptor ()
312329 ]
313- config ["data_converter" ] = DataConverter (
314- payload_converter_class = _OpenAIPayloadConverter
315- )
330+ config ["data_converter" ] = self ._data_converter (config .get ("data_converter" ))
316331 return self .next_worker_plugin .configure_replayer (config )
317332
318333 @asynccontextmanager
0 commit comments