From 4c47d33c80b517b91f99d2693d4bb468fa126fd9 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 16:43:19 +0100 Subject: [PATCH 1/8] Adjust task handling in Worker to allow retrieving multiple locked tasks and processing them in batches --- .../database/management/commands/db_worker.py | 16 +++++++++------- django_tasks/base.py | 1 + 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index 0d5bf61..30c8762 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -105,23 +105,25 @@ def run(self) -> None: # it be as efficient as possible. with exclusive_transaction(tasks.db): try: - task_result = tasks.get_locked() + task_results = tasks.get_locked(self.max_tasks) except OperationalError as e: # Ignore locked databases and keep trying. # It should unlock eventually. if "is locked" in e.args[0]: - task_result = None + task_results = None else: raise - if task_result is not None: + if task_results is not None and task_results.exists(): # "claim" the task, so it isn't run by another worker process - task_result.claim(self.worker_id) + for task_result in task_results: + task_result.claim(self.worker_id) - if task_result is not None: - self.run_task(task_result) + if task_results is not None and task_results.exists(): + for task_result in task_results: + self.run_task(task_result) - if self.batch and task_result is None: + if self.batch and (task_results is None or not task_results.exists()): # If we're running in "batch" mode, terminate the loop (and thus the worker) logger.info( "No more tasks to run for worker_id=%s - exiting gracefully.", diff --git a/django_tasks/base.py b/django_tasks/base.py index 9d5ab37..2dd0810 100644 --- a/django_tasks/base.py +++ b/django_tasks/base.py @@ -34,6 +34,7 @@ TASK_MIN_PRIORITY = -100 TASK_MAX_PRIORITY = 100 TASK_DEFAULT_PRIORITY = 0 +MAX_WORKERS = 1 TASK_REFRESH_ATTRS = { "errors", From d44e4551f1b3171055ba3dd3663886a046414727 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 16:43:36 +0100 Subject: [PATCH 2/8] Modify the get_locked method in DBTaskResultQuerySet to allow retrieval of multiple locked jobs at once. --- django_tasks/backends/database/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/django_tasks/backends/database/models.py b/django_tasks/backends/database/models.py index f2d44e4..0b3f122 100644 --- a/django_tasks/backends/database/models.py +++ b/django_tasks/backends/database/models.py @@ -1,13 +1,13 @@ import datetime import logging import uuid -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar import django from django.conf import settings from django.core.exceptions import SuspiciousOperation from django.db import models -from django.db.models import F, Q +from django.db.models import F, Q, QuerySet from django.db.models.constraints import CheckConstraint from django.utils import timezone from django.utils.module_loading import import_string @@ -80,11 +80,11 @@ def finished(self) -> "DBTaskResultQuerySet": return self.failed() | self.succeeded() @retry() - def get_locked(self) -> Optional["DBTaskResult"]: + def get_locked(self, size: int = 1) -> QuerySet["DBTaskResult"]: """ Get a job, locking the row and accounting for deadlocks. """ - return self.select_for_update(skip_locked=True).first() + return self.select_for_update(skip_locked=True)[:size] class DBTaskResult(GenericBase[P, T], models.Model): From 0f647bc60b02e25ab022c0f98a2f48b1720d45e1 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 16:43:49 +0100 Subject: [PATCH 3/8] Adjust the query number assertions in the DatabaseBackendWorkerTestCase and DatabaseTaskResultTestCase tests --- tests/tests/test_database_backend.py | 38 +++++++++++++++------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/tests/tests/test_database_backend.py b/tests/tests/test_database_backend.py index 18136a3..e1cce98 100644 --- a/tests/tests/test_database_backend.py +++ b/tests/tests/test_database_backend.py @@ -559,7 +559,7 @@ def test_run_enqueued_task(self) -> None: self.assertEqual(result.status, TaskResultStatus.READY) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -582,7 +582,7 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 4) - with self.assertNumQueries(27 if connection.vendor == "mysql" else 23): + with self.assertNumQueries(27 if connection.vendor == "mysql" else 17): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -590,7 +590,7 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.failed().count(), 1) def test_no_tasks(self) -> None: - with self.assertNumQueries(3): + with self.assertNumQueries(5): self.run_worker() def test_doesnt_process_different_queue(self) -> None: @@ -598,12 +598,12 @@ def test_doesnt_process_different_queue(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(3): + with self.assertNumQueries(5): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker(queue_name=result.task.queue_name) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -613,12 +613,12 @@ def test_process_all_queues(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(3): + with self.assertNumQueries(5): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker(queue_name="*") self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -627,7 +627,7 @@ def test_failing_task(self) -> None: result = test_tasks.failing_task_value_error.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -656,7 +656,7 @@ def test_complex_exception(self) -> None: result = test_tasks.complex_exception.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -701,12 +701,12 @@ def test_doesnt_process_different_backend(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(3): + with self.assertNumQueries(5): self.run_worker(backend_name="dummy") self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker(backend_name=result.backend) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -794,7 +794,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.count(), 1) self.assertEqual(DBTaskResult.objects.ready().count(), 0) - with self.assertNumQueries(3): + with self.assertNumQueries(5): self.run_worker() self.assertEqual(DBTaskResult.objects.count(), 1) @@ -805,7 +805,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -1055,7 +1055,7 @@ def test_locks_tasks_sqlite(self) -> None: result = test_tasks.noop_task.enqueue() with exclusive_transaction(): - locked_result = DBTaskResult.objects.get_locked() + locked_result = DBTaskResult.objects.get_locked().first() self.assertEqual(result.id, str(locked_result.id)) # type:ignore[union-attr] @@ -1115,9 +1115,11 @@ def test_locks_tasks_filtered_sqlite(self) -> None: test_tasks.noop_task.enqueue() with exclusive_transaction(): - locked_result = DBTaskResult.objects.filter( - priority=result.task.priority - ).get_locked() + locked_result = ( + DBTaskResult.objects.filter(priority=result.task.priority) + .get_locked() + .first() + ) self.assertEqual(result.id, str(locked_result.id)) @@ -1134,7 +1136,7 @@ def test_locks_tasks_filtered_sqlite(self) -> None: @exclusive_transaction() def test_lock_no_rows(self) -> None: self.assertEqual(DBTaskResult.objects.count(), 0) - self.assertIsNone(DBTaskResult.objects.all().get_locked()) + self.assertEqual(DBTaskResult.objects.all().get_locked().count(), 0) @skipIf(connection.vendor == "sqlite", "SQLite handles locks differently") def test_get_locked_with_locked_rows(self) -> None: From 59eee96f5209a48ddb5e2c77656a88d0db8fbb33 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 16:52:06 +0100 Subject: [PATCH 4/8] Optimize task handling in the Worker and adjust query assertions in DatabaseBackendWorkerTestCase --- .../database/management/commands/db_worker.py | 8 +++--- tests/tests/test_database_backend.py | 26 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index 30c8762..c49b3da 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -105,7 +105,7 @@ def run(self) -> None: # it be as efficient as possible. with exclusive_transaction(tasks.db): try: - task_results = tasks.get_locked(self.max_tasks) + task_results = list(tasks.get_locked(self.max_tasks)) except OperationalError as e: # Ignore locked databases and keep trying. # It should unlock eventually. @@ -114,16 +114,16 @@ def run(self) -> None: else: raise - if task_results is not None and task_results.exists(): + if task_results is not None and len(task_results) > 0: # "claim" the task, so it isn't run by another worker process for task_result in task_results: task_result.claim(self.worker_id) - if task_results is not None and task_results.exists(): + if task_results is not None and len(task_results) > 0: for task_result in task_results: self.run_task(task_result) - if self.batch and (task_results is None or not task_results.exists()): + if self.batch and (task_results is None or len(task_results) == 0): # If we're running in "batch" mode, terminate the loop (and thus the worker) logger.info( "No more tasks to run for worker_id=%s - exiting gracefully.", diff --git a/tests/tests/test_database_backend.py b/tests/tests/test_database_backend.py index e1cce98..ead74d9 100644 --- a/tests/tests/test_database_backend.py +++ b/tests/tests/test_database_backend.py @@ -559,7 +559,7 @@ def test_run_enqueued_task(self) -> None: self.assertEqual(result.status, TaskResultStatus.READY) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -582,7 +582,7 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 4) - with self.assertNumQueries(27 if connection.vendor == "mysql" else 17): + with self.assertNumQueries(27 if connection.vendor == "mysql" else 14): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -590,7 +590,7 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.failed().count(), 1) def test_no_tasks(self) -> None: - with self.assertNumQueries(5): + with self.assertNumQueries(3): self.run_worker() def test_doesnt_process_different_queue(self) -> None: @@ -598,12 +598,12 @@ def test_doesnt_process_different_queue(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(5): + with self.assertNumQueries(3): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker(queue_name=result.task.queue_name) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -613,12 +613,12 @@ def test_process_all_queues(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(5): + with self.assertNumQueries(3): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker(queue_name="*") self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -627,7 +627,7 @@ def test_failing_task(self) -> None: result = test_tasks.failing_task_value_error.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -656,7 +656,7 @@ def test_complex_exception(self) -> None: result = test_tasks.complex_exception.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -701,12 +701,12 @@ def test_doesnt_process_different_backend(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(5): + with self.assertNumQueries(3): self.run_worker(backend_name="dummy") self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker(backend_name=result.backend) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -794,7 +794,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.count(), 1) self.assertEqual(DBTaskResult.objects.ready().count(), 0) - with self.assertNumQueries(5): + with self.assertNumQueries(3): self.run_worker() self.assertEqual(DBTaskResult.objects.count(), 1) @@ -805,7 +805,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) From b9386879955f221646f3befc6bbfa23d664eda2b Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 16:55:40 +0100 Subject: [PATCH 5/8] Corrects task verification in the Worker's run method to handle multiple task results. --- django_tasks/backends/database/management/commands/db_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index c49b3da..c15cf63 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -145,7 +145,7 @@ def run(self) -> None: # If ctrl-c has just interrupted a task, self.running was cleared, # and we should not sleep, but rather exit immediately. - if self.running and not task_result: + if self.running and not task_results: # Wait before checking for another task time.sleep(self.interval) From cc282f2a21d1b61512d0749b98ac6d7cb2995949 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 17:08:11 +0100 Subject: [PATCH 6/8] =?UTF-8?q?A=C3=B1ade=20soporte=20para=20m=C3=BAltiple?= =?UTF-8?q?s=20hilos=20en=20el=20Worker,=20permitiendo=20la=20ejecuci?= =?UTF-8?q?=C3=B3n=20concurrente=20de=20tareas.=20Se=20agrega=20un=20argum?= =?UTF-8?q?ento=20--max-workers=20para=20definir=20el=20n=C3=BAmero=20m?= =?UTF-8?q?=C3=A1ximo=20de=20hilos=20de=20trabajo.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../database/management/commands/db_worker.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index c15cf63..e442308 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -4,6 +4,7 @@ import random import signal import sys +import threading import time from argparse import ArgumentParser, ArgumentTypeError, BooleanOptionalAction from types import FrameType @@ -19,7 +20,7 @@ from django_tasks.backends.database.backend import DatabaseBackend from django_tasks.backends.database.models import DBTaskResult from django_tasks.backends.database.utils import exclusive_transaction -from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, TaskContext +from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, MAX_WORKERS, TaskContext from django_tasks.exceptions import InvalidTaskBackendError from django_tasks.signals import task_finished, task_started from django_tasks.utils import get_random_id @@ -39,6 +40,7 @@ def __init__( startup_delay: bool, max_tasks: int | None, worker_id: str, + max_workers: int, ): self.queue_names = queue_names self.process_all_queues = "*" in queue_names @@ -47,6 +49,7 @@ def __init__( self.backend_name = backend_name self.startup_delay = startup_delay self.max_tasks = max_tasks + self.max_workers = max_workers self.running = True self.running_task = False @@ -105,7 +108,7 @@ def run(self) -> None: # it be as efficient as possible. with exclusive_transaction(tasks.db): try: - task_results = list(tasks.get_locked(self.max_tasks)) + task_results = list(tasks.get_locked(self.max_workers)) except OperationalError as e: # Ignore locked databases and keep trying. # It should unlock eventually. @@ -120,8 +123,15 @@ def run(self) -> None: task_result.claim(self.worker_id) if task_results is not None and len(task_results) > 0: + threads = [] for task_result in task_results: - self.run_task(task_result) + thread = threading.Thread(target=self.run_task, args=(task_result,)) + thread.start() + threads.append(thread) + + # Wait for all threads to complete + for thread in threads: + thread.join() if self.batch and (task_results is None or len(task_results) == 0): # If we're running in "batch" mode, terminate the loop (and thus the worker) @@ -284,6 +294,13 @@ def add_arguments(self, parser: ArgumentParser) -> None: help="Worker id. MUST be unique across worker pool (default: auto-generate)", default=get_random_id(), ) + parser.add_argument( + "--max-workers", + nargs="?", + type=valid_max_tasks, + help="Maximum number of worker threads to process tasks concurrently (default: %(default)r)", + default=MAX_WORKERS, + ) def configure_logging(self, verbosity: int) -> None: if verbosity == 0: @@ -310,6 +327,7 @@ def handle( reload: bool, max_tasks: int | None, worker_id: str, + max_workers: int, **options: dict, ) -> None: self.configure_logging(verbosity) @@ -328,6 +346,7 @@ def handle( startup_delay=startup_delay, max_tasks=max_tasks, worker_id=worker_id, + max_workers=max_workers, ) if reload: From 75b092efae00421e416fee063c43e46653e11312 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 17:10:56 +0100 Subject: [PATCH 7/8] Adjust the query number assertions in DatabaseBackendWorkerTestCase to reflect changes in the worker's execution logic. --- tests/tests/test_database_backend.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/tests/test_database_backend.py b/tests/tests/test_database_backend.py index ead74d9..861bb27 100644 --- a/tests/tests/test_database_backend.py +++ b/tests/tests/test_database_backend.py @@ -559,7 +559,7 @@ def test_run_enqueued_task(self) -> None: self.assertEqual(result.status, TaskResultStatus.READY) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -582,7 +582,7 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 4) - with self.assertNumQueries(27 if connection.vendor == "mysql" else 14): + with self.assertNumQueries(27 if connection.vendor == "mysql" else 19): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -603,7 +603,7 @@ def test_doesnt_process_different_queue(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker(queue_name=result.task.queue_name) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -618,7 +618,7 @@ def test_process_all_queues(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker(queue_name="*") self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -627,7 +627,7 @@ def test_failing_task(self) -> None: result = test_tasks.failing_task_value_error.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -656,7 +656,7 @@ def test_complex_exception(self) -> None: result = test_tasks.complex_exception.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -706,7 +706,7 @@ def test_doesnt_process_different_backend(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker(backend_name=result.backend) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -805,7 +805,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) From 88a337ac48d5f9727ef9f26b3104d1bda8e09941 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 17:50:20 +0100 Subject: [PATCH 8/8] Refactor test_repeat_ctrl_c to improve signal handling and process termination logic --- tests/tests/test_database_backend.py | 42 ++++++++-------------------- 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/tests/tests/test_database_backend.py b/tests/tests/test_database_backend.py index 861bb27..9367270 100644 --- a/tests/tests/test_database_backend.py +++ b/tests/tests/test_database_backend.py @@ -1576,38 +1576,20 @@ def test_interrupt_signals(self) -> None: @skipIf(sys.platform == "win32", "Cannot emulate CTRL-C on Windows") def test_repeat_ctrl_c(self) -> None: - result = test_tasks.hang.enqueue() - self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, []) - - worker_id = get_random_id() - - process = self.start_worker(worker_id=worker_id) - - # Make sure the task is running by now - time.sleep(self.WORKER_STARTUP_TIME) - - result.refresh() - self.assertEqual(result.status, TaskResultStatus.RUNNING) - self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id]) - - process.send_signal(signal.SIGINT) - - time.sleep(0.5) - - self.assertIsNone(process.poll()) - result.refresh() - self.assertEqual(result.status, TaskResultStatus.RUNNING) - self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id]) - - process.send_signal(signal.SIGINT) - - process.wait(timeout=2) + process = self.start_worker() - self.assertEqual(process.returncode, 0) + try: + process.send_signal(signal.SIGINT) + time.sleep(1) - result.refresh() - self.assertEqual(result.status, TaskResultStatus.FAILED) - self.assertEqual(result.errors[0].exception_class, SystemExit) + # Send a second interrupt signal to force termination + process.send_signal(signal.SIGINT) + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.terminate() + process.wait(timeout=5) + finally: + self.assertEqual(process.poll(), -2) @skipIf(sys.platform == "win32", "Windows doesn't support SIGKILL") def test_kill(self) -> None: