@@ -531,6 +531,19 @@ static bool getDeviceLibraries(const ArgList &Args,
531531 return FoundUnknownLib;
532532}
533533
534+ static Expected<std::unique_ptr<llvm::Module>>
535+ loadBitcodeLibrary (StringRef LibPath, LLVMContext &Context) {
536+ SMDiagnostic Diag;
537+ std::unique_ptr<llvm::Module> Lib = parseIRFile (LibPath, Diag, Context);
538+ if (!Lib) {
539+ std::string DiagMsg;
540+ raw_string_ostream SOS (DiagMsg);
541+ Diag.print (/* ProgName=*/ nullptr , SOS);
542+ return createStringError (DiagMsg);
543+ }
544+ return std::move (Lib);
545+ }
546+
534547Error jit_compiler::linkDeviceLibraries (llvm::Module &Module,
535548 const InputArgList &UserArgList,
536549 std::string &BuildLog) {
@@ -558,16 +571,13 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
558571 for (const std::string &LibName : LibNames) {
559572 std::string LibPath = DPCPPRoot + " /lib/" + LibName;
560573
561- SMDiagnostic Diag;
562- std::unique_ptr<llvm::Module> Lib = parseIRFile (LibPath, Diag, Context);
563- if (!Lib) {
564- std::string DiagMsg;
565- raw_string_ostream SOS (DiagMsg);
566- Diag.print (/* ProgName=*/ nullptr , SOS);
567- return createStringError (DiagMsg);
574+ auto LibOrErr = loadBitcodeLibrary (LibPath, Context);
575+ if (!LibOrErr) {
576+ return LibOrErr.takeError ();
568577 }
569578
570- if (Linker::linkModules (Module, std::move (Lib), Linker::LinkOnlyNeeded)) {
579+ if (Linker::linkModules (Module, std::move (*LibOrErr),
580+ Linker::LinkOnlyNeeded)) {
571581 return createStringError (" Unable to link device library %s: %s" ,
572582 LibPath.c_str (), BuildLog.c_str ());
573583 }
@@ -607,6 +617,31 @@ static IRSplitMode getDeviceCodeSplitMode(const InputArgList &UserArgList) {
607617 return SPLIT_AUTO;
608618}
609619
620+ static void encodeProperties (PropertySetRegistry &Properties,
621+ RTCDevImgInfo &DevImgInfo) {
622+ const auto &PropertySets = Properties.getPropSets ();
623+
624+ DevImgInfo.Properties = FrozenPropertyRegistry{PropertySets.size ()};
625+ for (auto [KV, FrozenPropSet] :
626+ zip_equal (PropertySets, DevImgInfo.Properties )) {
627+ const auto &PropertySetName = KV.first ;
628+ const auto &PropertySet = KV.second ;
629+ FrozenPropSet =
630+ FrozenPropertySet{PropertySetName.str (), PropertySet.size ()};
631+ for (auto [KV2, FrozenProp] :
632+ zip_equal (PropertySet, FrozenPropSet.Values )) {
633+ const auto &PropertyName = KV2.first ;
634+ const auto &PropertyValue = KV2.second ;
635+ FrozenProp = PropertyValue.getType () == PropertyValue::Type::UINT32
636+ ? FrozenPropertyValue{PropertyName.str (),
637+ PropertyValue.asUint32 ()}
638+ : FrozenPropertyValue{
639+ PropertyName.str (), PropertyValue.asRawByteArray (),
640+ PropertyValue.getRawByteArraySize ()};
641+ }
642+ };
643+ }
644+
610645Expected<PostLinkResult>
611646jit_compiler::performPostLink (std::unique_ptr<llvm::Module> Module,
612647 const InputArgList &UserArgList) {
@@ -637,9 +672,9 @@ jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
637672 // Otherwise: Port over the `removeSYCLKernelsConstRefArray` and
638673 // `removeDeviceGlobalFromCompilerUsed` methods.
639674
640- assert (!isModuleUsingAsan (*Module));
641- // Otherwise: Need to instrument each image scope device globals if the module
642- // has been instrumented by sanitizer pass .
675+ assert (!( isModuleUsingAsan (*Module) || isModuleUsingMsan (*Module) ||
676+ isModuleUsingTsan (*Module)));
677+ // Otherwise: Run `SanitizerKernelMetadataPass` .
643678
644679 // Transform Joint Matrix builtin calls to align them with SPIR-V friendly
645680 // LLVM IR specification.
@@ -668,6 +703,7 @@ jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
668703 // `-fno-sycl-device-code-split-esimd` as a prerequisite for compiling
669704 // `invoke_simd` code.
670705
706+ bool IsBF16DeviceLibUsed = false ;
671707 while (Splitter->hasMoreSplits ()) {
672708 ModuleDesc MDesc = Splitter->nextSplit ();
673709
@@ -701,35 +737,58 @@ jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
701737 /* DeviceGlobals=*/ false };
702738 PropertySetRegistry Properties =
703739 computeModuleProperties (MDesc.getModule (), MDesc.entries (), PropReq);
740+
741+ // When the split mode is none, the required work group size will be added
742+ // to the whole module, which will make the runtime unable to launch the
743+ // other kernels in the module that have different required work group
744+ // sizes or no required work group sizes. So we need to remove the
745+ // required work group size metadata in this case.
746+ if (SplitMode == module_split::SPLIT_NONE) {
747+ Properties.remove (PropSetRegTy::SYCL_DEVICE_REQUIREMENTS,
748+ PropSetRegTy::PROPERTY_REQD_WORK_GROUP_SIZE);
749+ }
750+
704751 // TODO: Manually add `compile_target` property as in
705752 // `saveModuleProperties`?
706- const auto &PropertySets = Properties.getPropSets ();
707-
708- DevImgInfo.Properties = FrozenPropertyRegistry{PropertySets.size ()};
709- for (auto [KV, FrozenPropSet] :
710- zip_equal (PropertySets, DevImgInfo.Properties )) {
711- const auto &PropertySetName = KV.first ;
712- const auto &PropertySet = KV.second ;
713- FrozenPropSet =
714- FrozenPropertySet{PropertySetName.str (), PropertySet.size ()};
715- for (auto [KV2, FrozenProp] :
716- zip_equal (PropertySet, FrozenPropSet.Values )) {
717- const auto &PropertyName = KV2.first ;
718- const auto &PropertyValue = KV2.second ;
719- FrozenProp =
720- PropertyValue.getType () == PropertyValue::Type::UINT32
721- ? FrozenPropertyValue{PropertyName.str (),
722- PropertyValue.asUint32 ()}
723- : FrozenPropertyValue{PropertyName.str (),
724- PropertyValue.asRawByteArray (),
725- PropertyValue.getRawByteArraySize ()};
726- }
727- };
728753
754+ encodeProperties (Properties, DevImgInfo);
755+
756+ IsBF16DeviceLibUsed |= isSYCLDeviceLibBF16Used (MDesc.getModule ());
729757 Modules.push_back (MDesc.releaseModulePtr ());
730758 }
731759 }
732760
761+ if (IsBF16DeviceLibUsed) {
762+ const std::string &DPCPPRoot = getDPCPPRoot ();
763+ if (DPCPPRoot == InvalidDPCPPRoot) {
764+ return createStringError (" Could not locate DPCPP root directory" );
765+ }
766+
767+ auto &Ctx = Modules.front ()->getContext ();
768+ auto WrapLibraryInDevImg = [&](const std::string &LibName) -> Error {
769+ std::string LibPath = DPCPPRoot + " /lib/" + LibName;
770+ auto LibOrErr = loadBitcodeLibrary (LibPath, Ctx);
771+ if (!LibOrErr) {
772+ return LibOrErr.takeError ();
773+ }
774+
775+ std::unique_ptr<llvm::Module> LibModule = std::move (*LibOrErr);
776+ PropertySetRegistry Properties =
777+ computeDeviceLibProperties (*LibModule, LibName);
778+ encodeProperties (Properties, DevImgInfoVec.emplace_back ());
779+ Modules.push_back (std::move (LibModule));
780+
781+ return Error::success ();
782+ };
783+
784+ if (auto Err = WrapLibraryInDevImg (" libsycl-fallback-bfloat16.bc" )) {
785+ return std::move (Err);
786+ }
787+ if (auto Err = WrapLibraryInDevImg (" libsycl-native-bfloat16.bc" )) {
788+ return std::move (Err);
789+ }
790+ }
791+
733792 assert (DevImgInfoVec.size () == Modules.size ());
734793 RTCBundleInfo BundleInfo;
735794 BundleInfo.DevImgInfos = DynArray<RTCDevImgInfo>{DevImgInfoVec.size ()};
0 commit comments