@@ -12615,6 +12615,9 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
1261512615};
1261612616
1261712617ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
12618+ static std::mutex mutex;
12619+ std::lock_guard<std::mutex> lock(mutex);
12620+
1261812621 GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
1261912622
1262012623 if (device>=ggml_sycl_info().device_count or device<0) {
@@ -12640,31 +12643,6 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
1264012643 return &ggml_backend_sycl_buffer_types[device];
1264112644}
1264212645
12643- ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
12644- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
12645-
12646- int device = ctx->device;
12647- if (device>=ggml_sycl_info().device_count or device<0) {
12648- printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
12649- device, ggml_sycl_info().device_count-1);
12650- GGML_ASSERT(device<ggml_sycl_info().device_count);
12651- }
12652- static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
12653-
12654- static bool ggml_backend_sycl_buffer_type_initialized = false;
12655-
12656- if (!ggml_backend_sycl_buffer_type_initialized) {
12657- for (int i = 0; i < ggml_sycl_info().device_count; i++) {
12658- ggml_backend_sycl_buffer_types[i] = {
12659- /* .iface = */ ggml_backend_sycl_buffer_type_interface,
12660- /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)},
12661- };
12662- }
12663- ggml_backend_sycl_buffer_type_initialized = true;
12664- }
12665- return &ggml_backend_sycl_buffer_types[device];
12666- }
12667-
1266812646// sycl split buffer type
1266912647static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split, int id) {
1267012648 const int64_t nrows = ggml_nrows(tensor);
@@ -13016,6 +12994,9 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface
1301612994};
1301712995
1301812996GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
12997+ static std::mutex mutex;
12998+ std::lock_guard<std::mutex> lock(mutex);
12999+
1301913000 GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
1302013001 ggml_check_sycl();
1302113002 // FIXME: this is not thread safe
@@ -13123,16 +13104,17 @@ GGML_CALL static void ggml_backend_sycl_free(ggml_backend_t backend) {
1312313104
1312413105GGML_CALL static ggml_backend_buffer_type_t ggml_backend_sycl_get_default_buffer_type(ggml_backend_t backend) {
1312513106 ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
13126- return ggml_backend_sycl_buffer_type(sycl_ctx);
13107+ return ggml_backend_sycl_buffer_type(sycl_ctx->device );
1312713108}
1312813109
1312913110GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
1313013111 ggml_tensor *tensor,
1313113112 const void *data, size_t offset,
1313213113 size_t size) try {
1313313114 ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
13134- GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type");
13135- GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
13115+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
13116+
13117+ GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
1313613118 const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
1313713119 SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
1313813120 (char *)tensor->data + offset, data, size).wait()));
@@ -13148,8 +13130,9 @@ GGML_CALL static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
1314813130 void *data, size_t offset,
1314913131 size_t size) try {
1315013132 ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
13151- GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type");
13152- GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
13133+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
13134+
13135+ GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
1315313136 const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
1315413137 SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
1315513138 data, (const char *)tensor->data + offset, size).wait()));
@@ -13164,7 +13147,7 @@ GGML_CALL static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
1316413147 const ggml_tensor *src,
1316513148 ggml_tensor *dst) try {
1316613149 ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
13167- if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && ggml_backend_buffer_is_sycl(src->buffer)) {
13150+ if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device ) && ggml_backend_buffer_is_sycl(src->buffer)) {
1316813151 /*
1316913152 DPCT1009:215: SYCL uses exceptions to report errors and does not use the
1317013153 error codes. The original code was commented out and a warning string
@@ -13208,10 +13191,10 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back
1320813191 continue;
1320913192 }
1321013193#ifndef NDEBUG
13211- assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx));
13194+ assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device ));
1321213195 for (int j = 0; j < GGML_MAX_SRC; j++) {
1321313196 if (node->src[j] != nullptr) {
13214- assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx));
13197+ assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device ));
1321513198 }
1321613199 }
1321713200#endif
0 commit comments