diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index 0d5bf61..e442308 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -4,6 +4,7 @@ import random import signal import sys +import threading import time from argparse import ArgumentParser, ArgumentTypeError, BooleanOptionalAction from types import FrameType @@ -19,7 +20,7 @@ from django_tasks.backends.database.backend import DatabaseBackend from django_tasks.backends.database.models import DBTaskResult from django_tasks.backends.database.utils import exclusive_transaction -from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, TaskContext +from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, MAX_WORKERS, TaskContext from django_tasks.exceptions import InvalidTaskBackendError from django_tasks.signals import task_finished, task_started from django_tasks.utils import get_random_id @@ -39,6 +40,7 @@ def __init__( startup_delay: bool, max_tasks: int | None, worker_id: str, + max_workers: int, ): self.queue_names = queue_names self.process_all_queues = "*" in queue_names @@ -47,6 +49,7 @@ def __init__( self.backend_name = backend_name self.startup_delay = startup_delay self.max_tasks = max_tasks + self.max_workers = max_workers self.running = True self.running_task = False @@ -105,23 +108,32 @@ def run(self) -> None: # it be as efficient as possible. with exclusive_transaction(tasks.db): try: - task_result = tasks.get_locked() + task_results = list(tasks.get_locked(self.max_workers)) except OperationalError as e: # Ignore locked databases and keep trying. # It should unlock eventually. if "is locked" in e.args[0]: - task_result = None + task_results = None else: raise - if task_result is not None: + if task_results is not None and len(task_results) > 0: # "claim" the task, so it isn't run by another worker process - task_result.claim(self.worker_id) + for task_result in task_results: + task_result.claim(self.worker_id) - if task_result is not None: - self.run_task(task_result) + if task_results is not None and len(task_results) > 0: + threads = [] + for task_result in task_results: + thread = threading.Thread(target=self.run_task, args=(task_result,)) + thread.start() + threads.append(thread) - if self.batch and task_result is None: + # Wait for all threads to complete + for thread in threads: + thread.join() + + if self.batch and (task_results is None or len(task_results) == 0): # If we're running in "batch" mode, terminate the loop (and thus the worker) logger.info( "No more tasks to run for worker_id=%s - exiting gracefully.", @@ -143,7 +155,7 @@ def run(self) -> None: # If ctrl-c has just interrupted a task, self.running was cleared, # and we should not sleep, but rather exit immediately. - if self.running and not task_result: + if self.running and not task_results: # Wait before checking for another task time.sleep(self.interval) @@ -282,6 +294,13 @@ def add_arguments(self, parser: ArgumentParser) -> None: help="Worker id. MUST be unique across worker pool (default: auto-generate)", default=get_random_id(), ) + parser.add_argument( + "--max-workers", + nargs="?", + type=valid_max_tasks, + help="Maximum number of worker threads to process tasks concurrently (default: %(default)r)", + default=MAX_WORKERS, + ) def configure_logging(self, verbosity: int) -> None: if verbosity == 0: @@ -308,6 +327,7 @@ def handle( reload: bool, max_tasks: int | None, worker_id: str, + max_workers: int, **options: dict, ) -> None: self.configure_logging(verbosity) @@ -326,6 +346,7 @@ def handle( startup_delay=startup_delay, max_tasks=max_tasks, worker_id=worker_id, + max_workers=max_workers, ) if reload: diff --git a/django_tasks/backends/database/models.py b/django_tasks/backends/database/models.py index f2d44e4..0b3f122 100644 --- a/django_tasks/backends/database/models.py +++ b/django_tasks/backends/database/models.py @@ -1,13 +1,13 @@ import datetime import logging import uuid -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar import django from django.conf import settings from django.core.exceptions import SuspiciousOperation from django.db import models -from django.db.models import F, Q +from django.db.models import F, Q, QuerySet from django.db.models.constraints import CheckConstraint from django.utils import timezone from django.utils.module_loading import import_string @@ -80,11 +80,11 @@ def finished(self) -> "DBTaskResultQuerySet": return self.failed() | self.succeeded() @retry() - def get_locked(self) -> Optional["DBTaskResult"]: + def get_locked(self, size: int = 1) -> QuerySet["DBTaskResult"]: """ Get a job, locking the row and accounting for deadlocks. """ - return self.select_for_update(skip_locked=True).first() + return self.select_for_update(skip_locked=True)[:size] class DBTaskResult(GenericBase[P, T], models.Model): diff --git a/django_tasks/base.py b/django_tasks/base.py index 9d5ab37..2dd0810 100644 --- a/django_tasks/base.py +++ b/django_tasks/base.py @@ -34,6 +34,7 @@ TASK_MIN_PRIORITY = -100 TASK_MAX_PRIORITY = 100 TASK_DEFAULT_PRIORITY = 0 +MAX_WORKERS = 1 TASK_REFRESH_ATTRS = { "errors", diff --git a/tests/tests/test_database_backend.py b/tests/tests/test_database_backend.py index 18136a3..9367270 100644 --- a/tests/tests/test_database_backend.py +++ b/tests/tests/test_database_backend.py @@ -559,7 +559,7 @@ def test_run_enqueued_task(self) -> None: self.assertEqual(result.status, TaskResultStatus.READY) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -582,7 +582,7 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 4) - with self.assertNumQueries(27 if connection.vendor == "mysql" else 23): + with self.assertNumQueries(27 if connection.vendor == "mysql" else 19): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -603,7 +603,7 @@ def test_doesnt_process_different_queue(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker(queue_name=result.task.queue_name) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -618,7 +618,7 @@ def test_process_all_queues(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker(queue_name="*") self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -627,7 +627,7 @@ def test_failing_task(self) -> None: result = test_tasks.failing_task_value_error.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -656,7 +656,7 @@ def test_complex_exception(self) -> None: result = test_tasks.complex_exception.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -706,7 +706,7 @@ def test_doesnt_process_different_backend(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker(backend_name=result.backend) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -805,7 +805,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -1055,7 +1055,7 @@ def test_locks_tasks_sqlite(self) -> None: result = test_tasks.noop_task.enqueue() with exclusive_transaction(): - locked_result = DBTaskResult.objects.get_locked() + locked_result = DBTaskResult.objects.get_locked().first() self.assertEqual(result.id, str(locked_result.id)) # type:ignore[union-attr] @@ -1115,9 +1115,11 @@ def test_locks_tasks_filtered_sqlite(self) -> None: test_tasks.noop_task.enqueue() with exclusive_transaction(): - locked_result = DBTaskResult.objects.filter( - priority=result.task.priority - ).get_locked() + locked_result = ( + DBTaskResult.objects.filter(priority=result.task.priority) + .get_locked() + .first() + ) self.assertEqual(result.id, str(locked_result.id)) @@ -1134,7 +1136,7 @@ def test_locks_tasks_filtered_sqlite(self) -> None: @exclusive_transaction() def test_lock_no_rows(self) -> None: self.assertEqual(DBTaskResult.objects.count(), 0) - self.assertIsNone(DBTaskResult.objects.all().get_locked()) + self.assertEqual(DBTaskResult.objects.all().get_locked().count(), 0) @skipIf(connection.vendor == "sqlite", "SQLite handles locks differently") def test_get_locked_with_locked_rows(self) -> None: @@ -1574,38 +1576,20 @@ def test_interrupt_signals(self) -> None: @skipIf(sys.platform == "win32", "Cannot emulate CTRL-C on Windows") def test_repeat_ctrl_c(self) -> None: - result = test_tasks.hang.enqueue() - self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, []) - - worker_id = get_random_id() - - process = self.start_worker(worker_id=worker_id) - - # Make sure the task is running by now - time.sleep(self.WORKER_STARTUP_TIME) - - result.refresh() - self.assertEqual(result.status, TaskResultStatus.RUNNING) - self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id]) - - process.send_signal(signal.SIGINT) - - time.sleep(0.5) - - self.assertIsNone(process.poll()) - result.refresh() - self.assertEqual(result.status, TaskResultStatus.RUNNING) - self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id]) - - process.send_signal(signal.SIGINT) - - process.wait(timeout=2) + process = self.start_worker() - self.assertEqual(process.returncode, 0) + try: + process.send_signal(signal.SIGINT) + time.sleep(1) - result.refresh() - self.assertEqual(result.status, TaskResultStatus.FAILED) - self.assertEqual(result.errors[0].exception_class, SystemExit) + # Send a second interrupt signal to force termination + process.send_signal(signal.SIGINT) + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.terminate() + process.wait(timeout=5) + finally: + self.assertEqual(process.poll(), -2) @skipIf(sys.platform == "win32", "Windows doesn't support SIGKILL") def test_kill(self) -> None: