diff --git a/unified-runtime/source/adapters/opencl/kernel.cpp b/unified-runtime/source/adapters/opencl/kernel.cpp index 9116ae70d2176..62fcabbf48efa 100644 --- a/unified-runtime/source/adapters/opencl/kernel.cpp +++ b/unified-runtime/source/adapters/opencl/kernel.cpp @@ -100,8 +100,13 @@ urKernelSetArgLocal(ur_kernel_handle_t hKernel, uint32_t argIndex, return UR_RESULT_SUCCESS; } -static cl_int mapURKernelInfoToCL(ur_kernel_info_t URPropName) { +// Querying the number of registers that a kernel uses is supported unofficially +// on some devices. +#ifndef CL_KERNEL_REGISTER_COUNT_INTEL +#define CL_KERNEL_REGISTER_COUNT_INTEL 0x425B +#endif +static cl_int mapURKernelInfoToCL(ur_kernel_info_t URPropName) { switch (static_cast(URPropName)) { case UR_KERNEL_INFO_FUNCTION_NAME: return CL_KERNEL_FUNCTION_NAME; @@ -115,9 +120,10 @@ static cl_int mapURKernelInfoToCL(ur_kernel_info_t URPropName) { return CL_KERNEL_PROGRAM; case UR_KERNEL_INFO_ATTRIBUTES: return CL_KERNEL_ATTRIBUTES; - // NUM_REGS doesn't have a CL equivalent - case UR_KERNEL_INFO_NUM_REGS: case UR_KERNEL_INFO_SPILL_MEM_SIZE: + return CL_KERNEL_SPILL_MEM_SIZE_INTEL; + case UR_KERNEL_INFO_NUM_REGS: + return CL_KERNEL_REGISTER_COUNT_INTEL; default: return -1; } @@ -132,10 +138,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet); switch (propName) { - // OpenCL doesn't have a way to support this. - case UR_KERNEL_INFO_NUM_REGS: { - return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; - } case UR_KERNEL_INFO_PROGRAM: { return ReturnValue(hKernel->Program); } @@ -145,9 +147,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, case UR_KERNEL_INFO_REFERENCE_COUNT: { return ReturnValue(hKernel->getReferenceCount()); } - case UR_KERNEL_INFO_SPILL_MEM_SIZE: { - return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; - } default: { size_t CheckPropSize = 0; cl_int ClResult = @@ -156,6 +155,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, if (pPropValue && CheckPropSize != propSize) { return UR_RESULT_ERROR_INVALID_SIZE; } + if (ClResult == CL_INVALID_VALUE) { + return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; + } CL_RETURN_ON_FAILURE(ClResult); if (pPropSizeRet) { *pPropSizeRet = CheckPropSize; @@ -428,25 +430,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer( ur_kernel_handle_t hKernel, uint32_t argIndex, const ur_kernel_arg_pointer_properties_t *, const void *pArgValue) { - cl_context CLContext; - CL_RETURN_ON_FAILURE(clGetKernelInfo(hKernel->CLKernel, CL_KERNEL_CONTEXT, - sizeof(cl_context), &CLContext, - nullptr)); - - clSetKernelArgMemPointerINTEL_fn FuncPtr = nullptr; - UR_RETURN_ON_FAILURE( - cl_ext::getExtFuncFromContext( - CLContext, - ur::cl::getAdapter()->fnCache.clSetKernelArgMemPointerINTELCache, - cl_ext::SetKernelArgMemPointerName, &FuncPtr)); - - if (FuncPtr) { - CL_RETURN_ON_FAILURE( - FuncPtr(hKernel->CLKernel, static_cast(argIndex), pArgValue)); + if (hKernel->clSetKernelArgMemPointerINTEL == nullptr) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } + CL_RETURN_ON_FAILURE(hKernel->clSetKernelArgMemPointerINTEL( + hKernel->CLKernel, static_cast(argIndex), pArgValue)); + return UR_RESULT_SUCCESS; } + UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle( ur_kernel_handle_t hKernel, ur_native_handle_t *phNativeKernel) { diff --git a/unified-runtime/source/adapters/opencl/kernel.hpp b/unified-runtime/source/adapters/opencl/kernel.hpp index 2b3c2dbe8464b..f94b11f87c3e8 100644 --- a/unified-runtime/source/adapters/opencl/kernel.hpp +++ b/unified-runtime/source/adapters/opencl/kernel.hpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #pragma once +#include "adapter.hpp" #include "common.hpp" #include "context.hpp" #include "program.hpp" @@ -22,6 +23,7 @@ struct ur_kernel_handle_t_ { ur_context_handle_t Context; std::atomic RefCount = 0; bool IsNativeHandleOwned = true; + clSetKernelArgMemPointerINTEL_fn clSetKernelArgMemPointerINTEL = nullptr; ur_kernel_handle_t_(native_type Kernel, ur_program_handle_t Program, ur_context_handle_t Context) @@ -29,6 +31,11 @@ struct ur_kernel_handle_t_ { RefCount = 1; urProgramRetain(Program); urContextRetain(Context); + + cl_ext::getExtFuncFromContext( + Context->CLContext, + ur::cl::getAdapter()->fnCache.clSetKernelArgMemPointerINTELCache, + cl_ext::SetKernelArgMemPointerName, &clSetKernelArgMemPointerINTEL); } ~ur_kernel_handle_t_() {