Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions llvm/include/llvm/CodeGen/RegAllocCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,11 @@ namespace llvm {
class TargetRegisterClass;
class TargetRegisterInfo;

/// Filter function for register classes during regalloc. Default register class
/// filter is nullptr, where all registers should be allocated.
typedef std::function<bool(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC)> RegClassFilterFunc;

/// Default register class filter function for register allocation. All virtual
/// registers should be allocated.
static inline bool allocateAllRegClasses(const TargetRegisterInfo &,
const TargetRegisterClass &) {
return true;
}

const TargetRegisterClass &RC)>
RegClassFilterFunc;
}

#endif // LLVM_CODEGEN_REGALLOCCOMMON_H
2 changes: 1 addition & 1 deletion llvm/include/llvm/CodeGen/RegAllocFast.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
namespace llvm {

struct RegAllocFastPassOptions {
RegClassFilterFunc Filter = allocateAllRegClasses;
RegClassFilterFunc Filter = nullptr;
StringRef FilterName = "all";
bool ClearVRegs = true;
};
Expand Down
4 changes: 3 additions & 1 deletion llvm/include/llvm/Passes/PassBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "llvm/Transforms/IPO/ModuleInliner.h"
#include "llvm/Transforms/Instrumentation.h"
#include "llvm/Transforms/Scalar/LoopPassManager.h"
#include <optional>
#include <vector>

namespace llvm {
Expand Down Expand Up @@ -390,7 +391,8 @@ class PassBuilder {
Error parseAAPipeline(AAManager &AA, StringRef PipelineText);

/// Parse RegClassFilterName to get RegClassFilterFunc.
RegClassFilterFunc parseRegAllocFilter(StringRef RegClassFilterName);
std::optional<RegClassFilterFunc>
parseRegAllocFilter(StringRef RegClassFilterName);

/// Print pass names.
void printPassNames(raw_ostream &OS);
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/CodeGen/RegAllocBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@ void RegAllocBase::enqueue(const LiveInterval *LI) {
if (VRM->hasPhys(Reg))
return;

const TargetRegisterClass &RC = *MRI->getRegClass(Reg);
if (ShouldAllocateClass(*TRI, RC)) {
if (shouldAllocateRegister(Reg)) {
LLVM_DEBUG(dbgs() << "Enqueuing " << printReg(Reg, TRI) << '\n');
enqueueImpl(LI);
} else {
Expand Down
15 changes: 13 additions & 2 deletions llvm/lib/CodeGen/RegAllocBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#define LLVM_LIB_CODEGEN_REGALLOCBASE_H

#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/RegAllocCommon.h"
#include "llvm/CodeGen/RegisterClassInfo.h"

Expand Down Expand Up @@ -68,22 +69,32 @@ class RegAllocBase {
LiveIntervals *LIS = nullptr;
LiveRegMatrix *Matrix = nullptr;
RegisterClassInfo RegClassInfo;

private:
/// Private, callees should go through shouldAllocateRegister
const RegClassFilterFunc ShouldAllocateClass;

protected:
/// Inst which is a def of an original reg and whose defs are already all
/// dead after remat is saved in DeadRemats. The deletion of such inst is
/// postponed till all the allocations are done, so its remat expr is
/// always available for the remat of all the siblings of the original reg.
SmallPtrSet<MachineInstr *, 32> DeadRemats;

RegAllocBase(const RegClassFilterFunc F = allocateAllRegClasses) :
ShouldAllocateClass(F) {}
RegAllocBase(const RegClassFilterFunc F = nullptr) : ShouldAllocateClass(F) {}

virtual ~RegAllocBase() = default;

// A RegAlloc pass should call this before allocatePhysRegs.
void init(VirtRegMap &vrm, LiveIntervals &lis, LiveRegMatrix &mat);

/// Get whether a given register should be allocated
bool shouldAllocateRegister(Register Reg) {
if (!ShouldAllocateClass)
return true;
return ShouldAllocateClass(*TRI, *MRI->getRegClass(Reg));
}

// The top-level driver. The output is a VirtRegMap that us updated with
// physical register assignments.
void allocatePhysRegs();
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/RegAllocBasic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class RABasic : public MachineFunctionPass,
void LRE_WillShrinkVirtReg(Register) override;

public:
RABasic(const RegClassFilterFunc F = allocateAllRegClasses);
RABasic(const RegClassFilterFunc F = nullptr);

/// Return the pass name.
StringRef getPassName() const override { return "Basic Register Allocator"; }
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/CodeGen/RegAllocFast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class InstrPosIndexes {

class RegAllocFastImpl {
public:
RegAllocFastImpl(const RegClassFilterFunc F = allocateAllRegClasses,
RegAllocFastImpl(const RegClassFilterFunc F = nullptr,
bool ClearVirtRegs_ = true)
: ShouldAllocateClass(F), StackSlotForVirtReg(-1),
ClearVirtRegs(ClearVirtRegs_) {}
Expand Down Expand Up @@ -387,8 +387,7 @@ class RegAllocFast : public MachineFunctionPass {
public:
static char ID;

RegAllocFast(const RegClassFilterFunc F = allocateAllRegClasses,
bool ClearVirtRegs_ = true)
RegAllocFast(const RegClassFilterFunc F = nullptr, bool ClearVirtRegs_ = true)
: MachineFunctionPass(ID), Impl(F, ClearVirtRegs_) {}

bool runOnMachineFunction(MachineFunction &MF) override {
Expand Down Expand Up @@ -431,6 +430,8 @@ INITIALIZE_PASS(RegAllocFast, "regallocfast", "Fast Register Allocator", false,

bool RegAllocFastImpl::shouldAllocateRegister(const Register Reg) const {
assert(Reg.isVirtual());
if (!ShouldAllocateClass)
return true;
const TargetRegisterClass &RC = *MRI->getRegClass(Reg);
return ShouldAllocateClass(*TRI, RC);
}
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/CodeGen/RegAllocGreedy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2308,7 +2308,7 @@ void RAGreedy::tryHintRecoloring(const LiveInterval &VirtReg) {

// This may be a skipped class
if (!VRM->hasPhys(Reg)) {
assert(!ShouldAllocateClass(*TRI, *MRI->getRegClass(Reg)) &&
assert(!shouldAllocateRegister(Reg) &&
"We have an unallocated variable which should have been handled");
continue;
}
Expand Down Expand Up @@ -2698,7 +2698,7 @@ bool RAGreedy::hasVirtRegAlloc() {
const TargetRegisterClass *RC = MRI->getRegClass(Reg);
if (!RC)
continue;
if (ShouldAllocateClass(*TRI, *RC))
if (shouldAllocateRegister(Reg))
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/RegAllocGreedy.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class LLVM_LIBRARY_VISIBILITY RAGreedy : public MachineFunctionPass,
bool ReverseLocalAssignment = false;

public:
RAGreedy(const RegClassFilterFunc F = allocateAllRegClasses);
RAGreedy(const RegClassFilterFunc F = nullptr);

/// Return the pass name.
StringRef getPassName() const override { return "Greedy Register Allocator"; }
Expand Down
12 changes: 7 additions & 5 deletions llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1173,14 +1173,15 @@ parseRegAllocFastPassOptions(PassBuilder &PB, StringRef Params) {
std::tie(ParamName, Params) = Params.split(';');

if (ParamName.consume_front("filter=")) {
RegClassFilterFunc Filter = PB.parseRegAllocFilter(ParamName);
std::optional<RegClassFilterFunc> Filter =
PB.parseRegAllocFilter(ParamName);
if (!Filter) {
return make_error<StringError>(
formatv("invalid regallocfast register filter '{0}' ", ParamName)
.str(),
inconvertibleErrorCode());
}
Opts.Filter = Filter;
Opts.Filter = *Filter;
Opts.FilterName = ParamName;
continue;
}
Expand Down Expand Up @@ -2220,13 +2221,14 @@ Error PassBuilder::parseAAPipeline(AAManager &AA, StringRef PipelineText) {
return Error::success();
}

RegClassFilterFunc PassBuilder::parseRegAllocFilter(StringRef FilterName) {
std::optional<RegClassFilterFunc>
PassBuilder::parseRegAllocFilter(StringRef FilterName) {
if (FilterName == "all")
return allocateAllRegClasses;
return nullptr;
for (auto &C : RegClassFilterParsingCallbacks)
if (auto F = C(FilterName))
return F;
return nullptr;
return std::nullopt;
}

static void printPassName(StringRef PassName, raw_ostream &OS) {
Expand Down