From 0afea98ef059c27a70f2535eeff89197dbb7a495 Mon Sep 17 00:00:00 2001 From: nscipione Date: Tue, 14 Jan 2025 16:07:13 +0000 Subject: [PATCH 1/6] Implement host pool for matrix_info Creating a new memory pool on the host to store memory location for matrix_info needed to launch gemm_batch from oneMKL/oneMath. Removing complex support in gemm_batch since it is not used in llama.cpp Signed-off-by: nscipione --- ggml/src/ggml-sycl/common.hpp | 15 +++++ ggml/src/ggml-sycl/dpct/helper.hpp | 93 +++++++++++------------------- ggml/src/ggml-sycl/ggml-sycl.cpp | 90 ++++++++++++++++++++++++++++- 3 files changed, 137 insertions(+), 61 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index e9500f3a1682b..91b432da922dd 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -333,8 +333,12 @@ struct ggml_backend_sycl_context { // pool std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; + std::unique_ptr host_pools[GGML_SYCL_MAX_DEVICES]; + static std::unique_ptr new_pool_for_device(queue_ptr qptr, int device); + static std::unique_ptr new_pool_for_host(queue_ptr qptr, int device); + ggml_sycl_pool & pool(int device) { if (pools[device] == nullptr) { pools[device] = new_pool_for_device(stream(device,0), device); @@ -345,6 +349,17 @@ struct ggml_backend_sycl_context { ggml_sycl_pool & pool() { return pool(device); } + + ggml_sycl_pool & host_pool(int device) { + if (host_pools[device] == nullptr) { + host_pools[device] = new_pool_for_host(stream(device,0), device); + } + return *host_pools[device]; + } + + ggml_sycl_pool & host_pool() { + return host_pool(device); + } }; // common device functions diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index e167948e7a3f0..645f681d88321 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -18,9 +18,16 @@ #include #include #include +#include +#include "ggml-sycl.h" #include "ggml.h" +#include "ggml-backend.h" +#include "ggml-backend-impl.h" +#include "ggml-alloc.h" +#include "ggml-impl.h" + #if defined(__linux__) #include #elif defined(_WIN64) @@ -82,6 +89,16 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { return device_type.str(); } +template +struct matrix_info_t +{ + oneapi::mkl::transpose transpose_info[2]; + Ts value_info[2]; + std::int64_t size_info[3]; + std::int64_t ld_info[3]; + std::int64_t groupsize_info; +}; + namespace dpct { typedef sycl::queue *queue_ptr; @@ -1731,22 +1748,12 @@ namespace dpct oneapi::mkl::transpose b_trans, int m, int n, int k, const void *alpha, const void **a, int lda, const void **b, int ldb, const void *beta, void **c, - int ldc, int batch_size) + int ldc, int batch_size, matrix_info_t* matrix_info) { - struct matrix_info_t - { - oneapi::mkl::transpose transpose_info[2]; - Ts value_info[2]; - std::int64_t size_info[3]; - std::int64_t ld_info[3]; - std::int64_t groupsize_info; - }; Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); matrix_info->transpose_info[0] = a_trans; matrix_info->transpose_info[1] = b_trans; matrix_info->value_info[0] = alpha_value; @@ -1763,23 +1770,19 @@ namespace dpct sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( oneapi::mkl::backend_selector{ q }, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast(a), + matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), matrix_info->ld_info + 2, 1, + reinterpret_cast(matrix_info->value_info+1), reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); #else sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, - matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, + matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), - matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast(c), + matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); #endif - q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(e); - cgh.host_task([=] { std::free(matrix_info); }); }); } template @@ -2428,19 +2431,9 @@ namespace dpct library_data_t a_type, int lda, const void *b[], library_data_t b_type, int ldb, const void *beta, void *c[], library_data_t c_type, int ldc, - int batch_size, library_data_t scaling_type) + int batch_size, library_data_t scaling_type, + matrix_info_t* matrix_info) { - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) - { - scaling_type = library_data_t::complex_float; - } - else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) - { - scaling_type = library_data_t::complex_double; - } - std::uint64_t key = detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); switch (key) @@ -2451,7 +2444,7 @@ namespace dpct { detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2460,27 +2453,7 @@ namespace dpct { detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2490,7 +2463,7 @@ namespace dpct detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + batch_size, matrix_info); break; } #ifdef __INTEL_MKL__ @@ -2501,7 +2474,7 @@ namespace dpct detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2510,7 +2483,7 @@ namespace dpct { detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_size); + b, ldb, beta, c, ldc, batch_size, matrix_info); break; } #endif @@ -2525,7 +2498,7 @@ namespace dpct detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, - batch_size); + batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2534,7 +2507,7 @@ namespace dpct { detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2543,7 +2516,7 @@ namespace dpct { detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2558,7 +2531,7 @@ namespace dpct sycl::half beta_half(beta_value); detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, - batch_size); + batch_size, matrix_info); break; } default: diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 037c8093eef30..d89286d890d49 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -37,6 +37,7 @@ #include "ggml-backend-impl.h" #include "ggml-sycl/backend.hpp" +#include "ggml-sycl/dpct/helper.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" @@ -1173,6 +1174,92 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } }; +struct ggml_sycl_pool_host : public ggml_sycl_pool { + + int device; + queue_ptr qptr; + + inline static int counter{0}; + struct ggml_sycl_buffer { + void * ptr = nullptr; + size_t size = 0; + }; + + // Set arbitrarly to 64 + static constexpr int MAX_POOL_SIZE{64}; + std::vector buffer_pool = std::vector(MAX_POOL_SIZE); + size_t pool_size = 0; + + explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : + qptr(qptr_), + device(device_) { + } + + ~ggml_sycl_pool_host() { + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + b.ptr = nullptr; + pool_size -= b.size; + b.size = 0; + } + } + counter = 0; + } + + void * alloc(size_t size, size_t * actual_size) override { + if ( counter == MAX_POOL_SIZE){ + ggml_sycl_buffer b = buffer_pool[0]; + size_t look_ahead_size = (size_t) (1.05 * size); + void *ptr = b.ptr; + *actual_size = b.size; + counter = 1; + return ptr; + } + ggml_sycl_buffer& b = buffer_pool[counter]; + + if (b.ptr == nullptr) { + void * ptr; + + SYCL_CHECK( + CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_host( + size, *qptr))); + if (!ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); + return nullptr; + } + pool_size += size; + *actual_size = size; + counter = counter + 1; + return ptr; + } + else if (b.ptr != nullptr) { + ++counter; + b.size = size; + return b.ptr; + } + } + + void free(void * ptr, size_t size) override { + // if the pool is not completed add the pointer to it in place of the first nullptr found. + // Otherwise do nothing, pointers will be freed once the pool is deallocated. + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer& b = buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + } +}; + +std::unique_ptr ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) { + // return pool for the host to speed up memory management + return std::unique_ptr(new ggml_sycl_pool_host(qptr, device)); +} + std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { // TBD: NO VMM support // if (ggml_sycl_info().devices[device].vmm) { @@ -3363,6 +3450,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); + ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(),1); sycl::range<3> block_dims(1, ne12, ne13); /* @@ -3398,7 +3486,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, (const void **)(ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, - cu_compute_type))); + cu_compute_type, (matrix_info_t*)matrix_info.get()))); } } catch (sycl::exception const &exc) { From ee11dea6d6dcb6f2d705aa164770c970956c11a9 Mon Sep 17 00:00:00 2001 From: nscipione Date: Wed, 15 Jan 2025 10:05:48 +0100 Subject: [PATCH 2/6] Remove unnecessary headers and cast Signed-off-by: nscipione --- ggml/src/ggml-sycl/dpct/helper.hpp | 7 ------- ggml/src/ggml-sycl/ggml-sycl.cpp | 3 +-- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index 645f681d88321..b66073f482287 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -18,16 +18,9 @@ #include #include #include -#include -#include "ggml-sycl.h" #include "ggml.h" -#include "ggml-backend.h" -#include "ggml-backend-impl.h" -#include "ggml-alloc.h" -#include "ggml-impl.h" - #if defined(__linux__) #include #elif defined(_WIN64) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index d89286d890d49..5592e119ca15a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -37,7 +37,6 @@ #include "ggml-backend-impl.h" #include "ggml-sycl/backend.hpp" -#include "ggml-sycl/dpct/helper.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" @@ -3486,7 +3485,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, (const void **)(ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, - cu_compute_type, (matrix_info_t*)matrix_info.get()))); + cu_compute_type, matrix_info.get()))); } } catch (sycl::exception const &exc) { From 6b776392586bfc6c16cee171c2b153812a9062f6 Mon Sep 17 00:00:00 2001 From: nscipione Date: Wed, 15 Jan 2025 14:47:40 +0100 Subject: [PATCH 3/6] Reorder member variable to avoid warning on initialization Signed-off-by: nscipione --- ggml/src/ggml-sycl/ggml-sycl.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 5592e119ca15a..e57a72afa7a01 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1175,8 +1175,8 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { struct ggml_sycl_pool_host : public ggml_sycl_pool { - int device; queue_ptr qptr; + int device; inline static int counter{0}; struct ggml_sycl_buffer { @@ -1189,10 +1189,7 @@ struct ggml_sycl_pool_host : public ggml_sycl_pool { std::vector buffer_pool = std::vector(MAX_POOL_SIZE); size_t pool_size = 0; - explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : - qptr(qptr_), - device(device_) { - } + explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} ~ggml_sycl_pool_host() { for (int i = 0; i < MAX_POOL_SIZE; ++i) { From b0f14c5c2a463b0335eaa804ea2649cff89a6362 Mon Sep 17 00:00:00 2001 From: nscipione Date: Wed, 15 Jan 2025 15:13:44 +0000 Subject: [PATCH 4/6] Formatting Signed-off-by: nscipione --- ggml/src/ggml-sycl/common.hpp | 6 +- ggml/src/ggml-sycl/dpct/helper.hpp | 95 ++++++++++++------------------ ggml/src/ggml-sycl/ggml-sycl.cpp | 71 ++++++++++------------ 3 files changed, 71 insertions(+), 101 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 91b432da922dd..abad847ca8199 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -352,14 +352,12 @@ struct ggml_backend_sycl_context { ggml_sycl_pool & host_pool(int device) { if (host_pools[device] == nullptr) { - host_pools[device] = new_pool_for_host(stream(device,0), device); + host_pools[device] = new_pool_for_host(stream(device, 0), device); } return *host_pools[device]; } - ggml_sycl_pool & host_pool() { - return host_pool(device); - } + ggml_sycl_pool & host_pool() { return host_pool(device); } }; // common device functions diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index b66073f482287..c96395be61312 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -82,14 +82,12 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { return device_type.str(); } -template -struct matrix_info_t -{ +template struct matrix_info_t { oneapi::mkl::transpose transpose_info[2]; - Ts value_info[2]; - std::int64_t size_info[3]; - std::int64_t ld_info[3]; - std::int64_t groupsize_info; + Ts value_info[2]; + std::int64_t size_info[3]; + std::int64_t ld_info[3]; + std::int64_t groupsize_info; }; namespace dpct @@ -1737,13 +1735,10 @@ namespace dpct }; template - inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void **a, int lda, - const void **b, int ldb, const void *beta, void **c, - int ldc, int batch_size, matrix_info_t* matrix_info) - { - + inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, + int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, + int ldb, const void * beta, void ** c, int ldc, int batch_size, + matrix_info_t * matrix_info) { Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); @@ -1763,19 +1758,18 @@ namespace dpct sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( oneapi::mkl::backend_selector{ q }, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), - matrix_info->ld_info, reinterpret_cast(b), matrix_info->ld_info + 1, - reinterpret_cast(matrix_info->value_info+1), reinterpret_cast(c), matrix_info->ld_info + 2, 1, - &(matrix_info->groupsize_info)); + matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), + reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), + matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), + reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); #else sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, - matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), + matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), - matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), reinterpret_cast(c), - matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); + matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), + reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); #endif - } template @@ -2418,15 +2412,11 @@ namespace dpct /// \param [in] ldc Leading dimension of C. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a[], - library_data_t a_type, int lda, const void *b[], - library_data_t b_type, int ldb, const void *beta, - void *c[], library_data_t c_type, int ldc, - int batch_size, library_data_t scaling_type, - matrix_info_t* matrix_info) - { + inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, + int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, + const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], + library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, + matrix_info_t * matrix_info) { std::uint64_t key = detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); switch (key) @@ -2435,28 +2425,24 @@ namespace dpct library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size, matrix_info); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_double, library_data_t::real_double, library_data_t::real_double, library_data_t::real_double): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size, matrix_info); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_half): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, b, ldb, beta, c, ldc, - batch_size, matrix_info); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } #ifdef __INTEL_MKL__ @@ -2464,19 +2450,16 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size, matrix_info); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_size, matrix_info); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } #endif @@ -2488,10 +2471,9 @@ namespace dpct dpct::get_value(reinterpret_cast(alpha), q); float beta_float = dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, - a, lda, b, ldb, &beta_float, c, ldc, - batch_size, matrix_info); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size, + matrix_info); break; } case detail::get_type_combination_id( @@ -2499,8 +2481,7 @@ namespace dpct library_data_t::real_float, library_data_t::real_float): { detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size, matrix_info); + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2508,8 +2489,7 @@ namespace dpct library_data_t::real_float, library_data_t::real_float): { detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size, matrix_info); + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2523,8 +2503,7 @@ namespace dpct sycl::half alpha_half(alpha_value); sycl::half beta_half(beta_value); detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, - batch_size, matrix_info); + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info); break; } default: diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e57a72afa7a01..c6757f29f95eb 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1174,20 +1174,20 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { }; struct ggml_sycl_pool_host : public ggml_sycl_pool { - queue_ptr qptr; - int device; + int device; + + inline static int counter{ 0 }; - inline static int counter{0}; struct ggml_sycl_buffer { - void * ptr = nullptr; + void * ptr = nullptr; size_t size = 0; }; // Set arbitrarly to 64 - static constexpr int MAX_POOL_SIZE{64}; + static constexpr int MAX_POOL_SIZE{ 64 }; std::vector buffer_pool = std::vector(MAX_POOL_SIZE); - size_t pool_size = 0; + size_t pool_size = 0; explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} @@ -1205,32 +1205,29 @@ struct ggml_sycl_pool_host : public ggml_sycl_pool { } void * alloc(size_t size, size_t * actual_size) override { - if ( counter == MAX_POOL_SIZE){ - ggml_sycl_buffer b = buffer_pool[0]; - size_t look_ahead_size = (size_t) (1.05 * size); - void *ptr = b.ptr; - *actual_size = b.size; - counter = 1; + if (counter == MAX_POOL_SIZE) { + ggml_sycl_buffer b = buffer_pool[0]; + size_t look_ahead_size = (size_t) (1.05 * size); + void * ptr = b.ptr; + *actual_size = b.size; + counter = 1; return ptr; } - ggml_sycl_buffer& b = buffer_pool[counter]; + ggml_sycl_buffer & b = buffer_pool[counter]; if (b.ptr == nullptr) { - void * ptr; - - SYCL_CHECK( - CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_host( - size, *qptr))); - if (!ptr) { - GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); - return nullptr; - } - pool_size += size; - *actual_size = size; - counter = counter + 1; - return ptr; - } - else if (b.ptr != nullptr) { + void * ptr; + + SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr))); + if (!ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); + return nullptr; + } + pool_size += size; + *actual_size = size; + counter = counter + 1; + return ptr; + } else if (b.ptr != nullptr) { ++counter; b.size = size; return b.ptr; @@ -1241,9 +1238,9 @@ struct ggml_sycl_pool_host : public ggml_sycl_pool { // if the pool is not completed add the pointer to it in place of the first nullptr found. // Otherwise do nothing, pointers will be freed once the pool is deallocated. for (int i = 0; i < MAX_POOL_SIZE; ++i) { - ggml_sycl_buffer& b = buffer_pool[i]; + ggml_sycl_buffer & b = buffer_pool[i]; if (b.ptr == nullptr) { - b.ptr = ptr; + b.ptr = ptr; b.size = size; return; } @@ -3446,7 +3443,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); - ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(),1); + ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); sycl::range<3> block_dims(1, ne12, ne13); /* @@ -3475,14 +3472,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, }); } SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *main_stream, oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, - (const void **)(ptrs_src.get() + 0 * ne23), - dpct::library_data_t::real_half, nb01 / nb00, - (const void **)(ptrs_src.get() + 1 * ne23), - dpct::library_data_t::real_half, nb11 / nb10, beta, - (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, - cu_compute_type, matrix_info.get()))); + *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, + (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, + (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, + (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get()))); } } catch (sycl::exception const &exc) { From d8956a4ce38c06289f684fd56c10546b71e4de6d Mon Sep 17 00:00:00 2001 From: nscipione Date: Wed, 15 Jan 2025 16:53:38 +0100 Subject: [PATCH 5/6] Remove unused variable Signed-off-by: nscipione --- ggml/src/ggml-sycl/ggml-sycl.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index c6757f29f95eb..75f1b0f103a6b 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1207,7 +1207,6 @@ struct ggml_sycl_pool_host : public ggml_sycl_pool { void * alloc(size_t size, size_t * actual_size) override { if (counter == MAX_POOL_SIZE) { ggml_sycl_buffer b = buffer_pool[0]; - size_t look_ahead_size = (size_t) (1.05 * size); void * ptr = b.ptr; *actual_size = b.size; counter = 1; From 963b6850751ef6a17c4c58eb6b20d2226fd1ea6c Mon Sep 17 00:00:00 2001 From: nscipione Date: Thu, 16 Jan 2025 11:22:56 +0100 Subject: [PATCH 6/6] Address PR review feedback - remove warning Signed-off-by: nscipione --- ggml/src/ggml-sycl/ggml-sycl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 75f1b0f103a6b..1723888ae9455 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1226,7 +1226,7 @@ struct ggml_sycl_pool_host : public ggml_sycl_pool { *actual_size = size; counter = counter + 1; return ptr; - } else if (b.ptr != nullptr) { + } else { ++counter; b.size = size; return b.ptr;