@@ -777,9 +777,11 @@ class MemorySanitizerOnSpirv {
777777 IsSPIRV = TargetTriple.isSPIROrSPIRV ();
778778
779779 IntptrTy = DL.getIntPtrType (C);
780+ Int32Ty = Type::getInt32Ty (C);
780781 }
781782
782783 bool instrumentModule ();
784+ void instrumentFunction (Function &F);
783785
784786 Constant *getOrCreateGlobalString (StringRef Name, StringRef Value,
785787 unsigned AddressSpace);
@@ -788,6 +790,7 @@ class MemorySanitizerOnSpirv {
788790 void initializeCallbacks ();
789791 void instrumentGlobalVariables ();
790792 void instrumentStaticLocalMemory ();
793+ void instrumentDynamicLocalMemory (Function &F);
791794 void instrumentKernelsMetadata ();
792795
793796 void initializeRetVecMap (Function *F);
@@ -799,15 +802,25 @@ class MemorySanitizerOnSpirv {
799802 const DataLayout &DL;
800803 bool IsSPIRV;
801804 Type *IntptrTy;
805+ Type *Int32Ty;
802806
803807 StringMap<GlobalVariable *> GlobalStringMap;
804808
805809 DenseMap<Function *, SmallVector<Instruction *, 8 >> KernelToRetVecMap;
806810 DenseMap<Function *, SmallVector<Constant *, 8 >> KernelToLocalMemMap;
807811 DenseMap<Function *, DenseSet<Function *>> FuncToKernelCallerMap;
808812
813+ // Make sure that we insert barriers only once per function, and the barrier
814+ // needs to be inserted after all "MsanPoisonShadowStaticLocalFunc" and
815+ // "MsanPoisonShadowDynamicLocalFunc", and before
816+ // "MsanUnpoisonShadowStaticLocalFunc" and
817+ // "MsanUnpoisonShadowDynamicLocalFunc".
818+ DenseMap<Function *, bool > InsertBarrier;
819+
809820 FunctionCallee MsanPoisonShadowStaticLocalFunc;
810821 FunctionCallee MsanUnpoisonShadowStaticLocalFunc;
822+ FunctionCallee MsanPoisonShadowDynamicLocalFunc;
823+ FunctionCallee MsanUnpoisonShadowDynamicLocalFunc;
811824 FunctionCallee MsanBarrierFunc;
812825};
813826
@@ -874,6 +887,21 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
874887 M.getOrInsertFunction (" __msan_unpoison_shadow_static_local" ,
875888 IRB.getVoidTy (), IntptrTy, IntptrTy);
876889
890+ // __asan_poison_shadow_dynamic_local(
891+ // uptr ptr,
892+ // uint32_t num_args
893+ // )
894+ MsanPoisonShadowDynamicLocalFunc = M.getOrInsertFunction (
895+ " __msan_poison_shadow_dynamic_local" , IRB.getVoidTy (), IntptrTy, Int32Ty);
896+
897+ // __asan_unpoison_shadow_dynamic_local(
898+ // uptr ptr,
899+ // uint32_t num_args
900+ // )
901+ MsanUnpoisonShadowDynamicLocalFunc =
902+ M.getOrInsertFunction (" __msan_unpoison_shadow_dynamic_local" ,
903+ IRB.getVoidTy (), IntptrTy, Int32Ty);
904+
877905 // __msan_barrier()
878906 MsanBarrierFunc = M.getOrInsertFunction (" __msan_barrier" , IRB.getVoidTy ());
879907}
@@ -951,16 +979,15 @@ void MemorySanitizerOnSpirv::instrumentStaticLocalMemory() {
951979 if (!ClSpirOffloadLocals)
952980 return ;
953981
954- DenseMap<Function *, bool > InsertBarrier;
955-
956- auto Instrument = [this , &InsertBarrier](GlobalVariable *G, Function *F) {
982+ auto Instrument = [this ](GlobalVariable *G, Function *F) {
957983 const uint64_t SizeInBytes = DL.getTypeAllocSize (G->getValueType ());
958984
959- // Poison shadow of static local memory
960985 if (!InsertBarrier[F]) {
961986 IRBuilder<> Builder (&F->getEntryBlock ().front ());
962987 Builder.CreateCall (MsanBarrierFunc);
963988 }
989+
990+ // Poison shadow of static local memory
964991 IRBuilder<> Builder (&F->getEntryBlock ().front ());
965992 Builder.CreateCall (MsanPoisonShadowStaticLocalFunc,
966993 {Builder.CreatePointerCast (G, IntptrTy),
@@ -1001,6 +1028,54 @@ void MemorySanitizerOnSpirv::instrumentStaticLocalMemory() {
10011028 }
10021029}
10031030
1031+ void MemorySanitizerOnSpirv::instrumentDynamicLocalMemory (Function &F) {
1032+ if (!ClSpirOffloadLocals)
1033+ return ;
1034+
1035+ // Poison shadow of local memory in kernel argument, required by CPU device
1036+ SmallVector<Argument *> LocalArgs;
1037+ for (auto &Arg : F.args ()) {
1038+ Type *PtrTy = dyn_cast<PointerType>(Arg.getType ()->getScalarType ());
1039+ if (PtrTy && PtrTy->getPointerAddressSpace () == kSpirOffloadLocalAS )
1040+ LocalArgs.push_back (&Arg);
1041+ }
1042+
1043+ if (LocalArgs.empty ())
1044+ return ;
1045+
1046+ if (!InsertBarrier[&F]) {
1047+ IRBuilder<> Builder (&F.getEntryBlock ().front ());
1048+ Builder.CreateCall (MsanBarrierFunc);
1049+ }
1050+
1051+ IRBuilder<> IRB (&F.getEntryBlock ().front ());
1052+
1053+ AllocaInst *ArgsArray = IRB.CreateAlloca (
1054+ IntptrTy, ConstantInt::get (Int32Ty, LocalArgs.size ()), " local_args" );
1055+ for (size_t i = 0 ; i < LocalArgs.size (); i++) {
1056+ auto *StoreDest =
1057+ IRB.CreateGEP (IntptrTy, ArgsArray, ConstantInt::get (Int32Ty, i));
1058+ IRB.CreateStore (IRB.CreatePointerCast (LocalArgs[i], IntptrTy), StoreDest);
1059+ }
1060+
1061+ auto *ArgsArrayAddr = IRB.CreatePointerCast (ArgsArray, IntptrTy);
1062+ IRB.CreateCall (MsanPoisonShadowDynamicLocalFunc,
1063+ {ArgsArrayAddr, ConstantInt::get (Int32Ty, LocalArgs.size ())});
1064+
1065+ // Unpoison shadow of dynamic local memory, required by CPU device
1066+ initializeRetVecMap (&F);
1067+ for (Instruction *Ret : KernelToRetVecMap[&F]) {
1068+ IRBuilder<> IRBRet (Ret);
1069+ if (!InsertBarrier[&F])
1070+ IRBRet.CreateCall (MsanBarrierFunc);
1071+ IRBRet.CreateCall (
1072+ MsanUnpoisonShadowDynamicLocalFunc,
1073+ {ArgsArrayAddr, ConstantInt::get (Int32Ty, LocalArgs.size ())});
1074+ }
1075+
1076+ InsertBarrier[&F] = true ;
1077+ }
1078+
10041079// Instrument __MsanKernelMetadata, which records information of sanitized
10051080// kernel
10061081void MemorySanitizerOnSpirv::instrumentKernelsMetadata () {
@@ -1087,6 +1162,14 @@ bool MemorySanitizerOnSpirv::instrumentModule() {
10871162 return true ;
10881163}
10891164
1165+ void MemorySanitizerOnSpirv::instrumentFunction (Function &F) {
1166+ if (!IsSPIRV)
1167+ return ;
1168+
1169+ if (F.getCallingConv () == CallingConv::SPIR_KERNEL)
1170+ instrumentDynamicLocalMemory (F);
1171+ }
1172+
10901173PreservedAnalyses MemorySanitizerPass::run (Module &M,
10911174 ModuleAnalysisManager &AM) {
10921175 // Return early if nosanitize_memory module flag is present for the module.
@@ -1110,6 +1193,7 @@ PreservedAnalyses MemorySanitizerPass::run(Module &M,
11101193 MemorySanitizer Msan (*F.getParent (), MsanSpirv, Options);
11111194 Modified |=
11121195 Msan.sanitizeFunction (F, FAM.getResult <TargetLibraryAnalysis>(F));
1196+ MsanSpirv.instrumentFunction (F);
11131197 }
11141198
11151199 if (!Modified)
0 commit comments