From 80a94559c581132c2fa4e643caad70bc0e81a283 Mon Sep 17 00:00:00 2001 From: hipudding Date: Thu, 25 Sep 2025 09:37:09 +0000 Subject: [PATCH] CANN: Update several operators to support FP16 data format Many Ascend operators internally use FP16 precision for computation. If input data is in FP32, it must first be cast to FP16 before computation, and then cast back to FP32 after computation, which introduces unnecessary cast operations. Moreover, FP16 computation requires significantly less workload compared to FP32, leading to noticeable efficiency improvements. In this change, `get_rows`, `rms_norm`, and `flash_attn_ext` are extended to support multiple data types. Validation on the Qwen2 0.5b model shows correct accuracy and about 10% performance gain in concurrent scenarios. Co-authored-by: noemotiovon <757486878@qq.com> --- ggml/src/ggml-cann/aclnn_ops.cpp | 197 +++++++++++++++---------------- 1 file changed, 96 insertions(+), 101 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 434023dd22ab3..240e8a1b2c025 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -894,14 +894,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, } /** - * @brief Get or expand a cached float32 tensor filled with a scalar value. + * @brief Get or expand a cached tensor filled with a scalar value. * - * This function manages cached device memory for float32 tensors. If the current + * This function manages cached device memory for tensors. If the current * cache size is insufficient for the requested tensor shape, the old memory will - * be released and new memory will be allocated. The allocated buffer is then - * initialized either with zeros (when @p value == 0.0f) or with the given scalar - * value using CANN operations. Finally, an aclTensor object is created from the - * cached memory and returned. + * be released and new memory will be allocated. The allocated buffer is + * initialized with the given scalar value using CANN operations. + * Finally, an aclTensor object is created from the cached memory and returned. * * @param ctx The CANN backend context that manages device memory. * @param buffer A pointer to the cached device buffer (will be allocated @@ -910,17 +909,19 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, * updated when the cache is expanded. * @param ne The tensor shape array (number of elements in each dimension). * @param nb The stride size for each dimension. + * @param dtype Data type of cached tensor. * @param dims The number of tensor dimensions. * @param value The scalar value used to fill the tensor (supports zero * initialization via memset or arbitrary values via fill_scalar). * @return An aclTensor pointer created from the cached buffer. */ -static aclTensor* get_f32_cache_acl_tensor( +static aclTensor* get_cache_acl_tensor( ggml_backend_cann_context& ctx, void** buffer, int64_t &cache_element, int64_t* ne, size_t* nb, + ggml_type dtype, int64_t dims, float value) { // Calculate total number of elements @@ -928,7 +929,7 @@ static aclTensor* get_f32_cache_acl_tensor( for (int i = 0; i < dims; i++) { n_element *= ne[i]; } - size_t size = n_element * sizeof(float); + size_t size = n_element * ggml_type_size(dtype); // Allocate or expand cache if needed if (cache_element < n_element) { @@ -941,19 +942,17 @@ static aclTensor* get_f32_cache_acl_tensor( cache_element = n_element; // Initialize cache - if (value == 0.0f) { - ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream())); - } else { - int64_t pool_ne[1] = { n_element }; - size_t pool_nb[1] = { sizeof(float) }; - aclTensor* acl_value = ggml_cann_create_tensor( - *buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1); - aclnn_fill_scalar(ctx, 1, acl_value); - ggml_cann_release_resources(ctx, acl_value); - } + int64_t pool_ne[1] = { n_element }; + size_t pool_nb[1] = { ggml_type_size(dtype) }; + aclTensor* acl_value = ggml_cann_create_tensor( + *buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype), + pool_ne, pool_nb, 1); + aclnn_fill_scalar(ctx, value, acl_value); + ggml_cann_release_resources(ctx, acl_value); } - return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims); + return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype), + ggml_type_size(dtype), ne, nb, dims); } void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -965,35 +964,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); - // build gamma, one... + // build gamma. size_t acl_gamma_nb[GGML_MAX_DIMS]; - acl_gamma_nb[0] = sizeof(float); + // gamma's type is the same with dst. + acl_gamma_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1]; } - aclTensor* acl_gamma = get_f32_cache_acl_tensor( + aclTensor* acl_gamma = get_cache_acl_tensor( ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, src->ne, acl_gamma_nb, + dst->type, 1, // dims 1.0f // value ); - // build rstd, zero... + // build rstd. int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]}; size_t acl_rstd_nb[GGML_MAX_DIMS - 1]; + // rstd will always be F32. acl_rstd_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS - 1; i++) { acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1]; } - aclTensor* acl_rstd = get_f32_cache_acl_tensor( + aclTensor* acl_rstd = get_cache_acl_tensor( ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size, acl_rstd_ne, acl_rstd_nb, + GGML_TYPE_F32, GGML_MAX_DIMS - 1, 0.0f // value ); @@ -1765,33 +1768,35 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src0 = dst->src[0]; // src ggml_tensor* src1 = dst->src[1]; // index + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + switch (src0->type) { - case GGML_TYPE_F32: { - aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, - dst->data, dst->ne, dst->nb, - src1, dst->type); - break; - } - case GGML_TYPE_F16: { - aclTensor* acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator( - ctx.pool(), ggml_nelements(src0) * sizeof(float)); - void* src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = sizeof(float); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + case GGML_TYPE_F16: + case GGML_TYPE_F32: + if(src0->type == dst->type) { + aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, + dst->data, dst->ne, dst->nb, + src1, dst->type); + } else { + aclTensor* acl_src0 = ggml_cann_create_tensor(src0); + ggml_cann_pool_alloc src_buffer_allocator( + ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst)); + void* src_trans_buffer = src_buffer_allocator.get(); + size_t src_trans_nb[GGML_MAX_DIMS]; + src_trans_nb[0] = dst->nb[0]; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + } + aclTensor* src_trans_tensor = ggml_cann_create_tensor( + src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->ne, src_trans_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); + aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, + dst->data, dst->ne, dst->nb, + src1, dst->type); + ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor); } - aclTensor* src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type), - src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); - aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, - dst->data, dst->ne, dst->nb, - src1, dst->type); - ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor); break; - } case GGML_TYPE_Q8_0: { // add 1 dim for bcast mul. size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], @@ -1799,7 +1804,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne; int64_t scale_offset = 0; - // [3,4,5,64] -> [3,4,5,2,32] weight_ne[0] = QK8_0; weight_ne[1] = src0->ne[0] / QK8_0; @@ -1809,7 +1813,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { weight_ne[i] = src0->ne[i - 1]; weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,1] scale_ne[0] = 1; scale_ne[1] = src0->ne[0] / QK8_0; @@ -1819,18 +1822,15 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { scale_ne[i] = src0->ne[i - 1]; scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,32] dequant_ne = weight_ne; - dequant_nb[0] = sizeof(float); + dequant_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; } - scale_offset = ggml_nelements(src0) * sizeof(int8_t); ggml_cann_pool_alloc dequant_buffer_allocator( - ctx.pool(), ggml_nelements(src0) * sizeof(float)); - + ctx.pool(), ggml_nelements(src0) * ggml_type_size(dst->type)); aclTensor* acl_weight_tensor = ggml_cann_create_tensor( src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, GGML_MAX_DIMS + 1); @@ -1838,16 +1838,14 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); aclTensor* dequant_tensor = ggml_cann_create_tensor( - dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float), + dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); - aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); - dequant_nb[0] = sizeof(float); + dequant_nb[0] = ggml_type_size(dst->type); dequant_ne = src0->ne; for (int i = 1; i < GGML_MAX_DIMS; i++) { dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; } - aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne, dst->nb, @@ -1965,16 +1963,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, // Only check env once. static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); if (weight_to_nz && is_matmul_weight(weight)) { - int64_t acl_stride[2] = {1, transpose_ne[1]}; - - // Reverse ne. - std::reverse(transpose_ne, transpose_ne + n_dims); - - std::vector storageDims = {transpose_ne[0], transpose_ne[1]}; - - acl_weight_tensor = aclCreateTensor( - transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride, - 0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data); + acl_weight_tensor = + ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); } else { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND); @@ -3178,7 +3168,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclTensor* acl_src0_f16_tensor = nullptr; aclTensor* acl_src1_f16_tensor = nullptr; aclTensor* acl_src2_f16_tensor = nullptr; - aclTensor* acl_dst_f16_tensor = nullptr; // Step 1: cast the src0 (Query) to fp16 if needed ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); @@ -3216,22 +3205,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS); - ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); - void* out_f16_buffer = out_f16_allocator.alloc( - ggml_nelements(dst) * faElemSize); - - int64_t* out_f16_ne = src0_bsnd_ne; - size_t out_f16_nb[GGML_MAX_DIMS]; - out_f16_nb[0] = faElemSize; - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; - } - - acl_dst_f16_tensor = ggml_cann_create_tensor( - out_f16_buffer, faDataType, faElemSize, - out_f16_ne, out_f16_nb, GGML_MAX_DIMS - ); - // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp aclTensor* bcast_pse_tensor = nullptr; @@ -3334,8 +3307,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ int64_t keyAntiquantMode = 0; int64_t valueAntiquantMode = 0; - // Step 5: launch the FusedInferAttentionScoreV2 kernel. - // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + aclTensor * fa_dst_tensor = nullptr; + aclTensor * acl_dst_tensor = nullptr; + ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); + if (dst->type == GGML_TYPE_F32) { + void* out_f16_buffer = out_f16_allocator.alloc( + ggml_nelements(dst) * faElemSize); + + int64_t* out_f16_ne = src0_bsnd_ne; + size_t out_f16_nb[GGML_MAX_DIMS]; + out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; + } + + fa_dst_tensor = ggml_cann_create_tensor( + out_f16_buffer, faDataType, faElemSize, + out_f16_ne, out_f16_nb, GGML_MAX_DIMS + ); + } + else { + fa_dst_tensor = ggml_cann_create_tensor(dst); + } GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v @@ -3357,23 +3351,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ blockSize, antiquantMode, // blockSize, antiquantMode softmaxLseFlag, // softmaxLseFlag keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode - acl_dst_f16_tensor, // attentionOut + fa_dst_tensor, // attentionOut nullptr // softmaxLse ); - // Step 6: post-processing, permute and cast to f32 - aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - // TODO: when dst is fp16, don't need cast - aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); - ggml_cann_release_resources(ctx, acl_src0_f16_tensor, - acl_src1_f16_tensor, - acl_src2_f16_tensor, - acl_dst_f16_tensor, - acl_dst_tensor); - if(src3 != nullptr){ - ggml_cann_release_resources(ctx, bcast_pse_tensor); + if (dst->type == GGML_TYPE_F32) { + // Step 6: post-processing, permute and cast to f32 + aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + aclnn_cast(ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); } - }else{ + + ggml_cann_release_resources(ctx, acl_src0_f16_tensor, + acl_src1_f16_tensor, + acl_src2_f16_tensor, + fa_dst_tensor, + acl_dst_tensor, + bcast_pse_tensor); + + } else { GGML_ABORT("Function is not implemented."); } }