@@ -582,6 +582,7 @@ struct vk_device_struct {
582582 bool disable_fusion;
583583 bool disable_host_visible_vidmem;
584584 bool allow_sysmem_fallback;
585+ bool disable_optimize_graph;
585586
586587#ifdef GGML_VULKAN_MEMORY_DEBUG
587588 std::unique_ptr<vk_memory_logger> memory_logger;
@@ -3502,6 +3503,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
35023503 const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK");
35033504 device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
35043505
3506+ const char* GGML_VK_DISABLE_OPTIMIZE_GRAPH = getenv("GGML_VK_DISABLE_OPTIMIZE_GRAPH");
3507+ device->disable_optimize_graph = GGML_VK_DISABLE_OPTIMIZE_GRAPH != nullptr;
3508+
35053509 bool fp16_storage = false;
35063510 bool fp16_compute = false;
35073511 bool maintenance4_support = false;
@@ -11633,6 +11637,131 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1163311637 UNUSED(backend);
1163411638}
1163511639
11640+ // Sort the graph for improved parallelism.
11641+ static void ggml_vk_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * graph)
11642+ {
11643+ VK_LOG_DEBUG("ggml_vk_optimize_graph(" << graph->n_nodes << " nodes)");
11644+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
11645+
11646+ if (ctx->device->disable_optimize_graph) {
11647+ return;
11648+ }
11649+
11650+ auto const &is_empty = [](ggml_tensor * node) -> bool {
11651+ return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
11652+ };
11653+
11654+ auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool {
11655+ for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
11656+ if (dst->src[s] == src) {
11657+ return true;
11658+ }
11659+ }
11660+ // implicit dependency if they view the same tensor
11661+ const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst;
11662+ const ggml_tensor *src2 = src->view_src ? src->view_src : src;
11663+ if (dst2 == src2) {
11664+ return true;
11665+ }
11666+ return false;
11667+ };
11668+
11669+ // This function tries to reorder the graph to allow nodes to run in parallel.
11670+ // This helps with small batches, but for large batches its a slowdown, probably
11671+ // due to cache contention. So only reorder if the majority of nodes have few rows.
11672+ int num_small_nodes = 0;
11673+ int num_counted_nodes = 0;
11674+ for (int i = 0; i < graph->n_nodes; ++i) {
11675+ if (!is_empty(graph->nodes[i]) &&
11676+ graph->nodes[i]->op != GGML_OP_SET_ROWS) {
11677+ if (ggml_nrows(graph->nodes[i]) <= 8) {
11678+ num_small_nodes++;
11679+ }
11680+ num_counted_nodes++;
11681+ }
11682+ }
11683+ if (num_small_nodes < num_counted_nodes / 2) {
11684+ return;
11685+ }
11686+
11687+ std::vector<ggml_tensor *> new_order;
11688+ std::vector<bool> used(graph->n_nodes, false);
11689+ int first_unused = 0;
11690+ while (first_unused < graph->n_nodes) {
11691+ std::vector<int> current_set;
11692+
11693+ // First, grab the next unused node.
11694+ current_set.push_back(first_unused);
11695+
11696+ // Loop through the next N nodes. Grab any that don't depend on other nodes that
11697+ // haven't already been run. Nodes that have already been run have used[i] set
11698+ // to true. Allow nodes that depend on the previous node if it's a fusion pattern
11699+ // that we support (e.g. RMS_NORM + MUL).
11700+ // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes.
11701+ // The goal is to not interleave real and view nodes in a way that breaks fusion.
11702+ const int NUM_TO_CHECK = 20;
11703+ for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
11704+ if (used[j]) {
11705+ continue;
11706+ }
11707+ if (is_empty(graph->nodes[j])) {
11708+ continue;
11709+ }
11710+ bool ok = true;
11711+ for (int c = first_unused; c < j; ++c) {
11712+ if (!used[c] &&
11713+ is_src_of(graph->nodes[j], graph->nodes[c]) &&
11714+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL)) {
11715+ ok = false;
11716+ break;
11717+ }
11718+ }
11719+ if (ok) {
11720+ current_set.push_back(j);
11721+ }
11722+ }
11723+ // Second pass grabs view nodes.
11724+ // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add).
11725+ if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) {
11726+ for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
11727+ if (used[j]) {
11728+ continue;
11729+ }
11730+ if (!is_empty(graph->nodes[j])) {
11731+ continue;
11732+ }
11733+ bool ok = true;
11734+ for (int c = first_unused; c < j; ++c) {
11735+ bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end();
11736+ // skip views whose srcs haven't been processed.
11737+ if (!used[c] &&
11738+ is_src_of(graph->nodes[j], graph->nodes[c]) &&
11739+ !c_in_current_set) {
11740+ ok = false;
11741+ break;
11742+ }
11743+ }
11744+ if (ok) {
11745+ current_set.push_back(j);
11746+ }
11747+ }
11748+ }
11749+
11750+ // Push the current set into new_order
11751+ for (auto c : current_set) {
11752+ new_order.push_back(graph->nodes[c]);
11753+ used[c] = true;
11754+ }
11755+ while (first_unused < graph->n_nodes && used[first_unused]) {
11756+ first_unused++;
11757+ }
11758+ }
11759+ // Replace the graph with the new order.
11760+ for (size_t i = 0; i < graph->n_nodes; ++i) {
11761+ graph->nodes[i] = new_order[i];
11762+ }
11763+ }
11764+
1163611765// TODO: enable async and synchronize
1163711766static ggml_backend_i ggml_backend_vk_interface = {
1163811767 /* .get_name = */ ggml_backend_vk_name,
@@ -11648,6 +11777,7 @@ static ggml_backend_i ggml_backend_vk_interface = {
1164811777 /* .graph_compute = */ ggml_backend_vk_graph_compute,
1164911778 /* .event_record = */ NULL,
1165011779 /* .event_wait = */ NULL,
11780+ /* .optimize_graph = */ ggml_vk_optimize_graph,
1165111781};
1165211782
1165311783static ggml_guid_t ggml_backend_vk_guid() {
0 commit comments