@@ -108,15 +108,17 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
108108 // Compute the argument types of the corresponding scalar call and the scalar
109109 // function name. For calls, it additionally finds the function to replace
110110 // and checks that all vector operands match the previously found EC.
111- SmallVector<Type *, 8 > ScalarArgTypes;
111+ SmallVector<Type *, 8 > ScalarArgTypes, OrigArgTypes ;
112112 std::string ScalarName;
113113 Function *FuncToReplace = nullptr ;
114- if (auto *CI = dyn_cast<CallInst>(&I)) {
114+ auto *CI = dyn_cast<CallInst>(&I);
115+ if (CI) {
115116 FuncToReplace = CI->getCalledFunction ();
116117 Intrinsic::ID IID = FuncToReplace->getIntrinsicID ();
117118 assert (IID != Intrinsic::not_intrinsic && " Not an intrinsic" );
118119 for (auto Arg : enumerate(CI->args ())) {
119120 auto *ArgTy = Arg.value ()->getType ();
121+ OrigArgTypes.push_back (ArgTy);
120122 if (isVectorIntrinsicWithScalarOpAtArg (IID, Arg.index ())) {
121123 ScalarArgTypes.push_back (ArgTy);
122124 } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
@@ -174,6 +176,24 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
174176
175177 Function *TLIFunc = getTLIFunction (I.getModule (), VectorFTy,
176178 VD->getVectorFnName (), FuncToReplace);
179+
180+ // For calls, bail out when their arguments do not match with the TLI mapping.
181+ if (CI) {
182+ int IdxNonPred = 0 ;
183+ for (auto [OrigTy, VFParam] :
184+ zip (OrigArgTypes, OptInfo->Shape .Parameters )) {
185+ if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
186+ continue ;
187+ ++IdxNonPred;
188+ if (OrigTy->isVectorTy () != (VFParam.ParamKind == VFParamKind::Vector)) {
189+ LLVM_DEBUG (dbgs () << DEBUG_TYPE
190+ << " : Will not replace: wrong type at index: "
191+ << IdxNonPred << " : " << *OrigTy << " \n " );
192+ return false ;
193+ }
194+ }
195+ }
196+
177197 replaceWithTLIFunction (I, *OptInfo, TLIFunc);
178198 LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Replaced call to `" << ScalarName
179199 << " ` with call to `" << TLIFunc->getName () << " `.\n " );
0 commit comments