diff --git a/src/provider/provider_cuda.c b/src/provider/provider_cuda.c index 85d6d9b798..4d7a265164 100644 --- a/src/provider/provider_cuda.c +++ b/src/provider/provider_cuda.c @@ -51,6 +51,8 @@ typedef struct cu_ops_t { CUresult (*cuGetErrorName)(CUresult error, const char **pStr); CUresult (*cuGetErrorString)(CUresult error, const char **pStr); + CUresult (*cuCtxGetCurrent)(CUcontext *pctx); + CUresult (*cuCtxSetCurrent)(CUcontext ctx); } cu_ops_t; static cu_ops_t g_cu_ops; @@ -117,11 +119,16 @@ static void init_cu_global_state(void) { utils_get_symbol_addr(0, "cuGetErrorName", lib_name); *(void **)&g_cu_ops.cuGetErrorString = utils_get_symbol_addr(0, "cuGetErrorString", lib_name); + *(void **)&g_cu_ops.cuCtxGetCurrent = + utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name); + *(void **)&g_cu_ops.cuCtxSetCurrent = + utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name); if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc || !g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged || !g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost || - !g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString) { + !g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString || + !g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent) { LOG_ERR("Required CUDA symbols not found."); Init_cu_global_state_failed = true; } @@ -190,6 +197,31 @@ static void cu_memory_provider_finalize(void *provider) { umf_ba_global_free(provider); } +/* + * This function is used by the CUDA provider to make sure that + * the required context is set. If the current context is + * not the required one, it will be saved in restore_ctx. + */ +static inline umf_result_t set_context(CUcontext required_ctx, + CUcontext *restore_ctx) { + CUcontext current_ctx = NULL; + CUresult cu_result = g_cu_ops.cuCtxGetCurrent(¤t_ctx); + if (cu_result != CUDA_SUCCESS) { + LOG_ERR("cuCtxGetCurrent() failed."); + return cu2umf_result(cu_result); + } + *restore_ctx = current_ctx; + if (current_ctx != required_ctx) { + cu_result = g_cu_ops.cuCtxSetCurrent(required_ctx); + if (cu_result != CUDA_SUCCESS) { + LOG_ERR("cuCtxSetCurrent() failed."); + return cu2umf_result(cu_result); + } + } + + return UMF_RESULT_SUCCESS; +} + static umf_result_t cu_memory_provider_alloc(void *provider, size_t size, size_t alignment, void **resultPtr) { @@ -205,6 +237,14 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size, return UMF_RESULT_ERROR_NOT_SUPPORTED; } + // Remember current context and set the one from the provider + CUcontext restore_ctx = NULL; + umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx); + if (umf_result != UMF_RESULT_SUCCESS) { + LOG_ERR("Failed to set CUDA context, ret = %d", umf_result); + return umf_result; + } + CUresult cu_result = CUDA_SUCCESS; switch (cu_provider->memory_type) { case UMF_MEMORY_TYPE_HOST: { @@ -224,17 +264,29 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size, // this shouldn't happen as we check the memory_type settings during // the initialization LOG_ERR("unsupported USM memory type"); + assert(false); return UMF_RESULT_ERROR_UNKNOWN; } + umf_result = set_context(restore_ctx, &restore_ctx); + if (umf_result != UMF_RESULT_SUCCESS) { + LOG_ERR("Failed to restore CUDA context, ret = %d", umf_result); + } + + umf_result = cu2umf_result(cu_result); + if (umf_result != UMF_RESULT_SUCCESS) { + LOG_ERR("Failed to allocate memory, cu_result = %d, ret = %d", + cu_result, umf_result); + return umf_result; + } + // check the alignment if (alignment > 0 && ((uintptr_t)(*resultPtr) % alignment) != 0) { cu_memory_provider_free(provider, *resultPtr, size); LOG_ERR("unsupported alignment size"); return UMF_RESULT_ERROR_INVALID_ALIGNMENT; } - - return cu2umf_result(cu_result); + return umf_result; } static umf_result_t cu_memory_provider_free(void *provider, void *ptr, diff --git a/test/providers/cuda_helpers.cpp b/test/providers/cuda_helpers.cpp index 366efc197d..2b332dde18 100644 --- a/test/providers/cuda_helpers.cpp +++ b/test/providers/cuda_helpers.cpp @@ -17,6 +17,7 @@ struct libcu_ops { CUresult (*cuInit)(unsigned int flags); CUresult (*cuCtxCreate)(CUcontext *pctx, unsigned int flags, CUdevice dev); CUresult (*cuCtxDestroy)(CUcontext ctx); + CUresult (*cuCtxGetCurrent)(CUcontext *pctx); CUresult (*cuDeviceGet)(CUdevice *device, int ordinal); CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t size); CUresult (*cuMemFree)(CUdeviceptr dptr); @@ -26,7 +27,9 @@ struct libcu_ops { CUresult (*cuMemFreeHost)(void *p); CUresult (*cuMemsetD32)(CUdeviceptr dstDevice, unsigned int pattern, size_t size); - CUresult (*cuMemcpyDtoH)(void *dstHost, CUdeviceptr srcDevice, size_t size); + CUresult (*cuMemcpy)(CUdeviceptr dst, CUdeviceptr src, size_t size); + CUresult (*cuPointerGetAttribute)(void *data, CUpointer_attribute attribute, + CUdeviceptr ptr); CUresult (*cuPointerGetAttributes)(unsigned int numAttributes, CUpointer_attribute *attributes, void **data, CUdeviceptr ptr); @@ -74,6 +77,12 @@ int InitCUDAOps() { fprintf(stderr, "cuCtxDestroy symbol not found in %s\n", lib_name); return -1; } + *(void **)&libcu_ops.cuCtxGetCurrent = + utils_get_symbol_addr(cuDlHandle.get(), "cuCtxGetCurrent", lib_name); + if (libcu_ops.cuCtxGetCurrent == nullptr) { + fprintf(stderr, "cuCtxGetCurrent symbol not found in %s\n", lib_name); + return -1; + } *(void **)&libcu_ops.cuDeviceGet = utils_get_symbol_addr(cuDlHandle.get(), "cuDeviceGet", lib_name); if (libcu_ops.cuDeviceGet == nullptr) { @@ -116,10 +125,17 @@ int InitCUDAOps() { fprintf(stderr, "cuMemsetD32_v2 symbol not found in %s\n", lib_name); return -1; } - *(void **)&libcu_ops.cuMemcpyDtoH = - utils_get_symbol_addr(cuDlHandle.get(), "cuMemcpyDtoH_v2", lib_name); - if (libcu_ops.cuMemcpyDtoH == nullptr) { - fprintf(stderr, "cuMemcpyDtoH_v2 symbol not found in %s\n", lib_name); + *(void **)&libcu_ops.cuMemcpy = + utils_get_symbol_addr(cuDlHandle.get(), "cuMemcpy", lib_name); + if (libcu_ops.cuMemcpy == nullptr) { + fprintf(stderr, "cuMemcpy symbol not found in %s\n", lib_name); + return -1; + } + *(void **)&libcu_ops.cuPointerGetAttribute = utils_get_symbol_addr( + cuDlHandle.get(), "cuPointerGetAttribute", lib_name); + if (libcu_ops.cuPointerGetAttribute == nullptr) { + fprintf(stderr, "cuPointerGetAttribute symbol not found in %s\n", + lib_name); return -1; } *(void **)&libcu_ops.cuPointerGetAttributes = utils_get_symbol_addr( @@ -140,6 +156,7 @@ int InitCUDAOps() { libcu_ops.cuInit = cuInit; libcu_ops.cuCtxCreate = cuCtxCreate; libcu_ops.cuCtxDestroy = cuCtxDestroy; + libcu_ops.cuCtxGetCurrent = cuCtxGetCurrent; libcu_ops.cuDeviceGet = cuDeviceGet; libcu_ops.cuMemAlloc = cuMemAlloc; libcu_ops.cuMemAllocHost = cuMemAllocHost; @@ -147,7 +164,8 @@ int InitCUDAOps() { libcu_ops.cuMemFree = cuMemFree; libcu_ops.cuMemFreeHost = cuMemFreeHost; libcu_ops.cuMemsetD32 = cuMemsetD32; - libcu_ops.cuMemcpyDtoH = cuMemcpyDtoH; + libcu_ops.cuMemcpy = cuMemcpy; + libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute; libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes; return 0; @@ -193,9 +211,10 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr, (void)device; int ret = 0; - CUresult res = libcu_ops.cuMemcpyDtoH(dst_ptr, (CUdeviceptr)src_ptr, size); + CUresult res = + libcu_ops.cuMemcpy((CUdeviceptr)dst_ptr, (CUdeviceptr)src_ptr, size); if (res != CUDA_SUCCESS) { - fprintf(stderr, "cuMemcpyDtoH() failed!\n"); + fprintf(stderr, "cuMemcpy() failed!\n"); return -1; } @@ -230,6 +249,29 @@ umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr) { return UMF_MEMORY_TYPE_UNKNOWN; } +CUcontext get_mem_context(void *ptr) { + CUcontext context; + CUresult res = libcu_ops.cuPointerGetAttribute( + &context, CU_POINTER_ATTRIBUTE_CONTEXT, (CUdeviceptr)ptr); + if (res != CUDA_SUCCESS) { + fprintf(stderr, "cuPointerGetAttribute() failed!\n"); + return nullptr; + } + + return context; +} + +CUcontext get_current_context() { + CUcontext context; + CUresult res = libcu_ops.cuCtxGetCurrent(&context); + if (res != CUDA_SUCCESS) { + fprintf(stderr, "cuCtxGetCurrent() failed!\n"); + return nullptr; + } + + return context; +} + UTIL_ONCE_FLAG cuda_init_flag; int InitResult; void init_cuda_once() { diff --git a/test/providers/cuda_helpers.h b/test/providers/cuda_helpers.h index 3227fc9c59..5e42153bb7 100644 --- a/test/providers/cuda_helpers.h +++ b/test/providers/cuda_helpers.h @@ -26,6 +26,10 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr, void *src_ptr, umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr); +CUcontext get_mem_context(void *ptr); + +CUcontext get_current_context(); + cuda_memory_provider_params_t create_cuda_prov_params(umf_usm_memory_type_t memory_type); diff --git a/test/providers/provider_cuda.cpp b/test/providers/provider_cuda.cpp index f563d45c8a..4bdbbba73b 100644 --- a/test/providers/provider_cuda.cpp +++ b/test/providers/provider_cuda.cpp @@ -21,10 +21,8 @@ using namespace umf_test; class CUDAMemoryAccessor : public MemoryAccessor { public: - void init(CUcontext hContext, CUdevice hDevice) { - hDevice_ = hDevice; - hContext_ = hContext; - } + CUDAMemoryAccessor(CUcontext hContext, CUdevice hDevice) + : hDevice_(hDevice), hContext_(hContext) {} void fill(void *ptr, size_t size, const void *pattern, size_t pattern_size) { @@ -53,7 +51,7 @@ class CUDAMemoryAccessor : public MemoryAccessor { }; using CUDAProviderTestParams = - std::tuple; + std::tuple; struct umfCUDAProviderTest : umf_test::test, @@ -62,23 +60,12 @@ struct umfCUDAProviderTest void SetUp() override { test::SetUp(); - auto [memory_type, accessor] = this->GetParam(); - params = create_cuda_prov_params(memory_type); + auto [cuda_params, accessor] = this->GetParam(); + params = cuda_params; memAccessor = accessor; - if (memory_type == UMF_MEMORY_TYPE_DEVICE) { - ((CUDAMemoryAccessor *)memAccessor) - ->init((CUcontext)params.cuda_context_handle, - params.cuda_device_handle); - } } - void TearDown() override { - if (params.cuda_context_handle) { - int ret = destroy_context((CUcontext)params.cuda_context_handle); - ASSERT_EQ(ret, 0); - } - test::TearDown(); - } + void TearDown() override { test::TearDown(); } cuda_memory_provider_params_t params; MemoryAccessor *memAccessor = nullptr; @@ -87,6 +74,7 @@ struct umfCUDAProviderTest TEST_P(umfCUDAProviderTest, basic) { const size_t size = 1024 * 8; const uint32_t pattern = 0xAB; + CUcontext expected_current_context = get_current_context(); // create CUDA provider umf_memory_provider_handle_t provider = nullptr; @@ -113,6 +101,12 @@ TEST_P(umfCUDAProviderTest, basic) { // use the allocated memory - fill it with a 0xAB pattern memAccessor->fill(ptr, size, &pattern, sizeof(pattern)); + CUcontext actual_mem_context = get_mem_context(ptr); + ASSERT_EQ(actual_mem_context, (CUcontext)params.cuda_context_handle); + + CUcontext actual_current_context = get_current_context(); + ASSERT_EQ(actual_current_context, expected_current_context); + umf_usm_memory_type_t memoryTypeActual = get_mem_type((CUcontext)params.cuda_context_handle, ptr); ASSERT_EQ(memoryTypeActual, params.memory_type); @@ -132,6 +126,7 @@ TEST_P(umfCUDAProviderTest, basic) { } TEST_P(umfCUDAProviderTest, allocInvalidSize) { + CUcontext expected_current_context = get_current_context(); // create CUDA provider umf_memory_provider_handle_t provider = nullptr; umf_result_t umf_result = @@ -151,32 +146,32 @@ TEST_P(umfCUDAProviderTest, allocInvalidSize) { ASSERT_EQ(umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT); } - // destroy context and try to alloc some memory - destroy_context((CUcontext)params.cuda_context_handle); - params.cuda_context_handle = 0; - umf_result = umfMemoryProviderAlloc(provider, 128, 0, &ptr); - ASSERT_EQ(umf_result, UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC); - - const char *message; - int32_t error; - umfMemoryProviderGetLastNativeError(provider, &message, &error); - ASSERT_EQ(error, CUDA_ERROR_INVALID_CONTEXT); - const char *expected_message = - "CUDA_ERROR_INVALID_CONTEXT - invalid device context"; - ASSERT_EQ(strncmp(message, expected_message, strlen(expected_message)), 0); + CUcontext actual_current_context = get_current_context(); + ASSERT_EQ(actual_current_context, expected_current_context); + + umfMemoryProviderDestroy(provider); } // TODO add tests that mixes CUDA Memory Provider and Disjoint Pool -CUDAMemoryAccessor cuAccessor; +cuda_memory_provider_params_t cuParams_device_memory = + create_cuda_prov_params(UMF_MEMORY_TYPE_DEVICE); +cuda_memory_provider_params_t cuParams_shared_memory = + create_cuda_prov_params(UMF_MEMORY_TYPE_SHARED); +cuda_memory_provider_params_t cuParams_host_memory = + create_cuda_prov_params(UMF_MEMORY_TYPE_HOST); + +CUDAMemoryAccessor + cuAccessor((CUcontext)cuParams_device_memory.cuda_context_handle, + (CUdevice)cuParams_device_memory.cuda_device_handle); HostMemoryAccessor hostAccessor; INSTANTIATE_TEST_SUITE_P( umfCUDAProviderTestSuite, umfCUDAProviderTest, ::testing::Values( - CUDAProviderTestParams{UMF_MEMORY_TYPE_DEVICE, &cuAccessor}, - CUDAProviderTestParams{UMF_MEMORY_TYPE_SHARED, &hostAccessor}, - CUDAProviderTestParams{UMF_MEMORY_TYPE_HOST, &hostAccessor})); + CUDAProviderTestParams{cuParams_device_memory, &cuAccessor}, + CUDAProviderTestParams{cuParams_shared_memory, &hostAccessor}, + CUDAProviderTestParams{cuParams_host_memory, &hostAccessor})); // TODO: add IPC API GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(umfIpcTest); @@ -185,5 +180,5 @@ INSTANTIATE_TEST_SUITE_P(umfCUDAProviderTestSuite, umfIpcTest, ::testing::Values(ipcTestParams{ umfProxyPoolOps(), nullptr, umfCUDAMemoryProviderOps(), - &cuParams_device_memory, &l0Accessor})); + &cuParams_device_memory, &cuAccessor})); */