From 1e59892a704d9da9a9911fcecd46073dd1b7838a Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 5 Feb 2025 11:45:15 -0800 Subject: [PATCH] MonitoredQueue: fail fast when subprocess exits --- torchft/multiprocessing.py | 91 ++++++++++++++++ torchft/multiprocessing_test.py | 48 +++++++++ torchft/process_group.py | 178 +++++++++++++++++--------------- torchft/process_group_test.py | 4 +- 4 files changed, 235 insertions(+), 86 deletions(-) create mode 100644 torchft/multiprocessing.py create mode 100644 torchft/multiprocessing_test.py diff --git a/torchft/multiprocessing.py b/torchft/multiprocessing.py new file mode 100644 index 00000000..273e820e --- /dev/null +++ b/torchft/multiprocessing.py @@ -0,0 +1,91 @@ +import queue +import time +from datetime import timedelta +from typing import Union + +import torch.multiprocessing as mp + + +class _MonitoredQueue: + def __init__( + self, + p: mp.Process, + q: mp.Queue, + poll_interval: timedelta = timedelta(seconds=1), + ) -> None: + """ + Args: + p: process to monitor + q: queue to monitor + poll_interval: interval to poll the Process health when calling get/put + """ + self._p = p + self._q = q + self._poll_interval_s: float = poll_interval.total_seconds() + + def get(self, timeout: Union[float, timedelta]) -> object: + """ + Get an item from the queue. If the process is not alive, raise RuntimeError. + If the queue is empty, wait for up to timeout seconds for an item to be + available. If no item is available after timeout seconds, raise TimeoutError. + + Args: + timeout: timeout in seconds + """ + + if isinstance(timeout, timedelta): + timeout = timeout.total_seconds() + + start = time.perf_counter() + while True: + elapsed = time.perf_counter() - start + if elapsed > timeout: + raise TimeoutError(f"queue.get() timed out after {timeout} seconds") + if not self._p.is_alive(): + raise RuntimeError(f"process is not alive {self._p.exitcode}") + + try: + v = self._q.get(timeout=self._poll_interval_s) + break + except queue.Empty: + continue + + if isinstance(v, Exception): + raise v + return v + + def put(self, obj: object, timeout: Union[float, timedelta]) -> None: + """ + Put an item into the queue. If the process is not alive, raise RuntimeError. + If the queue is full, wait for up to timeout seconds for an item to be + available. If queue is full after timeout seconds, raise TimeoutError. + + If an exception is put into the queue, it will be raised when calling get(). + + Args: + obj: object to put into the queue + timeout: timeout in seconds + """ + if isinstance(timeout, timedelta): + timeout = timeout.total_seconds() + + start = time.perf_counter() + while True: + elapsed = time.perf_counter() - start + if elapsed > timeout: + raise TimeoutError(f"queue.put() timed out after {timeout} seconds") + if not self._p.is_alive(): + raise RuntimeError(f"process is not alive {self._p.exitcode}") + + try: + self._q.put(obj, timeout=self._poll_interval_s) + break + except queue.Full: + continue + + def close(self) -> None: + self._q.close() + + def closed(self) -> bool: + # pyre-ignore[16]: no attribute _closed + return self._q._closed diff --git a/torchft/multiprocessing_test.py b/torchft/multiprocessing_test.py new file mode 100644 index 00000000..47459e2f --- /dev/null +++ b/torchft/multiprocessing_test.py @@ -0,0 +1,48 @@ +from unittest import TestCase + +import torch.multiprocessing as mp + +from torchft.multiprocessing import _MonitoredQueue + + +def queue_get(q: mp.Queue) -> None: + q.get() + + +def queue_put(q: mp.Queue) -> None: + q.put(1) + + +class MultiprocessingTest(TestCase): + def test_monitored_queue_put(self) -> None: + ctx = mp.get_context("fork") + q = ctx.Queue(maxsize=1) + p = ctx.Process(target=queue_get, args=(q,), daemon=True) + p.start() + + mq = _MonitoredQueue(p, q) + mq.put(1, timeout=10) + mq.put(1, timeout=10) + with self.assertRaisesRegex(RuntimeError, "process is not alive 0"): + mq.put(1, timeout=10) + + with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"): + mq.put(1, timeout=0.0) + + mq.close() + + def test_monitored_queue_get(self) -> None: + ctx = mp.get_context("fork") + q = ctx.Queue(maxsize=1) + p = ctx.Process(target=queue_put, args=(q,), daemon=True) + p.start() + + mq = _MonitoredQueue(p, q) + self.assertEqual(mq.get(timeout=10), 1) + with self.assertRaisesRegex(RuntimeError, "process is not alive 0"): + mq.get(timeout=10) + + with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"): + mq.get(timeout=0.0) + + mq.close() diff --git a/torchft/process_group.py b/torchft/process_group.py index 4790352e..540633b3 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -17,7 +17,6 @@ """ import logging -import queue import threading from contextlib import contextmanager, nullcontext from dataclasses import dataclass @@ -63,6 +62,8 @@ from torch.futures import Future from torch.utils._pytree import tree_any +from torchft.multiprocessing import _MonitoredQueue + if TYPE_CHECKING: from torchft.manager import Manager @@ -77,28 +78,6 @@ T = TypeVar("T") -def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object: - """ - Gets an item from a queue with a timeout. If the timeout is exceeded then - a TimeoutError is raised. - - If an exception is returned from the queue then it is raised. - - Args: - q: queue to get from - timeout: timeout in seconds - """ - if isinstance(timeout, timedelta): - timeout = timeout.total_seconds() - try: - v = q.get(timeout=timeout) - except queue.Empty as e: - raise TimeoutError(f"queue.get() timed out after {timeout} seconds") from e - if isinstance(v, Exception): - raise v - return v - - def create_store_client(store_addr: str) -> Store: """ Creates a PrefixStore(TCPStore(...)) client from an address in the format: @@ -573,31 +552,15 @@ class _BabyWork(Work): def __init__( self, pg: "ProcessGroupBaby", - tx: mp.Queue, - rx: mp.Queue, op_id: int, - timeout: float, ) -> None: super().__init__() self._pg = pg - self._tx = tx - self._rx = rx self._op_id = op_id - self._timeout = timeout def wait(self, timeout: Optional[timedelta] = None) -> bool: - self._pg._assert_alive() - - self._tx.put(("wait", self._op_id), timeout=self._timeout) - op_id, event = cast( - Tuple[int, Optional[torch.cuda.Event]], - _get(self._rx, timeout or self._timeout), - ) - assert op_id == self._op_id - if event is not None: - event.wait() - return True + return self._pg._wait(self._op_id, timeout) def synchronize(self) -> None: # TODO: No one seems to use this and NCCL wait already only waits the @@ -609,7 +572,7 @@ def get_future(self) -> Future[object]: return self._pg._get_future(self._op_id) def __del__(self) -> None: - self._tx.put(("del", self._op_id), timeout=self._timeout) + self._pg._del(self._op_id) def _is_any_cuda(obj: object) -> bool: @@ -649,9 +612,9 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None: self._world_size = -1 self._p: Optional[mp.Process] = None - self._tx: Optional[mp.Queue] = None - self._rx: Optional[mp.Queue] = None - self._future_queue: Optional[mp.Queue] = None + self._tx: Optional[_MonitoredQueue] = None + self._rx: Optional[_MonitoredQueue] = None + self._future_queue: Optional[_MonitoredQueue] = None self._future_thread: Optional[threading.Thread] = None self._futures: Dict[int, Future[object]] = {} self._futures_lock = threading.Lock() @@ -661,60 +624,80 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None: self._timeout: float = timeout - def configure(self, store_addr: str, rank: int, world_size: int) -> None: - if self._p is not None: - self._p.kill() + def shutdown(self) -> None: + """ + Shutdown the process group. This will kill the underlying process and + close all queues. - self._world_size = world_size + This is a no-op if the process group is already shutdown. + + ProcessGroup can be reconfigured after shutdown. + """ if self._tx is not None: self._tx.close() if self._rx is not None: self._rx.close() - if self._future_queue is not None: + + future_queue = self._future_queue + if future_queue is not None: # wait for the future thread to exit and then close the queue - self._future_queue.put(_QUEUE_CLOSE) - assert self._future_thread is not None - self._future_thread.join(timeout=10.0) - # pyre-ignore[16]: optional value is checked above - if self._future_thread.is_alive(): + future_queue.put(_QUEUE_CLOSE, timeout=timedelta(seconds=10.0)) + + future_thread = self._future_thread + assert future_thread is not None + future_thread.join(timeout=10.0) + if future_thread.is_alive(): raise RuntimeError("future thread did not exit") - # pyre-ignore[16]: optional value is checked above - self._future_queue.close() + + future_queue.close() + + # Kill after closing queues to avoid log spam. + if self._p is not None: + self._p.kill() + + def configure(self, store_addr: str, rank: int, world_size: int) -> None: + self._world_size = world_size + + self.shutdown() ctx = mp.get_context("spawn") - self._tx = ctx.Queue() - self._rx = rx = ctx.Queue() + tx = ctx.Queue() + rx = ctx.Queue() + future_queue = ctx.Queue() + + self._p = p = ctx.Process( + target=self._worker, + args=( + store_addr, + rank, + world_size, + tx, + rx, + future_queue, + ), + daemon=True, + ) + p.start() + + self._tx = tx = _MonitoredQueue(p, tx) + self._rx = rx = _MonitoredQueue(p, rx) + self._future_queue = future_queue = _MonitoredQueue(p, future_queue) # futures need thread to fire callbacks - self._future_queue = ctx.Queue() # this lock needs to be held when manipulating _futures self._futures_lock = threading.Lock() self._futures = {} self._future_thread = threading.Thread( target=self._future_handler, - args=(self._future_queue,), + args=(future_queue,), daemon=True, ) self._future_thread.start() - self._p = ctx.Process( - target=self._worker, - args=( - store_addr, - rank, - world_size, - self._tx, - self._rx, - self._future_queue, - ), - daemon=True, - ) - self._p.start() - # fetch the status of the PG init - # if an exception was returned _get will throw - assert _get(rx, self._timeout) is None + # if an exception was returned get will throw + assert rx.get(self._timeout) is None @classmethod def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup: @@ -829,17 +812,21 @@ def callback(fut: Future[object]) -> None: raise ValueError(f"unknown cmd: {cmd}") except Exception as e: - logger.exception("worker errored") + logger.exception(f"worker errored: {e}") tx.put(e) raise - def _future_handler(self, future_queue: mp.Queue) -> None: + def _future_handler(self, future_queue: _MonitoredQueue) -> None: try: while True: - cmd = future_queue.get() + try: + # timeout doesn't really matter here + cmd = future_queue.get(timeout=timedelta(seconds=10.0)) + except TimeoutError: + continue if cmd == _QUEUE_CLOSE: break - op_id, mode, data = cmd + op_id, mode, data = cast(Tuple[int, str, object], cmd) with self._futures_lock: fut = self._futures[op_id] del self._futures[op_id] @@ -862,10 +849,33 @@ def _get_future(self, op_id: int) -> Future[object]: self._tx.put(("future", op_id), timeout=self._timeout) assert self._rx is not None - assert _get(self._rx, self._timeout) == op_id + assert self._rx.get(self._timeout) == op_id # TODO: return correct tensor instead of None return fut + def _wait(self, op_id: int, timeout: Optional[timedelta] = None) -> bool: + self._assert_alive() + + assert self._tx is not None + self._tx.put(("wait", op_id), timeout=self._timeout) + + assert self._rx is not None + op_id, event = cast( + Tuple[int, Optional[torch.cuda.Event]], + self._rx.get(timeout or self._timeout), + ) + assert op_id == op_id + if event is not None: + event.wait() + + return True + + def _del(self, op_id: int) -> None: + self._assert_alive() + + assert self._tx is not None + self._tx.put(("del", op_id), timeout=self._timeout) + def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: self._assert_alive() @@ -899,10 +909,10 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: timeout=self._timeout, ) - op_id = _get(rx, self._timeout) + op_id = rx.get(self._timeout) assert isinstance(op_id, int), f"invalid return {op_id}" - return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout) + return _BabyWork(pg=self, op_id=op_id) def _assert_alive(self) -> None: """ @@ -968,7 +978,7 @@ def num_active_work(self) -> int: self._tx.put(("num_active_work",), timeout=self._timeout) assert self._rx is not None - return cast(int, _get(self._rx, self._timeout)) + return cast(int, self._rx.get(self._timeout)) @dataclass diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index f7656259..e49f5a47 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -238,14 +238,14 @@ def test_reconfigure_baby_process_group(self) -> None: assert future_thread_1 is not None self.assertFalse(future_thread_1.is_alive()) assert future_queue_1 is not None - self.assertTrue(future_queue_1._closed) # pyre-ignore[16]: no attribute _closed + self.assertTrue(future_queue_1.closed()) assert p_1 is not None self.assertFalse(p_1.is_alive()) assert future_thread_2 is not None self.assertTrue(future_thread_2.is_alive()) assert future_queue_2 is not None - self.assertFalse(future_queue_2._closed) + self.assertFalse(future_queue_2.closed()) assert p_2 is not None self.assertTrue(p_2.is_alive())