@@ -296,6 +296,63 @@ collectVirtualRegUses(SmallVectorImpl<RegisterMaskPair> &RegMaskPairs,
296296 }
297297}
298298
299+ // / Mostly copy/paste from CodeGen/RegisterPressure.cpp
300+ static LaneBitmask getLanesWithProperty (
301+ const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
302+ bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
303+ LaneBitmask SafeDefault,
304+ function_ref<bool (const LiveRange &LR, SlotIndex Pos)> Property) {
305+ if (RegUnit.isVirtual ()) {
306+ const LiveInterval &LI = LIS.getInterval (RegUnit);
307+ LaneBitmask Result;
308+ if (TrackLaneMasks && LI.hasSubRanges ()) {
309+ for (const LiveInterval::SubRange &SR : LI.subranges ()) {
310+ if (Property (SR, Pos))
311+ Result |= SR.LaneMask ;
312+ }
313+ } else if (Property (LI, Pos)) {
314+ Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg (RegUnit)
315+ : LaneBitmask::getAll ();
316+ }
317+
318+ return Result;
319+ }
320+
321+ const LiveRange *LR = LIS.getCachedRegUnit (RegUnit);
322+ if (LR == nullptr )
323+ return SafeDefault;
324+ return Property (*LR, Pos) ? LaneBitmask::getAll () : LaneBitmask::getNone ();
325+ }
326+
327+ // / Mostly copy/paste from CodeGen/RegisterPressure.cpp
328+ // / Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}.
329+ // / The query starts with a lane bitmask which gets lanes/bits removed for every
330+ // / use we find.
331+ static LaneBitmask findUseBetween (unsigned Reg, LaneBitmask LastUseMask,
332+ SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
333+ const MachineRegisterInfo &MRI,
334+ const SIRegisterInfo *TRI,
335+ const LiveIntervals *LIS,
336+ bool Upward = false ) {
337+ for (const MachineOperand &MO : MRI.use_nodbg_operands (Reg)) {
338+ if (MO.isUndef ())
339+ continue ;
340+ const MachineInstr *MI = MO.getParent ();
341+ SlotIndex InstSlot = LIS->getInstructionIndex (*MI).getRegSlot ();
342+ bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
343+ : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
344+ if (!InRange)
345+ continue ;
346+
347+ unsigned SubRegIdx = MO.getSubReg ();
348+ LaneBitmask UseMask = TRI->getSubRegIndexLaneMask (SubRegIdx);
349+ LastUseMask &= ~UseMask;
350+ if (LastUseMask.none ())
351+ return LaneBitmask::getNone ();
352+ }
353+ return LastUseMask;
354+ }
355+
299356// /////////////////////////////////////////////////////////////////////////////
300357// GCNRPTracker
301358
@@ -354,17 +411,28 @@ void GCNRPTracker::reset(const MachineInstr &MI,
354411 MaxPressure = CurPressure = getRegPressure (*MRI, LiveRegs);
355412}
356413
357- // //////////////////////////////////////////////////////////////////////////////
358- // GCNUpwardRPTracker
359-
360- void GCNUpwardRPTracker::reset (const MachineRegisterInfo &MRI_,
361- const LiveRegSet &LiveRegs_) {
414+ void GCNRPTracker::reset (const MachineRegisterInfo &MRI_,
415+ const LiveRegSet &LiveRegs_) {
362416 MRI = &MRI_;
363417 LiveRegs = LiveRegs_;
364418 LastTrackedMI = nullptr ;
365419 MaxPressure = CurPressure = getRegPressure (MRI_, LiveRegs_);
366420}
367421
422+ // / Mostly copy/paste from CodeGen/RegisterPressure.cpp
423+ LaneBitmask GCNRPTracker::getLastUsedLanes (Register RegUnit,
424+ SlotIndex Pos) const {
425+ return getLanesWithProperty (
426+ LIS, *MRI, true , RegUnit, Pos.getBaseIndex (), LaneBitmask::getNone (),
427+ [](const LiveRange &LR, SlotIndex Pos) {
428+ const LiveRange::Segment *S = LR.getSegmentContaining (Pos);
429+ return S != nullptr && S->end == Pos.getRegSlot ();
430+ });
431+ }
432+
433+ // //////////////////////////////////////////////////////////////////////////////
434+ // GCNUpwardRPTracker
435+
368436void GCNUpwardRPTracker::recede (const MachineInstr &MI) {
369437 assert (MRI && " call reset first" );
370438
@@ -441,25 +509,37 @@ bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
441509 return true ;
442510}
443511
444- bool GCNDownwardRPTracker::advanceBeforeNext () {
512+ bool GCNDownwardRPTracker::advanceBeforeNext (MachineInstr *MI,
513+ bool UseInternalIterator) {
445514 assert (MRI && " call reset first" );
446- if (!LastTrackedMI)
447- return NextMI == MBBEnd;
448-
449- assert (NextMI == MBBEnd || !NextMI->isDebugInstr ());
515+ SlotIndex SI;
516+ const MachineInstr *CurrMI;
517+ if (UseInternalIterator) {
518+ if (!LastTrackedMI)
519+ return NextMI == MBBEnd;
520+
521+ assert (NextMI == MBBEnd || !NextMI->isDebugInstr ());
522+ CurrMI = LastTrackedMI;
523+
524+ SI = NextMI == MBBEnd
525+ ? LIS.getInstructionIndex (*LastTrackedMI).getDeadSlot ()
526+ : LIS.getInstructionIndex (*NextMI).getBaseIndex ();
527+ } else { // ! UseInternalIterator
528+ SI = LIS.getInstructionIndex (*MI).getBaseIndex ();
529+ CurrMI = MI;
530+ }
450531
451- SlotIndex SI = NextMI == MBBEnd
452- ? LIS.getInstructionIndex (*LastTrackedMI).getDeadSlot ()
453- : LIS.getInstructionIndex (*NextMI).getBaseIndex ();
454532 assert (SI.isValid ());
455533
456534 // Remove dead registers or mask bits.
457535 SmallSet<Register, 8 > SeenRegs;
458- for (auto &MO : LastTrackedMI ->operands ()) {
536+ for (auto &MO : CurrMI ->operands ()) {
459537 if (!MO.isReg () || !MO.getReg ().isVirtual ())
460538 continue ;
461539 if (MO.isUse () && !MO.readsReg ())
462540 continue ;
541+ if (!UseInternalIterator && MO.isDef ())
542+ continue ;
463543 if (!SeenRegs.insert (MO.getReg ()).second )
464544 continue ;
465545 const LiveInterval &LI = LIS.getInterval (MO.getReg ());
@@ -492,15 +572,22 @@ bool GCNDownwardRPTracker::advanceBeforeNext() {
492572
493573 LastTrackedMI = nullptr ;
494574
495- return NextMI == MBBEnd;
575+ return UseInternalIterator && ( NextMI == MBBEnd) ;
496576}
497577
498- void GCNDownwardRPTracker::advanceToNext () {
499- LastTrackedMI = &*NextMI++;
500- NextMI = skipDebugInstructionsForward (NextMI, MBBEnd);
578+ void GCNDownwardRPTracker::advanceToNext (MachineInstr *MI,
579+ bool UseInternalIterator) {
580+ if (UseInternalIterator) {
581+ LastTrackedMI = &*NextMI++;
582+ NextMI = skipDebugInstructionsForward (NextMI, MBBEnd);
583+ } else {
584+ LastTrackedMI = MI;
585+ }
586+
587+ const MachineInstr *CurrMI = LastTrackedMI;
501588
502589 // Add new registers or mask bits.
503- for (const auto &MO : LastTrackedMI ->all_defs ()) {
590+ for (const auto &MO : CurrMI ->all_defs ()) {
504591 Register Reg = MO.getReg ();
505592 if (!Reg.isVirtual ())
506593 continue ;
@@ -513,11 +600,16 @@ void GCNDownwardRPTracker::advanceToNext() {
513600 MaxPressure = max (MaxPressure, CurPressure);
514601}
515602
516- bool GCNDownwardRPTracker::advance () {
517- if (NextMI == MBBEnd)
603+ bool GCNDownwardRPTracker::advance (MachineInstr *MI, bool UseInternalIterator ) {
604+ if (UseInternalIterator && NextMI == MBBEnd)
518605 return false ;
519- advanceBeforeNext ();
520- advanceToNext ();
606+
607+ advanceBeforeNext (MI, UseInternalIterator);
608+ advanceToNext (MI, UseInternalIterator);
609+ if (!UseInternalIterator) {
610+ // We must remove any dead def lanes from the current RP
611+ advanceBeforeNext (MI, true );
612+ }
521613 return true ;
522614}
523615
@@ -559,6 +651,67 @@ Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
559651 });
560652}
561653
654+ GCNRegPressure
655+ GCNDownwardRPTracker::bumpDownwardPressure (const MachineInstr *MI,
656+ const SIRegisterInfo *TRI) const {
657+ assert (!MI->isDebugOrPseudoInstr () && " Expect a nondebug instruction." );
658+
659+ SlotIndex SlotIdx;
660+ SlotIdx = LIS.getInstructionIndex (*MI).getRegSlot ();
661+
662+ // Account for register pressure similar to RegPressureTracker::recede().
663+ RegisterOperands RegOpers;
664+ RegOpers.collect (*MI, *TRI, *MRI, true , /* IgnoreDead=*/ false );
665+ RegOpers.adjustLaneLiveness (LIS, *MRI, SlotIdx);
666+ GCNRegPressure TempPressure = CurPressure;
667+
668+ for (const RegisterMaskPair &Use : RegOpers.Uses ) {
669+ Register Reg = Use.RegUnit ;
670+ if (!Reg.isVirtual ())
671+ continue ;
672+ LaneBitmask LastUseMask = getLastUsedLanes (Reg, SlotIdx);
673+ if (LastUseMask.none ())
674+ continue ;
675+ // The LastUseMask is queried from the liveness information of instruction
676+ // which may be further down the schedule. Some lanes may actually not be
677+ // last uses for the current position.
678+ // FIXME: allow the caller to pass in the list of vreg uses that remain
679+ // to be bottom-scheduled to avoid searching uses at each query.
680+ SlotIndex CurrIdx;
681+ const MachineBasicBlock *MBB = MI->getParent ();
682+ MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward (
683+ LastTrackedMI ? LastTrackedMI : MBB->begin (), MBB->end ());
684+ if (IdxPos == MBB->end ()) {
685+ CurrIdx = LIS.getMBBEndIdx (MBB);
686+ } else {
687+ CurrIdx = LIS.getInstructionIndex (*IdxPos).getRegSlot ();
688+ }
689+
690+ LastUseMask =
691+ findUseBetween (Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS);
692+ if (LastUseMask.none ())
693+ continue ;
694+
695+ LaneBitmask LiveMask =
696+ LiveRegs.contains (Reg) ? LiveRegs.at (Reg) : LaneBitmask (0 );
697+ LaneBitmask NewMask = LiveMask & ~LastUseMask;
698+ TempPressure.inc (Reg, LiveMask, NewMask, *MRI);
699+ }
700+
701+ // Generate liveness for defs.
702+ for (const RegisterMaskPair &Def : RegOpers.Defs ) {
703+ Register Reg = Def.RegUnit ;
704+ if (!Reg.isVirtual ())
705+ continue ;
706+ LaneBitmask LiveMask =
707+ LiveRegs.contains (Reg) ? LiveRegs.at (Reg) : LaneBitmask (0 );
708+ LaneBitmask NewMask = LiveMask | Def.LaneMask ;
709+ TempPressure.inc (Reg, LiveMask, NewMask, *MRI);
710+ }
711+
712+ return TempPressure;
713+ }
714+
562715bool GCNUpwardRPTracker::isValid () const {
563716 const auto &SI = LIS.getInstructionIndex (*LastTrackedMI).getBaseIndex ();
564717 const auto LISLR = llvm::getLiveRegs (SI, LIS, *MRI);
0 commit comments