diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index f1b740785914e..48ab40e678073 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -93,8 +93,11 @@ extern "C" { GGML_API void ggml_backend_synchronize(ggml_backend_t backend); + GGML_API bool ggml_backend_supports_graph_plan(ggml_backend_t backend); + GGML_API bool ggml_backend_supports_graph_plan_update(ggml_backend_t backend); GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph); GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan); + GGML_API void ggml_backend_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph); GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan); GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index ff9135fe2d878..718a688454c9a 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -327,6 +327,18 @@ void ggml_backend_synchronize(ggml_backend_t backend) { backend->iface.synchronize(backend); } +bool ggml_backend_supports_graph_plan(ggml_backend_t backend) { + GGML_ASSERT(backend); + + return (bool) backend->iface.graph_plan_create; +} + +bool ggml_backend_supports_graph_plan_update(ggml_backend_t backend) { + GGML_ASSERT(backend); + + return (bool) backend->iface.graph_plan_update; +} + ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { GGML_ASSERT(backend); GGML_ASSERT(backend->iface.graph_plan_create != NULL); @@ -341,6 +353,13 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla backend->iface.graph_plan_free(backend, plan); } +void ggml_backend_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph* cgraph) { + GGML_ASSERT(backend); + GGML_ASSERT(backend->iface.graph_plan_update != NULL); + + backend->iface.graph_plan_update(backend, plan, cgraph); +} + enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { GGML_ASSERT(backend); GGML_ASSERT(backend->iface.graph_plan_compute != NULL); @@ -675,6 +694,11 @@ struct ggml_backend_sched_split { struct ggml_cgraph graph; }; +struct ggml_backend_sched_plan { + int backend_id; + ggml_backend_graph_plan_t plan; +}; + struct ggml_backend_sched { bool is_reset; // true if the scheduler has been reset since the last graph split bool is_alloc; @@ -704,6 +728,12 @@ struct ggml_backend_sched { int n_splits; int splits_capacity; + // graph plans + struct ggml_backend_sched_plan * plans; + int n_plans; + int plans_capacity; + bool plan_needs_update; + // pipeline parallelism support int n_copies; int cur_copy; @@ -908,6 +938,16 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru } } +static void ggml_backend_sched_free_plans(ggml_backend_sched_t sched) { + for (int i = 0; i < sched->n_plans; i++) { + ggml_backend_t backend = sched->backends[sched->plans[i].backend_id]; + if (ggml_backend_supports_graph_plan(backend)) { + ggml_backend_graph_plan_free(backend, sched->plans[i].plan); + } + } + sched->n_plans = 0; +} + // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { // reset splits @@ -1372,6 +1412,7 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra assert(graph_copy->size > graph_copy->n_leafs); graph_copy->leafs[graph_copy->n_leafs++] = leaf; } + sched->plan_needs_update = true; } static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { @@ -1413,6 +1454,62 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { return true; } +static void ggml_backend_sched_update_plans(ggml_backend_sched_t sched) { + // create graph plans + if (sched->plan_needs_update) { + bool create_new_plans; + if (sched->n_plans == sched->n_splits) { + create_new_plans = false; + for (int i = 0; i < sched->n_splits; i++) { + if (sched->splits[i].backend_id != sched->plans[i].backend_id) { + create_new_plans = true; + break; + } + } + } else { + create_new_plans = true; + } + if (create_new_plans) { + // free previous and recreate new plans + ggml_backend_sched_free_plans(sched); + if (sched->plans_capacity < sched->n_splits) { + while (sched->plans_capacity < sched->n_splits) { + sched->plans_capacity *= 2; + } + sched->plans = (ggml_backend_sched_plan *) realloc( + sched->plans, sched->plans_capacity * sizeof(struct ggml_backend_sched_plan)); + GGML_ASSERT(sched->plans); + } + sched->n_plans = sched->n_splits; + for (int i = 0; i < sched->n_splits; i++) { + ggml_backend_t backend = sched->backends[sched->splits[i].backend_id]; + sched->plans[i].backend_id = sched->splits[i].backend_id; + if (ggml_backend_supports_graph_plan(backend)) { + sched->plans[i].plan = ggml_backend_graph_plan_create(backend, &sched->splits[i].graph); + } else { + sched->plans[i].plan = nullptr; + } + } + } else { + // update existing plans + for (int i = 0; i < sched->n_splits; i++) { + ggml_backend_t backend = sched->backends[sched->splits[i].backend_id]; + if (ggml_backend_supports_graph_plan(backend)) { + if (ggml_backend_supports_graph_plan_update(backend)) { + ggml_backend_graph_plan_update(backend, sched->plans[i].plan, &sched->splits[i].graph); + } else { + ggml_backend_graph_plan_free(backend, sched->plans[i].plan); + sched->plans[i].plan = ggml_backend_graph_plan_create(backend, &sched->splits[i].graph); + } + } else { + sched->plans[i].plan = nullptr; + } + } + } + sched->plan_needs_update = false; + } +} + static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) { GGML_ASSERT(sched); struct ggml_backend_sched_split * splits = sched->splits; @@ -1421,6 +1518,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s std::vector ids; std::vector used_ids; + ggml_backend_sched_update_plans(sched); + for (int split_id = 0; split_id < sched->n_splits; split_id++) { struct ggml_backend_sched_split * split = &splits[split_id]; int split_backend_id = split->backend_id; @@ -1550,7 +1649,12 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } if (!sched->callback_eval) { - enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); + enum ggml_status ec; + if (ggml_backend_supports_graph_plan(split_backend) && sched->plans[split_id].plan) { + ec = ggml_backend_graph_plan_compute(split_backend, sched->plans[split_id].plan); + } else { + ec = ggml_backend_graph_compute_async(split_backend, &split->graph); + } if (ec != GGML_STATUS_SUCCESS) { return ec; } @@ -1637,6 +1741,10 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->splits = (ggml_backend_sched_split *) calloc(initial_splits_capacity, sizeof(sched->splits[0])); sched->splits_capacity = initial_splits_capacity; + const int initial_plans_capacity = 16; + sched->plans = (ggml_backend_sched_plan *) calloc(initial_plans_capacity, sizeof(sched->plans[0])); + sched->plans_capacity = initial_plans_capacity; + for (int b = 0; b < n_backends; b++) { sched->backends[b] = backends[b]; sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]); @@ -1670,6 +1778,8 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { ggml_free(sched->ctx); ggml_hash_set_free(&sched->hash_set); free(sched->splits); + ggml_backend_sched_free_plans(sched); + free(sched->plans); free(sched->hv_tensor_backend_ids); free(sched->hv_tensor_copies); free(sched->node_backend_ids); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 41ff89c4d6922..a6ee559eb2ca2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -936,13 +936,12 @@ struct ggml_cuda_graph { } cudaGraph_t graph = nullptr; cudaGraphExec_t instance = nullptr; + const ggml_cgraph * cgraph; size_t num_nodes = 0; std::vector nodes; std::vector params; - bool disable_due_to_gpu_arch = false; - bool disable_due_to_too_many_updates = false; - bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; + int number_consecutive_computes = 0; std::vector ggml_graph_properties; #endif }; @@ -955,7 +954,12 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; - std::unique_ptr cuda_graph; +#ifdef USE_CUDA_GRAPH + bool cuda_graph_initialized = false; + bool disable_graph_due_to_env = false; + bool disable_graph_due_to_gpu_arch = false; + bool disable_graph_due_to_too_many_updates = false; +#endif explicit ggml_backend_cuda_context(int device) : device(device), diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index f5a6a751acfd5..d5d70af273552 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2240,7 +2240,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * nb1, nb2, nb3, stream); } -static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { +static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, struct ggml_tensor * dst) { // why is this here instead of mul_mat? if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) { ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); @@ -2495,7 +2495,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_sum_rows(ctx, dst); break; case GGML_OP_MEAN: - ggml_cuda_op_mean(ctx, dst); + ggml_cuda_op_mean(ctx, cuda_graph, dst); break; case GGML_OP_SSM_CONV: ggml_cuda_op_ssm_conv(ctx, dst); @@ -2642,8 +2642,8 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { } #ifdef USE_CUDA_GRAPH -static bool check_node_graph_compatibility(ggml_cgraph * cgraph, - bool use_cuda_graph) { +static bool check_node_graph_compatibility(const ggml_cgraph * cgraph) { + bool use_cuda_graph = true; // Loop over nodes in GGML graph to obtain info needed for CUDA graph @@ -2753,45 +2753,49 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return true; } -static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { +static void update_cuda_graph_properties(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) { + cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]); + } +} +static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) { bool cuda_graph_update_required = false; - if (cuda_ctx->cuda_graph->instance == nullptr) { + if (cuda_graph->instance == nullptr) { cuda_graph_update_required = true; } // Check if the graph size has changed - if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + if (cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { cuda_graph_update_required = true; - cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); } // Loop over nodes in GGML graph to determine if CUDA graph update is required - // and store properties to allow this comparison for the next token for (int i = 0; i < cgraph->n_nodes; i++) { bool has_matching_properties = true; if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]); } if (!has_matching_properties) { cuda_graph_update_required = true; } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); } return cuda_graph_update_required; } -static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { +static void update_cuda_graph_executable(ggml_cuda_graph * cuda_graph) { #if CUDART_VERSION >= 12000 cudaGraphExecUpdateResultInfo result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); + cudaError_t stat = cudaGraphExecUpdate(cuda_graph->instance, cuda_graph->graph, &result_info); #else cudaGraphNode_t errorNode; cudaGraphExecUpdateResult result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); + cudaError_t stat = cudaGraphExecUpdate(cuda_graph->instance, cuda_graph->graph, &errorNode, &result_info); #endif // CUDART_VERSION >= 12000 if (stat == cudaErrorGraphExecUpdateFailure) { @@ -2802,9 +2806,9 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate (void)cudaGetLastError(); - CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); - cuda_ctx->cuda_graph->instance = nullptr; - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + CUDA_CHECK(cudaGraphExecDestroy(cuda_graph->instance)); + cuda_graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&cuda_graph->instance, cuda_graph->graph, NULL, NULL, 0)); } else { GGML_ASSERT(stat == cudaSuccess); } @@ -2924,233 +2928,272 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } -static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { +static void evaluate_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) { // flag used to determine whether it is an integrated_gpu const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; - while (!graph_evaluated_or_captured) { - // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. - // With the use of CUDA graphs, the execution will be performed by the graph launch. - if (!use_cuda_graph || cuda_graph_update_required) { - - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - - if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { - continue; - } - - static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); - if (!disable_fusion) { - - if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { - ggml_tensor * weights = cgraph->nodes[i+8]; - ggml_tensor * selected_experts = cgraph->nodes[i+3]; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true, - /*delayed softmax*/ false); - i += 8; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { - ggml_tensor * weights = cgraph->nodes[i+4]; - ggml_tensor * selected_experts = cgraph->nodes[i+3]; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false, - /*delayed softmax*/ false); - i += 4; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, - ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) { - ggml_tensor * weights = cgraph->nodes[i + 5]; - ggml_tensor * ids = cgraph->nodes[i + 1]; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false, - /*delayed_softmax*/ true); - i += 5; - continue; - } + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } - if (node->op == GGML_OP_ADD) { - int n_fuse = 0; - ggml_op ops[8]; - std::fill(ops, ops + 8, GGML_OP_ADD); + static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); + if (!disable_fusion) { - for (; n_fuse <= 6; ++n_fuse){ - if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { - break; - } - if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { - break; - } - if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { - break; - } - } + if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { + ggml_tensor * weights = cgraph->nodes[i+8]; + ggml_tensor * selected_experts = cgraph->nodes[i+3]; + ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true, + /*delayed softmax*/ false); + i += 8; + continue; + } - n_fuse++; + if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { + ggml_tensor * weights = cgraph->nodes[i+4]; + ggml_tensor * selected_experts = cgraph->nodes[i+3]; + ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false, + /*delayed softmax*/ false); + i += 4; + continue; + } - if (n_fuse > 1) { - for (int j = 0; j < n_fuse - 1; ++j) { - node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; - } - cgraph->nodes[i + n_fuse - 1]->data = node->data; - ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse); - i += n_fuse - 1; + if (ggml_cuda_can_fuse(cgraph, i, + ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) { + ggml_tensor * weights = cgraph->nodes[i + 5]; + ggml_tensor * ids = cgraph->nodes[i + 1]; - continue; - } - } + ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false, + /*delayed_softmax*/ true); + i += 5; + continue; + } + if (node->op == GGML_OP_ADD) { + int n_fuse = 0; + ggml_op ops[8]; + std::fill(ops, ops + 8, GGML_OP_ADD); - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { - ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); - i += 2; - continue; + for (; n_fuse <= 6; ++n_fuse){ + if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { + break; } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { - ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; + if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { + break; } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { - i += 2; - ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); - continue; + if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { + break; } } -#ifndef NDEBUG - assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); - for (int j = 0; j < GGML_MAX_SRC; j++) { - if (node->src[j] != nullptr) { - assert(node->src[j]->buffer); - assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || - ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft))); + + n_fuse++; + + if (n_fuse > 1) { + for (int j = 0; j < n_fuse - 1; ++j) { + node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; } - } -#else - GGML_UNUSED(integrated); -#endif // NDEBUG + cgraph->nodes[i + n_fuse - 1]->data = node->data; + ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse); + i += n_fuse - 1; - bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); - if (!ok) { - GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + continue; } - GGML_ASSERT(ok); } - } -#ifdef USE_CUDA_GRAPH - if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture - if (cuda_ctx->cuda_graph->graph != nullptr) { - CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); - cuda_ctx->cuda_graph->graph = nullptr; - } - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); - graph_evaluated_or_captured = true; // CUDA graph has been captured + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { + ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; + } - std::lock_guard lock(ggml_cuda_lock); - if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) { - ggml_cuda_lock_cv.notify_all(); + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; } - } else { - graph_evaluated_or_captured = true; // ggml graph has been directly evaluated - } - } - if (use_cuda_graph) { - if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { + i += 2; + ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); + continue; + } } - if (cuda_graph_update_required) { // Update graph executable - update_cuda_graph_executable(cuda_ctx); +#ifndef NDEBUG + assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != nullptr) { + assert(node->src[j]->buffer); + assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || + ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft))); + } } - // Launch graph - CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); #else - graph_evaluated_or_captured = true; -#endif // USE_CUDA_GRAPH + GGML_UNUSED(integrated); +#endif // NDEBUG + + bool ok = ggml_cuda_compute_forward(*cuda_ctx, cuda_graph, node); + if (!ok) { + GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); } } -static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; +#ifdef USE_CUDA_GRAPH +static void capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) { + // Start CUDA graph capture + { + std::lock_guard lock(ggml_cuda_lock); + ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed); + } - ggml_cuda_set_device(cuda_ctx->device); + CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); -#ifdef USE_CUDA_GRAPH - static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + evaluate_cuda_graph(cuda_ctx, cuda_graph, cgraph); + + if (cuda_graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(cuda_graph->graph)); + cuda_graph->graph = nullptr; + } + + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_graph->graph)); + + if (cuda_graph->instance == nullptr) { + CUDA_CHECK(cudaGraphInstantiate(&cuda_graph->instance, cuda_graph->graph, NULL, NULL, 0)); + } else { + update_cuda_graph_executable(cuda_graph); + } - // Objects required for CUDA Graph - if (cuda_ctx->cuda_graph == nullptr) { - cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); + std::lock_guard lock(ggml_cuda_lock); + if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) { + ggml_cuda_lock_cv.notify_all(); } +} +static bool should_use_cuda_graph(ggml_backend_cuda_context * cuda_ctx, const struct ggml_cgraph * cgraph) { bool use_cuda_graph = true; - bool cuda_graph_update_required = false; - if (cuda_ctx->cuda_graph->graph == nullptr) { + if (!cuda_ctx->cuda_graph_initialized) { + cuda_ctx->disable_graph_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { - cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; + cuda_ctx->disable_graph_due_to_gpu_arch = true; #ifndef NDEBUG GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); #endif } + cuda_ctx->cuda_graph_initialized = true; } // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, // or previous graph capture failure. // Also disable for multi-gpu for now. TO DO investigate - if (disable_cuda_graphs_due_to_env - || cuda_ctx->cuda_graph->disable_due_to_gpu_arch - || cuda_ctx->cuda_graph->disable_due_to_too_many_updates - || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { + if (cuda_ctx->disable_graph_due_to_env || cuda_ctx->disable_graph_due_to_gpu_arch || + cuda_ctx->disable_graph_due_to_too_many_updates) { use_cuda_graph = false; } if (use_cuda_graph) { - cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); + use_cuda_graph = check_node_graph_compatibility(cgraph); + } - use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph); + return use_cuda_graph; +} - // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. - if (use_cuda_graph && cuda_graph_update_required) { - cuda_ctx->cuda_graph->number_consecutive_updates++; - } else { - cuda_ctx->cuda_graph->number_consecutive_updates = 0; - } +static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; - if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { - cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); -#endif + ggml_cuda_set_device(cuda_ctx->device); + + ggml_cuda_graph * cuda_graph = new ggml_cuda_graph(); + + cuda_graph->cgraph = cgraph; + + if (should_use_cuda_graph(cuda_ctx, cgraph)) { + capture_cuda_graph(cuda_ctx, cuda_graph, cgraph); + update_cuda_graph_properties(cuda_graph, cgraph); + } + + return cuda_graph; +} + +static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + delete (ggml_cuda_graph *) plan; + + GGML_UNUSED(backend); +} + +static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + + ggml_cuda_set_device(cuda_ctx->device); + + ggml_cuda_graph * cuda_graph = (ggml_cuda_graph *) plan; + + cuda_graph->cgraph = cgraph; + + bool use_cuda_graph = should_use_cuda_graph(cuda_ctx, cgraph); + bool cuda_graph_update_required = false; + + // check if we are doing a graph update + if (cuda_graph->instance == nullptr && use_cuda_graph // no graph -> graph + || cuda_graph->instance != nullptr && !use_cuda_graph // graph -> no graph + || use_cuda_graph && is_cuda_graph_update_required(cuda_graph, cgraph)) { // graph property mismatch + cuda_graph->number_consecutive_updates++; + if (cuda_graph->number_consecutive_updates >= 4) { + cuda_ctx->disable_graph_due_to_too_many_updates = true; + use_cuda_graph = false; + } else { + cuda_graph_update_required = true; } + cuda_graph->number_consecutive_computes = 0; } if (use_cuda_graph && cuda_graph_update_required) { - // Start CUDA graph capture - { - std::lock_guard lock(ggml_cuda_lock); - ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed); + capture_cuda_graph(cuda_ctx, cuda_graph, cgraph); + update_cuda_graph_properties(cuda_graph, cgraph); + } else if (!use_cuda_graph) { + if (cuda_graph->instance != nullptr) { + CUDA_CHECK(cudaGraphExecDestroy(cuda_graph->instance)); + } + if (cuda_graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(cuda_graph->graph)); } + cuda_graph->instance = nullptr; + cuda_graph->graph = nullptr; + } +} + +static enum ggml_status ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + + ggml_cuda_graph * cuda_graph = (ggml_cuda_graph *) plan; - CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); + ggml_cuda_set_device(cuda_ctx->device); + + cuda_graph->number_consecutive_computes++; + if (cuda_graph->number_consecutive_computes > 1) { + cuda_graph->number_consecutive_updates = 0; } -#else - bool use_cuda_graph = false; - bool cuda_graph_update_required = false; -#endif // USE_CUDA_GRAPH + if (cuda_graph->instance) { + CUDA_CHECK(cudaGraphLaunch(cuda_graph->instance, cuda_ctx->stream())); + } else { + evaluate_cuda_graph(cuda_ctx, cuda_graph, cuda_graph->cgraph); + } + return GGML_STATUS_SUCCESS; +} +#endif + +static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; - bool graph_evaluated_or_captured = false; + ggml_cuda_set_device(cuda_ctx->device); - evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + evaluate_cuda_graph(cuda_ctx, nullptr, cgraph); return GGML_STATUS_SUCCESS; } @@ -3187,10 +3230,17 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async, /* .synchronize = */ ggml_backend_cuda_synchronize, +#ifdef USE_CUDA_GRAPH + /* .graph_plan_create = */ ggml_backend_cuda_graph_plan_create, + /* .graph_plan_free = */ ggml_backend_cuda_graph_plan_free, + /* .graph_plan_update = */ ggml_backend_cuda_graph_plan_update, + /* .graph_plan_compute = */ ggml_backend_cuda_graph_plan_compute, +#else /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, +#endif /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 347abc18660ca..6e91f236b958a 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -10,7 +10,7 @@ template __global__ void divide_by_count(T * result, size_t count) *result /= static_cast(count); } -void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *) src0->data; float * dst_d = (float *) dst->data; @@ -33,14 +33,12 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { #ifdef USE_CUDA_GRAPH // CUDA_GRAPHS_DISABLED ((ncols > 65536) && - ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || - ctx.cuda_graph->disable_due_to_failed_graph_capture)) || + ((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || + ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates)) || // CUDA_GRAPHS ENABLED ((ncols > 32768) && - !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || - ctx.cuda_graph->disable_due_to_failed_graph_capture))) { + !((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || + ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates))) { #else (ncols > 65536)) { #endif // USE_CUDA_GRAPH diff --git a/ggml/src/ggml-cuda/mean.cuh b/ggml/src/ggml-cuda/mean.cuh index 2b9b10433438e..14dca24736281 100644 --- a/ggml/src/ggml-cuda/mean.cuh +++ b/ggml/src/ggml-cuda/mean.cuh @@ -1,3 +1,3 @@ #include "common.cuh" -void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, ggml_tensor * dst);