From d04849c986b18e6b59a8f4f8d9b39ba5d6a3d1b6 Mon Sep 17 00:00:00 2001 From: wang yize Date: Fri, 6 Jun 2025 13:34:55 +0800 Subject: [PATCH 1/8] Add cudagraph --- cpp/neuralnet/trtbackend.cpp | 216 +++++++++++++++++++---------------- 1 file changed, 119 insertions(+), 97 deletions(-) diff --git a/cpp/neuralnet/trtbackend.cpp b/cpp/neuralnet/trtbackend.cpp index c6df9f251..805885241 100644 --- a/cpp/neuralnet/trtbackend.cpp +++ b/cpp/neuralnet/trtbackend.cpp @@ -1089,7 +1089,10 @@ struct ComputeHandle { unique_ptr runtime; unique_ptr engine; unique_ptr exec; - +#ifdef TENSORRT_CUDA_GRAPH + vector cudaGraphs; + vector cudaGraphExecs; +#endif ComputeHandle( Logger* logger, const cudaDeviceProp* prop, @@ -1546,15 +1549,15 @@ struct InputBuffers { size_t scoreValueResultBufferBytes; size_t ownershipResultBufferBytes; - unique_ptr maskInputs; // Host pointer - unique_ptr spatialInputs; // Host pointer - unique_ptr globalInputs; // Host pointer - unique_ptr metaInputs; // Host pointer - unique_ptr policyPassResults; // Host pointer - unique_ptr policyResults; // Host pointer - unique_ptr valueResults; // Host pointer - unique_ptr scoreValueResults; // Host pointer - unique_ptr ownershipResults; // Host pointer + float* maskInputs; // Host pointer + float* spatialInputs; // Host pointer + float* globalInputs; // Host pointer + float* metaInputs; // Host pointer + float* policyPassResults; // Host pointer + float* policyResults; // Host pointer + float* valueResults; // Host pointer + float* scoreValueResults; // Host pointer + float* ownershipResults; // Host pointer InputBuffers(const LoadedModel* loadedModel, int maxBatchSz, int nnXLen, int nnYLen) { const ModelDesc& m = loadedModel->modelDesc; @@ -1602,15 +1605,15 @@ struct InputBuffers { scoreValueResultBufferBytes = maxBatchSize * singleScoreValueResultBytes; ownershipResultBufferBytes = maxBatchSize * singleOwnershipResultBytes; - maskInputs = make_unique(maxBatchSize * singleMaskElts); - spatialInputs = make_unique(maxBatchSize * singleInputElts); - globalInputs = make_unique(maxBatchSize * singleInputGlobalElts); - metaInputs = make_unique(maxBatchSize * singleInputMetaElts); - policyPassResults = make_unique(maxBatchSize * singlePolicyPassResultElts); - policyResults = make_unique(maxBatchSize * singlePolicyResultElts); - valueResults = make_unique(maxBatchSize * singleValueResultElts); - scoreValueResults = make_unique(maxBatchSize * singleScoreValueResultElts); - ownershipResults = make_unique(maxBatchSize * singleOwnershipResultElts); + cudaMallocHost((void**)&maskInputs, maxBatchSize * singleMaskElts * sizeof(float)); + cudaMallocHost((void**)&spatialInputs, maxBatchSize * singleInputElts * sizeof(float)); + cudaMallocHost((void**)&globalInputs, maxBatchSize * singleInputGlobalElts * sizeof(float)); + cudaMallocHost((void**)&metaInputs, maxBatchSize * singleInputMetaElts * sizeof(float)); + cudaMallocHost((void**)&policyPassResults, maxBatchSize * singlePolicyPassResultElts * sizeof(float)); + cudaMallocHost((void**)&policyResults, maxBatchSize * singlePolicyResultElts * sizeof(float)); + cudaMallocHost((void**)&valueResults, maxBatchSize * singleValueResultElts * sizeof(float)); + cudaMallocHost((void**)&scoreValueResults, maxBatchSize * singleScoreValueResultElts * sizeof(float)); + cudaMallocHost((void**)&ownershipResults, maxBatchSize * singleOwnershipResultElts * sizeof(float)); } InputBuffers() = delete; @@ -1656,7 +1659,6 @@ void NeuralNet::getOutput( const float* rowSpatial = inputBufs[nIdx]->rowSpatialBuf.data(); const float* rowMeta = inputBufs[nIdx]->rowMetaBuf.data(); const bool hasRowMeta = inputBufs[nIdx]->hasRowMeta; - copy(rowGlobal, rowGlobal + numGlobalFeatures, rowGlobalInput); std::copy(rowGlobal,rowGlobal+numGlobalFeatures,rowGlobalInput); if(numMetaFeatures > 0) { testAssert(rowMeta != NULL); @@ -1696,89 +1698,109 @@ void NeuralNet::getOutput( const int numPolicyChannels = inputBuffers->singlePolicyPassResultElts; assert(inputBuffers->singlePolicyResultElts == numPolicyChannels * nnXLen * nnYLen); - // Transfers from host memory to device memory are asynchronous with respect to the host - CUDA_ERR( - "getOutput", - cudaMemcpyAsync( - gpuHandle->getBuffer("InputMask"), - inputBuffers->maskInputs.get(), - inputBuffers->singleMaskBytes * batchSize, - cudaMemcpyHostToDevice)); - CUDA_ERR( - "getOutput", - cudaMemcpyAsync( - gpuHandle->getBuffer("InputSpatial"), - inputBuffers->spatialInputs.get(), - inputBuffers->singleInputBytes * batchSize, - cudaMemcpyHostToDevice)); - CUDA_ERR( - "getOutput", - cudaMemcpyAsync( - gpuHandle->getBuffer("InputGlobal"), - inputBuffers->globalInputs.get(), - inputBuffers->singleInputGlobalBytes * batchSize, - cudaMemcpyHostToDevice)); - if(numMetaFeatures > 0) { - CUDA_ERR( - "getOutput", - cudaMemcpyAsync( - gpuHandle->getBuffer("InputMeta"), - inputBuffers->metaInputs.get(), - inputBuffers->singleInputMetaBytes * batchSize, - cudaMemcpyHostToDevice)); +#ifdef TENSORRT_CUDA_GRAPH + if(gpuHandle->cudaGraphExecs.empty()) { + gpuHandle->cudaGraphs.resize(inputBuffers->maxBatchSize + 1); + gpuHandle->cudaGraphExecs.resize(inputBuffers->maxBatchSize + 1); } + auto& graph = gpuHandle->cudaGraphs[batchSize]; + auto& instance = gpuHandle->cudaGraphExecs[batchSize]; + if(instance == nullptr) { // First evaluation with current batchsize. Initialize cuda graph +#endif - auto maskInputDims = gpuHandle->getBufferDynamicShape("InputMask", batchSize); - auto spatialInputDims = gpuHandle->getBufferDynamicShape("InputSpatial", batchSize); - auto globalInputDims = gpuHandle->getBufferDynamicShape("InputGlobal", batchSize); + auto maskInputDims = gpuHandle->getBufferDynamicShape("InputMask", batchSize); + auto spatialInputDims = gpuHandle->getBufferDynamicShape("InputSpatial", batchSize); + auto globalInputDims = gpuHandle->getBufferDynamicShape("InputGlobal", batchSize); - gpuHandle->exec->setInputShape("InputMask", maskInputDims); - gpuHandle->exec->setInputShape("InputSpatial", spatialInputDims); - gpuHandle->exec->setInputShape("InputGlobal", globalInputDims); + gpuHandle->exec->setInputShape("InputMask", maskInputDims); + gpuHandle->exec->setInputShape("InputSpatial", spatialInputDims); + gpuHandle->exec->setInputShape("InputGlobal", globalInputDims); - if(numMetaFeatures > 0) { - auto metaInputDims = gpuHandle->getBufferDynamicShape("InputMeta", batchSize); - gpuHandle->exec->setInputShape("InputMeta", metaInputDims); - } + if(numMetaFeatures > 0) { + auto metaInputDims = gpuHandle->getBufferDynamicShape("InputMeta", batchSize); + gpuHandle->exec->setInputShape("InputMeta", metaInputDims); + } +#ifdef TENSORRT_CUDA_GRAPH + gpuHandle->exec->enqueueV3(cudaStreamPerThread); // Warm up + cudaStreamBeginCapture( + cudaStreamPerThread, cudaStreamCaptureModeThreadLocal); // In case other server thread is also capturing. +#endif + // Transfers from host memory to device memory are asynchronous with respect to the host + CUDA_ERR( + "getOutput", + cudaMemcpyAsync( + gpuHandle->getBuffer("InputMask"), + inputBuffers->maskInputs, + inputBuffers->singleMaskBytes * batchSize, + cudaMemcpyHostToDevice)); + CUDA_ERR( + "getOutput", + cudaMemcpyAsync( + gpuHandle->getBuffer("InputSpatial"), + inputBuffers->spatialInputs, + inputBuffers->singleInputBytes * batchSize, + cudaMemcpyHostToDevice)); + CUDA_ERR( + "getOutput", + cudaMemcpyAsync( + gpuHandle->getBuffer("InputGlobal"), + inputBuffers->globalInputs, + inputBuffers->singleInputGlobalBytes * batchSize, + cudaMemcpyHostToDevice)); + if(numMetaFeatures > 0) { + CUDA_ERR( + "getOutput", + cudaMemcpyAsync( + gpuHandle->getBuffer("InputMeta"), + inputBuffers->metaInputs, + inputBuffers->singleInputMetaBytes * batchSize, + cudaMemcpyHostToDevice)); + } - gpuHandle->exec->enqueueV3(cudaStreamPerThread); - - CUDA_ERR( - "getOutput", - cudaMemcpy( - inputBuffers->policyPassResults.get(), - gpuHandle->getBuffer("OutputPolicyPass"), - inputBuffers->singlePolicyPassResultBytes * batchSize, - cudaMemcpyDeviceToHost)); - CUDA_ERR( - "getOutput", - cudaMemcpy( - inputBuffers->policyResults.get(), - gpuHandle->getBuffer("OutputPolicy"), - inputBuffers->singlePolicyResultBytes * batchSize, - cudaMemcpyDeviceToHost)); - CUDA_ERR( - "getOutput", - cudaMemcpy( - inputBuffers->valueResults.get(), - gpuHandle->getBuffer("OutputValue"), - inputBuffers->singleValueResultBytes * batchSize, - cudaMemcpyDeviceToHost)); - CUDA_ERR( - "getOutput", - cudaMemcpy( - inputBuffers->scoreValueResults.get(), - gpuHandle->getBuffer("OutputScoreValue"), - inputBuffers->singleScoreValueResultBytes * batchSize, - cudaMemcpyDeviceToHost)); - CUDA_ERR( - "getOutput", - cudaMemcpy( - inputBuffers->ownershipResults.get(), - gpuHandle->getBuffer("OutputOwnership"), - inputBuffers->singleOwnershipResultBytes * batchSize, - cudaMemcpyDeviceToHost)); + gpuHandle->exec->enqueueV3(cudaStreamPerThread); + CUDA_ERR( + "getOutput", + cudaMemcpy( + inputBuffers->policyPassResults.get(), + gpuHandle->getBuffer("OutputPolicyPass"), + inputBuffers->singlePolicyPassResultBytes * batchSize, + cudaMemcpyDeviceToHost)); + CUDA_ERR( + "getOutput", + cudaMemcpy( + inputBuffers->policyResults.get(), + gpuHandle->getBuffer("OutputPolicy"), + inputBuffers->singlePolicyResultBytes * batchSize, + cudaMemcpyDeviceToHost)); + CUDA_ERR( + "getOutput", + cudaMemcpy( + inputBuffers->valueResults.get(), + gpuHandle->getBuffer("OutputValue"), + inputBuffers->singleValueResultBytes * batchSize, + cudaMemcpyDeviceToHost)); + CUDA_ERR( + "getOutput", + cudaMemcpy( + inputBuffers->scoreValueResults.get(), + gpuHandle->getBuffer("OutputScoreValue"), + inputBuffers->singleScoreValueResultBytes * batchSize, + cudaMemcpyDeviceToHost)); + CUDA_ERR( + "getOutput", + cudaMemcpy( + inputBuffers->ownershipResults.get(), + gpuHandle->getBuffer("OutputOwnership"), + inputBuffers->singleOwnershipResultBytes * batchSize, + cudaMemcpyDeviceToHost)); +#ifdef TENSORRT_CUDA_GRAPH + cudaStreamEndCapture(cudaStreamPerThread, &graph); + cudaGraphInstantiate(&instance, graph, 0); + } // if (instance == nullptr) + cudaGraphLaunch(instance, cudaStreamPerThread); +#endif + cudaStreamSynchronize(cudaStreamPerThread); gpuHandle->printDebugOutput(batchSize); gpuHandle->trtErrorRecorder.clear(); From 5336a81b351e171af92c8f3c36ab457b88baae50 Mon Sep 17 00:00:00 2001 From: wang yize Date: Fri, 6 Jun 2025 13:44:26 +0800 Subject: [PATCH 2/8] Update cmakelist for USE_TENSORRT_CUDA_GRAPH --- cpp/CMakeLists.txt | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 515abd2ae..bc6fb420a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -40,7 +40,9 @@ set(USE_AVX2 0 CACHE BOOL "Compile with AVX2") set(USE_BIGGER_BOARDS_EXPENSIVE 0 CACHE BOOL "Allow boards up to size 50. Compiling with this will use more memory and slow down KataGo, even when playing on boards of size 19.") set(USE_CACHE_TENSORRT_PLAN 0 CACHE BOOL "Use TENSORRT plan cache. May use a lot of disk space. Only applies when USE_BACKEND is TENSORRT.") +set(USE_TENSORRT_CUDA_GRAPH 0 CACHE BOOL "Use cudaGraph when using TENSORRT backend.") mark_as_advanced(USE_CACHE_TENSORRT_PLAN) +mark_as_advanced(USE_TENSORRT_CUDA_GRAPH) #--------------------------- NEURAL NET BACKEND ------------------------------------------------------------------------ @@ -89,10 +91,14 @@ elseif(USE_BACKEND STREQUAL "TENSORRT") ) if(USE_CACHE_TENSORRT_PLAN AND (NOT BUILD_DISTRIBUTED)) message(STATUS "-DUSE_CACHE_TENSORRT_PLAN is set, using TENSORRT plan cache.") - add_compile_definitions(CACHE_TENSORRT_PLAN) + add_compile_definitions(CACHE_TENSORRT_PLAN) elseif(USE_CACHE_TENSORRT_PLAN AND BUILD_DISTRIBUTED) message(FATAL_ERROR "Combining USE_CACHE_TENSORRT_PLAN with BUILD_DISTRIBUTED is not supported - it would consume excessive disk space and might worsen performance every time models are updated. Use only one at a time in a given build of KataGo.") endif() + if(USE_TENSORRT_CUDA_GRAPH) + message(STATUS "-DUSE_TENSORRT_CUDA_GRAPH is set, using cuda graph at inference") + add_compile_definitions(TENSORRT_CUDA_GRAPH) + endif() elseif(USE_BACKEND STREQUAL "METAL") message(STATUS "-DUSE_BACKEND=METAL, using Metal backend.") if(NOT "${CMAKE_GENERATOR}" STREQUAL "Ninja") From 8242d03fcba1ac911a11e91e1e6b9e818d8288c6 Mon Sep 17 00:00:00 2001 From: wang yize Date: Fri, 6 Jun 2025 13:49:03 +0800 Subject: [PATCH 3/8] fix compile error --- cpp/neuralnet/trtbackend.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/neuralnet/trtbackend.cpp b/cpp/neuralnet/trtbackend.cpp index 805885241..d736e542c 100644 --- a/cpp/neuralnet/trtbackend.cpp +++ b/cpp/neuralnet/trtbackend.cpp @@ -1762,35 +1762,35 @@ void NeuralNet::getOutput( CUDA_ERR( "getOutput", cudaMemcpy( - inputBuffers->policyPassResults.get(), + inputBuffers->policyPassResults, gpuHandle->getBuffer("OutputPolicyPass"), inputBuffers->singlePolicyPassResultBytes * batchSize, cudaMemcpyDeviceToHost)); CUDA_ERR( "getOutput", cudaMemcpy( - inputBuffers->policyResults.get(), + inputBuffers->policyResults, gpuHandle->getBuffer("OutputPolicy"), inputBuffers->singlePolicyResultBytes * batchSize, cudaMemcpyDeviceToHost)); CUDA_ERR( "getOutput", cudaMemcpy( - inputBuffers->valueResults.get(), + inputBuffers->valueResults, gpuHandle->getBuffer("OutputValue"), inputBuffers->singleValueResultBytes * batchSize, cudaMemcpyDeviceToHost)); CUDA_ERR( "getOutput", cudaMemcpy( - inputBuffers->scoreValueResults.get(), + inputBuffers->scoreValueResults, gpuHandle->getBuffer("OutputScoreValue"), inputBuffers->singleScoreValueResultBytes * batchSize, cudaMemcpyDeviceToHost)); CUDA_ERR( "getOutput", cudaMemcpy( - inputBuffers->ownershipResults.get(), + inputBuffers->ownershipResults, gpuHandle->getBuffer("OutputOwnership"), inputBuffers->singleOwnershipResultBytes * batchSize, cudaMemcpyDeviceToHost)); From 476d17f9912d05fc2bec37d41326bceac896daf4 Mon Sep 17 00:00:00 2001 From: wang yize Date: Fri, 6 Jun 2025 14:01:54 +0800 Subject: [PATCH 4/8] Must use async io when capturing --- cpp/neuralnet/trtbackend.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/neuralnet/trtbackend.cpp b/cpp/neuralnet/trtbackend.cpp index d736e542c..90cddc94a 100644 --- a/cpp/neuralnet/trtbackend.cpp +++ b/cpp/neuralnet/trtbackend.cpp @@ -1761,35 +1761,35 @@ void NeuralNet::getOutput( CUDA_ERR( "getOutput", - cudaMemcpy( + cudaMemcpyAsync( inputBuffers->policyPassResults, gpuHandle->getBuffer("OutputPolicyPass"), inputBuffers->singlePolicyPassResultBytes * batchSize, cudaMemcpyDeviceToHost)); CUDA_ERR( "getOutput", - cudaMemcpy( + cudaMemcpyAsync( inputBuffers->policyResults, gpuHandle->getBuffer("OutputPolicy"), inputBuffers->singlePolicyResultBytes * batchSize, cudaMemcpyDeviceToHost)); CUDA_ERR( "getOutput", - cudaMemcpy( + cudaMemcpyAsync( inputBuffers->valueResults, gpuHandle->getBuffer("OutputValue"), inputBuffers->singleValueResultBytes * batchSize, cudaMemcpyDeviceToHost)); CUDA_ERR( "getOutput", - cudaMemcpy( + cudaMemcpyAsync( inputBuffers->scoreValueResults, gpuHandle->getBuffer("OutputScoreValue"), inputBuffers->singleScoreValueResultBytes * batchSize, cudaMemcpyDeviceToHost)); CUDA_ERR( "getOutput", - cudaMemcpy( + cudaMemcpyAsync( inputBuffers->ownershipResults, gpuHandle->getBuffer("OutputOwnership"), inputBuffers->singleOwnershipResultBytes * batchSize, From a2d6411e16b57744ee98ad62512bdf26318f46a5 Mon Sep 17 00:00:00 2001 From: wang yize Date: Sat, 7 Jun 2025 19:52:29 +0800 Subject: [PATCH 5/8] Fix cudaMemcpyAsync not using cudaStreamPerThread --- cpp/neuralnet/trtbackend.cpp | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/cpp/neuralnet/trtbackend.cpp b/cpp/neuralnet/trtbackend.cpp index 90cddc94a..2013aebe5 100644 --- a/cpp/neuralnet/trtbackend.cpp +++ b/cpp/neuralnet/trtbackend.cpp @@ -1732,21 +1732,24 @@ void NeuralNet::getOutput( gpuHandle->getBuffer("InputMask"), inputBuffers->maskInputs, inputBuffers->singleMaskBytes * batchSize, - cudaMemcpyHostToDevice)); + cudaMemcpyHostToDevice, + cudaStreamPerThread)); CUDA_ERR( "getOutput", cudaMemcpyAsync( gpuHandle->getBuffer("InputSpatial"), inputBuffers->spatialInputs, inputBuffers->singleInputBytes * batchSize, - cudaMemcpyHostToDevice)); + cudaMemcpyHostToDevice, + cudaStreamPerThread)); CUDA_ERR( "getOutput", cudaMemcpyAsync( gpuHandle->getBuffer("InputGlobal"), inputBuffers->globalInputs, inputBuffers->singleInputGlobalBytes * batchSize, - cudaMemcpyHostToDevice)); + cudaMemcpyHostToDevice, + cudaStreamPerThread)); if(numMetaFeatures > 0) { CUDA_ERR( "getOutput", @@ -1754,7 +1757,8 @@ void NeuralNet::getOutput( gpuHandle->getBuffer("InputMeta"), inputBuffers->metaInputs, inputBuffers->singleInputMetaBytes * batchSize, - cudaMemcpyHostToDevice)); + cudaMemcpyHostToDevice, + cudaStreamPerThread)); } gpuHandle->exec->enqueueV3(cudaStreamPerThread); @@ -1765,35 +1769,40 @@ void NeuralNet::getOutput( inputBuffers->policyPassResults, gpuHandle->getBuffer("OutputPolicyPass"), inputBuffers->singlePolicyPassResultBytes * batchSize, - cudaMemcpyDeviceToHost)); + cudaMemcpyDeviceToHost, + cudaStreamPerThread)); CUDA_ERR( "getOutput", cudaMemcpyAsync( inputBuffers->policyResults, gpuHandle->getBuffer("OutputPolicy"), inputBuffers->singlePolicyResultBytes * batchSize, - cudaMemcpyDeviceToHost)); + cudaMemcpyDeviceToHost, + cudaStreamPerThread)); CUDA_ERR( "getOutput", cudaMemcpyAsync( inputBuffers->valueResults, gpuHandle->getBuffer("OutputValue"), inputBuffers->singleValueResultBytes * batchSize, - cudaMemcpyDeviceToHost)); + cudaMemcpyDeviceToHost, + cudaStreamPerThread)); CUDA_ERR( "getOutput", cudaMemcpyAsync( inputBuffers->scoreValueResults, gpuHandle->getBuffer("OutputScoreValue"), inputBuffers->singleScoreValueResultBytes * batchSize, - cudaMemcpyDeviceToHost)); + cudaMemcpyDeviceToHos, + cudaStreamPerThread)); CUDA_ERR( "getOutput", cudaMemcpyAsync( inputBuffers->ownershipResults, gpuHandle->getBuffer("OutputOwnership"), inputBuffers->singleOwnershipResultBytes * batchSize, - cudaMemcpyDeviceToHost)); + cudaMemcpyDeviceToHost, + cudaStreamPerThread)); #ifdef TENSORRT_CUDA_GRAPH cudaStreamEndCapture(cudaStreamPerThread, &graph); cudaGraphInstantiate(&instance, graph, 0); From 1758eb2a36e5371c69db807aecf99894521f7892 Mon Sep 17 00:00:00 2001 From: wang yize Date: Sat, 7 Jun 2025 19:52:29 +0800 Subject: [PATCH 6/8] Fix cudaMemcpyAsync not using cudaStreamPerThread --- cpp/neuralnet/trtbackend.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/neuralnet/trtbackend.cpp b/cpp/neuralnet/trtbackend.cpp index 2013aebe5..b05c5cccc 100644 --- a/cpp/neuralnet/trtbackend.cpp +++ b/cpp/neuralnet/trtbackend.cpp @@ -1793,7 +1793,7 @@ void NeuralNet::getOutput( inputBuffers->scoreValueResults, gpuHandle->getBuffer("OutputScoreValue"), inputBuffers->singleScoreValueResultBytes * batchSize, - cudaMemcpyDeviceToHos, + cudaMemcpyDeviceToHost, cudaStreamPerThread)); CUDA_ERR( "getOutput", From 84f5510787f290350e6dd248849cba97b721384d Mon Sep 17 00:00:00 2001 From: wang yize Date: Mon, 9 Jun 2025 13:17:09 +0800 Subject: [PATCH 7/8] Fix multi-thread-single-gpu initialization seg fault when not using cuda plan cache. --- cpp/neuralnet/trtbackend.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/cpp/neuralnet/trtbackend.cpp b/cpp/neuralnet/trtbackend.cpp index b05c5cccc..06d7c92f0 100644 --- a/cpp/neuralnet/trtbackend.cpp +++ b/cpp/neuralnet/trtbackend.cpp @@ -1092,6 +1092,7 @@ struct ComputeHandle { #ifdef TENSORRT_CUDA_GRAPH vector cudaGraphs; vector cudaGraphExecs; + inline static unordered_map serializeLockPerGpu; #endif ComputeHandle( Logger* logger, @@ -1101,7 +1102,7 @@ struct ComputeHandle { int maxBatchSz, bool requireExactNNLen) { ctx = context; - + maxBatchSize = maxBatchSz; modelVersion = loadedModel->modelDesc.modelVersion; @@ -1111,7 +1112,7 @@ struct ComputeHandle { if(getInferLibVersion() / 100 != NV_TENSORRT_VERSION / 100) { throw StringError("TensorRT backend: detected incompatible version of TensorRT library"); } - + trtLogger.setLogger(logger); auto builder = unique_ptr(createInferBuilder(trtLogger)); @@ -1321,7 +1322,15 @@ struct ComputeHandle { tuneMutex.unlock(); } else { tuneMutex.unlock(); - planBuffer.reset(builder->buildSerializedNetwork(*model->network, *config)); + { + int gpuId; + cudaGetDevice(&gpuId); + auto& serializeMutex = serializeLockPerGpu[gpuId]; + serializeMutex.lock(); + planBuffer.reset(builder->buildSerializedNetwork(*model->network, *config)); + serializeMutex.unlock(); + } + if(!planBuffer) { throw StringError("TensorRT backend: failed to create plan"); } From 51c9b02299dd8a78122a1f7ff9f89fe5e2c31c42 Mon Sep 17 00:00:00 2001 From: wang yize Date: Mon, 9 Jun 2025 13:33:58 +0800 Subject: [PATCH 8/8] Add lock to cuda graph instantiate too. --- cpp/neuralnet/trtbackend.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/cpp/neuralnet/trtbackend.cpp b/cpp/neuralnet/trtbackend.cpp index 06d7c92f0..105c10584 100644 --- a/cpp/neuralnet/trtbackend.cpp +++ b/cpp/neuralnet/trtbackend.cpp @@ -1092,7 +1092,7 @@ struct ComputeHandle { #ifdef TENSORRT_CUDA_GRAPH vector cudaGraphs; vector cudaGraphExecs; - inline static unordered_map serializeLockPerGpu; + inline static unordered_map mutexPerGpu; #endif ComputeHandle( Logger* logger, @@ -1325,10 +1325,10 @@ struct ComputeHandle { { int gpuId; cudaGetDevice(&gpuId); - auto& serializeMutex = serializeLockPerGpu[gpuId]; - serializeMutex.lock(); + auto& mutex = mutexPerGpu[gpuId]; + mutex.lock(); planBuffer.reset(builder->buildSerializedNetwork(*model->network, *config)); - serializeMutex.unlock(); + mutex.unlock(); } if(!planBuffer) { @@ -1715,6 +1715,10 @@ void NeuralNet::getOutput( auto& graph = gpuHandle->cudaGraphs[batchSize]; auto& instance = gpuHandle->cudaGraphExecs[batchSize]; if(instance == nullptr) { // First evaluation with current batchsize. Initialize cuda graph + int gpuId; + cudaGetDevice(&gpuId); + auto& mutex = gpuHandle->mutexPerGpu[gpuId]; + mutex.lock(); #endif auto maskInputDims = gpuHandle->getBufferDynamicShape("InputMask", batchSize); @@ -1815,6 +1819,7 @@ void NeuralNet::getOutput( #ifdef TENSORRT_CUDA_GRAPH cudaStreamEndCapture(cudaStreamPerThread, &graph); cudaGraphInstantiate(&instance, graph, 0); + mutex.unlock(); } // if (instance == nullptr) cudaGraphLaunch(instance, cudaStreamPerThread); #endif