Skip to content

Commit e71329b

Browse files
committed
vulkan: sort graph to allow more parallel execution
Add a backend proc to allow the backend to modify the graph. The vulkan implementation looks at which nodes depend on each other and greedily reorders them to group together nodes that don't depend on each other. It only reorders the nodes, doesn't change the contents of any of them. With ggml-org#15489, this reduces the number of synchronizations needed.
1 parent c4df49a commit e71329b

File tree

3 files changed

+152
-0
lines changed

3 files changed

+152
-0
lines changed

ggml/src/ggml-backend-impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ extern "C" {
114114
void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);
115115
// wait for an event on on a different stream
116116
void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
117+
118+
// (optional) sort/optimize the nodes in the graph
119+
void (*optimize_graph) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
117120
};
118121

119122
struct ggml_backend {

ggml/src/ggml-backend.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,13 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
463463
backend->iface.event_wait(backend, event);
464464
}
465465

466+
void ggml_backend_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
467+
GGML_ASSERT(backend);
468+
if (backend->iface.optimize_graph != NULL) {
469+
backend->iface.optimize_graph(backend, cgraph);
470+
}
471+
}
472+
466473
// Backend device
467474

468475
const char * ggml_backend_dev_name(ggml_backend_dev_t device) {
@@ -1702,6 +1709,16 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
17021709
return true;
17031710
}
17041711

1712+
void ggml_backend_sched_optimize_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
1713+
GGML_ASSERT(sched);
1714+
// Run through each backend, before splitting, giving a chance to optimize.
1715+
// Would be better to have each backend optimize its own split, but sched->graph
1716+
// gets out of sync.
1717+
for (int i = 0; i < sched->n_backends; i++) {
1718+
ggml_backend_optimize_graph(sched->backends[i], graph);
1719+
}
1720+
}
1721+
17051722
bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
17061723
GGML_ASSERT(sched);
17071724
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
@@ -1710,6 +1727,8 @@ bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgra
17101727
sched->cur_copy = sched->next_copy;
17111728
sched->next_copy = (sched->next_copy + 1) % sched->n_copies;
17121729

1730+
ggml_backend_sched_optimize_graph(sched, graph);
1731+
17131732
ggml_backend_sched_split_graph(sched, graph);
17141733

