diff --git a/sycl-jit/common/include/Kernel.h b/sycl-jit/common/include/Kernel.h index eb5ba0f05c913..09f7392cf0a8b 100644 --- a/sycl-jit/common/include/Kernel.h +++ b/sycl-jit/common/include/Kernel.h @@ -411,6 +411,9 @@ struct RTCDevImgInfo { using RTCBundleInfo = DynArray; +// LLVM's APIs prefer `char *` for byte buffers. +using RTCDeviceCodeIR = DynArray; + } // namespace jit_compiler #endif // SYCL_FUSION_COMMON_KERNEL_H diff --git a/sycl-jit/jit-compiler/CMakeLists.txt b/sycl-jit/jit-compiler/CMakeLists.txt index 63bb2ecc34ad9..dc45039c684a0 100644 --- a/sycl-jit/jit-compiler/CMakeLists.txt +++ b/sycl-jit/jit-compiler/CMakeLists.txt @@ -18,6 +18,7 @@ add_llvm_library(sycl-jit LINK_COMPONENTS BitReader + BitWriter Core Support Option diff --git a/sycl-jit/jit-compiler/include/KernelFusion.h b/sycl-jit/jit-compiler/include/KernelFusion.h index 7310d69a91952..1a50af5584f86 100644 --- a/sycl-jit/jit-compiler/include/KernelFusion.h +++ b/sycl-jit/jit-compiler/include/KernelFusion.h @@ -56,13 +56,45 @@ class JITResult { sycl::detail::string ErrorMessage; }; +class RTCHashResult { +public: + static RTCHashResult success(const char *Hash) { + return RTCHashResult{/*Failed=*/false, Hash}; + } + + static RTCHashResult failure(const char *PreprocLog) { + return RTCHashResult{/*Failed=*/true, PreprocLog}; + } + + bool failed() { return Failed; } + + const char *getPreprocLog() { + assert(failed() && "No preprocessor log"); + return HashOrLog.c_str(); + } + + const char *getHash() { + assert(!failed() && "No hash"); + return HashOrLog.c_str(); + } + +private: + RTCHashResult(bool Failed, const char *HashOrLog) + : Failed(Failed), HashOrLog(HashOrLog) {} + + bool Failed; + sycl::detail::string HashOrLog; +}; + class RTCResult { public: explicit RTCResult(const char *BuildLog) : Failed{true}, BundleInfo{}, BuildLog{BuildLog} {} - RTCResult(RTCBundleInfo &&BundleInfo, const char *BuildLog) - : Failed{false}, BundleInfo{std::move(BundleInfo)}, BuildLog{BuildLog} {} + RTCResult(RTCBundleInfo &&BundleInfo, RTCDeviceCodeIR &&DeviceCodeIR, + const char *BuildLog) + : Failed{false}, BundleInfo{std::move(BundleInfo)}, + DeviceCodeIR(std::move(DeviceCodeIR)), BuildLog{BuildLog} {} bool failed() const { return Failed; } @@ -73,9 +105,15 @@ class RTCResult { return BundleInfo; } + const RTCDeviceCodeIR &getDeviceCodeIR() const { + assert(!failed() && "No device code IR"); + return DeviceCodeIR; + } + private: bool Failed; RTCBundleInfo BundleInfo; + RTCDeviceCodeIR DeviceCodeIR; sycl::detail::string BuildLog; }; @@ -100,9 +138,14 @@ KF_EXPORT_SYMBOL JITResult materializeSpecConstants( const char *KernelName, jit_compiler::SYCLKernelBinaryInfo &BinInfo, View SpecConstBlob); +KF_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile, + View IncludeFiles, + View UserArgs); + KF_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile, View IncludeFiles, - View UserArgs); + View UserArgs, + View CachedIR, bool SaveIR); KF_EXPORT_SYMBOL void destroyBinary(BinaryAddress Address); diff --git a/sycl-jit/jit-compiler/ld-version-script.txt b/sycl-jit/jit-compiler/ld-version-script.txt index 2c6f307c88d03..ac707cc182c9c 100644 --- a/sycl-jit/jit-compiler/ld-version-script.txt +++ b/sycl-jit/jit-compiler/ld-version-script.txt @@ -3,6 +3,7 @@ /* Export the library entry points */ fuseKernels; materializeSpecConstants; + calculateHash; compileSYCL; destroyBinary; resetJITConfiguration; diff --git a/sycl-jit/jit-compiler/lib/KernelFusion.cpp b/sycl-jit/jit-compiler/lib/KernelFusion.cpp index 34c67c8fb22b6..b85e9ce146a53 100644 --- a/sycl-jit/jit-compiler/lib/KernelFusion.cpp +++ b/sycl-jit/jit-compiler/lib/KernelFusion.cpp @@ -19,7 +19,10 @@ #include "translation/SPIRVLLVMTranslation.h" #include +#include +#include #include +#include #include #include @@ -31,17 +34,21 @@ using namespace jit_compiler; using FusedFunction = helper::FusionHelper::FusedFunction; using FusedFunctionList = std::vector; -template -static ResultType errorTo(llvm::Error &&Err, const std::string &Msg) { +static std::string formatError(llvm::Error &&Err, const std::string &Msg) { std::stringstream ErrMsg; ErrMsg << Msg << "\nDetailed information:\n"; llvm::handleAllErrors(std::move(Err), [&ErrMsg](const llvm::StringError &StrErr) { - // Cannot throw an exception here if LLVM itself is - // compiled without exception support. ErrMsg << "\t" << StrErr.getMessage() << "\n"; }); - return ResultType{ErrMsg.str().c_str()}; + return ErrMsg.str(); +} + +template +static ResultType errorTo(llvm::Error &&Err, const std::string &Msg) { + // Cannot throw an exception here if LLVM itself is compiled without exception + // support. + return ResultType{formatError(std::move(Err), Msg).c_str()}; } static std::vector @@ -240,10 +247,42 @@ fuseKernels(View KernelInformation, const char *FusedKernelName, return JITResult{FusedKernelInfo}; } +extern "C" KF_EXPORT_SYMBOL RTCHashResult +calculateHash(InMemoryFile SourceFile, View IncludeFiles, + View UserArgs) { + auto UserArgListOrErr = parseUserArgs(UserArgs); + if (!UserArgListOrErr) { + return RTCHashResult::failure( + formatError(UserArgListOrErr.takeError(), + "Parsing of user arguments failed") + .c_str()); + } + llvm::opt::InputArgList UserArgList = std::move(*UserArgListOrErr); + + auto Start = std::chrono::high_resolution_clock::now(); + auto HashOrError = calculateHash(SourceFile, IncludeFiles, UserArgList); + if (!HashOrError) { + return RTCHashResult::failure( + formatError(HashOrError.takeError(), "Hashing failed").c_str()); + } + auto Hash = *HashOrError; + auto Stop = std::chrono::high_resolution_clock::now(); + + if (UserArgList.hasArg(clang::driver::options::OPT_ftime_trace_EQ)) { + std::chrono::duration HashTime = Stop - Start; + llvm::dbgs() << "Hashing of " << SourceFile.Path << " took " + << int(HashTime.count()) << " ms\n"; + } + + return RTCHashResult::success(Hash.c_str()); +} + extern "C" KF_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile, View IncludeFiles, - View UserArgs) { + View UserArgs, View CachedIR, bool SaveIR) { + llvm::LLVMContext Context; std::string BuildLog; + configureDiagnostics(Context, BuildLog); auto UserArgListOrErr = parseUserArgs(UserArgs); if (!UserArgListOrErr) { @@ -272,16 +311,43 @@ compileSYCL(InMemoryFile SourceFile, View IncludeFiles, Verbose); } - auto ModuleOrErr = - compileDeviceCode(SourceFile, IncludeFiles, UserArgList, BuildLog); - if (!ModuleOrErr) { - return errorTo(ModuleOrErr.takeError(), - "Device compilation failed"); + std::unique_ptr Module; + + if (CachedIR.size() > 0) { + llvm::StringRef IRStr{CachedIR.begin(), CachedIR.size()}; + std::unique_ptr IRBuf = + llvm::MemoryBuffer::getMemBuffer(IRStr, /*BufferName=*/"", + /*RequiresNullTerminator=*/false); + auto ModuleOrError = llvm::parseBitcodeFile(*IRBuf, Context); + if (!ModuleOrError) { + // Not a fatal error, we'll just compile the source string normally. + BuildLog.append(formatError(ModuleOrError.takeError(), + "Loading of cached device code failed")); + } else { + Module = std::move(*ModuleOrError); + } } - std::unique_ptr Context; - std::unique_ptr Module = std::move(*ModuleOrErr); - Context.reset(&Module->getContext()); + bool FromSource = false; + if (!Module) { + auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgList, + BuildLog, Context); + if (!ModuleOrErr) { + return errorTo(ModuleOrErr.takeError(), + "Device compilation failed"); + } + + Module = std::move(*ModuleOrErr); + FromSource = true; + } + + RTCDeviceCodeIR IR; + if (SaveIR && FromSource) { + std::string BCString; + llvm::raw_string_ostream BCStream{BCString}; + llvm::WriteBitcodeToFile(*Module, BCStream); + IR = RTCDeviceCodeIR{BCString.data(), BCString.data() + BCString.size()}; + } if (auto Error = linkDeviceLibraries(*Module, UserArgList, BuildLog)) { return errorTo(std::move(Error), "Device linking failed"); @@ -314,7 +380,7 @@ compileSYCL(InMemoryFile SourceFile, View IncludeFiles, } } - return RTCResult{std::move(BundleInfo), BuildLog.c_str()}; + return RTCResult{std::move(BundleInfo), std::move(IR), BuildLog.c_str()}; } extern "C" KF_EXPORT_SYMBOL void destroyBinary(BinaryAddress Address) { diff --git a/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp b/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp index 8e59966daff1c..adc212f44eba1 100644 --- a/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp +++ b/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp @@ -16,8 +16,10 @@ #include #include #include +#include #include #include +#include #include #include @@ -32,6 +34,8 @@ #include #include #include +#include +#include #include #include @@ -132,13 +136,46 @@ static const std::string &getDPCPPRoot() { namespace { -struct GetLLVMModuleAction : public ToolAction { +class HashPreprocessedAction : public PreprocessorFrontendAction { +protected: + void ExecuteAction() override { + CompilerInstance &CI = getCompilerInstance(); + + std::string PreprocessedSource; + raw_string_ostream PreprocessStream(PreprocessedSource); + + PreprocessorOutputOptions Opts; + Opts.ShowCPP = 1; + Opts.MinimizeWhitespace = 1; + // Make cache key insensitive to virtual source file and header locations. + Opts.ShowLineMarkers = 0; + + DoPrintPreprocessedInput(CI.getPreprocessor(), &PreprocessStream, Opts); + + Hash = BLAKE3::hash(arrayRefFromStringRef(PreprocessedSource)); + Executed = true; + } + +public: + BLAKE3Result<> takeHash() { + assert(Executed); + Executed = false; + return std::move(Hash); + } + +private: + BLAKE3Result<> Hash; + bool Executed = false; +}; + +class RTCToolActionBase : public ToolAction { +public: // Code adapted from `FrontendActionFactory::runInvocation`. bool runInvocation(std::shared_ptr Invocation, FileManager *Files, std::shared_ptr PCHContainerOps, DiagnosticConsumer *DiagConsumer) override { - assert(!Module && "Action should only be invoked on a single file"); + assert(!hasExecuted() && "Action should only be invoked on a single file"); // Create a compiler instance to handle the actual work. CompilerInstance Compiler(std::move(PCHContainerOps)); @@ -157,23 +194,75 @@ struct GetLLVMModuleAction : public ToolAction { Compiler.createSourceManager(*Files); + return executeAction(Compiler, Files); + } + + virtual ~RTCToolActionBase() = default; + +protected: + virtual bool hasExecuted() = 0; + virtual bool executeAction(CompilerInstance &, FileManager *) = 0; +}; + +class GetSourceHashAction : public RTCToolActionBase { +protected: + bool executeAction(CompilerInstance &CI, FileManager *Files) override { + HashPreprocessedAction HPA; + const bool Success = CI.ExecuteAction(HPA); + Files->clearStatCache(); + if (!Success) { + return false; + } + + Hash = HPA.takeHash(); + Executed = true; + return true; + } + + bool hasExecuted() override { return Executed; } + +public: + BLAKE3Result<> takeHash() { + assert(Executed); + Executed = false; + return std::move(Hash); + } + +private: + BLAKE3Result<> Hash; + bool Executed = false; +}; + +struct GetLLVMModuleAction : public RTCToolActionBase { +protected: + bool executeAction(CompilerInstance &CI, FileManager *Files) override { // Ignore `Compiler.getFrontendOpts().ProgramAction` (would be `EmitBC`) and // create/execute an `EmitLLVMOnlyAction` (= codegen to LLVM module without // emitting anything) instead. - EmitLLVMOnlyAction ELOA; - const bool Success = Compiler.ExecuteAction(ELOA); + EmitLLVMOnlyAction ELOA{&Context}; + const bool Success = CI.ExecuteAction(ELOA); Files->clearStatCache(); if (!Success) { return false; } - // Take the module and its context to extend the objects' lifetime. + // Take the module to extend its lifetime. Module = ELOA.takeModule(); - ELOA.takeLLVMContext(); return true; } + bool hasExecuted() override { return static_cast(Module); } + +public: + GetLLVMModuleAction(LLVMContext &Context) : Context{Context}, Module{} {} + std::unique_ptr takeModule() { + assert(Module); + return std::move(Module); + } + +private: + LLVMContext &Context; std::unique_ptr Module; }; @@ -223,16 +312,9 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler { } // anonymous namespace -Expected> jit_compiler::compileDeviceCode( - InMemoryFile SourceFile, View IncludeFiles, - const InputArgList &UserArgList, std::string &BuildLog) { - TimeTraceScope TTS{"compileDeviceCode"}; - - const std::string &DPCPPRoot = getDPCPPRoot(); - if (DPCPPRoot == InvalidDPCPPRoot) { - return createStringError("Could not locate DPCPP root directory"); - } - +static void adjustArgs(const InputArgList &UserArgList, + const std::string &DPCPPRoot, + SmallVectorImpl &CommandLine) { DerivedArgList DAL{UserArgList}; const auto &OptTable = getDriverOptTable(); DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_fsycl_device_only)); @@ -251,17 +333,15 @@ Expected> jit_compiler::compileDeviceCode( DAL.eraseArg(OPT_ftime_trace_granularity_EQ); DAL.eraseArg(OPT_ftime_trace_verbose); - SmallVector CommandLine; for (auto *Arg : DAL) { CommandLine.emplace_back(Arg->getAsString(DAL)); } +} - FixedCompilationDatabase DB{".", CommandLine}; - ClangTool Tool{DB, {SourceFile.Path}}; - - IntrusiveRefCntPtr DiagOpts{new DiagnosticOptions}; - ClangDiagnosticWrapper Wrapper(BuildLog, DiagOpts.get()); - Tool.setDiagnosticConsumer(Wrapper.consumer()); +static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot, + InMemoryFile SourceFile, View IncludeFiles, + DiagnosticConsumer *Consumer) { + Tool.setDiagnosticConsumer(Consumer); // Suppress message "Error while processing" being printed to stdout. Tool.setPrintErrorMessage(false); @@ -284,10 +364,72 @@ Expected> jit_compiler::compileDeviceCode( NewArgs[0] = (Twine(DPCPPRoot) + "/bin/clang++").str(); return NewArgs; }); +} + +Expected +jit_compiler::calculateHash(InMemoryFile SourceFile, + View IncludeFiles, + const InputArgList &UserArgList) { + TimeTraceScope TTS{"calculateHash"}; - GetLLVMModuleAction Action; + const std::string &DPCPPRoot = getDPCPPRoot(); + if (DPCPPRoot == InvalidDPCPPRoot) { + return createStringError("Could not locate DPCPP root directory"); + } + + SmallVector CommandLine; + adjustArgs(UserArgList, DPCPPRoot, CommandLine); + + FixedCompilationDatabase DB{".", CommandLine}; + ClangTool Tool{DB, {SourceFile.Path}}; + + clang::IgnoringDiagConsumer DiagConsumer; + setupTool(Tool, DPCPPRoot, SourceFile, IncludeFiles, &DiagConsumer); + + GetSourceHashAction Action; if (!Tool.run(&Action)) { - return std::move(Action.Module); + BLAKE3Result<> SourceHash = Action.takeHash(); + // The adjusted command line contains the DPCPP root and clang major + // version. + BLAKE3Result<> CommandLineHash = + BLAKE3::hash(arrayRefFromStringRef(join(CommandLine, ","))); + + std::string EncodedHash = + encodeBase64(SourceHash) + encodeBase64(CommandLineHash); + // Make the encoding filesystem-friendly. + std::replace(EncodedHash.begin(), EncodedHash.end(), '/', '-'); + return std::move(EncodedHash); + } + + return createStringError("Calculating source hash failed"); +} + +Expected> +jit_compiler::compileDeviceCode(InMemoryFile SourceFile, + View IncludeFiles, + const InputArgList &UserArgList, + std::string &BuildLog, LLVMContext &Context) { + TimeTraceScope TTS{"compileDeviceCode"}; + + const std::string &DPCPPRoot = getDPCPPRoot(); + if (DPCPPRoot == InvalidDPCPPRoot) { + return createStringError("Could not locate DPCPP root directory"); + } + + SmallVector CommandLine; + adjustArgs(UserArgList, DPCPPRoot, CommandLine); + + FixedCompilationDatabase DB{".", CommandLine}; + ClangTool Tool{DB, {SourceFile.Path}}; + + IntrusiveRefCntPtr DiagOpts{new DiagnosticOptions}; + ClangDiagnosticWrapper Wrapper(BuildLog, DiagOpts.get()); + + setupTool(Tool, DPCPPRoot, SourceFile, IncludeFiles, Wrapper.consumer()); + + GetLLVMModuleAction Action{Context}; + if (!Tool.run(&Action)) { + return Action.takeModule(); } return createStringError(BuildLog); @@ -409,8 +551,6 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module, } LLVMContext &Context = Module.getContext(); - Context.setDiagnosticHandler( - std::make_unique(BuildLog)); for (const std::string &LibName : LibNames) { std::string LibPath = DPCPPRoot + "/lib/" + LibName; @@ -652,3 +792,9 @@ jit_compiler::parseUserArgs(View UserArgs) { return std::move(AL); } + +void jit_compiler::configureDiagnostics(LLVMContext &Context, + std::string &BuildLog) { + Context.setDiagnosticHandler( + std::make_unique(BuildLog)); +} diff --git a/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h b/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h index 1c30e5a61fb4b..7708c1ca857fd 100644 --- a/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h +++ b/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h @@ -22,10 +22,14 @@ namespace jit_compiler { +llvm::Expected +calculateHash(InMemoryFile SourceFile, View IncludeFiles, + const llvm::opt::InputArgList &UserArgList); + llvm::Expected> compileDeviceCode(InMemoryFile SourceFile, View IncludeFiles, const llvm::opt::InputArgList &UserArgList, - std::string &BuildLog); + std::string &BuildLog, llvm::LLVMContext &Context); llvm::Error linkDeviceLibraries(llvm::Module &Module, const llvm::opt::InputArgList &UserArgList, @@ -40,6 +44,8 @@ performPostLink(std::unique_ptr Module, llvm::Expected parseUserArgs(View UserArgs); +void configureDiagnostics(llvm::LLVMContext &Context, std::string &BuildLog); + } // namespace jit_compiler #endif // SYCL_JIT_COMPILER_RTC_DEVICE_COMPILATION_H diff --git a/sycl/source/detail/jit_compiler.cpp b/sycl/source/detail/jit_compiler.cpp index 6fc88bb812a20..7dd1b5cf31816 100644 --- a/sycl/source/detail/jit_compiler.cpp +++ b/sycl/source/detail/jit_compiler.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -90,6 +91,15 @@ jit_compiler::jit_compiler() return false; } + this->CalculateHashHandle = reinterpret_cast( + sycl::detail::ur::getOsLibraryFuncAddress(LibraryPtr.get(), + "calculateHash")); + if (!this->CalculateHashHandle) { + printPerformanceWarning( + "Cannot resolve JIT library function entry point"); + return false; + } + this->CompileSYCLHandle = reinterpret_cast( sycl::detail::ur::getOsLibraryFuncAddress(LibraryPtr.get(), "compileSYCL")); @@ -1142,7 +1152,7 @@ sycl_device_binaries jit_compiler::createPIDeviceBinary( sycl_device_binaries jit_compiler::createDeviceBinaries( const ::jit_compiler::RTCBundleInfo &BundleInfo, - const std::string &OffloadEntryPrefix) { + const std::string &Prefix) { auto Collection = std::make_unique(); for (const auto &DevImgInfo : BundleInfo) { @@ -1153,7 +1163,7 @@ sycl_device_binaries jit_compiler::createDeviceBinaries( // entrypoints remain unchanged. // It seems to be OK to set zero for most of the information here, at // least that is the case for compiled SPIR-V binaries. - std::string PrefixedName = OffloadEntryPrefix + Symbol.c_str(); + std::string PrefixedName = Prefix + Symbol.c_str(); OffloadEntryContainer Entry{PrefixedName, /*Addr=*/nullptr, /*Size=*/0, /*Flags=*/0, /*Reserved=*/0}; Binary.addOffloadEntry(std::move(Entry)); @@ -1246,11 +1256,16 @@ std::vector jit_compiler::encodeReqdWorkGroupSize( return Encoded; } -sycl_device_binaries jit_compiler::compileSYCL( +std::pair jit_compiler::compileSYCL( const std::string &CompilationID, const std::string &SYCLSource, const std::vector> &IncludePairs, const std::vector &UserArgs, std::string *LogPtr, const std::vector &RegisteredKernelNames) { + auto appendToLog = [LogPtr](const char *Msg) { + if (LogPtr) { + LogPtr->append(Msg); + } + }; // RegisteredKernelNames may contain template specializations, so we just put // them in main() which ensures they are instantiated. @@ -1281,18 +1296,42 @@ sycl_device_binaries jit_compiler::compileSYCL( std::back_inserter(UserArgsView), [](const auto &Arg) { return Arg.c_str(); }); - auto Result = CompileSYCLHandle(SourceFile, IncludeFilesView, UserArgsView); + std::string CacheKey; + std::vector CachedIR; + if (PersistentDeviceCodeCache::isEnabled()) { + auto Result = + CalculateHashHandle(SourceFile, IncludeFilesView, UserArgsView); - if (LogPtr) { - LogPtr->append(Result.getBuildLog()); + if (Result.failed()) { + appendToLog(Result.getPreprocLog()); + } else { + CacheKey = Result.getHash(); + CachedIR = PersistentDeviceCodeCache::getDeviceCodeIRFromDisc(CacheKey); + } } + auto Result = CompileSYCLHandle(SourceFile, IncludeFilesView, UserArgsView, + CachedIR, /*SaveIR=*/!CacheKey.empty()); + + appendToLog(Result.getBuildLog()); if (Result.failed()) { throw sycl::exception(sycl::errc::build, Result.getBuildLog()); } - return createDeviceBinaries(Result.getBundleInfo(), - /*OffloadEntryPrefix=*/CompilationID + '$'); + const auto &IR = Result.getDeviceCodeIR(); + if (!CacheKey.empty() && !IR.empty()) { + // The RTC result contains the bitcode blob iff the frontend was invoked on + // the source string, meaning we encountered either a cache miss, or a cache + // hit that returned unusable IR (e.g. due to a bitcode version mismatch). + // There's no explicit mechanism to invalidate the cache entry - we just + // overwrite the entry with the newly compiled IR. + std::vector SavedIR{IR.begin(), IR.end()}; + PersistentDeviceCodeCache::putDeviceCodeIRToDisc(CacheKey, SavedIR); + } + + std::string Prefix = CompilationID + '$'; + return std::make_pair(createDeviceBinaries(Result.getBundleInfo(), Prefix), + std::move(Prefix)); } } // namespace detail diff --git a/sycl/source/detail/jit_compiler.hpp b/sycl/source/detail/jit_compiler.hpp index cf404e7bb723e..6a3bbe56e3d46 100644 --- a/sycl/source/detail/jit_compiler.hpp +++ b/sycl/source/detail/jit_compiler.hpp @@ -49,7 +49,7 @@ class jit_compiler { const std::string &KernelName, const std::vector &SpecConstBlob); - sycl_device_binaries compileSYCL( + std::pair compileSYCL( const std::string &CompilationID, const std::string &SYCLSource, const std::vector> &IncludePairs, const std::vector &UserArgs, std::string *LogPtr, @@ -78,7 +78,7 @@ class jit_compiler { sycl_device_binaries createDeviceBinaries(const ::jit_compiler::RTCBundleInfo &BundleInfo, - const std::string &OffloadEntryPrefix); + const std::string &Prefix); std::vector encodeArgUsageMask(const ::jit_compiler::ArgUsageMask &Mask) const; @@ -105,12 +105,14 @@ class jit_compiler { using FuseKernelsFuncT = decltype(::jit_compiler::fuseKernels) *; using MaterializeSpecConstFuncT = decltype(::jit_compiler::materializeSpecConstants) *; + using CalculateHashFuncT = decltype(::jit_compiler::calculateHash) *; using CompileSYCLFuncT = decltype(::jit_compiler::compileSYCL) *; using DestroyBinaryFuncT = decltype(::jit_compiler::destroyBinary) *; using ResetConfigFuncT = decltype(::jit_compiler::resetJITConfiguration) *; using AddToConfigFuncT = decltype(::jit_compiler::addToJITConfiguration) *; FuseKernelsFuncT FuseKernelsHandle = nullptr; MaterializeSpecConstFuncT MaterializeSpecConstHandle = nullptr; + CalculateHashFuncT CalculateHashHandle = nullptr; CompileSYCLFuncT CompileSYCLHandle = nullptr; DestroyBinaryFuncT DestroyBinaryHandle = nullptr; ResetConfigFuncT ResetConfigHandle = nullptr; diff --git a/sycl/source/detail/kernel_bundle_impl.hpp b/sycl/source/detail/kernel_bundle_impl.hpp index be22afa63712a..8bad9bb34ee4a 100644 --- a/sycl/source/detail/kernel_bundle_impl.hpp +++ b/sycl/source/detail/kernel_bundle_impl.hpp @@ -498,10 +498,8 @@ class kernel_bundle_impl { if (MLanguage == syclex::source_language::sycl_jit) { // Build device images via the program manager. - // TODO: Support persistent caching. - const std::string &SourceStr = std::get(MSource); - auto [Binaries, CompilationID] = syclex::detail::SYCL_JIT_to_SPIRV( + auto [Binaries, Prefix] = syclex::detail::SYCL_JIT_to_SPIRV( SourceStr, MIncludePairs, BuildOptions, LogPtr, RegisteredKernelNames); @@ -510,9 +508,6 @@ class kernel_bundle_impl { std::vector KernelIDs; std::vector KernelNames; - // `jit_compiler::compileSYCL(..)` uses `CompilationID + '$'` as prefix - // for offload entry names. - std::string Prefix = CompilationID + '$'; for (const auto &KernelID : PM.getAllSYCLKernelIDs()) { std::string_view KernelName{KernelID.get_name()}; if (KernelName.find(Prefix) == 0) { diff --git a/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp b/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp index 56b79340b1309..ce5793e356abf 100644 --- a/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp +++ b/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp @@ -305,18 +305,16 @@ bool SYCL_JIT_Compilation_Available() { std::pair SYCL_JIT_to_SPIRV( [[maybe_unused]] const std::string &SYCLSource, - [[maybe_unused]] include_pairs_t IncludePairs, + [[maybe_unused]] const include_pairs_t &IncludePairs, [[maybe_unused]] const std::vector &UserArgs, [[maybe_unused]] std::string *LogPtr, [[maybe_unused]] const std::vector &RegisteredKernelNames) { #if SYCL_EXT_JIT_ENABLE static std::atomic_uintptr_t CompilationCounter; std::string CompilationID = "rtc_" + std::to_string(CompilationCounter++); - sycl_device_binaries Binaries = - sycl::detail::jit_compiler::get_instance().compileSYCL( - CompilationID, SYCLSource, IncludePairs, UserArgs, LogPtr, - RegisteredKernelNames); - return std::make_pair(Binaries, std::move(CompilationID)); + return sycl::detail::jit_compiler::get_instance().compileSYCL( + CompilationID, SYCLSource, IncludePairs, UserArgs, LogPtr, + RegisteredKernelNames); #else throw sycl::exception(sycl::errc::build, "kernel_compiler via sycl-jit is not available"); diff --git a/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp b/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp index 1a1a2665ae313..fdcaaea537046 100644 --- a/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp +++ b/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp @@ -35,8 +35,14 @@ bool SYCL_Compilation_Available(); std::string userArgsAsString(const std::vector &UserArguments); +// Compile the given SYCL source string and virtual include files into the image +// format understood by the program manager. +// +// Returns a pointer to the image (owned by the `jit_compiler` class), and the +// bundle-specific prefix used for loading the kernels. std::pair -SYCL_JIT_to_SPIRV(const std::string &Source, include_pairs_t IncludePairs, +SYCL_JIT_to_SPIRV(const std::string &Source, + const include_pairs_t &IncludePairs, const std::vector &UserArgs, std::string *LogPtr, const std::vector &RegisteredKernelNames); diff --git a/sycl/source/detail/persistent_device_code_cache.cpp b/sycl/source/detail/persistent_device_code_cache.cpp index 213948d526f59..0a6e708543e27 100644 --- a/sycl/source/detail/persistent_device_code_cache.cpp +++ b/sycl/source/detail/persistent_device_code_cache.cpp @@ -557,6 +557,53 @@ void PersistentDeviceCodeCache::putCompiledKernelToDisc( updateCacheFileSizeAndTriggerEviction(getRootDir(), TotalSize); } +void PersistentDeviceCodeCache::putDeviceCodeIRToDisc( + const std::string &Key, const std::vector &IR) { + + repopulateCacheSizeFile(getRootDir()); + + // Do not insert any new item if eviction is in progress. + // Since evictions are rare, we can afford to spin lock here. + const std::string EvictionInProgressFile = + getRootDir() + EvictionInProgressFileSuffix; + // Stall until the other process finishes eviction. + while (OSUtil::isPathPresent(EvictionInProgressFile)) + continue; + + // Total size of the item that we are writing to the cache. + size_t TotalSize = 0; + + std::string DirName = getDeviceCodeIRPath(Key); + std::string FileName = DirName + "/ir"; + std::string FullFileName = FileName + ".bin"; + + try { + OSUtil::makeDir(DirName.c_str()); + LockCacheItem Lock{FileName}; + if (Lock.isOwned()) { + writeBinaryDataToFile(FullFileName, IR); + PersistentDeviceCodeCache::trace_KernelCompiler( + "storing device code IR: ", FullFileName); + + TotalSize = getFileSize(FullFileName); + saveCurrentTimeInAFile(FileName + CacheEntryAccessTimeSuffix); + } else { + PersistentDeviceCodeCache::trace_KernelCompiler("cache lock not owned ", + FileName); + } + } catch (std::exception &e) { + PersistentDeviceCodeCache::trace_KernelCompiler( + std::string("exception encountered making cache: ") + e.what()); + } catch (...) { + PersistentDeviceCodeCache::trace_KernelCompiler( + std::string("error outputting cache: ") + std::strerror(errno)); + } + + // Update the cache size file and trigger cache eviction if needed. + if (TotalSize) + updateCacheFileSizeAndTriggerEviction(getRootDir(), TotalSize); +} + /* Program binaries built for one or more devices are read from persistent * cache and returned in form of vector of programs. Each binary program is * stored in vector of chars. There is a one-to-one correspondence between @@ -664,6 +711,39 @@ PersistentDeviceCodeCache::getCompiledKernelFromDisc( return Binaries; } +std::vector +PersistentDeviceCodeCache::getDeviceCodeIRFromDisc(const std::string &Key) { + std::vector IR; + + std::string DirName = getDeviceCodeIRPath(Key); + std::string FileName = DirName + "/ir"; + std::string FullFileName = FileName + ".bin"; + + if (DirName.empty() || !OSUtil::isPathPresent(FullFileName)) { + trace_KernelCompiler("cache miss: ", Key); + return {}; + } + + if (!LockCacheItem::isLocked(FileName)) { + try { + IR = readBinaryDataFromFile(FullFileName); + + // Explicitly update the access time of the file. This is required for + // eviction. + if (isEvictionEnabled()) + saveCurrentTimeInAFile(FileName + CacheEntryAccessTimeSuffix); + } catch (...) { + // If read was unsuccessfull give up + trace_KernelCompiler("cache miss: ", Key); + return {}; + } + } + + PersistentDeviceCodeCache::trace_KernelCompiler( + "using cached device code IR: ", FullFileName); + return IR; +} + /* Returns string value which can be used to identify different device */ std::string PersistentDeviceCodeCache::getDeviceIDString(const device &Device) { @@ -864,6 +944,17 @@ std::string PersistentDeviceCodeCache::getCompiledKernelItemPath( std::to_string(StringHasher(SourceString)); } +std::string +PersistentDeviceCodeCache::getDeviceCodeIRPath(const std::string &Key) { + std::string cache_root{getRootDir()}; + if (cache_root.empty()) { + trace("Disable persistent cache due to unconfigured cache root."); + return {}; + } + + return cache_root + "/ext_kernel_compiler/" + Key; +} + /* Returns true if persistent cache is enabled. */ bool PersistentDeviceCodeCache::isEnabled() { diff --git a/sycl/source/detail/persistent_device_code_cache.hpp b/sycl/source/detail/persistent_device_code_cache.hpp index 9346461c9229f..464794ddbde30 100644 --- a/sycl/source/detail/persistent_device_code_cache.hpp +++ b/sycl/source/detail/persistent_device_code_cache.hpp @@ -178,6 +178,14 @@ class PersistentDeviceCodeCache { const std::string &BuildOptionsString, const std::string &SourceString); + /* Get directory name when storing runtime compiled device code IR (via + * kernel_compiler, sycl_jit language). The key is computed in the sycl-jit + * library, and encompasses the preprocesses source code, build options and + * compiler location. The frontend invocation (whose output we cache here) is + * device-agnostic, hence the device (list) is not part of the lookup. + */ + static std::string getDeviceCodeIRPath(const std::string &Key); + /* Program binaries built for one or more devices are read from persistent * cache and returned in form of vector of programs. Each binary program is * stored in vector of chars. @@ -193,6 +201,8 @@ class PersistentDeviceCodeCache { const std::string &BuildOptionsString, const std::string &SourceStr); + static std::vector getDeviceCodeIRFromDisc(const std::string &Key); + /* Stores build program in persistent cache */ static void @@ -207,6 +217,9 @@ class PersistentDeviceCodeCache { const std::string &SourceStr, const ur_program_handle_t &NativePrg); + static void putDeviceCodeIRToDisc(const std::string &Key, + const std::vector &IR); + /* Sends message to std:cerr stream when SYCL_CACHE_TRACE environemnt is set*/ static void trace(const std::string &msg, const std::string &path = "") { static const bool traceEnabled = diff --git a/sycl/test-e2e/KernelCompiler/kernel_compiler_sycl_jit_cache.cpp b/sycl/test-e2e/KernelCompiler/kernel_compiler_sycl_jit_cache.cpp new file mode 100644 index 0000000000000..fa527c49a6854 --- /dev/null +++ b/sycl/test-e2e/KernelCompiler/kernel_compiler_sycl_jit_cache.cpp @@ -0,0 +1,145 @@ +//==- kernel_compiler_sycl_jit_cache.cpp --- persistent cache for SYCL-RTC -==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// REQUIRES: (opencl || level_zero) +// REQUIRES: aspect-usm_device_allocations + +// UNSUPPORTED: accelerator +// UNSUPPORTED-INTENDED: while accelerator is AoT only, this cannot run there. + +// DEFINE: %{cache_vars} = env SYCL_CACHE_PERSISTENT=1 SYCL_CACHE_TRACE=7 SYCL_CACHE_DIR=%t/cache_dir +// DEFINE: %{max_cache_size} = SYCL_CACHE_MAX_SIZE=30000 +// RUN: %{build} -o %t.out +// RUN: %{run-aux} rm -rf %t/cache_dir +// RUN: %{cache_vars} %{run-unfiltered-devices} %t.out 2>&1 | FileCheck %s --check-prefixes=CHECK,CHECK-UNLIM +// RUN: %{run-aux} rm -rf %t/cache_dir +// RUN: %{cache_vars} %{max_cache_size} %{run-unfiltered-devices} %t.out 2>&1 | FileCheck %s --check-prefixes=CHECK,CHECK-EVICT + +#include +#include + +auto constexpr SYCLSource = R"""( +#include + +extern "C" SYCL_EXTERNAL +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((sycl::ext::oneapi::experimental::nd_range_kernel<1>)) +void vec_add(float* in1, float* in2, float* out){ + size_t id = sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_global_linear_id(); + out[id] = in1[id] + in2[id]; +} +)"""; + +auto constexpr SYCLSourceWithInclude = R"""( + #include "myheader.h" + #include + + extern "C" SYCL_EXTERNAL + SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((sycl::ext::oneapi::experimental::nd_range_kernel<1>)) + void KERNEL_NAME(float* in1, float* out){ + size_t id = sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_global_linear_id(); + out[id] = in1[id]; + } + )"""; + +static void dumpKernelIDs() { + for (auto &kernelID : sycl::get_kernel_ids()) + std::cout << kernelID.get_name() << std::endl; +} + +int test_persistent_cache() { + namespace syclex = sycl::ext::oneapi::experimental; + using source_kb = sycl::kernel_bundle; + using exe_kb = sycl::kernel_bundle; + + sycl::queue q; + sycl::context ctx = q.get_context(); + + bool ok = + q.get_device().ext_oneapi_can_compile(syclex::source_language::sycl_jit); + if (!ok) { + std::cout << "Apparently this device does not support `sycl_jit` source " + "kernel bundle extension: " + << q.get_device().get_info() + << std::endl; + return -1; + } + + source_kb kbSrc1 = syclex::create_kernel_bundle_from_source( + ctx, syclex::source_language::sycl_jit, SYCLSource); + + // Bundle is entered into cache on first build. + // CHECK: [kernel_compiler Persistent Cache]: cache miss: [[KEY1:.*]] + // CHECK: [kernel_compiler Persistent Cache]: storing device code IR: {{.*}}/[[KEY1]] + exe_kb kbExe1a = syclex::build(kbSrc1); + dumpKernelIDs(); + // CHECK: rtc_0$__sycl_kernel_vec_add + + // Cache hit! We get independent bundles with their own version of the kernel. + // CHECK: [kernel_compiler Persistent Cache]: using cached device code IR: {{.*}}/[[KEY1]] + exe_kb kbExe1b = syclex::build(kbSrc1); + dumpKernelIDs(); + // CHECK-DAG: rtc_0$__sycl_kernel_vec_add + // CHECK-DAG: rtc_1$__sycl_kernel_vec_add + + source_kb kbSrc2 = syclex::create_kernel_bundle_from_source( + ctx, syclex::source_language::sycl_jit, SYCLSource); + + // Different source bundle, but identical source is a cache hit. + // CHECK: [kernel_compiler Persistent Cache]: using cached device code IR: {{.*}}/[[KEY1]] + exe_kb kbExe2a = syclex::build(kbSrc2); + + // Different build_options means no cache hit. + // CHECK: [kernel_compiler Persistent Cache]: cache miss: [[KEY2:.*]] + // CHECK: [kernel_compiler Persistent Cache]: storing device code IR: {{.*}}/[[KEY2]] + std::vector flags{"-g", "-fno-fast-math"}; + exe_kb kbExe1c = + syclex::build(kbSrc1, syclex::properties{syclex::build_options{flags}}); + + // The kbExe1c build should trigger eviction if cache size is limited. + // CHECK-UNLIM: [kernel_compiler Persistent Cache]: using cached device code IR: {{.*}}/[[KEY1]] + // CHECK-EVICT: [Persistent Cache]: Cache eviction triggered. + // CHECK-EVICT: [Persistent Cache]: File removed: {{.*}}/[[KEY1]] + // CHECK-EVICT: [kernel_compiler Persistent Cache]: cache miss: [[KEY1]] + // CHECK-EVICT: [kernel_compiler Persistent Cache]: storing device code IR: {{.*}}/[[KEY1]] + exe_kb kbExe2b = syclex::build(kbSrc2); + + source_kb kbSrc3 = syclex::create_kernel_bundle_from_source( + ctx, syclex::source_language::sycl_jit, SYCLSourceWithInclude, + syclex::properties{ + syclex::include_files{"myheader.h", "#define KERNEL_NAME foo"}}); + + // New source string -> cache miss + // CHECK: [kernel_compiler Persistent Cache]: cache miss: [[KEY3:.*]] + // CHECK: [kernel_compiler Persistent Cache]: storing device code IR: {{.*}}/[[KEY3]] + exe_kb kbExe3a = syclex::build(kbSrc3); + dumpKernelIDs(); + // CHECK: rtc_5$__sycl_kernel_foo + + source_kb kbSrc4 = syclex::create_kernel_bundle_from_source( + ctx, syclex::source_language::sycl_jit, SYCLSourceWithInclude, + syclex::properties{ + syclex::include_files{"myheader.h", "#define KERNEL_NAME bar"}}); + + // Same source string, but different header contents -> cache miss + // CHECK: [kernel_compiler Persistent Cache]: cache miss: [[KEY4:.*]] + // CHECK: [kernel_compiler Persistent Cache]: storing device code IR: {{.*}}/[[KEY4]] + exe_kb kbExe4a = syclex::build(kbSrc4); + dumpKernelIDs(); + // CHECK: rtc_6$__sycl_kernel_bar + + return 0; +} + +int main(int argc, char **) { +#ifdef SYCL_EXT_ONEAPI_KERNEL_COMPILER + return test_persistent_cache(); +#else + static_assert(false, "Kernel Compiler feature test macro undefined"); +#endif + return 0; +} diff --git a/sycl/test/e2e_test_requirements/no_sycl_hpp_in_e2e_tests.cpp b/sycl/test/e2e_test_requirements/no_sycl_hpp_in_e2e_tests.cpp index 692ca4b8a16d2..dd068fb40752a 100644 --- a/sycl/test/e2e_test_requirements/no_sycl_hpp_in_e2e_tests.cpp +++ b/sycl/test/e2e_test_requirements/no_sycl_hpp_in_e2e_tests.cpp @@ -6,7 +6,7 @@ // CHECK-DAG: README.md // CHECK-DAG: lit.cfg.py // -// CHECK-NUM-MATCHES: 6 +// CHECK-NUM-MATCHES: 7 // // This test verifies that `` isn't used in E2E tests. Instead, // fine-grained includes should used, see