Skip to content
39 changes: 30 additions & 9 deletions django_tasks/backends/database/management/commands/db_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import signal
import sys
import threading
import time
from argparse import ArgumentParser, ArgumentTypeError, BooleanOptionalAction
from types import FrameType
Expand All @@ -19,7 +20,7 @@
from django_tasks.backends.database.backend import DatabaseBackend
from django_tasks.backends.database.models import DBTaskResult
from django_tasks.backends.database.utils import exclusive_transaction
from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, TaskContext
from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, MAX_WORKERS, TaskContext
from django_tasks.exceptions import InvalidTaskBackendError
from django_tasks.signals import task_finished, task_started
from django_tasks.utils import get_random_id
Expand All @@ -39,6 +40,7 @@ def __init__(
startup_delay: bool,
max_tasks: int | None,
worker_id: str,
max_workers: int,
):
self.queue_names = queue_names
self.process_all_queues = "*" in queue_names
Expand All @@ -47,6 +49,7 @@ def __init__(
self.backend_name = backend_name
self.startup_delay = startup_delay
self.max_tasks = max_tasks
self.max_workers = max_workers

self.running = True
self.running_task = False
Expand Down Expand Up @@ -105,23 +108,32 @@ def run(self) -> None:
# it be as efficient as possible.
with exclusive_transaction(tasks.db):
try:
task_result = tasks.get_locked()
task_results = list(tasks.get_locked(self.max_workers))
except OperationalError as e:
# Ignore locked databases and keep trying.
# It should unlock eventually.
if "is locked" in e.args[0]:
task_result = None
task_results = None
else:
raise

if task_result is not None:
if task_results is not None and len(task_results) > 0:
# "claim" the task, so it isn't run by another worker process
task_result.claim(self.worker_id)
for task_result in task_results:
task_result.claim(self.worker_id)

if task_result is not None:
self.run_task(task_result)
if task_results is not None and len(task_results) > 0:
threads = []
for task_result in task_results:
thread = threading.Thread(target=self.run_task, args=(task_result,))
thread.start()
threads.append(thread)

if self.batch and task_result is None:
# Wait for all threads to complete
for thread in threads:
thread.join()

if self.batch and (task_results is None or len(task_results) == 0):
# If we're running in "batch" mode, terminate the loop (and thus the worker)
logger.info(
"No more tasks to run for worker_id=%s - exiting gracefully.",
Expand All @@ -143,7 +155,7 @@ def run(self) -> None:

# If ctrl-c has just interrupted a task, self.running was cleared,
# and we should not sleep, but rather exit immediately.
if self.running and not task_result:
if self.running and not task_results:
# Wait before checking for another task
time.sleep(self.interval)

Expand Down Expand Up @@ -282,6 +294,13 @@ def add_arguments(self, parser: ArgumentParser) -> None:
help="Worker id. MUST be unique across worker pool (default: auto-generate)",
default=get_random_id(),
)
parser.add_argument(
"--max-workers",
nargs="?",
type=valid_max_tasks,
help="Maximum number of worker threads to process tasks concurrently (default: %(default)r)",
default=MAX_WORKERS,
)

def configure_logging(self, verbosity: int) -> None:
if verbosity == 0:
Expand All @@ -308,6 +327,7 @@ def handle(
reload: bool,
max_tasks: int | None,
worker_id: str,
max_workers: int,
**options: dict,
) -> None:
self.configure_logging(verbosity)
Expand All @@ -326,6 +346,7 @@ def handle(
startup_delay=startup_delay,
max_tasks=max_tasks,
worker_id=worker_id,
max_workers=max_workers,
)

if reload:
Expand Down
8 changes: 4 additions & 4 deletions django_tasks/backends/database/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import datetime
import logging
import uuid
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import django
from django.conf import settings
from django.core.exceptions import SuspiciousOperation
from django.db import models
from django.db.models import F, Q
from django.db.models import F, Q, QuerySet
from django.db.models.constraints import CheckConstraint
from django.utils import timezone
from django.utils.module_loading import import_string
Expand Down Expand Up @@ -80,11 +80,11 @@ def finished(self) -> "DBTaskResultQuerySet":
return self.failed() | self.succeeded()

@retry()
def get_locked(self) -> Optional["DBTaskResult"]:
def get_locked(self, size: int = 1) -> QuerySet["DBTaskResult"]:
"""
Get a job, locking the row and accounting for deadlocks.
"""
return self.select_for_update(skip_locked=True).first()
return self.select_for_update(skip_locked=True)[:size]


class DBTaskResult(GenericBase[P, T], models.Model):
Expand Down
1 change: 1 addition & 0 deletions django_tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TASK_MIN_PRIORITY = -100
TASK_MAX_PRIORITY = 100
TASK_DEFAULT_PRIORITY = 0
MAX_WORKERS = 1

TASK_REFRESH_ATTRS = {
"errors",
Expand Down
70 changes: 27 additions & 43 deletions tests/tests/test_database_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def test_run_enqueued_task(self) -> None:

self.assertEqual(result.status, TaskResultStatus.READY)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker()

self.assertEqual(result.status, TaskResultStatus.READY)
Expand All @@ -582,7 +582,7 @@ def test_batch_processes_all_tasks(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 4)

with self.assertNumQueries(27 if connection.vendor == "mysql" else 23):
with self.assertNumQueries(27 if connection.vendor == "mysql" else 19):
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand All @@ -603,7 +603,7 @@ def test_doesnt_process_different_queue(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker(queue_name=result.task.queue_name)

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand All @@ -618,7 +618,7 @@ def test_process_all_queues(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker(queue_name="*")

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand All @@ -627,7 +627,7 @@ def test_failing_task(self) -> None:
result = test_tasks.failing_task_value_error.enqueue()
self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker()

self.assertEqual(result.status, TaskResultStatus.READY)
Expand Down Expand Up @@ -656,7 +656,7 @@ def test_complex_exception(self) -> None:
result = test_tasks.complex_exception.enqueue()
self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker()

self.assertEqual(result.status, TaskResultStatus.READY)
Expand Down Expand Up @@ -706,7 +706,7 @@ def test_doesnt_process_different_backend(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker(backend_name=result.backend)

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand Down Expand Up @@ -805,7 +805,7 @@ def test_run_after(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 7):
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand Down Expand Up @@ -1055,7 +1055,7 @@ def test_locks_tasks_sqlite(self) -> None:
result = test_tasks.noop_task.enqueue()

with exclusive_transaction():
locked_result = DBTaskResult.objects.get_locked()
locked_result = DBTaskResult.objects.get_locked().first()

self.assertEqual(result.id, str(locked_result.id)) # type:ignore[union-attr]

Expand Down Expand Up @@ -1115,9 +1115,11 @@ def test_locks_tasks_filtered_sqlite(self) -> None:
test_tasks.noop_task.enqueue()

with exclusive_transaction():
locked_result = DBTaskResult.objects.filter(
priority=result.task.priority
).get_locked()
locked_result = (
DBTaskResult.objects.filter(priority=result.task.priority)
.get_locked()
.first()
)

self.assertEqual(result.id, str(locked_result.id))

Expand All @@ -1134,7 +1136,7 @@ def test_locks_tasks_filtered_sqlite(self) -> None:
@exclusive_transaction()
def test_lock_no_rows(self) -> None:
self.assertEqual(DBTaskResult.objects.count(), 0)
self.assertIsNone(DBTaskResult.objects.all().get_locked())
self.assertEqual(DBTaskResult.objects.all().get_locked().count(), 0)

@skipIf(connection.vendor == "sqlite", "SQLite handles locks differently")
def test_get_locked_with_locked_rows(self) -> None:
Expand Down Expand Up @@ -1574,38 +1576,20 @@ def test_interrupt_signals(self) -> None:

@skipIf(sys.platform == "win32", "Cannot emulate CTRL-C on Windows")
def test_repeat_ctrl_c(self) -> None:
result = test_tasks.hang.enqueue()
self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [])

worker_id = get_random_id()

process = self.start_worker(worker_id=worker_id)

# Make sure the task is running by now
time.sleep(self.WORKER_STARTUP_TIME)

result.refresh()
self.assertEqual(result.status, TaskResultStatus.RUNNING)
self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id])

process.send_signal(signal.SIGINT)

time.sleep(0.5)

self.assertIsNone(process.poll())
result.refresh()
self.assertEqual(result.status, TaskResultStatus.RUNNING)
self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id])

process.send_signal(signal.SIGINT)

process.wait(timeout=2)
process = self.start_worker()

self.assertEqual(process.returncode, 0)
try:
process.send_signal(signal.SIGINT)
time.sleep(1)

result.refresh()
self.assertEqual(result.status, TaskResultStatus.FAILED)
self.assertEqual(result.errors[0].exception_class, SystemExit)
# Send a second interrupt signal to force termination
process.send_signal(signal.SIGINT)
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.terminate()
process.wait(timeout=5)
finally:
self.assertEqual(process.poll(), -2)

@skipIf(sys.platform == "win32", "Windows doesn't support SIGKILL")
def test_kill(self) -> None:
Expand Down