139139#include " NVPTX.h"
140140#include " NVPTXTargetMachine.h"
141141#include " NVPTXUtilities.h"
142+ #include " llvm/ADT/STLExtras.h"
143+ #include " llvm/Analysis/PtrUseVisitor.h"
142144#include " llvm/Analysis/ValueTracking.h"
143145#include " llvm/CodeGen/TargetPassConfig.h"
144146#include " llvm/IR/Function.h"
145147#include " llvm/IR/IRBuilder.h"
146148#include " llvm/IR/Instructions.h"
149+ #include " llvm/IR/IntrinsicInst.h"
147150#include " llvm/IR/IntrinsicsNVPTX.h"
148151#include " llvm/IR/Module.h"
149152#include " llvm/IR/Type.h"
150153#include " llvm/InitializePasses.h"
151154#include " llvm/Pass.h"
155+ #include " llvm/Support/Debug.h"
156+ #include " llvm/Support/ErrorHandling.h"
152157#include < numeric>
153158#include < queue>
154159
@@ -217,7 +222,8 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
217222// pointer in parameter AS.
218223// For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
219224// generic using cvta.param.
220- static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
225+ static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
226+ bool IsGridConstant) {
221227 Instruction *I = dyn_cast<Instruction>(OldUse->getUser ());
222228 assert (I && " OldUse must be in an instruction" );
223229 struct IP {
@@ -228,7 +234,8 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
228234 SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
229235 SmallVector<Instruction *> InstructionsToDelete;
230236
231- auto CloneInstInParamAS = [GridConstant](const IP &I) -> Value * {
237+ auto CloneInstInParamAS = [HasCvtaParam,
238+ IsGridConstant](const IP &I) -> Value * {
232239 if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction )) {
233240 LI->setOperand (0 , I.NewParam );
234241 return LI;
@@ -252,8 +259,25 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
252259 // Just pass through the argument, the old ASC is no longer needed.
253260 return I.NewParam ;
254261 }
262+ if (auto *MI = dyn_cast<MemTransferInst>(I.OldInstruction )) {
263+ if (MI->getRawSource () == I.OldUse ->get ()) {
264+ // convert to memcpy/memmove from param space.
265+ IRBuilder<> Builder (I.OldInstruction );
266+ Intrinsic::ID ID = MI->getIntrinsicID ();
267+
268+ CallInst *B = Builder.CreateMemTransferInst (
269+ ID, MI->getRawDest (), MI->getDestAlign (), I.NewParam ,
270+ MI->getSourceAlign (), MI->getLength (), MI->isVolatile ());
271+ for (unsigned I : {0 , 1 })
272+ if (uint64_t Bytes = MI->getParamDereferenceableBytes (I))
273+ B->addDereferenceableParamAttr (I, Bytes);
274+ return B;
275+ }
276+ // We may be able to handle other cases if the argument is
277+ // __grid_constant__
278+ }
255279
256- if (GridConstant ) {
280+ if (HasCvtaParam ) {
257281 auto GetParamAddrCastToGeneric =
258282 [](Value *Addr, Instruction *OriginalUser) -> Value * {
259283 PointerType *ReturnTy =
@@ -269,24 +293,44 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
269293 OriginalUser->getIterator ());
270294 return CvtToGenCall;
271295 };
272-
273- if (auto *CI = dyn_cast<CallInst>(I.OldInstruction )) {
274- I.OldUse ->set (GetParamAddrCastToGeneric (I.NewParam , CI));
275- return CI;
296+ auto *ParamInGenericAS =
297+ GetParamAddrCastToGeneric (I.NewParam , I.OldInstruction );
298+
299+ // phi/select could use generic arg pointers w/o __grid_constant__
300+ if (auto *PHI = dyn_cast<PHINode>(I.OldInstruction )) {
301+ for (auto [Idx, V] : enumerate(PHI->incoming_values ())) {
302+ if (V.get () == I.OldUse ->get ())
303+ PHI->setIncomingValue (Idx, ParamInGenericAS);
304+ }
276305 }
277- if (auto *SI = dyn_cast<StoreInst >(I.OldInstruction )) {
278- // byval address is being stored, cast it to generic
279- if ( SI->getValueOperand () == I. OldUse -> get ())
280- SI->setOperand ( 0 , GetParamAddrCastToGeneric (I. NewParam , SI));
281- return SI ;
306+ if (auto *SI = dyn_cast<SelectInst >(I.OldInstruction )) {
307+ if (SI-> getTrueValue () == I. OldUse -> get ())
308+ SI->setTrueValue (ParamInGenericAS);
309+ if ( SI->getFalseValue () == I. OldUse -> get ())
310+ SI-> setFalseValue (ParamInGenericAS) ;
282311 }
283- if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction )) {
284- if (PI->getPointerOperand () == I.OldUse ->get ())
285- PI->setOperand (0 , GetParamAddrCastToGeneric (I.NewParam , PI));
286- return PI;
312+
313+ // Escapes or writes can only use generic param pointers if
314+ // __grid_constant__ is in effect.
315+ if (IsGridConstant) {
316+ if (auto *CI = dyn_cast<CallInst>(I.OldInstruction )) {
317+ I.OldUse ->set (ParamInGenericAS);
318+ return CI;
319+ }
320+ if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction )) {
321+ // byval address is being stored, cast it to generic
322+ if (SI->getValueOperand () == I.OldUse ->get ())
323+ SI->setOperand (0 , ParamInGenericAS);
324+ return SI;
325+ }
326+ if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction )) {
327+ if (PI->getPointerOperand () == I.OldUse ->get ())
328+ PI->setOperand (0 , ParamInGenericAS);
329+ return PI;
330+ }
331+ // TODO: iIf we allow stores, we should allow memcpy/memset to
332+ // parameter, too.
287333 }
288- llvm_unreachable (
289- " Instruction unsupported even for grid_constant argument" );
290334 }
291335
292336 llvm_unreachable (" Unsupported instruction" );
@@ -409,49 +453,110 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
409453 }
410454}
411455
456+ namespace {
457+ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
458+ using Base = PtrUseVisitor<ArgUseChecker>;
459+
460+ bool IsGridConstant;
461+ // Set of phi/select instructions using the Arg
462+ SmallPtrSet<Instruction *, 4 > Conditionals;
463+
464+ ArgUseChecker (const DataLayout &DL, bool IsGridConstant)
465+ : PtrUseVisitor(DL), IsGridConstant(IsGridConstant) {}
466+
467+ PtrInfo visitArgPtr (Argument &A) {
468+ assert (A.getType ()->isPointerTy ());
469+ IntegerType *IntIdxTy = cast<IntegerType>(DL.getIndexType (A.getType ()));
470+ IsOffsetKnown = false ;
471+ Offset = APInt (IntIdxTy->getBitWidth (), 0 );
472+ PI.reset ();
473+ Conditionals.clear ();
474+
475+ LLVM_DEBUG (dbgs () << " Checking Argument " << A << " \n " );
476+ // Enqueue the uses of this pointer.
477+ enqueueUsers (A);
478+
479+ // Visit all the uses off the worklist until it is empty.
480+ // Note that unlike PtrUseVisitor we intentionally do not track offsets.
481+ // We're only interested in how we use the pointer.
482+ while (!(Worklist.empty () || PI.isAborted ())) {
483+ UseToVisit ToVisit = Worklist.pop_back_val ();
484+ U = ToVisit.UseAndIsOffsetKnown .getPointer ();
485+ Instruction *I = cast<Instruction>(U->getUser ());
486+ if (isa<PHINode>(I) || isa<SelectInst>(I))
487+ Conditionals.insert (I);
488+ LLVM_DEBUG (dbgs () << " Processing " << *I << " \n " );
489+ Base::visit (I);
490+ }
491+ if (PI.isEscaped ())
492+ LLVM_DEBUG (dbgs () << " Argument pointer escaped: " << *PI.getEscapingInst ()
493+ << " \n " );
494+ else if (PI.isAborted ())
495+ LLVM_DEBUG (dbgs () << " Pointer use needs a copy: " << *PI.getAbortingInst ()
496+ << " \n " );
497+ LLVM_DEBUG (dbgs () << " Traversed " << Conditionals.size ()
498+ << " conditionals\n " );
499+ return PI;
500+ }
501+
502+ void visitStoreInst (StoreInst &SI) {
503+ // Storing the pointer escapes it.
504+ if (U->get () == SI.getValueOperand ())
505+ return PI.setEscapedAndAborted (&SI);
506+ // Writes to the pointer are UB w/ __grid_constant__, but do not force a
507+ // copy.
508+ if (!IsGridConstant)
509+ return PI.setAborted (&SI);
510+ }
511+
512+ void visitAddrSpaceCastInst (AddrSpaceCastInst &ASC) {
513+ // ASC to param space are no-ops and do not need a copy
514+ if (ASC.getDestAddressSpace () != ADDRESS_SPACE_PARAM)
515+ return PI.setEscapedAndAborted (&ASC);
516+ Base::visitAddrSpaceCastInst (ASC);
517+ }
518+
519+ void visitPtrToIntInst (PtrToIntInst &I) {
520+ if (IsGridConstant)
521+ return ;
522+ Base::visitPtrToIntInst (I);
523+ }
524+ void visitPHINodeOrSelectInst (Instruction &I) {
525+ assert (isa<PHINode>(I) || isa<SelectInst>(I));
526+ }
527+ // PHI and select just pass through the pointers.
528+ void visitPHINode (PHINode &PN) { enqueueUsers (PN); }
529+ void visitSelectInst (SelectInst &SI) { enqueueUsers (SI); }
530+
531+ void visitMemTransferInst (MemTransferInst &II) {
532+ if (*U == II.getRawDest () && !IsGridConstant)
533+ PI.setAborted (&II);
534+ // memcpy/memmove are OK when the pointer is source. We can convert them to
535+ // AS-specific memcpy.
536+ }
537+
538+ void visitMemSetInst (MemSetInst &II) {
539+ if (!IsGridConstant)
540+ PI.setAborted (&II);
541+ }
542+ }; // struct ArgUseChecker
543+ } // namespace
544+
412545void NVPTXLowerArgs::handleByValParam (const NVPTXTargetMachine &TM,
413546 Argument *Arg) {
414- bool IsGridConstant = isParamGridConstant (*Arg);
415547 Function *Func = Arg->getParent ();
548+ bool HasCvtaParam = TM.getSubtargetImpl (*Func)->hasCvtaParam ();
549+ bool IsGridConstant = HasCvtaParam && isParamGridConstant (*Arg);
550+ const DataLayout &DL = Func->getDataLayout ();
416551 BasicBlock::iterator FirstInst = Func->getEntryBlock ().begin ();
417552 Type *StructType = Arg->getParamByValType ();
418553 assert (StructType && " Missing byval type" );
419554
420- auto AreSupportedUsers = [&](Value *Start) {
421- SmallVector<Value *, 16 > ValuesToCheck = {Start};
422- auto IsSupportedUse = [IsGridConstant](Value *V) -> bool {
423- if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
424- return true ;
425- // ASC to param space are OK, too -- we'll just strip them.
426- if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
427- if (ASC->getDestAddressSpace () == ADDRESS_SPACE_PARAM)
428- return true ;
429- }
430- // Simple calls and stores are supported for grid_constants
431- // writes to these pointers are undefined behaviour
432- if (IsGridConstant &&
433- (isa<CallInst>(V) || isa<StoreInst>(V) || isa<PtrToIntInst>(V)))
434- return true ;
435- return false ;
436- };
437-
438- while (!ValuesToCheck.empty ()) {
439- Value *V = ValuesToCheck.pop_back_val ();
440- if (!IsSupportedUse (V)) {
441- LLVM_DEBUG (dbgs () << " Need a "
442- << (isParamGridConstant (*Arg) ? " cast " : " copy " )
443- << " of " << *Arg << " because of " << *V << " \n " );
444- (void )Arg;
445- return false ;
446- }
447- if (!isa<LoadInst>(V) && !isa<CallInst>(V) && !isa<StoreInst>(V) &&
448- !isa<PtrToIntInst>(V))
449- llvm::append_range (ValuesToCheck, V->users ());
450- }
451- return true ;
452- };
453-
454- if (llvm::all_of (Arg->users (), AreSupportedUsers)) {
555+ ArgUseChecker AUC (DL, IsGridConstant);
556+ ArgUseChecker::PtrInfo PI = AUC.visitArgPtr (*Arg);
557+ bool ArgUseIsReadOnly = !(PI.isEscaped () || PI.isAborted ());
558+ // Easy case, accessing parameter directly is fine.
559+ if (ArgUseIsReadOnly && AUC.Conditionals .empty ()) {
455560 // Convert all loads and intermediate operations to use parameter AS and
456561 // skip creation of a local copy of the argument.
457562 SmallVector<Use *, 16 > UsesToUpdate;
@@ -462,7 +567,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
462567 Arg, PointerType::get (StructType, ADDRESS_SPACE_PARAM), Arg->getName (),
463568 FirstInst);
464569 for (Use *U : UsesToUpdate)
465- convertToParamAS (U, ArgInParamAS, IsGridConstant);
570+ convertToParamAS (U, ArgInParamAS, HasCvtaParam, IsGridConstant);
466571 LLVM_DEBUG (dbgs () << " No need to copy or cast " << *Arg << " \n " );
467572
468573 const auto *TLI =
@@ -473,13 +578,17 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
473578 return ;
474579 }
475580
476- const DataLayout &DL = Func->getDataLayout ();
581+ // We can't access byval arg directly and need a pointer. on sm_70+ we have
582+ // ability to take a pointer to the argument without making a local copy.
583+ // However, we're still not allowed to write to it. If the user specified
584+ // `__grid_constant__` for the argument, we'll consider escaped pointer as
585+ // read-only.
477586 unsigned AS = DL.getAllocaAddrSpace ();
478- if (isParamGridConstant (*Arg )) {
479- // Writes to a grid constant are undefined behaviour. We do not need a
480- // temporary copy. When a pointer might have escaped, conservatively replace
481- // all of its uses (which might include a device function call) with a cast
482- // to the generic address space .
587+ if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant )) {
588+ LLVM_DEBUG ( dbgs () << " Using non-copy pointer to " << *Arg << " \n " );
589+ // Replace all argument pointer uses (which might include a device function
590+ // call) with a cast to the generic address space using cvta.param
591+ // instruction, which avoids a local copy .
483592 IRBuilder<> IRB (&Func->getEntryBlock ().front ());
484593
485594 // Cast argument to param address space
@@ -500,6 +609,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
500609 // Do not replace Arg in the cast to param space
501610 CastToParam->setOperand (0 , Arg);
502611 } else {
612+ LLVM_DEBUG (dbgs () << " Creating a local copy of " << *Arg << " \n " );
503613 // Otherwise we have to create a temporary copy.
504614 AllocaInst *AllocA =
505615 new AllocaInst (StructType, AS, Arg->getName (), FirstInst);
0 commit comments