1818#include < syclcompat/math.hpp>
1919#include < oneapi/mkl.hpp>
2020#include < map>
21+ #include < cassert>
2122
2223#include " ggml-sycl.h"
2324#include " ggml.h"
@@ -88,6 +89,16 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
8889 return device_type.str ();
8990}
9091
92+ template <typename Ts>
93+ struct matrix_info_t
94+ {
95+ oneapi::mkl::transpose transpose_info[2 ];
96+ Ts value_info[2 ];
97+ std::int64_t size_info[3 ];
98+ std::int64_t ld_info[3 ];
99+ std::int64_t groupsize_info;
100+ };
101+
91102namespace dpct
92103{
93104 typedef sycl::queue *queue_ptr;
@@ -1737,27 +1748,16 @@ namespace dpct
17371748 oneapi::mkl::transpose b_trans, int m, int n, int k,
17381749 const void *alpha, const void **a, int lda,
17391750 const void **b, int ldb, const void *beta, void **c,
1740- int ldc, int batch_size)
1751+ int ldc, int batch_size, matrix_info_t < double >* matrix_info )
17411752 {
1742- struct matrix_info_t
1743- {
1744- oneapi::mkl::transpose transpose_info[2 ];
1745- Ts value_info[2 ];
1746- std::int64_t size_info[3 ];
1747- std::int64_t ld_info[3 ];
1748- std::int64_t groupsize_info;
1749- };
17501753
17511754 Ts alpha_value = dpct::get_value (reinterpret_cast <const Ts *>(alpha), q);
17521755 Ts beta_value = dpct::get_value (reinterpret_cast <const Ts *>(beta), q);
17531756
1754- // ggml_backend_sycl_host_buffer_type()->alloc_buffer;
1755- auto tmp = ggml_backend_sycl_reg ( );
1756- std::cout << " this is WARIO " << tmp-> iface . get_name (tmp) << ' \n ' ;
1757+ // ::matrix_info_t<Ts> *matrix_info =
1758+ // (::matrix_info_t<Ts> *)std::malloc(sizeof(matrix_info_t<Ts>) );
1759+ // printf("test pointer %p alpha_value %f before\n", matrix_info, alpha_value) ;
17571760
1758-
1759- matrix_info_t *matrix_info =
1760- (matrix_info_t *)std::malloc (sizeof (matrix_info_t ));
17611761 matrix_info->transpose_info [0 ] = a_trans;
17621762 matrix_info->transpose_info [1 ] = b_trans;
17631763 matrix_info->value_info [0 ] = alpha_value;
@@ -1770,13 +1770,15 @@ namespace dpct
17701770 matrix_info->ld_info [2 ] = ldc;
17711771 matrix_info->groupsize_info = batch_size;
17721772
1773+ // printf("test pointer %p alpha_value %f\n", matrix_info, matrix_info->value_info[0]);;
1774+
17731775#ifdef GGML_SYCL_NVIDIA
17741776 sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
17751777 oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
17761778 matrix_info->transpose_info + 1 , matrix_info->size_info , matrix_info->size_info + 1 ,
1777- matrix_info->size_info + 2 , matrix_info->value_info , reinterpret_cast <const Ta **>(a),
1779+ matrix_info->size_info + 2 , reinterpret_cast <Ts*>( matrix_info->value_info ) , reinterpret_cast <const Ta **>(a),
17781780 matrix_info->ld_info , reinterpret_cast <const Tb **>(b), matrix_info->ld_info + 1 ,
1779- matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
1781+ reinterpret_cast <Ts*>( matrix_info->value_info + 1 ) , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
17801782 &(matrix_info->groupsize_info ));
17811783#else
17821784 sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
@@ -1786,11 +1788,14 @@ namespace dpct
17861788 matrix_info->ld_info + 1 , matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c),
17871789 matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
17881790#endif
1791+ // printf("gemm_launched\n");
17891792
1793+ /*
17901794 q.submit([&](sycl::handler &cgh)
17911795 {
17921796 cgh.depends_on(e);
17931797 cgh.host_task([=] { std::free(matrix_info); }); });
1798+ */
17941799 }
17951800
17961801 template <class Ta , class Tb , class Tc , class Ts >
@@ -2439,7 +2444,8 @@ namespace dpct
24392444 library_data_t a_type, int lda, const void *b[],
24402445 library_data_t b_type, int ldb, const void *beta,
24412446 void *c[], library_data_t c_type, int ldc,
2442- int batch_size, library_data_t scaling_type)
2447+ int batch_size, library_data_t scaling_type,
2448+ matrix_info_t <double >* matrix_info)
24432449 {
24442450 if (scaling_type == library_data_t ::real_float &&
24452451 c_type == library_data_t ::complex_float)
@@ -2451,7 +2457,6 @@ namespace dpct
24512457 {
24522458 scaling_type = library_data_t ::complex_double;
24532459 }
2454-
24552460 std::uint64_t key =
24562461 detail::get_type_combination_id (a_type, b_type, c_type, scaling_type);
24572462 switch (key)
@@ -2462,7 +2467,7 @@ namespace dpct
24622467 {
24632468 detail::gemm_batch_impl<float , float , float , float >(
24642469 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2465- batch_size);
2470+ batch_size, matrix_info );
24662471 break ;
24672472 }
24682473 case detail::get_type_combination_id (
@@ -2471,17 +2476,18 @@ namespace dpct
24712476 {
24722477 detail::gemm_batch_impl<double , double , double , double >(
24732478 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2474- batch_size);
2479+ batch_size, matrix_info );
24752480 break ;
24762481 }
2482+ /*
24772483 case detail::get_type_combination_id(
24782484 library_data_t::complex_float, library_data_t::complex_float,
24792485 library_data_t::complex_float, library_data_t::complex_float):
24802486 {
24812487 detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
24822488 std::complex<float>, std::complex<float>>(
24832489 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2484- batch_size);
2490+ batch_size, matrix_info );
24852491 break;
24862492 }
24872493 case detail::get_type_combination_id(
@@ -2491,17 +2497,18 @@ namespace dpct
24912497 detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
24922498 std::complex<double>, std::complex<double>>(
24932499 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2494- batch_size);
2500+ batch_size, matrix_info );
24952501 break;
24962502 }
2503+ */
24972504 case detail::get_type_combination_id (
24982505 library_data_t ::real_half, library_data_t ::real_half,
24992506 library_data_t ::real_half, library_data_t ::real_half):
25002507 {
25012508 detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
25022509 sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
25032510 a, lda, b, ldb, beta, c, ldc,
2504- batch_size);
2511+ batch_size, matrix_info );
25052512 break ;
25062513 }
25072514#ifdef __INTEL_MKL__
@@ -2512,7 +2519,7 @@ namespace dpct
25122519 detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
25132520 oneapi::mkl::bfloat16, float >(
25142521 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2515- batch_size);
2522+ batch_size, matrix_info );
25162523 break ;
25172524 }
25182525 case detail::get_type_combination_id (
@@ -2521,7 +2528,7 @@ namespace dpct
25212528 {
25222529 detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float ,
25232530 float >(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2524- b, ldb, beta, c, ldc, batch_size);
2531+ b, ldb, beta, c, ldc, batch_size, matrix_info );
25252532 break ;
25262533 }
25272534#endif
@@ -2536,7 +2543,7 @@ namespace dpct
25362543 detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t ,
25372544 float >(q, a_trans, b_trans, m, n, k, &alpha_float,
25382545 a, lda, b, ldb, &beta_float, c, ldc,
2539- batch_size);
2546+ batch_size, matrix_info );
25402547 break ;
25412548 }
25422549 case detail::get_type_combination_id (
@@ -2545,7 +2552,7 @@ namespace dpct
25452552 {
25462553 detail::gemm_batch_impl<std::int8_t , std::int8_t , float , float >(
25472554 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2548- batch_size);
2555+ batch_size, matrix_info );
25492556 break ;
25502557 }
25512558 case detail::get_type_combination_id (
@@ -2554,7 +2561,7 @@ namespace dpct
25542561 {
25552562 detail::gemm_batch_impl<sycl::half, sycl::half, float , float >(
25562563 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2557- batch_size);
2564+ batch_size, matrix_info );
25582565 break ;
25592566 }
25602567 case detail::get_type_combination_id (
@@ -2569,7 +2576,7 @@ namespace dpct
25692576 sycl::half beta_half (beta_value);
25702577 detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
25712578 q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2572- batch_size);
2579+ batch_size, matrix_info );
25732580 break ;
25742581 }
25752582 default :
0 commit comments