diff --git a/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp b/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp index 5c394e6d6296d..709ff0c3b4e42 100644 --- a/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp +++ b/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp @@ -288,6 +288,102 @@ collectVirtualRegUses(SmallVectorImpl &RegMaskPairs, } } +/// Mostly copy/paste from CodeGen/RegisterPressure.cpp +static LaneBitmask getRegLanes(ArrayRef RegUnits, + Register RegUnit) { + auto I = llvm::find_if(RegUnits, [RegUnit](const RegisterMaskPair Other) { + return Other.RegUnit == RegUnit; + }); + if (I == RegUnits.end()) + return LaneBitmask::getNone(); + return I->LaneMask; +} + +/// Mostly copy/paste from CodeGen/RegisterPressure.cpp +static LaneBitmask getLanesWithProperty( + const LiveIntervals &LIS, const MachineRegisterInfo &MRI, + bool TrackLaneMasks, Register RegUnit, SlotIndex Pos, + LaneBitmask SafeDefault, + function_ref Property) { + if (RegUnit.isVirtual()) { + const LiveInterval &LI = LIS.getInterval(RegUnit); + LaneBitmask Result; + if (TrackLaneMasks && LI.hasSubRanges()) { + for (const LiveInterval::SubRange &SR : LI.subranges()) { + if (Property(SR, Pos)) + Result |= SR.LaneMask; + } + } else if (Property(LI, Pos)) { + Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit) + : LaneBitmask::getAll(); + } + + return Result; + } + + const LiveRange *LR = LIS.getCachedRegUnit(RegUnit); + if (LR == nullptr) + return SafeDefault; + return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone(); +} + +/// Mostly copy/paste from CodeGen/RegisterPressure.cpp +/// Helper to find a vreg use between two indices [PriorUseIdx, NextUseIdx). +/// The query starts with a lane bitmask which gets lanes/bits removed for every +/// use we find. +static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask, + SlotIndex PriorUseIdx, SlotIndex NextUseIdx, + const MachineRegisterInfo &MRI, + const SIRegisterInfo *TRI, + const LiveIntervals *LIS, + bool Upward = false) { + for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) { + if (MO.isUndef()) + continue; + const MachineInstr *MI = MO.getParent(); + SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot(); + bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx) + : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx); + if (InRange) { + unsigned SubRegIdx = MO.getSubReg(); + LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubRegIdx); + LastUseMask &= ~UseMask; + if (LastUseMask.none()) + return LaneBitmask::getNone(); + } + } + return LastUseMask; +} + +/// Mostly copy/paste from CodeGen/RegisterPressure.cpp +static LaneBitmask getLiveLanesAt(const LiveIntervals &LIS, + const MachineRegisterInfo &MRI, + bool TrackLaneMasks, Register RegUnit, + SlotIndex Pos) { + return getLanesWithProperty( + LIS, MRI, TrackLaneMasks, RegUnit, Pos, LaneBitmask::getAll(), + [](const LiveRange &LR, SlotIndex Pos) { return LR.liveAt(Pos); }); +} + +// Copy/paste from RegisterPressure.cpp (RegisterOperands::adjustLaneLiveness) +static void adjustDefLaneLiveness(SmallVectorImpl &Defs, + SlotIndex &Pos, const LiveIntervals &LIS, + const MachineRegisterInfo &MRI) { + for (auto *I = Defs.begin(); I != Defs.end();) { + LaneBitmask LiveAfter = + getLiveLanesAt(LIS, MRI, true, I->RegUnit, Pos.getDeadSlot()); + // If the def is all that is live after the instruction, then in case + // of a subregister def we need a read-undef flag. + LaneBitmask ActualDef = I->LaneMask & LiveAfter; + if (ActualDef.none()) { + I = Defs.erase(I); + } else { + I->LaneMask = ActualDef; + ++I; + } + } +} + /////////////////////////////////////////////////////////////////////////////// // GCNRPTracker @@ -343,17 +439,41 @@ void GCNRPTracker::reset(const MachineInstr &MI, MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); } -//////////////////////////////////////////////////////////////////////////////// -// GCNUpwardRPTracker - -void GCNUpwardRPTracker::reset(const MachineRegisterInfo &MRI_, - const LiveRegSet &LiveRegs_) { +void GCNRPTracker::reset(const MachineRegisterInfo &MRI_, + const LiveRegSet &LiveRegs_) { MRI = &MRI_; LiveRegs = LiveRegs_; LastTrackedMI = nullptr; MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_); } +void GCNRPTracker::bumpDeadDefs(ArrayRef DeadDefs) { + GCNRegPressure TempPressure = CurPressure; + for (const RegisterMaskPair &P : DeadDefs) { + Register Reg = P.RegUnit; + if (!Reg.isVirtual()) + continue; + LaneBitmask LiveMask = LiveRegs[Reg]; + LaneBitmask BumpedMask = LiveMask | P.LaneMask; + CurPressure.inc(Reg, LiveMask, BumpedMask, *MRI); + } + MaxPressure = max(MaxPressure, CurPressure); + CurPressure = TempPressure; +} +/// Mostly copy/paste from CodeGen/RegisterPressure.cpp +LaneBitmask GCNRPTracker::getLastUsedLanes(Register RegUnit, + SlotIndex Pos) const { + return getLanesWithProperty( + LIS, *MRI, true, RegUnit, Pos.getBaseIndex(), LaneBitmask::getNone(), + [](const LiveRange &LR, SlotIndex Pos) { + const LiveRange::Segment *S = LR.getSegmentContaining(Pos); + return S != nullptr && S->end == Pos.getRegSlot(); + }); +} + +//////////////////////////////////////////////////////////////////////////////// +// GCNUpwardRPTracker + void GCNUpwardRPTracker::recede(const MachineInstr &MI) { assert(MRI && "call reset first"); @@ -414,6 +534,49 @@ void GCNUpwardRPTracker::recede(const MachineInstr &MI) { assert(CurPressure == getRegPressure(*MRI, LiveRegs)); } +void GCNUpwardRPTracker::bumpUpwardPressure(const MachineInstr *MI, + const SIRegisterInfo *TRI) { + assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction."); + + SlotIndex SlotIdx = LIS.getInstructionIndex(*MI).getRegSlot(); + + // Account for register pressure similar to RegPressureTracker::recede(). + RegisterOperands RegOpers; + + RegOpers.collect(*MI, *TRI, *MRI, true, /*IgnoreDead=*/true); + assert(RegOpers.DeadDefs.empty()); + adjustDefLaneLiveness(RegOpers.Defs, SlotIdx, LIS, *MRI); + RegOpers.detectDeadDefs(*MI, LIS); + + // Boost max pressure for all dead defs together. + // Since CurrSetPressure and MaxSetPressure + bumpDeadDefs(RegOpers.DeadDefs); + + // Kill liveness at live defs. + for (const RegisterMaskPair &P : RegOpers.Defs) { + Register Reg = P.RegUnit; + if (!Reg.isVirtual()) + continue; + LaneBitmask LiveAfter = LiveRegs[Reg]; + LaneBitmask UseLanes = getRegLanes(RegOpers.Uses, Reg); + LaneBitmask DefLanes = P.LaneMask; + LaneBitmask LiveBefore = (LiveAfter & ~DefLanes) | UseLanes; + + CurPressure.inc(Reg, LiveAfter, LiveAfter & LiveBefore, *MRI); + MaxPressure = max(MaxPressure, CurPressure); + } + // Generate liveness for uses. + for (const RegisterMaskPair &P : RegOpers.Uses) { + Register Reg = P.RegUnit; + if (!Reg.isVirtual()) + continue; + LaneBitmask LiveAfter = LiveRegs[Reg]; + LaneBitmask LiveBefore = LiveAfter | P.LaneMask; + CurPressure.inc(Reg, LiveAfter, LiveBefore, *MRI); + } + MaxPressure = max(MaxPressure, CurPressure); +} + //////////////////////////////////////////////////////////////////////////////// // GCNDownwardRPTracker @@ -430,28 +593,44 @@ bool GCNDownwardRPTracker::reset(const MachineInstr &MI, return true; } -bool GCNDownwardRPTracker::advanceBeforeNext() { +bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI, + bool UseInternalIterator, + LiveIntervals *TheLIS) { assert(MRI && "call reset first"); - if (!LastTrackedMI) - return NextMI == MBBEnd; - - assert(NextMI == MBBEnd || !NextMI->isDebugInstr()); + SlotIndex SI; + const LiveIntervals *CurrLIS; + const MachineInstr *CurrMI; + if (UseInternalIterator) { + if (!LastTrackedMI) + return NextMI == MBBEnd; + + assert(NextMI == MBBEnd || !NextMI->isDebugInstr()); + CurrLIS = &LIS; + CurrMI = LastTrackedMI; + + SI = NextMI == MBBEnd + ? CurrLIS->getInstructionIndex(*LastTrackedMI).getDeadSlot() + : CurrLIS->getInstructionIndex(*NextMI).getBaseIndex(); + } else { //! UseInternalIterator + CurrLIS = TheLIS; + SI = CurrLIS->getInstructionIndex(*MI).getBaseIndex(); + CurrMI = MI; + } - SlotIndex SI = NextMI == MBBEnd - ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot() - : LIS.getInstructionIndex(*NextMI).getBaseIndex(); assert(SI.isValid()); // Remove dead registers or mask bits. SmallSet SeenRegs; - for (auto &MO : LastTrackedMI->operands()) { + for (auto &MO : CurrMI->operands()) { if (!MO.isReg() || !MO.getReg().isVirtual()) continue; if (MO.isUse() && !MO.readsReg()) continue; + if (!UseInternalIterator && MO.isDef()) + continue; if (!SeenRegs.insert(MO.getReg()).second) continue; - const LiveInterval &LI = LIS.getInterval(MO.getReg()); + const LiveInterval &LI = CurrLIS->getInterval(MO.getReg()); if (LI.hasSubRanges()) { auto It = LiveRegs.end(); for (const auto &S : LI.subranges()) { @@ -481,15 +660,22 @@ bool GCNDownwardRPTracker::advanceBeforeNext() { LastTrackedMI = nullptr; - return NextMI == MBBEnd; + return UseInternalIterator && (NextMI == MBBEnd); } -void GCNDownwardRPTracker::advanceToNext() { - LastTrackedMI = &*NextMI++; - NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); +void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI, + bool UseInternalIterator) { + if (UseInternalIterator) { + LastTrackedMI = &*NextMI++; + NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); + } else { + LastTrackedMI = MI; + } + + const MachineInstr *CurrMI = LastTrackedMI; // Add new registers or mask bits. - for (const auto &MO : LastTrackedMI->all_defs()) { + for (const auto &MO : CurrMI->all_defs()) { Register Reg = MO.getReg(); if (!Reg.isVirtual()) continue; @@ -502,11 +688,17 @@ void GCNDownwardRPTracker::advanceToNext() { MaxPressure = max(MaxPressure, CurPressure); } -bool GCNDownwardRPTracker::advance() { - if (NextMI == MBBEnd) +bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator, + LiveIntervals *TheLIS) { + if (UseInternalIterator && NextMI == MBBEnd) return false; - advanceBeforeNext(); - advanceToNext(); + + advanceBeforeNext(MI, UseInternalIterator, TheLIS); + advanceToNext(MI, UseInternalIterator); + if (!UseInternalIterator) { + // We must remove any dead def lanes from the current RP + advanceBeforeNext(MI, true, TheLIS); + } return true; } @@ -548,6 +740,65 @@ Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, }); } +void GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI, + const SIRegisterInfo *TRI) { + assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction."); + + SlotIndex SlotIdx; + SlotIdx = LIS.getInstructionIndex(*MI).getRegSlot(); + + // Account for register pressure similar to RegPressureTracker::recede(). + RegisterOperands RegOpers; + RegOpers.collect(*MI, *TRI, *MRI, true, /*IgnoreDead=*/false); + RegOpers.adjustLaneLiveness(LIS, *MRI, SlotIdx); + + for (const RegisterMaskPair &Use : RegOpers.Uses) { + Register Reg = Use.RegUnit; + if (!Reg.isVirtual()) + continue; + LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx); + if (LastUseMask.none()) + continue; + // The LastUseMask is queried from the liveness information of instruction + // which may be further down the schedule. Some lanes may actually not be + // last uses for the current position. + // FIXME: allow the caller to pass in the list of vreg uses that remain + // to be bottom-scheduled to avoid searching uses at each query. + SlotIndex CurrIdx; + const MachineBasicBlock *MBB = MI->getParent(); + MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward( + LastTrackedMI ? LastTrackedMI : MBB->begin(), MBB->end()); + if (IdxPos == MBB->end()) { + CurrIdx = LIS.getMBBEndIdx(MBB); + } else { + CurrIdx = LIS.getInstructionIndex(*IdxPos).getRegSlot(); + } + + LastUseMask = + findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS); + if (LastUseMask.none()) + continue; + + LaneBitmask LiveMask = LiveRegs[Reg]; + LaneBitmask NewMask = LiveMask & ~LastUseMask; + CurPressure.inc(Reg, LiveMask, NewMask, *MRI); + } + + // Generate liveness for defs. + for (const RegisterMaskPair &Def : RegOpers.Defs) { + Register Reg = Def.RegUnit; + if (!Reg.isVirtual()) + continue; + LaneBitmask LiveMask = LiveRegs[Reg]; + LaneBitmask NewMask = LiveMask | Def.LaneMask; + CurPressure.inc(Reg, LiveMask, NewMask, *MRI); + } + MaxPressure = max(MaxPressure, CurPressure); + + // Boost pressure for all dead defs together. + bumpDeadDefs(RegOpers.DeadDefs); +} + bool GCNUpwardRPTracker::isValid() const { const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); diff --git a/llvm/lib/Target/AMDGPU/GCNRegPressure.h b/llvm/lib/Target/AMDGPU/GCNRegPressure.h index 752f53752fa68..c868a54b7c68d 100644 --- a/llvm/lib/Target/AMDGPU/GCNRegPressure.h +++ b/llvm/lib/Target/AMDGPU/GCNRegPressure.h @@ -19,6 +19,7 @@ #include "GCNSubtarget.h" #include "llvm/CodeGen/LiveIntervals.h" +#include "llvm/CodeGen/RegisterPressure.h" #include namespace llvm { @@ -143,6 +144,9 @@ inline GCNRegPressure operator-(const GCNRegPressure &P1, return Diff; } +/////////////////////////////////////////////////////////////////////////////// +// GCNRPTracker + class GCNRPTracker { public: using LiveRegSet = DenseMap; @@ -159,7 +163,15 @@ class GCNRPTracker { void reset(const MachineInstr &MI, const LiveRegSet *LiveRegsCopy, bool After); + /// Mostly copy/paste from CodeGen/RegisterPressure.cpp + void bumpDeadDefs(ArrayRef DeadDefs); + + LaneBitmask getLastUsedLanes(Register RegUnit, SlotIndex Pos) const; + public: + // reset tracker and set live register set to the specified value. + void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_); + // live regs for the current state const decltype(LiveRegs) &getLiveRegs() const { return LiveRegs; } const MachineInstr *getLastTrackedMI() const { return LastTrackedMI; } @@ -176,34 +188,45 @@ class GCNRPTracker { GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS, const MachineRegisterInfo &MRI); +//////////////////////////////////////////////////////////////////////////////// +// GCNUpwardRPTracker + class GCNUpwardRPTracker : public GCNRPTracker { public: GCNUpwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {} - // reset tracker and set live register set to the specified value. - void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_); + using GCNRPTracker::reset; - // reset tracker at the specified slot index. + /// reset tracker at the specified slot index \p SI. void reset(const MachineRegisterInfo &MRI, SlotIndex SI) { - reset(MRI, llvm::getLiveRegs(SI, LIS, MRI)); + GCNRPTracker::reset(MRI, llvm::getLiveRegs(SI, LIS, MRI)); } - // reset tracker to the end of the MBB. + /// reset tracker to the end of the \p MBB. void reset(const MachineBasicBlock &MBB) { reset(MBB.getParent()->getRegInfo(), LIS.getSlotIndexes()->getMBBEndIdx(&MBB)); } - // reset tracker to the point just after MI (in program order). + /// reset tracker to the point just after \p MI (in program order). void reset(const MachineInstr &MI) { reset(MI.getMF()->getRegInfo(), LIS.getInstructionIndex(MI).getDeadSlot()); } - // move to the state just before the MI (in program order). + /// Move to the state of RP just before the \p MI . If \p UseInternalIterator + /// is set, also update the internal iterators. Setting \p UseInternalIterator + /// to false allows for an externally managed iterator / program order. void recede(const MachineInstr &MI); - // checks whether the tracker's state after receding MI corresponds - // to reported by LIS. + /// Mostly copy/paste from CodeGen/RegisterPressure.cpp + /// Calculate the impact \p MI will have on CurPressure and MaxPressure. This + /// does not rely on the implicit program ordering in the LiveIntervals to + /// support RP Speculation. It leaves the state of pressure inconsistent with + /// the current position + void bumpUpwardPressure(const MachineInstr *MI, const SIRegisterInfo *TRI); + + /// \p returns whether the tracker's state after receding MI corresponds + /// to reported by LIS. bool isValid() const; const GCNRegPressure &getMaxPressure() const { return MaxPressure; } @@ -217,6 +240,9 @@ class GCNUpwardRPTracker : public GCNRPTracker { } }; +//////////////////////////////////////////////////////////////////////////////// +// GCNDownwardRPTracker + class GCNDownwardRPTracker : public GCNRPTracker { // Last position of reset or advanceBeforeNext MachineBasicBlock::const_iterator NextMI; @@ -226,37 +252,67 @@ class GCNDownwardRPTracker : public GCNRPTracker { public: GCNDownwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {} + using GCNRPTracker::reset; + MachineBasicBlock::const_iterator getNext() const { return NextMI; } - // Return MaxPressure and clear it. + /// \p return MaxPressure and clear it. GCNRegPressure moveMaxPressure() { auto Res = MaxPressure; MaxPressure.clear(); return Res; } - // Reset tracker to the point before the MI - // filling live regs upon this point using LIS. - // Returns false if block is empty except debug values. + /// Reset tracker to the point before the \p MI + /// filling \p LiveRegs upon this point using LIS. + /// \p returns false if block is empty except debug values. bool reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr); - // Move to the state right before the next MI or after the end of MBB. - // Returns false if reached end of the block. - bool advanceBeforeNext(); - - // Move to the state at the MI, advanceBeforeNext has to be called first. - void advanceToNext(); - - // Move to the state at the next MI. Returns false if reached end of block. - bool advance(); - - // Advance instructions until before End. + /// Move to the state right before the next MI or after the end of MBB. + /// \p returns false if reached end of the block. + /// If \p UseInternalIterator is true, then internal iterators are used and + /// set to process in program order. If \p UseInternalIterator is false, then + /// it is assumed that the tracker is using an externally managed iterator, + /// and advance* calls will not update the state of the iterator. In such + /// cases, the tracker will move to the state right before the provided \p MI + /// and use the provided \p TheLIS for RP calculations. + bool advanceBeforeNext(MachineInstr *MI = nullptr, + bool UseInternalIterator = true, + LiveIntervals *TheLIS = nullptr); + + /// Move to the state at the MI, advanceBeforeNext has to be called first. + /// If \p UseInternalIterator is true, then internal iterators are used and + /// set to process in program order. If \p UseInternalIterator is false, then + /// it is assumed that the tracker is using an externally managed iterator, + /// and advance* calls will not update the state of the iterator. In such + /// cases, the tracker will move to the state at the provided \p MI . + void advanceToNext(MachineInstr *MI = nullptr, + bool UseInternalIterator = true); + + /// Move to the state at the next MI. \p returns false if reached end of + /// block. If \p UseInternalIterator is true, then internal iterators are used + /// and set to process in program order. If \p UseInternalIterator is false, + /// then it is assumed that the tracker is using an externally managed + /// iterator, and advance* calls will not update the state of the iterator. In + /// such cases, the tracker will move to the state right before the provided + /// \p MI and use the provided \p TheLIS for RP calculations. + bool advance(MachineInstr *MI = nullptr, bool UseInternalIterator = true, + LiveIntervals *TheLIS = nullptr); + + /// Advance instructions until before \p End. bool advance(MachineBasicBlock::const_iterator End); - // Reset to Begin and advance to End. + /// Reset to \p Begin and advance to \p End. bool advance(MachineBasicBlock::const_iterator Begin, MachineBasicBlock::const_iterator End, const LiveRegSet *LiveRegsCopy = nullptr); + + /// Mostly copy/paste from CodeGen/RegisterPressure.cpp + /// Calculate the impact \p MI will have on CurPressure and MaxPressure. This + /// does not rely on the implicit program ordering in the LiveIntervals to + /// support RP Speculation. It leaves the state of pressure inconsistent with + /// the current position + void bumpDownwardPressure(const MachineInstr *MI, const SIRegisterInfo *TRI); }; LaneBitmask getLiveLaneMask(unsigned Reg,