diff --git a/AUTHORS.rst b/AUTHORS.rst index d08f44875b4..3c7bb7a18e1 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -58,6 +58,7 @@ Contributors * Filip Vavera -- napoleon todo directive * Glenn Matthews -- python domain signature improvements * Gregory Szorc -- performance improvements +* Grzegorz Bokota -- parallelism improvements * Henrique Bastos -- SVG support for graphviz extension * Hernan Grecco -- search improvements * Hong Xu -- svg support in imgmath extension and various bug fixes diff --git a/CHANGES.rst b/CHANGES.rst index 791038e5e15..ce9a038100a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -58,6 +58,8 @@ Features added The location of the cache directory must not be relied upon externally, as it may change without notice or warning in future releases. Patch by Adam Turner. +* #13738: Improve parallelism by dynamic assign of tasks to workers. + Patch by Grzegorz Bokota Bugs fixed ---------- diff --git a/sphinx/builders/__init__.py b/sphinx/builders/__init__.py index 2dd972ecfe0..c6bd5d52a56 100644 --- a/sphinx/builders/__init__.py +++ b/sphinx/builders/__init__.py @@ -624,6 +624,7 @@ def merge(docs: list[str], otherenv: bytes) -> None: for chunk in chunks: tasks.add_task(read_process, chunk, merge) + tasks.start() # make sure all threads have finished tasks.join() logger.info('') @@ -820,6 +821,7 @@ def on_chunk_done(args: list[tuple[str, nodes.document]], result: None) -> None: arg.append((docname, doctree)) tasks.add_task(write_process, arg, on_chunk_done) + tasks.start() # make sure all threads have finished tasks.join() logger.info('') diff --git a/sphinx/util/parallel.py b/sphinx/util/parallel.py index 3dd5e574c58..fd5c3d226ad 100644 --- a/sphinx/util/parallel.py +++ b/sphinx/util/parallel.py @@ -3,9 +3,9 @@ from __future__ import annotations import os +import queue import time import traceback -from math import sqrt from typing import TYPE_CHECKING try: @@ -20,8 +20,12 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence + from logging import LogRecord from typing import Any + InQueueArg = tuple[int, Callable[..., Any], Any] + OutQueueArg = tuple[int, tuple[bool, list[LogRecord], Any]] + logger = logging.getLogger(__name__) # our parallel functionality only works for the forking Process @@ -47,6 +51,9 @@ def add_task( if result_func: result_func(res) + def start(self) -> None: + pass + def join(self) -> None: pass @@ -70,24 +77,8 @@ def __init__(self, nproc: int) -> None: self._pworking = 0 # task number of each subprocess self._taskid = 0 - - def _process( - self, pipe: Any, func: Callable[[Any], Any] | Callable[[], Any], arg: Any - ) -> None: - try: - collector = logging.LogCollector() - with collector.collect(): - if arg is None: - ret = func() # type: ignore[call-arg] - else: - ret = func(arg) # type: ignore[call-arg] - failed = False - except BaseException as err: - failed = True - errmsg = traceback.format_exception_only(err.__class__, err)[0].strip() - ret = (errmsg, traceback.format_exc()) - logging.convert_serializable(collector.logs) - pipe.send((failed, collector.logs, ret)) + self._args_queue: multiprocessing.Queue[Any] = multiprocessing.Queue() + self._result_queue: multiprocessing.Queue[Any] = multiprocessing.Queue() def add_task( self, @@ -99,71 +90,103 @@ def add_task( self._taskid += 1 self._result_funcs[tid] = result_func or (lambda arg, result: None) self._args[tid] = arg - precv, psend = multiprocessing.Pipe(False) - context: Any = multiprocessing.get_context('fork') - proc = context.Process(target=self._process, args=(psend, task_func, arg)) - self._procs[tid] = proc - self._precvs_waiting[tid] = precv - try: - self._join_one() - except Exception: - # shutdown other child processes on failure - # (e.g. OSError: Failed to allocate memory) - self.terminate() + self._args_queue.put((tid, task_func, arg)) + + def start(self) -> None: + # start the worker processes + for i in range(self._pworking, self.nproc + self._pworking): + proc = multiprocessing.Process( + target=process_data_chunks, + args=(self._args_queue, self._result_queue), + name=f'SphinxParallelWorker-{i}', + ) + self._procs[i] = proc + self._pworking += 1 + proc.start() def join(self) -> None: try: while self._pworking: - if not self._join_one(): - time.sleep(0.02) + while not self._result_queue.empty(): + tid, result = self._result_queue.get_nowait() + if tid in self._result_funcs: + exc, logs, res = result + if exc: + raise SphinxParallelError(*res) + for log in logs: + logger.handle(log) + self._result_funcs[tid](self._args.pop(tid), res) + else: + raise SphinxParallelError( + message=f'Result function for task {tid} not found. ' + f'This is a bug in Sphinx.', + traceback='', + ) + for num, proc in list(self._procs.items()): + if not proc.is_alive(): + self._procs.pop(num) + self._pworking -= 1 + if self._pworking: + time.sleep(0.02) finally: # shutdown other child processes on failure self.terminate() def terminate(self) -> None: - for tid in list(self._precvs): + for tid in list(self._procs): self._procs[tid].terminate() self._result_funcs.pop(tid) self._procs.pop(tid) - self._precvs.pop(tid) self._pworking -= 1 - def _join_one(self) -> bool: - joined_any = False - for tid, pipe in self._precvs.items(): - if pipe.poll(): - exc, logs, result = pipe.recv() - if exc: - raise SphinxParallelError(*result) - for log in logs: - logger.handle(log) - self._result_funcs.pop(tid)(self._args.pop(tid), result) - self._procs[tid].join() - self._precvs.pop(tid) - self._pworking -= 1 - joined_any = True + # clear queues to avoid memory leaks + while not self._args_queue.empty(): + try: + self._args_queue.get_nowait() + except queue.Empty: break - while self._precvs_waiting and self._pworking < self.nproc: - newtid, newprecv = self._precvs_waiting.popitem() - self._precvs[newtid] = newprecv - self._procs[newtid].start() - self._pworking += 1 - - return joined_any + # clear result queue to avoid memory leaks + while not self._result_queue.empty(): + try: + self._result_queue.get_nowait() + except queue.Empty: + break def make_chunks(arguments: Sequence[str], nproc: int, maxbatch: int = 10) -> list[Any]: # determine how many documents to read in one go nargs = len(arguments) chunksize = nargs // nproc - if chunksize >= maxbatch: - # try to improve batch size vs. number of batches - chunksize = int(sqrt(nargs / nproc * maxbatch)) - if chunksize == 0: - chunksize = 1 + chunksize = max(min(chunksize, maxbatch), 1) nchunks, rest = divmod(nargs, chunksize) if rest: nchunks += 1 # partition documents in "chunks" that will be written by one Process return [arguments[i * chunksize : (i + 1) * chunksize] for i in range(nchunks)] + + +def process_data_chunks( + queue_in: multiprocessing.Queue[InQueueArg], + queue_out: multiprocessing.Queue[OutQueueArg], +) -> None: + """Process data chunks from the queue using the given function.""" + while True: + try: + task_id, func, arg = queue_in.get_nowait() + except queue.Empty: + break + collector = logging.LogCollector() + try: + with collector.collect(): + if arg is not None: + result = func(arg) + else: + result = func() + failed = False + except BaseException as err: + failed = True + errmsg = traceback.format_exception_only(err.__class__, err)[0].strip() + result = (errmsg, traceback.format_exc()) + logging.convert_serializable(collector.logs) + queue_out.put((task_id, (failed, collector.logs, result))) diff --git a/tests/test_util/test_util_logging.py b/tests/test_util/test_util_logging.py index c21434a8414..1703f49b0ca 100644 --- a/tests/test_util/test_util_logging.py +++ b/tests/test_util/test_util_logging.py @@ -347,6 +347,7 @@ def child_process(): tasks = ParallelTasks(1) tasks.add_task(child_process) + tasks.start() tasks.join() assert 'message1' in app.status.getvalue() assert 'index.txt: WARNING: message2' in app.warning.getvalue()