Skip to content

Commit d16b14c

Browse files
authored
Fix global threads count var update (#923)
* Update global variable indicating the number of background threads from main thread to ensure synchronization upon creating onnx global structure
1 parent 6ae0bdf commit d16b14c

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

src/execution/background_workers.c

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@ int pthread_setname_np(const char *name);
2121
#endif
2222
#endif
2323

24-
uintptr_t BGWorkersCounter; // Total number of BG threads running currently.
25-
pthread_key_t ThreadIdKey; // Key to hold thread id in its local storage.
24+
uintptr_t LastThreadId; // Last number given as thread id for BG threads running currently.
25+
pthread_key_t ThreadIdKey; // Key to hold thread id in its local storage.
26+
unsigned int BGWorkersCount; // Total number of BG threads spawned.
2627

2728
/**
2829
* @brief Save the id for some working thread in thread local storage.
2930
*/
3031
static void _BGWorker_SaveThreadId() {
3132
// Let the current thread have the next available id, and increase the counter.
32-
long id_value = __atomic_add_fetch(&BGWorkersCounter, 1, __ATOMIC_RELAXED);
33+
long id_value = __atomic_add_fetch(&LastThreadId, 1, __ATOMIC_RELAXED);
3334
// Convert the id value to a pointer and store it the thread local storage.
3435
// First id is 1, so we won't confuse with NULL (which is the error return value)
3536
pthread_setspecific(ThreadIdKey, (const void *)id_value);
@@ -291,7 +292,7 @@ long BGWorker_GetThreadId() {
291292
return (long)(thread_id)-1;
292293
}
293294

294-
uintptr_t BGWorker_GetThreadsCount() { return BGWorkersCounter; }
295+
uintptr_t BGWorker_GetThreadsCount() { return BGWorkersCount; }
295296

296297
void *BGWorker_ThreadMain(void *arg) {
297298
_BGWorker_SaveThreadId();

src/execution/run_queue_info.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "backends/backends.h"
44
#include "background_workers.h"
55

6+
extern unsigned int BGWorkersCount;
7+
68
RunQueueInfo *RunQueue_Create(const char *device_str) {
79

810
size_t device_str_len = strlen(device_str);
@@ -22,7 +24,7 @@ RunQueueInfo *RunQueue_Create(const char *device_str) {
2224
return NULL;
2325
}
2426

25-
// Create worker threads.
27+
// Create worker threads, update the global counter.
2628
for (int i = 0; i < Config_GetNumThreadsPerQueue(); i++) {
2729
pthread_t thread;
2830
if (pthread_create(&thread, NULL, BGWorker_ThreadMain, run_queue_info) != 0) {
@@ -32,6 +34,7 @@ RunQueueInfo *RunQueue_Create(const char *device_str) {
3234
}
3335
run_queue_info->threads = array_append(run_queue_info->threads, thread);
3436
}
37+
BGWorkersCount += Config_GetNumThreadsPerQueue();
3538

3639
// Add the new device worker threads to onnx run sessions tracking.
3740
if (RAI_backends.onnx.add_new_device_cb) {

tests/flow/tests_onnx.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import os
44
import subprocess
5+
import psutil
56
import redis
67
from includes import *
78
from RLTest import Env
@@ -554,6 +555,45 @@ def test_multiple_devices(self):
554555
self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'],
555556
str(len(devices)*self.threads_per_queue))
556557

558+
# Stress test to validate that we have no race condition between the creation of the onnx global array (from
559+
# the main threads) that contains an entry for every worker thread, and the background thread that runs the
560+
# session and access this global array.
561+
def test_synchronization(self):
562+
if self.env.isCluster() or self.env.useSlaves or VALGRIND == 1:
563+
self.env.debugPrint("skipping {} on cluster/slaves/valgrind modes".format(sys._getframe().f_code.co_name), force=True)
564+
return
565+
566+
model_pb = load_file_content('mul_1.onnx')
567+
568+
def launch_redis_and_run_onnx(con, proc_id, pipes):
569+
my_pipe = pipes[proc_id]
570+
port = 6380 + proc_id # Let every subprocess run on a fresh port.
571+
redis_server = subprocess.Popen(['redis-server', '--port', str(port),
572+
'--loadmodule', f'{ROOT}/install-{DEVICE.lower()}/redisai.so',
573+
'--logfile', f'{self.env.logDir}/test_onnx_kill_switch_synchronization-{port}.log',
574+
'--dir', f'{self.env.logDir}',
575+
'--dbfilename', f'test_onnx_kill_switch_synchronization-{port}.rdb'])
576+
# Wait until redis-server is up and ready to accept connections.
577+
while len([c for c in psutil.net_connections("tcp")
578+
if c.pid == redis_server.pid and c.laddr.port == port]) == 0:
579+
time.sleep(1)
580+
# Create a connection to Redis that immediately loads and execute onnx model. This is for testing that
581+
# there was a proper synchronization - otherwise, execution might cause a server crash.
582+
r = redis.Redis(host='localhost', port=port)
583+
r.flushall()
584+
r.execute_command('AI.MODELSTORE', 'mul{1}', 'ONNX', 'CPU', 'BLOB', model_pb)
585+
r.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 3, 2, 'VALUES', 1.0, 2.0, 3.0, 4.0, 5.0, 6.0)
586+
r.execute_command('AI.MODELEXECUTE', 'mul{1}', 'INPUTS', 1, 'a{1}', 'OUTPUTS', 1, 'b{1}')
587+
my_pipe.send(1) # To indicate that the flow was executed with success.
588+
redis_server.kill()
589+
590+
num_parallel_clients = 50
591+
parent_end_pipes, children_end_pipes = get_parent_children_pipes(num_parallel_clients)
592+
run_test_multiproc(self.env, '{1}', num_parallel_clients, launch_redis_and_run_onnx,
593+
args=(children_end_pipes, ))
594+
# Assert that all sub-processes have finished successfully.
595+
self.env.assertEqual(sum([p.recv() for p in parent_end_pipes]), num_parallel_clients)
596+
557597

558598
def test_forbidden_external_initializers(env):
559599
if not TEST_ONNX:

0 commit comments

Comments
 (0)