Skip to content
Open
66 changes: 58 additions & 8 deletions openhands-sdk/openhands/sdk/utils/async_executor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
"""Reusable async-to-sync execution utility."""

import asyncio
import concurrent.futures
import inspect
import threading
from collections.abc import Callable
import time
from collections.abc import Callable, Coroutine
from typing import Any

from openhands.sdk.logger import get_logger


logger = get_logger(__name__)


class AsyncExecutor:
"""
Expand All @@ -21,12 +28,26 @@ def __init__(self):
self._loop: asyncio.AbstractEventLoop | None = None
self._thread: threading.Thread | None = None
self._lock = threading.Lock()
self._shutdown = threading.Event()

def _ensure_loop(self) -> asyncio.AbstractEventLoop:
def _safe_submit_on_loop(self, coro: Coroutine) -> concurrent.futures.Future:
"""Ensure the background event loop is running."""
with self._lock:
if self._shutdown.is_set():
raise RuntimeError("AsyncExecutor has been shut down")

if self._loop is not None:
return self._loop
if self._loop.is_running():
return asyncio.run_coroutine_threadsafe(coro, self._loop)

logger.warning(
"The loop is not empty, but it is not in a running state. "
"Under normal circumstances, this should not happen."
)
try:
self._loop.close()
except RuntimeError as e:
logger.warning(f"Failed to close inactive loop: {e}")

loop = asyncio.new_event_loop()

Expand All @@ -39,15 +60,22 @@ def _runner():

# Wait for loop to start
while not loop.is_running():
pass
time.sleep(0.01)

self._loop = loop
self._thread = t
return loop
return asyncio.run_coroutine_threadsafe(coro, self._loop)

def _shutdown_loop(self) -> None:
"""Shutdown the background event loop."""
if self._shutdown.is_set():
logger.info("AsyncExecutor has been shut down")
return

with self._lock:
if self._shutdown.is_set():
return
self._shutdown.set()
loop, t = self._loop, self._thread
self._loop = None
self._thread = None
Expand All @@ -59,6 +87,20 @@ def _shutdown_loop(self) -> None:
pass
if t and t.is_alive():
t.join(timeout=1.0)
if t.is_alive():
logger.warning("AsyncExecutor thread did not terminate gracefully")

if loop and not loop.is_closed():
try:
if loop.is_running():
tasks = asyncio.all_tasks(loop)
for task in tasks:
if not task.done():
task.cancel()

loop.close()
except RuntimeError as e:
logger.warning(f"Failed to close event loop: {e}")

def run_async(
self,
Expand All @@ -83,16 +125,24 @@ def run_async(
TypeError: If awaitable_or_fn is not a coroutine or async function
asyncio.TimeoutError: If the operation times out
"""
if self._shutdown.is_set():
raise RuntimeError("AsyncExecutor has been shut down")
if inspect.iscoroutine(awaitable_or_fn):
coro = awaitable_or_fn
elif inspect.iscoroutinefunction(awaitable_or_fn):
coro = awaitable_or_fn(*args, **kwargs)
else:
raise TypeError("run_async expects a coroutine or async function")

loop = self._ensure_loop()
fut = asyncio.run_coroutine_threadsafe(coro, loop)
return fut.result(timeout)
fut = self._safe_submit_on_loop(coro)

try:
return fut.result(timeout)
except TimeoutError:
fut.cancel()
raise
except concurrent.futures.CancelledError:
raise

def close(self):
"""Close the async executor and cleanup resources."""
Expand Down
Loading