diff --git a/llvm/include/llvm/CodeGen/RegAllocCommon.h b/llvm/include/llvm/CodeGen/RegAllocCommon.h index 757ca8e112eec..ad533eab1861c 100644 --- a/llvm/include/llvm/CodeGen/RegAllocCommon.h +++ b/llvm/include/llvm/CodeGen/RegAllocCommon.h @@ -16,16 +16,11 @@ namespace llvm { class TargetRegisterClass; class TargetRegisterInfo; +/// Filter function for register classes during regalloc. Default register class +/// filter is nullptr, where all registers should be allocated. typedef std::function RegClassFilterFunc; - -/// Default register class filter function for register allocation. All virtual -/// registers should be allocated. -static inline bool allocateAllRegClasses(const TargetRegisterInfo &, - const TargetRegisterClass &) { - return true; -} - + const TargetRegisterClass &RC)> + RegClassFilterFunc; } #endif // LLVM_CODEGEN_REGALLOCCOMMON_H diff --git a/llvm/include/llvm/CodeGen/RegAllocFast.h b/llvm/include/llvm/CodeGen/RegAllocFast.h index c50deccabd995..c62bd14d0b4cb 100644 --- a/llvm/include/llvm/CodeGen/RegAllocFast.h +++ b/llvm/include/llvm/CodeGen/RegAllocFast.h @@ -15,7 +15,7 @@ namespace llvm { struct RegAllocFastPassOptions { - RegClassFilterFunc Filter = allocateAllRegClasses; + RegClassFilterFunc Filter = nullptr; StringRef FilterName = "all"; bool ClearVRegs = true; }; diff --git a/llvm/include/llvm/Passes/PassBuilder.h b/llvm/include/llvm/Passes/PassBuilder.h index ed817127c3db1..551d297e0c089 100644 --- a/llvm/include/llvm/Passes/PassBuilder.h +++ b/llvm/include/llvm/Passes/PassBuilder.h @@ -27,6 +27,7 @@ #include "llvm/Transforms/IPO/ModuleInliner.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" +#include #include namespace llvm { @@ -390,7 +391,8 @@ class PassBuilder { Error parseAAPipeline(AAManager &AA, StringRef PipelineText); /// Parse RegClassFilterName to get RegClassFilterFunc. - RegClassFilterFunc parseRegAllocFilter(StringRef RegClassFilterName); + std::optional + parseRegAllocFilter(StringRef RegClassFilterName); /// Print pass names. void printPassNames(raw_ostream &OS); diff --git a/llvm/lib/CodeGen/RegAllocBase.cpp b/llvm/lib/CodeGen/RegAllocBase.cpp index d0dec372f6896..71288469b8f0f 100644 --- a/llvm/lib/CodeGen/RegAllocBase.cpp +++ b/llvm/lib/CodeGen/RegAllocBase.cpp @@ -181,8 +181,7 @@ void RegAllocBase::enqueue(const LiveInterval *LI) { if (VRM->hasPhys(Reg)) return; - const TargetRegisterClass &RC = *MRI->getRegClass(Reg); - if (ShouldAllocateClass(*TRI, RC)) { + if (shouldAllocateRegister(Reg)) { LLVM_DEBUG(dbgs() << "Enqueuing " << printReg(Reg, TRI) << '\n'); enqueueImpl(LI); } else { diff --git a/llvm/lib/CodeGen/RegAllocBase.h b/llvm/lib/CodeGen/RegAllocBase.h index a8bf305a50c98..643094671d682 100644 --- a/llvm/lib/CodeGen/RegAllocBase.h +++ b/llvm/lib/CodeGen/RegAllocBase.h @@ -37,6 +37,7 @@ #define LLVM_LIB_CODEGEN_REGALLOCBASE_H #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/RegAllocCommon.h" #include "llvm/CodeGen/RegisterClassInfo.h" @@ -68,22 +69,32 @@ class RegAllocBase { LiveIntervals *LIS = nullptr; LiveRegMatrix *Matrix = nullptr; RegisterClassInfo RegClassInfo; + +private: + /// Private, callees should go through shouldAllocateRegister const RegClassFilterFunc ShouldAllocateClass; +protected: /// Inst which is a def of an original reg and whose defs are already all /// dead after remat is saved in DeadRemats. The deletion of such inst is /// postponed till all the allocations are done, so its remat expr is /// always available for the remat of all the siblings of the original reg. SmallPtrSet DeadRemats; - RegAllocBase(const RegClassFilterFunc F = allocateAllRegClasses) : - ShouldAllocateClass(F) {} + RegAllocBase(const RegClassFilterFunc F = nullptr) : ShouldAllocateClass(F) {} virtual ~RegAllocBase() = default; // A RegAlloc pass should call this before allocatePhysRegs. void init(VirtRegMap &vrm, LiveIntervals &lis, LiveRegMatrix &mat); + /// Get whether a given register should be allocated + bool shouldAllocateRegister(Register Reg) { + if (!ShouldAllocateClass) + return true; + return ShouldAllocateClass(*TRI, *MRI->getRegClass(Reg)); + } + // The top-level driver. The output is a VirtRegMap that us updated with // physical register assignments. void allocatePhysRegs(); diff --git a/llvm/lib/CodeGen/RegAllocBasic.cpp b/llvm/lib/CodeGen/RegAllocBasic.cpp index 181337ca4d60f..5d84e1e39e27c 100644 --- a/llvm/lib/CodeGen/RegAllocBasic.cpp +++ b/llvm/lib/CodeGen/RegAllocBasic.cpp @@ -74,7 +74,7 @@ class RABasic : public MachineFunctionPass, void LRE_WillShrinkVirtReg(Register) override; public: - RABasic(const RegClassFilterFunc F = allocateAllRegClasses); + RABasic(const RegClassFilterFunc F = nullptr); /// Return the pass name. StringRef getPassName() const override { return "Basic Register Allocator"; } diff --git a/llvm/lib/CodeGen/RegAllocFast.cpp b/llvm/lib/CodeGen/RegAllocFast.cpp index 09ce8c42a3850..dddc004be9293 100644 --- a/llvm/lib/CodeGen/RegAllocFast.cpp +++ b/llvm/lib/CodeGen/RegAllocFast.cpp @@ -177,7 +177,7 @@ class InstrPosIndexes { class RegAllocFastImpl { public: - RegAllocFastImpl(const RegClassFilterFunc F = allocateAllRegClasses, + RegAllocFastImpl(const RegClassFilterFunc F = nullptr, bool ClearVirtRegs_ = true) : ShouldAllocateClass(F), StackSlotForVirtReg(-1), ClearVirtRegs(ClearVirtRegs_) {} @@ -387,8 +387,7 @@ class RegAllocFast : public MachineFunctionPass { public: static char ID; - RegAllocFast(const RegClassFilterFunc F = allocateAllRegClasses, - bool ClearVirtRegs_ = true) + RegAllocFast(const RegClassFilterFunc F = nullptr, bool ClearVirtRegs_ = true) : MachineFunctionPass(ID), Impl(F, ClearVirtRegs_) {} bool runOnMachineFunction(MachineFunction &MF) override { @@ -431,6 +430,8 @@ INITIALIZE_PASS(RegAllocFast, "regallocfast", "Fast Register Allocator", false, bool RegAllocFastImpl::shouldAllocateRegister(const Register Reg) const { assert(Reg.isVirtual()); + if (!ShouldAllocateClass) + return true; const TargetRegisterClass &RC = *MRI->getRegClass(Reg); return ShouldAllocateClass(*TRI, RC); } diff --git a/llvm/lib/CodeGen/RegAllocGreedy.cpp b/llvm/lib/CodeGen/RegAllocGreedy.cpp index 500ceb3d8b700..19c1ee23af858 100644 --- a/llvm/lib/CodeGen/RegAllocGreedy.cpp +++ b/llvm/lib/CodeGen/RegAllocGreedy.cpp @@ -2308,7 +2308,7 @@ void RAGreedy::tryHintRecoloring(const LiveInterval &VirtReg) { // This may be a skipped class if (!VRM->hasPhys(Reg)) { - assert(!ShouldAllocateClass(*TRI, *MRI->getRegClass(Reg)) && + assert(!shouldAllocateRegister(Reg) && "We have an unallocated variable which should have been handled"); continue; } @@ -2698,7 +2698,7 @@ bool RAGreedy::hasVirtRegAlloc() { const TargetRegisterClass *RC = MRI->getRegClass(Reg); if (!RC) continue; - if (ShouldAllocateClass(*TRI, *RC)) + if (shouldAllocateRegister(Reg)) return true; } diff --git a/llvm/lib/CodeGen/RegAllocGreedy.h b/llvm/lib/CodeGen/RegAllocGreedy.h index 06cf0828ea79b..ac300c0024f5a 100644 --- a/llvm/lib/CodeGen/RegAllocGreedy.h +++ b/llvm/lib/CodeGen/RegAllocGreedy.h @@ -281,7 +281,7 @@ class LLVM_LIBRARY_VISIBILITY RAGreedy : public MachineFunctionPass, bool ReverseLocalAssignment = false; public: - RAGreedy(const RegClassFilterFunc F = allocateAllRegClasses); + RAGreedy(const RegClassFilterFunc F = nullptr); /// Return the pass name. StringRef getPassName() const override { return "Greedy Register Allocator"; } diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index 19e8a8ab68a73..b1488f9b86886 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -1173,14 +1173,15 @@ parseRegAllocFastPassOptions(PassBuilder &PB, StringRef Params) { std::tie(ParamName, Params) = Params.split(';'); if (ParamName.consume_front("filter=")) { - RegClassFilterFunc Filter = PB.parseRegAllocFilter(ParamName); + std::optional Filter = + PB.parseRegAllocFilter(ParamName); if (!Filter) { return make_error( formatv("invalid regallocfast register filter '{0}' ", ParamName) .str(), inconvertibleErrorCode()); } - Opts.Filter = Filter; + Opts.Filter = *Filter; Opts.FilterName = ParamName; continue; } @@ -2220,13 +2221,14 @@ Error PassBuilder::parseAAPipeline(AAManager &AA, StringRef PipelineText) { return Error::success(); } -RegClassFilterFunc PassBuilder::parseRegAllocFilter(StringRef FilterName) { +std::optional +PassBuilder::parseRegAllocFilter(StringRef FilterName) { if (FilterName == "all") - return allocateAllRegClasses; + return nullptr; for (auto &C : RegClassFilterParsingCallbacks) if (auto F = C(FilterName)) return F; - return nullptr; + return std::nullopt; } static void printPassName(StringRef PassName, raw_ostream &OS) {