Skip to content

Commit f60c19d

Browse files
authored
[FEATURE][ML] Fix possible deadlock in thread pool shutdown (#343)
1 parent 939da41 commit f60c19d

File tree

5 files changed

+119
-35
lines changed

5 files changed

+119
-35
lines changed

include/core/CConcurrentQueue.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ class CConcurrentQueue final : private CNonCopyable {
5858
}
5959

6060
//! Pop an item out of the queue, this returns none if an item isn't available
61-
TOptional tryPop() {
61+
//! or the pop isn't allowed
62+
template<typename PREDICATE>
63+
TOptional tryPop(PREDICATE allowed) {
6264
std::unique_lock<std::mutex> lock(m_Mutex);
63-
if (m_Queue.empty()) {
65+
if (m_Queue.empty() || allowed(m_Queue.front()) == false) {
6466
return boost::none;
6567
}
6668

@@ -72,6 +74,9 @@ class CConcurrentQueue final : private CNonCopyable {
7274
return result;
7375
}
7476

77+
//! Pop an item out of the queue, this returns none if an item isn't available
78+
TOptional tryPop() { return this->tryPop(always); }
79+
7580
//! Push a copy of \p item onto the queue, this blocks if the queue is full which
7681
//! means it can deadlock if no one consumes items (implementor's responsibility)
7782
void push(const T& item) {
@@ -150,6 +155,8 @@ class CConcurrentQueue final : private CNonCopyable {
150155
}
151156
}
152157

158+
static bool always(const T&) { return true; }
159+
153160
private:
154161
//! The internal queue
155162
boost::circular_buffer<T> m_Queue;

include/core/CStaticThreadPool.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class CORE_EXPORT CStaticThreadPool {
4949
//! and is suitable for our use case where we don't need to guaranty that this
5050
//! always returns immediately and instead want to exert back pressure on the
5151
//! thread scheduling tasks if the pool can't keep up.
52-
void schedule(TTask&& task);
52+
void schedule(std::packaged_task<boost::any()>&& task);
5353

5454
//! Executes the specified function in the thread pool.
5555
void schedule(std::function<void()>&& f);
@@ -61,14 +61,27 @@ class CORE_EXPORT CStaticThreadPool {
6161
void busy(bool busy);
6262

6363
private:
64-
using TOptionalTask = boost::optional<TTask>;
65-
using TTaskQueue = CConcurrentQueue<TTask, 50>;
66-
using TTaskQueueVec = std::vector<TTaskQueue>;
64+
using TOptionalSize = boost::optional<std::size_t>;
65+
class CWrappedTask {
66+
public:
67+
explicit CWrappedTask(TTask&& task, TOptionalSize threadId = boost::none);
68+
69+
bool executableOnThread(std::size_t id) const;
70+
void operator()();
71+
72+
private:
73+
TTask m_Task;
74+
TOptionalSize m_ThreadId;
75+
};
76+
using TOptionalTask = boost::optional<CWrappedTask>;
77+
using TWrappedTaskQueue = CConcurrentQueue<CWrappedTask, 50>;
78+
using TWrappedTaskQueueVec = std::vector<TWrappedTaskQueue>;
6779
using TThreadVec = std::vector<std::thread>;
6880

6981
private:
7082
void shutdown();
7183
void worker(std::size_t id);
84+
void drainQueuesWithoutBlocking();
7285

7386
private:
7487
// This doesn't have to be atomic because it is always only set to true,
@@ -77,7 +90,7 @@ class CORE_EXPORT CStaticThreadPool {
7790
bool m_Done = false;
7891
std::atomic_bool m_Busy;
7992
std::atomic<std::uint64_t> m_Cursor;
80-
TTaskQueueVec m_TaskQueues;
93+
TWrappedTaskQueueVec m_TaskQueues;
8194
TThreadVec m_Pool;
8295
};
8396
}

lib/core/CStaticThreadPool.cc

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
#include <core/CStaticThreadPool.h>
88

9+
#include <chrono>
10+
911
namespace ml {
1012
namespace core {
1113
namespace {
@@ -21,7 +23,7 @@ CStaticThreadPool::CStaticThreadPool(std::size_t size)
2123
m_Pool.reserve(m_TaskQueues.size());
2224
for (std::size_t id = 0; id < m_TaskQueues.size(); ++id) {
2325
try {
24-
m_Pool.emplace_back([&, id] { worker(id); });
26+
m_Pool.emplace_back([this, id] { this->worker(id); });
2527
} catch (...) {
2628
this->shutdown();
2729
throw;
@@ -33,18 +35,19 @@ CStaticThreadPool::~CStaticThreadPool() {
3335
this->shutdown();
3436
}
3537

36-
void CStaticThreadPool::schedule(TTask&& task) {
38+
void CStaticThreadPool::schedule(TTask&& task_) {
3739
// Only block if every queue is full.
3840
std::size_t size{m_TaskQueues.size()};
3941
std::size_t i{m_Cursor.load()};
4042
std::size_t end{i + size};
43+
CWrappedTask task{std::forward<TTask>(task_)};
4144
for (/**/; i < end; ++i) {
42-
if (m_TaskQueues[i % size].tryPush(std::forward<TTask>(task))) {
45+
if (m_TaskQueues[i % size].tryPush(std::move(task))) {
4346
break;
4447
}
4548
}
4649
if (i == end) {
47-
m_TaskQueues[i % size].push(std::forward<TTask>(task));
50+
m_TaskQueues[i % size].push(std::move(task));
4851
}
4952
m_Cursor.store(i + 1);
5053
}
@@ -65,33 +68,38 @@ void CStaticThreadPool::busy(bool value) {
6568
}
6669

6770
void CStaticThreadPool::shutdown() {
68-
// Signal to each thread that it is finished.
69-
for (auto& queue : m_TaskQueues) {
70-
queue.push(TTask{[&] {
71+
72+
// Drain the queues before starting to shut down in order to maximise throughput.
73+
this->drainQueuesWithoutBlocking();
74+
75+
// Signal to each thread that it is finished. We bind each task to a thread so
76+
// so each thread executes exactly one shutdown task.
77+
for (std::size_t id = 0; id < m_TaskQueues.size(); ++id) {
78+
TTask done{[&] {
7179
m_Done = true;
7280
return boost::any{};
73-
}});
81+
}};
82+
m_TaskQueues[id].push(CWrappedTask{std::move(done), id});
7483
}
84+
7585
for (auto& thread : m_Pool) {
7686
if (thread.joinable()) {
7787
thread.join();
7888
}
7989
}
90+
8091
m_TaskQueues.clear();
8192
m_Pool.clear();
8293
}
8394

8495
void CStaticThreadPool::worker(std::size_t id) {
8596

86-
auto noThrowExecute = [](TOptionalTask& task) {
87-
try {
88-
(*task)();
89-
} catch (const std::future_error& e) {
90-
LOG_ERROR(<< "Failed executing packaged task: '" << e.code() << "' "
91-
<< "with error '" << e.what() << "'");
92-
}
97+
auto ifAllowed = [id](const CWrappedTask& task) {
98+
return task.executableOnThread(id);
9399
};
94100

101+
TOptionalTask task;
102+
95103
while (m_Done == false) {
96104
// We maintain "worker count" queues and each worker has an affinity to a
97105
// different queue. We don't immediately block if the worker's "queue" is
@@ -101,9 +109,8 @@ void CStaticThreadPool::worker(std::size_t id) {
101109
// workers on queue reads.
102110

103111
std::size_t size{m_TaskQueues.size()};
104-
TOptionalTask task;
105112
for (std::size_t i = 0; i < size; ++i) {
106-
task = m_TaskQueues[(id + i) % size].tryPop();
113+
task = m_TaskQueues[(id + i) % size].tryPop(ifAllowed);
107114
if (task != boost::none) {
108115
break;
109116
}
@@ -112,12 +119,48 @@ void CStaticThreadPool::worker(std::size_t id) {
112119
task = m_TaskQueues[id].pop();
113120
}
114121

115-
noThrowExecute(task);
122+
(*task)();
123+
124+
// In the typical situation that the thread(s) adding tasks to the queues can
125+
// do this much faster than the threads consuming them, all queues will be full
126+
// and the producer(s) will be waiting to add a task as each one is consumed.
127+
// By switching to work on a new queue here we minimise contention between the
128+
// producers and consumers. Testing on bare metal (OSX) the overhead per task
129+
// dropped from around 2.2 microseconds to 1.5 microseconds by yielding here.
130+
std::this_thread::yield();
131+
}
132+
}
133+
134+
void CStaticThreadPool::drainQueuesWithoutBlocking() {
135+
TOptionalTask task;
136+
auto popTask = [&] {
137+
for (auto& queue : m_TaskQueues) {
138+
task = queue.tryPop();
139+
if (task != boost::none) {
140+
(*task)();
141+
return true;
142+
}
143+
}
144+
return false;
145+
};
146+
while (popTask()) {
116147
}
148+
}
149+
150+
CStaticThreadPool::CWrappedTask::CWrappedTask(TTask&& task, TOptionalSize threadId)
151+
: m_Task{std::forward<TTask>(task)}, m_ThreadId{threadId} {
152+
}
153+
154+
bool CStaticThreadPool::CWrappedTask::executableOnThread(std::size_t id) const {
155+
return m_ThreadId == boost::none || *m_ThreadId == id;
156+
}
117157

118-
// Drain this thread's queue before exiting.
119-
for (auto task = m_TaskQueues[id].tryPop(); task; task = m_TaskQueues[id].tryPop()) {
120-
noThrowExecute(task);
158+
void CStaticThreadPool::CWrappedTask::operator()() {
159+
try {
160+
m_Task();
161+
} catch (const std::future_error& e) {
162+
LOG_ERROR(<< "Failed executing packaged task: '" << e.code() << "' "
163+
<< "with error '" << e.what() << "'");
121164
}
122165
}
123166
}

lib/core/unittest/CStaticThreadPoolTest.cc

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,10 @@ void CStaticThreadPoolTest::testThroughputStability() {
105105

106106
CPPUNIT_ASSERT_EQUAL(2000u, counter.load());
107107

108-
// The best we can achieve is 2000ms ignoring all overheads. In fact, there will
109-
// be imbalance in the queues when the pool shuts down which is then performed
110-
// single threaded. Also there are other overheads.
108+
// The best we can achieve is 2000ms ignoring all overheads.
111109
std::uint64_t totalTime{totalTimeWatch.stop()};
112110
LOG_DEBUG(<< "Total time = " << totalTime);
113-
//CPPUNIT_ASSERT(totalTime <= 2600);
111+
//CPPUNIT_ASSERT(totalTime <= 2400);
114112
}
115113

116114
void CStaticThreadPoolTest::testManyTasksThroughput() {
@@ -149,7 +147,26 @@ void CStaticThreadPoolTest::testManyTasksThroughput() {
149147
//CPPUNIT_ASSERT(totalTime <= 780);
150148
}
151149

152-
void CStaticThreadPoolTest::testExceptions() {
150+
void CStaticThreadPoolTest::testSchedulingOverhead() {
151+
152+
// Test the overhead per task is less than 1.6 microseconds.
153+
154+
core::CStaticThreadPool pool{4};
155+
156+
core::CStopWatch watch{true};
157+
for (std::size_t i = 0; i < 1000000; ++i) {
158+
if (i % 100000 == 0) {
159+
LOG_DEBUG(<< i);
160+
}
161+
pool.schedule([]() {});
162+
}
163+
164+
double overhead{static_cast<double>(watch.stop()) / 1000.0};
165+
LOG_DEBUG(<< "Total time = " << overhead);
166+
//CPPUNIT_ASSERT(overhead < 1.6);
167+
}
168+
169+
void CStaticThreadPoolTest::testWithExceptions() {
153170

154171
// Check we don't deadlock we don't kill worker threads if we do stupid things.
155172

@@ -184,7 +201,10 @@ CppUnit::Test* CStaticThreadPoolTest::suite() {
184201
"CStaticThreadPoolTest::testManyTasksThroughput",
185202
&CStaticThreadPoolTest::testManyTasksThroughput));
186203
suiteOfTests->addTest(new CppUnit::TestCaller<CStaticThreadPoolTest>(
187-
"CStaticThreadPoolTest::testExceptions", &CStaticThreadPoolTest::testExceptions));
204+
"CStaticThreadPoolTest::testSchedulingOverhead",
205+
&CStaticThreadPoolTest::testSchedulingOverhead));
206+
suiteOfTests->addTest(new CppUnit::TestCaller<CStaticThreadPoolTest>(
207+
"CStaticThreadPoolTest::testWithExceptions", &CStaticThreadPoolTest::testWithExceptions));
188208

189209
return suiteOfTests;
190210
}

lib/core/unittest/CStaticThreadPoolTest.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ class CStaticThreadPoolTest : public CppUnit::TestFixture {
1414
void testScheduleDelayMinimisation();
1515
void testThroughputStability();
1616
void testManyTasksThroughput();
17-
void testExceptions();
17+
void testSchedulingOverhead();
18+
void testWithExceptions();
1819

1920
static CppUnit::Test* suite();
2021
};

0 commit comments

Comments
 (0)