17151734
if (!ggml_backend_sched_alloc_splits(sched)) {

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ struct vk_device_struct {
582582
bool disable_fusion;
583583
bool disable_host_visible_vidmem;
584584
bool allow_sysmem_fallback;
585+
bool disable_optimize_graph;
585586

586587
#ifdef GGML_VULKAN_MEMORY_DEBUG
587588
std::unique_ptr<vk_memory_logger> memory_logger;
@@ -3502,6 +3503,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
35023503
const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK");
35033504
device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
35043505

3506+
const char* GGML_VK_DISABLE_OPTIMIZE_GRAPH = getenv("GGML_VK_DISABLE_OPTIMIZE_GRAPH");
3507+
device->disable_optimize_graph = GGML_VK_DISABLE_OPTIMIZE_GRAPH != nullptr;
3508+
35053509
bool fp16_storage = false;
35063510
bool fp16_compute = false;
35073511
bool maintenance4_support = false;
@@ -11633,6 +11637,131 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1163311637
UNUSED(backend);
1163411638
}
1163511639

11640+
// Sort the graph for improved parallelism.
11641+
static void ggml_vk_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * graph)
11642+
{
11643+
VK_LOG_DEBUG("ggml_vk_optimize_graph(" << graph->n_nodes << " nodes)");
11644+
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
11645+
11646+
if (ctx->device->disable_optimize_graph) {
11647+
return;
11648+
}
11649+
11650+
auto const &is_empty = [](ggml_tensor * node) -> bool {
11651+
return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
11652+
};
11653+
11654+
auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool {
11655+
for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
11656+
if (dst->src[s] == src) {
11657+
return true;
11658+
}
11659+
}
11660+
// implicit dependency if they view the same tensor
11661+
const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst;
11662+
const ggml_tensor *src2 = src->view_src ? src->view_src : src;
11663+
if (dst2 == src2) {
11664+
return true;
11665+
}
11666+
return false;
11667+
};
11668+
11669+
// This function tries to reorder the graph to allow nodes to run in parallel.
11670+
// This helps with small batches, but for large batches its a slowdown, probably
11671+
// due to cache contention. So only reorder if the majority of nodes have few rows.
11672+
int num_small_nodes = 0;
11673+
int num_counted_nodes = 0;
11674+
for (int i = 0; i < graph->n_nodes; ++i) {
11675+
if (!is_empty(graph->nodes[i]) &&
11676+
graph->nodes[i]->op != GGML_OP_SET_ROWS) {
11677+
if (ggml_nrows(graph->nodes[i]) <= 8) {
11678+
num_small_nodes++;
11679+
}
11680+
num_counted_nodes++;
11681+
}
11682+
}
11683+
if (num_small_nodes < num_counted_nodes / 2) {
11684+
return;
11685+
}
11686+
11687+
std::vector<ggml_tensor *> new_order;
11688+
std::vector<bool> used(graph->n_nodes, false);
11689+
int first_unused = 0;
11690+
while (first_unused < graph->n_nodes) {
11691+
std::vector<int> current_set;
11692+
11693+
// First, grab the next unused node.
11694+
current_set.push_back(first_unused);
11695+
11696+
// Loop through the next N nodes. Grab any that don't depend on other nodes that
11697+
// haven't already been run. Nodes that have already been run have used[i] set
11698+
// to true. Allow nodes that depend on the previous node if it's a fusion pattern
11699+
// that we support (e.g. RMS_NORM + MUL).
11700+
// This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes.
11701+
// The goal is to not interleave real and view nodes in a way that breaks fusion.
11702+
const int NUM_TO_CHECK = 20;
11703+
for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
11704+
if (used[j]) {
11705+
continue;
11706+
}
11707+
if (is_empty(graph->nodes[j])) {
11708+
continue;
11709+
}
11710+
bool ok = true;
11711+
for (int c = first_unused; c < j; ++c) {
11712+
if (!used[c] &&
11713+
is_src_of(graph->nodes[j], graph->nodes[c]) &&
11714+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL)) {
11715+
ok = false;
11716+
break;
11717+
}
11718+
}
11719+
if (ok) {
11720+
current_set.push_back(j);
11721+
}
11722+
}
11723+
// Second pass grabs view nodes.
11724+
// Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add).
11725+
if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) {
11726+
for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
11727+
if (used[j]) {
11728+
continue;
11729+
}
11730+
if (!is_empty(graph->nodes[j])) {
11731+
continue;
11732+
}
11733+
bool ok = true;
11734+
for (int c = first_unused; c < j; ++c) {
11735+
bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end();
11736+
// skip views whose srcs haven't been processed.
11737+
if (!used[c] &&
11738+
is_src_of(graph->nodes[j], graph->nodes[c]) &&
11739+
!c_in_current_set) {
11740+
ok = false;
11741+
break;
11742+
}
11743+
}
11744+
if (ok) {
11745+
current_set.push_back(j);
11746+
}
11747+
}
11748+
}
11749+
11750+
// Push the current set into new_order
11751+
for (auto c : current_set) {
11752+
new_order.push_back(graph->nodes[c]);
11753+
used[c] = true;
11754+
}
11755+
while (first_unused < graph->n_nodes && used[first_unused]) {
11756+
first_unused++;
11757+
}
11758+
}
11759+
// Replace the graph with the new order.
11760+
for (size_t i = 0; i < graph->n_nodes; ++i) {
11761+
graph->nodes[i] = new_order[i];
11762+
}
11763+
}
11764+
1163611765
// TODO: enable async and synchronize
1163711766
static ggml_backend_i ggml_backend_vk_interface = {
1163811767
/* .get_name = */ ggml_backend_vk_name,
@@ -11648,6 +11777,7 @@ static ggml_backend_i ggml_backend_vk_interface = {
1164811777
/* .graph_compute = */ ggml_backend_vk_graph_compute,
1164911778
/* .event_record = */ NULL,
1165011779
/* .event_wait = */ NULL,
11780+
/* .optimize_graph = */ ggml_vk_optimize_graph,
1165111781
};
1165211782

1165311783
static ggml_guid_t ggml_backend_vk_guid() {

0 commit comments

Comments
 (0)