Skip to content

Commit f9a2f6a

Browse files
fix accuracy bug when enabling multi-stream. (#191)
* fix thread pool of multi-stream * fix race of memory * update way of initailizing thread pool
1 parent 59232b9 commit f9a2f6a

File tree

4 files changed

+27
-16
lines changed

4 files changed

+27
-16
lines changed

nlp_toolkit/backends/neural_engine/executor/include/memory_allocator.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ class MemoryAllocator {
193193
}
194194

195195
static void* GetMemory(size_t size, const int life_count) {
196+
static std::mutex getmem_lock;
197+
std::lock_guard<std::mutex> lock(getmem_lock);
196198
if (size == 0) {
197199
LOG(INFO) << "please set the tensor size...";
198200
return nullptr;

nlp_toolkit/backends/neural_engine/executor/include/operator.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <algorithm>
1919
#include <string>
2020
#include <vector>
21+
#include <mutex>
2122

2223
#include "common.hpp"
2324
#include "operator_registry.hpp"
@@ -60,6 +61,8 @@ class Operator {
6061
const vector<Tensor*>& output) = 0;
6162

6263
inline void unref_tensors(const vector<Tensor*>& input) {
64+
static std::mutex unref_lock;
65+
std::lock_guard<std::mutex> lock(unref_lock);
6366
for (size_t i = 0; i < input.size(); ++i) {
6467
auto status = input[i]->unref_data();
6568
// (TODO) maybe check the tensors

nlp_toolkit/backends/neural_engine/executor/include/thread_pool.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class ThreadPool {
4343
std::queue<Task> tasks;
4444
// sync of multi stream
4545
std::mutex tasks_lock;
46-
std::condition_variable task_cond_var;
46+
std::condition_variable task_cond_var, cond_finish_;
4747
std::atomic<unsigned int> idle_thread_num;
4848
std::atomic<unsigned int> work_thread_num;
4949
// is thread pool stoped
@@ -54,7 +54,7 @@ class ThreadPool {
5454
unsigned int index = pool.size();
5555
stoped.emplace_back(false);
5656
idle_thread_num++;
57-
pool.emplace_back([this, index] {
57+
pool.emplace_back(std::thread([this, index] {
5858
while (true) {
5959
// capature the task
6060
std::function<void()> task;
@@ -67,28 +67,33 @@ class ThreadPool {
6767
idle_thread_num--;
6868
return;
6969
}
70+
idle_thread_num--; work_thread_num++;
7071
task = std::move(this->tasks.front());
7172
this->tasks.pop();
7273
}
7374

7475
{
75-
idle_thread_num--, work_thread_num++;
7676
task(); // run the task
77-
idle_thread_num++, work_thread_num--;
77+
std::unique_lock<std::mutex> lock(this->tasks_lock);
78+
idle_thread_num++; work_thread_num--;
79+
cond_finish_.notify_one();
7880
}
7981
}
80-
});
82+
}));
8183
}
8284

8385
public:
8486
inline ThreadPool() {
8587
work_thread_num = 0;
88+
idle_thread_num = 0;
8689
pool_stoped = false;
8790
}
8891
// wait for all threads to finish and stop all threads
8992
inline ~ThreadPool() {
93+
std::unique_lock<std::mutex> lock(tasks_lock);
9094
for (auto item : stoped) item = true;
9195
task_cond_var.notify_all(); // wake up all thread to run
96+
lock.unlock();
9297
for (auto& th : pool) {
9398
if (th.joinable()) th.join();
9499
}
@@ -161,13 +166,8 @@ class ThreadPool {
161166
bool hasStopedPool() { return pool_stoped; }
162167
// wait for all tasks to be completed
163168
void waitAllTaskRunOver() {
164-
while (true) {
165-
if (work_thread_num == 0) {
166-
return;
167-
} else {
168-
std::this_thread::yield();
169-
}
170-
}
169+
std::unique_lock<std::mutex> lock(tasks_lock);
170+
cond_finish_.wait(lock, [this]{ return tasks.empty() && (work_thread_num == 0); });
171171
}
172172
};
173173

nlp_toolkit/backends/neural_engine/executor/src/model.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,16 @@ void Model::Init(const ModelConfig& conf) {
9393
multi_stream_tasks_.insert({i, StringToNum<int64_t>(it->second)});
9494
}
9595
}
96+
auto max_tasks = std::max_element(multi_stream_tasks_.begin(), multi_stream_tasks_.end(),
97+
[] (const std::pair<int, int64_t>& a, const std::pair<int, int64_t>& b)
98+
->bool{ return a.second < b.second; } );
99+
int tp_max_threads = max_tasks->second + (max_tasks->second & 1);
100+
int total_available_threads = omp_get_num_procs();
101+
tp_max_threads = tp_max_threads > total_available_threads ?
102+
total_available_threads : tp_max_threads;
103+
tp.begin(tp_max_threads);
104+
LOG(INFO) << "Thread pool is initialized with " << tp_max_threads << " threads. (" <<
105+
"Total avaiable threads: " << total_available_threads << ")";
96106
}
97107

98108
engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL); // profiling env
@@ -280,15 +290,13 @@ vector<Tensor>& Model::Forward(vector<Tensor>& input_data) {
280290
for (int i = 0; i < operators_.size(); ++i) {
281291
LOG(INFO) << "operator " << operators_[i]->name() << " gonna forward with type " << operators_[i]->type();
282292
if (multi_stream_flag && multi_stream_tasks_.find(i) != multi_stream_tasks_.end()) {
283-
tp.resize(thread_count);
284293
float start = Time("start");
285294
tp.commitTask(std::bind(&executor::Dispatcher::Forward, operators_[i], input_vecs_[i], output_vecs_[i]));
286295
float end = Time("end");
287296
operators_[i]->set_latency(end - start);
288297
LOG(INFO) << "operator: " << operators_[i]->name() << ", latency: " << end - start << " ms";
289298
if (thread_count >= multi_stream_tasks_[i]) {
290299
tp.waitAllTaskRunOver();
291-
tp.close();
292300
thread_count = 0;
293301
}
294302
thread_count++;
@@ -304,11 +312,9 @@ vector<Tensor>& Model::Forward(vector<Tensor>& input_data) {
304312
for (int i = 0; i < operators_.size(); ++i) {
305313
LOG(INFO) << "operator " << operators_[i]->name() << " gonna forward with type " << operators_[i]->type();
306314
if (multi_stream_flag && multi_stream_tasks_.find(i) != multi_stream_tasks_.end()) {
307-
tp.resize(thread_count);
308315
tp.commitTask(std::bind(&executor::Dispatcher::Forward, operators_[i], input_vecs_[i], output_vecs_[i]));
309316
if (thread_count >= multi_stream_tasks_[i]) {
310317
tp.waitAllTaskRunOver();
311-
tp.close();
312318
thread_count = 0;
313319
}
314320
thread_count++;

0 commit comments

Comments
 (0)