Skip to content

Commit 4971e53

Browse files
authored
[mlir][Target] Support Fatbin target for static nvptxcompiler (#118044)
### Background In `lib/Target/LLVM/NVVM/Target.cpp`, `NVPTXSerializer` compile PTX to binary with two different flows controlled by `MLIR_ENABLE_NVPTXCOMPILER`. If building mlir with `-DMLIR_ENABLE_NVPTXCOMPILER=ON`, the flow does not check if the target is `gpu::CompilationTarget::Fatbin`, and compile PTX to cubin directly, which is not consistent with another flow. ### Implement Use static [nvfatbin](https://docs.nvidia.com/cuda/nvfatbin/index.html) library. I have tested it locally, the two flows can return the same Fatbin result after inputing the same `GpuModule`.
1 parent b0f8f32 commit 4971e53

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

mlir/lib/Target/LLVM/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,20 @@ if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
8888
# Link against `nvptxcompiler_static`. TODO: use `CUDA::nvptxcompiler_static`.
8989
target_link_libraries(MLIRNVVMTarget PRIVATE MLIR_NVPTXCOMPILER_LIB)
9090
target_include_directories(obj.MLIRNVVMTarget PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
91+
92+
# Add the `nvfatbin` library.
93+
find_library(MLIR_NVFATBIN_LIB_PATH nvfatbin_static
94+
PATHS ${CUDAToolkit_LIBRARY_DIR} NO_DEFAULT_PATH)
95+
# Fail if `nvfatbin_static` couldn't be found.
96+
if(MLIR_NVFATBIN_LIB_PATH STREQUAL "MLIR_NVFATBIN_LIB_PATH-NOTFOUND")
97+
message(FATAL_ERROR
98+
"Requested using the static `nvptxcompiler` library which requires the \
99+
'nvfatbin` library, but it couldn't be found.")
100+
endif()
101+
102+
add_library(MLIR_NVFATBIN_LIB STATIC IMPORTED GLOBAL)
103+
set_property(TARGET MLIR_NVFATBIN_LIB PROPERTY IMPORTED_LOCATION ${MLIR_NVFATBIN_LIB_PATH})
104+
target_link_libraries(MLIRNVVMTarget PRIVATE MLIR_NVFATBIN_LIB)
91105
endif()
92106
else()
93107
# Fail if `MLIR_ENABLE_NVPTXCOMPILER` is enabled and the toolkit couldn't be found.

mlir/lib/Target/LLVM/NVVM/Target.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,18 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
473473
} \
474474
} while (false)
475475

476+
#include "nvFatbin.h"
477+
478+
#define RETURN_ON_NVFATBIN_ERROR(expr) \
479+
do { \
480+
auto result = (expr); \
481+
if (result != nvFatbinResult::NVFATBIN_SUCCESS) { \
482+
emitError(loc) << llvm::Twine(#expr).concat(" failed with error: ") \
483+
<< nvFatbinGetErrorString(result); \
484+
return std::nullopt; \
485+
} \
486+
} while (false)
487+
476488
std::optional<SmallVector<char, 0>>
477489
NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
478490
Location loc = getOperation().getLoc();
@@ -538,6 +550,32 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
538550
});
539551
#undef DEBUG_TYPE
540552
RETURN_ON_NVPTXCOMPILER_ERROR(nvPTXCompilerDestroy(&compiler));
553+
554+
if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Fatbin) {
555+
bool useFatbin32 = llvm::any_of(cmdOpts.second, [](const char *option) {
556+
return llvm::StringRef(option) == "-32";
557+
});
558+
559+
const char *cubinOpts[1] = {useFatbin32 ? "-32" : "-64"};
560+
nvFatbinHandle handle;
561+
562+
auto chip = getTarget().getChip();
563+
chip.consume_front("sm_");
564+
565+
RETURN_ON_NVFATBIN_ERROR(nvFatbinCreate(&handle, cubinOpts, 1));
566+
RETURN_ON_NVFATBIN_ERROR(nvFatbinAddCubin(
567+
handle, binary.data(), binary.size(), chip.data(), nullptr));
568+
RETURN_ON_NVFATBIN_ERROR(nvFatbinAddPTX(
569+
handle, ptxCode.data(), ptxCode.size(), chip.data(), nullptr, nullptr));
570+
571+
size_t fatbinSize;
572+
RETURN_ON_NVFATBIN_ERROR(nvFatbinSize(handle, &fatbinSize));
573+
SmallVector<char, 0> fatbin(fatbinSize, 0);
574+
RETURN_ON_NVFATBIN_ERROR(nvFatbinGet(handle, (void *)fatbin.data()));
575+
RETURN_ON_NVFATBIN_ERROR(nvFatbinDestroy(&handle));
576+
return fatbin;
577+
}
578+
541579
return binary;
542580
}
543581
#endif // MLIR_ENABLE_NVPTXCOMPILER

0 commit comments

Comments
 (0)