diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp index b8ac9ab90a9f3..5ec87d58cc57f 100644 --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -423,9 +423,24 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuTensorMapEncodeTiled( elementStrides[4], interleave, swizzle, l2Promotion, oobFill); } +namespace { + +template +void mgpuGetMemRefDataAndShape(void *raw_descriptor, char **addr, + uint64_t *globalDim) { + auto descriptor = + reinterpret_cast *>(raw_descriptor); + *addr = descriptor->data; + for (int i = 0; i < rank; ++i) { + globalDim[i] = static_cast(descriptor->sizes[rank - i - 1]); + } +} + +} // namespace + extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref( int64_t tensorRank, // Dimensionality of tensor - StridedMemRefType *descriptor, // Starting address + void *ranked_descriptor, // Ranked MemRef descriptor const CUtensorMapDataType tensorDataType, // Stride size (in bytes) CUtensorMapInterleave interleave, // Type of interleaved layout CUtensorMapSwizzle swizzle, // Bank swizzling pattern @@ -435,17 +450,39 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref( ) { CUtensorMap tensorMap; - auto *globalAddress = descriptor->data; uint32_t boxDim[5] = {1, 1, 1, 1, 1}, elementStrides[5] = {1, 1, 1, 1, 1}; uint64_t globalDim[5] = {1, 1, 1, 1, 1}, globalStrides[5] = {0}; uint32_t tensorRank32 = uint32_t(tensorRank); + char *globalAddress = nullptr; + switch (tensorRank) { + case 1: + mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim); + break; + case 2: + mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim); + break; + case 3: + mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim); + break; + case 4: + mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim); + break; + case 5: + mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim); + break; + default: + fprintf( + stderr, + "'mgpuTensorMapEncodeTiledMemref' failed with 'rank is too high'\n"); + return NULL; + } + static const int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2, 4, 8, 2, 4, 4, 4}; for (int64_t r = 0; r < tensorRank; ++r) { elementStrides[r] = uint32_t(1); boxDim[r] = static_cast(inputBoxDims[tensorRank - r - 1]); - globalDim[r] = static_cast(descriptor->sizes[tensorRank - r - 1]); } globalStrides[0] = globalDim[0] * elementSizeInBytes[tensorDataType];