@@ -2847,6 +2847,24 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
28472847 return false ;
28482848}
28492849
2850+ inline bool ggml_sycl_supports_reorder_dequantize (enum ggml_type type) {
2851+ switch (type) {
2852+ case GGML_TYPE_Q4_0:
2853+ return true ;
2854+ default :
2855+ return false ;
2856+ }
2857+ }
2858+
2859+ inline bool ggml_sycl_supports_reorder_dmmv (enum ggml_type type) {
2860+ switch (type) {
2861+ case GGML_TYPE_Q4_0:
2862+ return true ;
2863+ default :
2864+ return false ;
2865+ }
2866+ }
2867+
28502868inline bool ggml_sycl_supports_reorder_mmvq (enum ggml_type type) {
28512869 switch (type) {
28522870 case GGML_TYPE_Q4_0:
@@ -2884,7 +2902,7 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
28842902 GGML_ASSERT ((size % sizeof (block_q4_0) == 0 ));
28852903 GGML_ASSERT ((offset % sizeof (block_q4_0) == 0 ));
28862904 int offset_blks = offset / sizeof (block_q4_0);
2887- auto qs_ptr = (uint8_t *)data_device + offset_blks * QK4_0 / 2 ;;
2905+ auto qs_ptr = (uint8_t *)data_device + offset_blks * QK4_0 / 2 ;
28882906 auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2 ) + offset_blks;
28892907
28902908 stream->parallel_for (
@@ -2912,17 +2930,19 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
29122930 reorder_qw (data_device, ncols, nrows, size, 0 , stream);
29132931}
29142932
2933+ static bool should_reorder_tensor (ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
2934+ return !g_ggml_sycl_disable_optimize && // allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
2935+ ctx.opt_feature .reorder && // allow this device due to good perf, skip the devices with bad perf.
2936+ dst->op == GGML_OP_MUL_MAT && // limit to some supported cases of Q4_0, to do for more cases.
2937+ dst->src [1 ]->ne [2 ]==1 && dst->src [1 ]->ne [3 ]==1 ;
2938+ }
2939+
29152940/*
29162941* This function could be called when the OP (mul_mat) function support reorder optimizition.
29172942*/
29182943static void opt_for_reorder (ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1,
29192944 ggml_tensor * dst) {
2920- if (!g_ggml_sycl_disable_optimize && // allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
2921- ctx->opt_feature .reorder && // allow this device due to good perf, skip the devices with bad perf.
2922- dst->op == GGML_OP_MUL_MAT && // limit to some supported cases of Q4_0, to do for more cases.
2923- src0->type == GGML_TYPE_Q4_0 &&
2924- src1->ne [2 ]==1 && src1->ne [3 ]==1 ) {
2925-
2945+ if (should_reorder_tensor (*ctx, dst)) {
29262946 ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra ;
29272947 if (!extra) return ; // only happen in CI/UT permute case.
29282948
@@ -2975,21 +2995,16 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29752995 use_mul_mat_q = use_mul_mat_q && (src1->ne [1 ] <= MMQ_MAX_BATCH_SIZE);
29762996#endif // SYCL_USE_XMX
29772997
2978- const bool reorder = static_cast <ggml_tensor_extra_gpu *>(dst->src [0 ]->extra ) &&
2979- static_cast <ggml_tensor_extra_gpu *>(dst->src [0 ]->extra )->optimized_feature .reorder ;
29802998
29812999 // mmvq path is faster in the CUDA backend.
29823000 if (!g_ggml_sycl_disable_mmvq && (ctx.stream ()->get_backend () == sycl::backend::ext_oneapi_cuda
29833001 // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
29843002 // is enabled takes precedence over DMMV, the current if-else implementation
29853003 // requires disabling DMMV if both conditions are met
2986- || (reorder && ggml_sycl_supports_reorder_mmvq (src0->type )))) {
3004+ || (should_reorder_tensor (ctx, dst) && ggml_sycl_supports_reorder_mmvq (src0->type )))) {
29873005 use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
29883006 }
29893007
2990- // TODO: Romain
2991- printf (" \n\n ** mul_mat use_dequantize_mul_mat_vec=%d use_mul_mat_vec_q=%d use_mul_mat_q=%d reorder=%d split=%d m=%ld n=%ld k=%ld batch0=%ld batch1=%ld\n " , use_dequantize_mul_mat_vec, use_mul_mat_vec_q, use_mul_mat_q, reorder, split, src0->ne [1 ], src1->ne [1 ], src0->ne [0 ], src0->ne [3 ], src1->ne [3 ]);
2992-
29933008 if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted (src0) && ggml_is_permuted (src1) && src1->ne [1 ] == 1 ) {
29943009 // TODO: Refactor and cleanup of mul mat dispatching.
29953010 if (src0->ne [3 ] == 1 && src1->ne [3 ] == 1 ) {
@@ -3008,19 +3023,24 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
30083023 ggml_sycl_mul_mat_batched_sycl (ctx, src0, src1, dst);
30093024 } else if (use_dequantize_mul_mat_vec) {
30103025 constexpr bool convert_src1_to_q8_1 = false ;
3011- opt_for_reorder (&ctx, src0, src1, dst); // the OP function in this branch support reorder.
3026+ if (ggml_sycl_supports_reorder_dmmv (src0->type )) {
3027+ opt_for_reorder (&ctx, src0, src1, dst);
3028+ }
30123029 ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
3013- // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
30143030 } else if (use_mul_mat_vec_q) {
30153031 constexpr bool convert_src1_to_q8_1 = true ;
3016- opt_for_reorder (&ctx, src0, src1, dst); // the OP function in this branch support reorder.
3032+ if (ggml_sycl_supports_reorder_mmvq (src0->type )) {
3033+ opt_for_reorder (&ctx, src0, src1, dst);
3034+ }
30173035 ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
30183036 } else if (use_mul_mat_q) {
30193037 constexpr bool convert_src1_to_q8_1 = true ;
30203038 ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
30213039 } else {
30223040 constexpr bool convert_src1_to_q8_1 = false ;
3023- opt_for_reorder (&ctx, src0, src1, dst); // the OP function in this branch support reorder.
3041+ if (ggml_sycl_supports_reorder_dequantize (src0->type )) {
3042+ opt_for_reorder (&ctx, src0, src1, dst); // the OP function in this branch support reorder.
3043+ }
30243044 ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
30253045 }
30263046 GGML_SYCL_DEBUG (" call %s done\n " , __func__);
0 commit comments