@@ -458,7 +458,6 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
458458 using Base = PtrUseVisitor<ArgUseChecker>;
459459
460460 bool IsGridConstant;
461- SmallPtrSet<Value *, 16 > AllArgUsers;
462461 // Set of phi/select instructions using the Arg
463462 SmallPtrSet<Instruction *, 4 > Conditionals;
464463
@@ -471,13 +470,11 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
471470 IsOffsetKnown = false ;
472471 Offset = APInt (IntIdxTy->getBitWidth (), 0 );
473472 PI.reset ();
474- AllArgUsers.clear ();
475473 Conditionals.clear ();
476474
477475 LLVM_DEBUG (dbgs () << " Checking Argument " << A << " \n " );
478476 // Enqueue the uses of this pointer.
479477 enqueueUsers (A);
480- AllArgUsers.insert (&A);
481478
482479 // Visit all the uses off the worklist until it is empty.
483480 // Note that unlike PtrUseVisitor we're intentionally do not track offset.
@@ -486,7 +483,6 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
486483 UseToVisit ToVisit = Worklist.pop_back_val ();
487484 U = ToVisit.UseAndIsOffsetKnown .getPointer ();
488485 Instruction *I = cast<Instruction>(U->getUser ());
489- AllArgUsers.insert (I);
490486 if (isa<PHINode>(I) || isa<SelectInst>(I))
491487 Conditionals.insert (I);
492488 LLVM_DEBUG (dbgs () << " Processing " << *I << " \n " );
@@ -498,8 +494,8 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
498494 else if (PI.isAborted ())
499495 LLVM_DEBUG (dbgs () << " Pointer use needs a copy: " << *PI.getAbortingInst ()
500496 << " \n " );
501- LLVM_DEBUG (dbgs () << " Traversed " << AllArgUsers .size () << " with "
502- << Conditionals. size () << " conditionals\n " );
497+ LLVM_DEBUG (dbgs () << " Traversed " << Conditionals .size ()
498+ << " conditionals\n " );
503499 return PI;
504500 }
505501
@@ -535,25 +531,17 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
535531 void visitMemTransferInst (MemTransferInst &II) {
536532 if (*U == II.getRawDest () && !IsGridConstant)
537533 PI.setAborted (&II);
538-
539- // TODO: memcpy from arg is OK as it can get unrolled into ld.param.
540- // However, memcpys are currently expected to be unrolled before we
541- // get here, so we never see them in practice, and we do not currently
542- // handle them when we convert IR to access param space directly. So,
543- // we'll mark it as an escape for now. It would still force a copy on
544- // pre-sm_70 GPUs where we can't take address of a parameter w/o a copy.
545- //
546- // PI.setEscaped(&II);
534+ // memcpy/memmove are OK when the pointer is source. We can convert them to
535+ // AS-specific memcpy.
547536 }
548537
549538 void visitMemSetInst (MemSetInst &II) {
550- if (*U == II. getRawDest () && !IsGridConstant)
539+ if (!IsGridConstant)
551540 PI.setAborted (&II);
552541 }
553- // debug only helper.
554- auto &getVisitedUses () { return VisitedUses; }
555- };
542+ }; // struct ArgUseChecker
556543} // namespace
544+
557545void NVPTXLowerArgs::handleByValParam (const NVPTXTargetMachine &TM,
558546 Argument *Arg) {
559547 Function *Func = Arg->getParent ();
@@ -566,8 +554,9 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
566554
567555 ArgUseChecker AUC (DL, IsGridConstant);
568556 ArgUseChecker::PtrInfo PI = AUC.visitArgPtr (*Arg);
557+ bool ArgUseIsReadOnly = !(PI.isEscaped () || PI.isAborted ());
569558 // Easy case, accessing parameter directly is fine.
570- if (!(PI. isEscaped () || PI. isAborted ()) && AUC.Conditionals .empty ()) {
559+ if (ArgUseIsReadOnly && AUC.Conditionals .empty ()) {
571560 // Convert all loads and intermediate operations to use parameter AS and
572561 // skip creation of a local copy of the argument.
573562 SmallVector<Use *, 16 > UsesToUpdate;
@@ -595,7 +584,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
595584 // `__grid_constant__` for the argument, we'll consider escaped pointer as
596585 // read-only.
597586 unsigned AS = DL.getAllocaAddrSpace ();
598- if (HasCvtaParam && (!(PI. isEscaped () || PI. isAborted ()) || IsGridConstant)) {
587+ if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
599588 LLVM_DEBUG (dbgs () << " Using non-copy pointer to " << *Arg << " \n " );
600589 // Replace all argument pointer uses (which might include a device function
601590 // call) with a cast to the generic address space using cvta.param
0 commit comments