Skip to content

use queues for execute tasks more pararell #13738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Contributors
* Filip Vavera -- napoleon todo directive
* Glenn Matthews -- python domain signature improvements
* Gregory Szorc -- performance improvements
* Grzegorz Bokota -- parallelism improvements
* Henrique Bastos -- SVG support for graphviz extension
* Hernan Grecco -- search improvements
* Hong Xu -- svg support in imgmath extension and various bug fixes
Expand Down
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ Features added
The location of the cache directory must not be relied upon externally,
as it may change without notice or warning in future releases.
Patch by Adam Turner.
* #13738: Improve parallelism by dynamic assign of tasks to workers.
Patch by Grzegorz Bokota

Bugs fixed
----------
Expand Down
2 changes: 2 additions & 0 deletions sphinx/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ def merge(docs: list[str], otherenv: bytes) -> None:
for chunk in chunks:
tasks.add_task(read_process, chunk, merge)

tasks.start()
# make sure all threads have finished
tasks.join()
logger.info('')
Expand Down Expand Up @@ -820,6 +821,7 @@ def on_chunk_done(args: list[tuple[str, nodes.document]], result: None) -> None:
arg.append((docname, doctree))
tasks.add_task(write_process, arg, on_chunk_done)

tasks.start()
# make sure all threads have finished
tasks.join()
logger.info('')
Expand Down
143 changes: 83 additions & 60 deletions sphinx/util/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from __future__ import annotations

import os
import queue
import time
import traceback
from math import sqrt
from typing import TYPE_CHECKING

try:
Expand All @@ -20,8 +20,12 @@

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from logging import LogRecord
from typing import Any

InQueueArg = tuple[int, Callable[..., Any], Any]
OutQueueArg = tuple[int, tuple[bool, list[LogRecord], Any]]

logger = logging.getLogger(__name__)

# our parallel functionality only works for the forking Process
Expand All @@ -47,6 +51,9 @@ def add_task(
if result_func:
result_func(res)

def start(self) -> None:
pass

def join(self) -> None:
pass

Expand All @@ -70,24 +77,8 @@ def __init__(self, nproc: int) -> None:
self._pworking = 0
# task number of each subprocess
self._taskid = 0

def _process(
self, pipe: Any, func: Callable[[Any], Any] | Callable[[], Any], arg: Any
) -> None:
try:
collector = logging.LogCollector()
with collector.collect():
if arg is None:
ret = func() # type: ignore[call-arg]
else:
ret = func(arg) # type: ignore[call-arg]
failed = False
except BaseException as err:
failed = True
errmsg = traceback.format_exception_only(err.__class__, err)[0].strip()
ret = (errmsg, traceback.format_exc())
logging.convert_serializable(collector.logs)
pipe.send((failed, collector.logs, ret))
self._args_queue: multiprocessing.Queue[Any] = multiprocessing.Queue()
self._result_queue: multiprocessing.Queue[Any] = multiprocessing.Queue()

def add_task(
self,
Expand All @@ -99,71 +90,103 @@ def add_task(
self._taskid += 1
self._result_funcs[tid] = result_func or (lambda arg, result: None)
self._args[tid] = arg
precv, psend = multiprocessing.Pipe(False)
context: Any = multiprocessing.get_context('fork')
proc = context.Process(target=self._process, args=(psend, task_func, arg))
self._procs[tid] = proc
self._precvs_waiting[tid] = precv
try:
self._join_one()
except Exception:
# shutdown other child processes on failure
# (e.g. OSError: Failed to allocate memory)
self.terminate()
self._args_queue.put((tid, task_func, arg))

def start(self) -> None:
# start the worker processes
for i in range(self._pworking, self.nproc + self._pworking):
proc = multiprocessing.Process(
target=process_data_chunks,
args=(self._args_queue, self._result_queue),
name=f'SphinxParallelWorker-{i}',
)
self._procs[i] = proc
self._pworking += 1
proc.start()

def join(self) -> None:
try:
while self._pworking:
if not self._join_one():
time.sleep(0.02)
while not self._result_queue.empty():
tid, result = self._result_queue.get_nowait()
if tid in self._result_funcs:
exc, logs, res = result
if exc:
raise SphinxParallelError(*res)
for log in logs:
logger.handle(log)
self._result_funcs[tid](self._args.pop(tid), res)
else:
raise SphinxParallelError(
message=f'Result function for task {tid} not found. '
f'This is a bug in Sphinx.',
traceback='',
)
for num, proc in list(self._procs.items()):
if not proc.is_alive():
self._procs.pop(num)
self._pworking -= 1
if self._pworking:
time.sleep(0.02)
finally:
# shutdown other child processes on failure
self.terminate()

def terminate(self) -> None:
for tid in list(self._precvs):
for tid in list(self._procs):
self._procs[tid].terminate()
self._result_funcs.pop(tid)
self._procs.pop(tid)
self._precvs.pop(tid)
self._pworking -= 1

def _join_one(self) -> bool:
joined_any = False
for tid, pipe in self._precvs.items():
if pipe.poll():
exc, logs, result = pipe.recv()
if exc:
raise SphinxParallelError(*result)
for log in logs:
logger.handle(log)
self._result_funcs.pop(tid)(self._args.pop(tid), result)
self._procs[tid].join()
self._precvs.pop(tid)
self._pworking -= 1
joined_any = True
# clear queues to avoid memory leaks
while not self._args_queue.empty():
try:
self._args_queue.get_nowait()
except queue.Empty:
break

while self._precvs_waiting and self._pworking < self.nproc:
newtid, newprecv = self._precvs_waiting.popitem()
self._precvs[newtid] = newprecv
self._procs[newtid].start()
self._pworking += 1

return joined_any
# clear result queue to avoid memory leaks
while not self._result_queue.empty():
try:
self._result_queue.get_nowait()
except queue.Empty:
break


def make_chunks(arguments: Sequence[str], nproc: int, maxbatch: int = 10) -> list[Any]:
# determine how many documents to read in one go
nargs = len(arguments)
chunksize = nargs // nproc
if chunksize >= maxbatch:
# try to improve batch size vs. number of batches
chunksize = int(sqrt(nargs / nproc * maxbatch))
if chunksize == 0:
chunksize = 1
chunksize = max(min(chunksize, maxbatch), 1)
nchunks, rest = divmod(nargs, chunksize)
if rest:
nchunks += 1
# partition documents in "chunks" that will be written by one Process
return [arguments[i * chunksize : (i + 1) * chunksize] for i in range(nchunks)]


def process_data_chunks(
queue_in: multiprocessing.Queue[InQueueArg],
queue_out: multiprocessing.Queue[OutQueueArg],
) -> None:
"""Process data chunks from the queue using the given function."""
while True:
try:
task_id, func, arg = queue_in.get_nowait()
except queue.Empty:
break
collector = logging.LogCollector()
try:
with collector.collect():
if arg is not None:
result = func(arg)
else:
result = func()
failed = False
except BaseException as err:
failed = True
errmsg = traceback.format_exception_only(err.__class__, err)[0].strip()
result = (errmsg, traceback.format_exc())
logging.convert_serializable(collector.logs)
queue_out.put((task_id, (failed, collector.logs, result)))
1 change: 1 addition & 0 deletions tests/test_util/test_util_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def child_process():

tasks = ParallelTasks(1)
tasks.add_task(child_process)
tasks.start()
tasks.join()
assert 'message1' in app.status.getvalue()
assert 'index.txt: WARNING: message2' in app.warning.getvalue()
Expand Down
Loading