1919#include " translation/SPIRVLLVMTranslation.h"
2020
2121#include < llvm/ADT/StringExtras.h>
22+ #include < llvm/Bitcode/BitcodeReader.h>
23+ #include < llvm/Bitcode/BitcodeWriter.h>
2224#include < llvm/Support/Error.h>
25+ #include < llvm/Support/MemoryBuffer.h>
2326#include < llvm/Support/TimeProfiler.h>
2427
2528#include < clang/Driver/Options.h>
@@ -31,17 +34,21 @@ using namespace jit_compiler;
3134using FusedFunction = helper::FusionHelper::FusedFunction;
3235using FusedFunctionList = std::vector<FusedFunction>;
3336
34- template <typename ResultType>
35- static ResultType errorTo (llvm::Error &&Err, const std::string &Msg) {
37+ static std::string formatError (llvm::Error &&Err, const std::string &Msg) {
3638 std::stringstream ErrMsg;
3739 ErrMsg << Msg << " \n Detailed information:\n " ;
3840 llvm::handleAllErrors (std::move (Err),
3941 [&ErrMsg](const llvm::StringError &StrErr) {
40- // Cannot throw an exception here if LLVM itself is
41- // compiled without exception support.
4242 ErrMsg << " \t " << StrErr.getMessage () << " \n " ;
4343 });
44- return ResultType{ErrMsg.str ().c_str ()};
44+ return ErrMsg.str ();
45+ }
46+
47+ template <typename ResultType>
48+ static ResultType errorTo (llvm::Error &&Err, const std::string &Msg) {
49+ // Cannot throw an exception here if LLVM itself is compiled without exception
50+ // support.
51+ return ResultType{formatError (std::move (Err), Msg).c_str ()};
4552}
4653
4754static std::vector<jit_compiler::NDRange>
@@ -240,10 +247,42 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
240247 return JITResult{FusedKernelInfo};
241248}
242249
250+ extern " C" KF_EXPORT_SYMBOL RTCHashResult
251+ calculateHash (InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
252+ View<const char *> UserArgs) {
253+ auto UserArgListOrErr = parseUserArgs (UserArgs);
254+ if (!UserArgListOrErr) {
255+ return RTCHashResult::failure (
256+ formatError (UserArgListOrErr.takeError (),
257+ " Parsing of user arguments failed" )
258+ .c_str ());
259+ }
260+ llvm::opt::InputArgList UserArgList = std::move (*UserArgListOrErr);
261+
262+ auto Start = std::chrono::high_resolution_clock::now ();
263+ auto HashOrError = calculateHash (SourceFile, IncludeFiles, UserArgList);
264+ if (!HashOrError) {
265+ return RTCHashResult::failure (
266+ formatError (HashOrError.takeError (), " Hashing failed" ).c_str ());
267+ }
268+ auto Hash = *HashOrError;
269+ auto Stop = std::chrono::high_resolution_clock::now ();
270+
271+ if (UserArgList.hasArg (clang::driver::options::OPT_ftime_trace_EQ)) {
272+ std::chrono::duration<double , std::milli> HashTime = Stop - Start;
273+ llvm::dbgs () << " Hashing of " << SourceFile.Path << " took "
274+ << int (HashTime.count ()) << " ms\n " ;
275+ }
276+
277+ return RTCHashResult::success (Hash.c_str ());
278+ }
279+
243280extern " C" KF_EXPORT_SYMBOL RTCResult
244281compileSYCL (InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
245- View<const char *> UserArgs) {
282+ View<const char *> UserArgs, View<char > CachedIR, bool SaveIR) {
283+ llvm::LLVMContext Context;
246284 std::string BuildLog;
285+ configureDiagnostics (Context, BuildLog);
247286
248287 auto UserArgListOrErr = parseUserArgs (UserArgs);
249288 if (!UserArgListOrErr) {
@@ -272,16 +311,43 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
272311 Verbose);
273312 }
274313
275- auto ModuleOrErr =
276- compileDeviceCode (SourceFile, IncludeFiles, UserArgList, BuildLog);
277- if (!ModuleOrErr) {
278- return errorTo<RTCResult>(ModuleOrErr.takeError (),
279- " Device compilation failed" );
314+ std::unique_ptr<llvm::Module> Module;
315+
316+ if (CachedIR.size () > 0 ) {
317+ llvm::StringRef IRStr{CachedIR.begin (), CachedIR.size ()};
318+ std::unique_ptr<llvm::MemoryBuffer> IRBuf =
319+ llvm::MemoryBuffer::getMemBuffer (IRStr, /* BufferName=*/ " " ,
320+ /* RequiresNullTerminator=*/ false );
321+ auto ModuleOrError = llvm::parseBitcodeFile (*IRBuf, Context);
322+ if (!ModuleOrError) {
323+ // Not a fatal error, we'll just compile the source string normally.
324+ BuildLog.append (formatError (ModuleOrError.takeError (),
325+ " Loading of cached device code failed" ));
326+ } else {
327+ Module = std::move (*ModuleOrError);
328+ }
280329 }
281330
282- std::unique_ptr<llvm::LLVMContext> Context;
283- std::unique_ptr<llvm::Module> Module = std::move (*ModuleOrErr);
284- Context.reset (&Module->getContext ());
331+ bool FromSource = false ;
332+ if (!Module) {
333+ auto ModuleOrErr = compileDeviceCode (SourceFile, IncludeFiles, UserArgList,
334+ BuildLog, Context);
335+ if (!ModuleOrErr) {
336+ return errorTo<RTCResult>(ModuleOrErr.takeError (),
337+ " Device compilation failed" );
338+ }
339+
340+ Module = std::move (*ModuleOrErr);
341+ FromSource = true ;
342+ }
343+
344+ RTCDeviceCodeIR IR;
345+ if (SaveIR && FromSource) {
346+ std::string BCString;
347+ llvm::raw_string_ostream BCStream{BCString};
348+ llvm::WriteBitcodeToFile (*Module, BCStream);
349+ IR = RTCDeviceCodeIR{BCString.data (), BCString.data () + BCString.size ()};
350+ }
285351
286352 if (auto Error = linkDeviceLibraries (*Module, UserArgList, BuildLog)) {
287353 return errorTo<RTCResult>(std::move (Error), " Device linking failed" );
@@ -314,7 +380,7 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
314380 }
315381 }
316382
317- return RTCResult{std::move (BundleInfo), BuildLog.c_str ()};
383+ return RTCResult{std::move (BundleInfo), std::move (IR), BuildLog.c_str ()};
318384}
319385
320386extern " C" KF_EXPORT_SYMBOL void destroyBinary (BinaryAddress Address) {
0 commit comments