Skip to content

Commit 813b362

Browse files
wdx727lifengxiang1025zcfh
committed
Adding Matching and Inference Functionality to Propeller-PR4: Implement matching and inference and create clusters.
Co-authored-by: lifengxiang1025 <[email protected]> Co-authored-by: zcfh <[email protected]>
1 parent f8b5f86 commit 813b362

13 files changed

+488
-9
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//===- llvm/CodeGen/BasicBlockMatchingAndInference.h ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Infer weights for all basic blocks using matching and inference.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_CODEGEN_BASIC_BLOCK_AND_INFERENCE_H
14+
#define LLVM_CODEGEN_BASIC_BLOCK_AND_INFERENCE_H
15+
16+
#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h"
17+
#include "llvm/CodeGen/MachineFunctionPass.h"
18+
#include "llvm/Transforms/Utils/SampleProfileInference.h"
19+
20+
namespace llvm {
21+
22+
class BasicBlockMatchingAndInference : public MachineFunctionPass {
23+
private:
24+
using Edge = std::pair<const MachineBasicBlock *, const MachineBasicBlock *>;
25+
using BlockWeightMap = DenseMap<const MachineBasicBlock *, uint64_t>;
26+
using EdgeWeightMap = DenseMap<Edge, uint64_t>;
27+
using BlockEdgeMap = DenseMap<const MachineBasicBlock *,
28+
SmallVector<const MachineBasicBlock *, 8>>;
29+
30+
struct WeightInfo {
31+
// Weight of basic blocks.
32+
BlockWeightMap BlockWeights;
33+
// Weight of edges.
34+
EdgeWeightMap EdgeWeights;
35+
};
36+
37+
public:
38+
static char ID;
39+
BasicBlockMatchingAndInference();
40+
41+
StringRef getPassName() const override {
42+
return "Basic Block Matching and Inference";
43+
}
44+
45+
void getAnalysisUsage(AnalysisUsage &AU) const override;
46+
47+
bool runOnMachineFunction(MachineFunction &F) override;
48+
49+
std::optional<WeightInfo> getWeightInfo(StringRef FuncName) const;
50+
51+
private:
52+
StringMap<WeightInfo> ProgramWeightInfo;
53+
54+
WeightInfo initWeightInfoByMatching(MachineFunction &MF);
55+
56+
void generateWeightInfoByInference(MachineFunction &MF,
57+
WeightInfo &MatchWeight);
58+
};
59+
60+
} // end namespace llvm
61+
62+
#endif // LLVM_CODEGEN_BASIC_BLOCK_AND_INFERENCE_H

llvm/include/llvm/CodeGen/BasicBlockSectionsProfileReader.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ class BasicBlockSectionsProfileReader {
9090
uint64_t getEdgeCount(StringRef FuncName, const UniqueBBID &SrcBBID,
9191
const UniqueBBID &SinkBBID) const;
9292

93+
// Return the complete function path and cluster info for the given function.
94+
std::pair<bool, FunctionPathAndClusterInfo>
95+
getFunctionPathAndClusterInfo(StringRef FuncName) const;
96+
9397
private:
9498
StringRef getAliasName(StringRef FuncName) const {
9599
auto R = FuncAliasMap.find(FuncName);
@@ -199,6 +203,9 @@ class BasicBlockSectionsProfileReaderWrapperPass : public ImmutablePass {
199203
uint64_t getEdgeCount(StringRef FuncName, const UniqueBBID &SrcBBID,
200204
const UniqueBBID &DestBBID) const;
201205

206+
std::pair<bool, FunctionPathAndClusterInfo>
207+
getFunctionPathAndClusterInfo(StringRef FuncName) const;
208+
202209
// Initializes the FunctionNameToDIFilename map for the current module and
203210
// then reads the profile for the matching functions.
204211
bool doInitialization(Module &M) override;

llvm/include/llvm/CodeGen/MachineBlockHashInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ struct BlendedBlockHash {
8080
return Dist;
8181
}
8282

83+
uint16_t getOpcodeHash() const { return OpcodeHash; }
84+
8385
private:
8486
/// The offset of the basic block from the function start.
8587
uint16_t Offset{0};

llvm/include/llvm/CodeGen/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ LLVM_ABI MachineFunctionPass *createBasicBlockSectionsPass();
6969

7070
LLVM_ABI MachineFunctionPass *createBasicBlockPathCloningPass();
7171

72+
/// createBasicBlockMatchingAndInferencePass - This pass enables matching
73+
/// and inference when using propeller.
74+
LLVM_ABI MachineFunctionPass *createBasicBlockMatchingAndInferencePass();
75+
7276
/// createMachineBlockHashInfoPass - This pass computes basic block hashes.
7377
LLVM_ABI MachineFunctionPass *createMachineBlockHashInfoPass();
7478

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ LLVM_ABI void initializeAlwaysInlinerLegacyPassPass(PassRegistry &);
5555
LLVM_ABI void initializeAssignmentTrackingAnalysisPass(PassRegistry &);
5656
LLVM_ABI void initializeAssumptionCacheTrackerPass(PassRegistry &);
5757
LLVM_ABI void initializeAtomicExpandLegacyPass(PassRegistry &);
58+
LLVM_ABI void initializeBasicBlockMatchingAndInferencePass(PassRegistry &);
5859
LLVM_ABI void initializeBasicBlockPathCloningPass(PassRegistry &);
5960
LLVM_ABI void
6061
initializeBasicBlockSectionsProfileReaderWrapperPassPass(PassRegistry &);

llvm/include/llvm/Transforms/Utils/SampleProfileInference.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ template <typename FT> class SampleProfileInference {
130130
SampleProfileInference(FunctionT &F, BlockEdgeMap &Successors,
131131
BlockWeightMap &SampleBlockWeights)
132132
: F(F), Successors(Successors), SampleBlockWeights(SampleBlockWeights) {}
133+
SampleProfileInference(FunctionT &F, BlockEdgeMap &Successors,
134+
BlockWeightMap &SampleBlockWeights,
135+
EdgeWeightMap &SampleEdgeWeights)
136+
: F(F), Successors(Successors), SampleBlockWeights(SampleBlockWeights),
137+
SampleEdgeWeights(SampleEdgeWeights) {}
133138

134139
/// Apply the profile inference algorithm for a given function
135140
void apply(BlockWeightMap &BlockWeights, EdgeWeightMap &EdgeWeights);
@@ -157,6 +162,9 @@ template <typename FT> class SampleProfileInference {
157162

158163
/// Map basic blocks to their sampled weights.
159164
BlockWeightMap &SampleBlockWeights;
165+
166+
/// Map edges to their sampled weights.
167+
EdgeWeightMap SampleEdgeWeights;
160168
};
161169

162170
template <typename BT>
@@ -266,6 +274,14 @@ FlowFunction SampleProfileInference<BT>::createFlowFunction(
266274
FlowJump Jump;
267275
Jump.Source = BlockIndex[BB];
268276
Jump.Target = BlockIndex[Succ];
277+
auto It = SampleEdgeWeights.find(std::make_pair(BB, Succ));
278+
if (It != SampleEdgeWeights.end()) {
279+
Jump.HasUnknownWeight = false;
280+
Jump.Weight = It->second;
281+
} else {
282+
Jump.HasUnknownWeight = true;
283+
Jump.Weight = 0;
284+
}
269285
Func.Jumps.push_back(Jump);
270286
}
271287
}
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
//===- llvm/CodeGen/BasicBlockMatchingAndInference.cpp ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Infer weights for all basic blocks using matching and inference.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "llvm/CodeGen/BasicBlockMatchingAndInference.h"
14+
#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h"
15+
#include "llvm/CodeGen/MachineBlockHashInfo.h"
16+
#include "llvm/CodeGen/Passes.h"
17+
#include "llvm/InitializePasses.h"
18+
#include <llvm/Support/CommandLine.h>
19+
20+
using namespace llvm;
21+
22+
static cl::opt<float>
23+
PropellerInferThreshold("propeller-infer-threshold",
24+
cl::desc("Threshold for infer stale profile"),
25+
cl::init(0.6), cl::Optional);
26+
27+
/// The object is used to identify and match basic blocks given their hashes.
28+
class StaleMatcher {
29+
public:
30+
/// Initialize stale matcher.
31+
void init(const std::vector<MachineBasicBlock *> &Blocks,
32+
const std::vector<BlendedBlockHash> &Hashes) {
33+
assert(Blocks.size() == Hashes.size() &&
34+
"incorrect matcher initialization");
35+
for (size_t I = 0; I < Blocks.size(); I++) {
36+
MachineBasicBlock *Block = Blocks[I];
37+
uint16_t OpHash = Hashes[I].getOpcodeHash();
38+
OpHashToBlocks[OpHash].push_back(std::make_pair(Hashes[I], Block));
39+
}
40+
}
41+
42+
/// Find the most similar block for a given hash.
43+
MachineBasicBlock *matchBlock(BlendedBlockHash BlendedHash) const {
44+
auto BlockIt = OpHashToBlocks.find(BlendedHash.getOpcodeHash());
45+
if (BlockIt == OpHashToBlocks.end()) {
46+
return nullptr;
47+
}
48+
MachineBasicBlock *BestBlock = nullptr;
49+
uint64_t BestDist = std::numeric_limits<uint64_t>::max();
50+
for (auto It : BlockIt->second) {
51+
MachineBasicBlock *Block = It.second;
52+
BlendedBlockHash Hash = It.first;
53+
uint64_t Dist = Hash.distance(BlendedHash);
54+
if (BestBlock == nullptr || Dist < BestDist) {
55+
BestDist = Dist;
56+
BestBlock = Block;
57+
}
58+
}
59+
return BestBlock;
60+
}
61+
62+
private:
63+
using HashBlockPairType = std::pair<BlendedBlockHash, MachineBasicBlock *>;
64+
std::unordered_map<uint16_t, std::vector<HashBlockPairType>> OpHashToBlocks;
65+
};
66+
67+
INITIALIZE_PASS_BEGIN(BasicBlockMatchingAndInference,
68+
"machine-block-match-infer",
69+
"Machine Block Matching and Inference Analysis", true,
70+
true)
71+
INITIALIZE_PASS_DEPENDENCY(MachineBlockHashInfo)
72+
INITIALIZE_PASS_DEPENDENCY(BasicBlockSectionsProfileReaderWrapperPass)
73+
INITIALIZE_PASS_END(BasicBlockMatchingAndInference, "machine-block-match-infer",
74+
"Machine Block Matching and Inference Analysis", true, true)
75+
76+
char BasicBlockMatchingAndInference::ID = 0;
77+
78+
BasicBlockMatchingAndInference::BasicBlockMatchingAndInference()
79+
: MachineFunctionPass(ID) {
80+
initializeBasicBlockMatchingAndInferencePass(
81+
*PassRegistry::getPassRegistry());
82+
}
83+
84+
void BasicBlockMatchingAndInference::getAnalysisUsage(AnalysisUsage &AU) const {
85+
AU.addRequired<MachineBlockHashInfo>();
86+
AU.addRequired<BasicBlockSectionsProfileReaderWrapperPass>();
87+
AU.setPreservesAll();
88+
MachineFunctionPass::getAnalysisUsage(AU);
89+
}
90+
91+
std::optional<BasicBlockMatchingAndInference::WeightInfo>
92+
BasicBlockMatchingAndInference::getWeightInfo(StringRef FuncName) const {
93+
auto It = ProgramWeightInfo.find(FuncName);
94+
if (It == ProgramWeightInfo.end()) {
95+
return std::nullopt;
96+
}
97+
return It->second;
98+
}
99+
100+
BasicBlockMatchingAndInference::WeightInfo
101+
BasicBlockMatchingAndInference::initWeightInfoByMatching(MachineFunction &MF) {
102+
std::vector<MachineBasicBlock *> Blocks;
103+
std::vector<BlendedBlockHash> Hashes;
104+
auto BSPR = &getAnalysis<BasicBlockSectionsProfileReaderWrapperPass>();
105+
auto MBHI = &getAnalysis<MachineBlockHashInfo>();
106+
for (auto &Block : MF) {
107+
Blocks.push_back(&Block);
108+
Hashes.push_back(BlendedBlockHash(MBHI->getMBBHash(Block)));
109+
}
110+
StaleMatcher Matcher;
111+
Matcher.init(Blocks, Hashes);
112+
BasicBlockMatchingAndInference::WeightInfo MatchWeight;
113+
auto [Flag, PathAndClusterInfo] =
114+
BSPR->getFunctionPathAndClusterInfo(MF.getName());
115+
if (!Flag)
116+
return MatchWeight;
117+
for (auto &BlockCount : PathAndClusterInfo.NodeCounts) {
118+
if (PathAndClusterInfo.BBHashes.count(BlockCount.first.BaseID)) {
119+
auto Hash = PathAndClusterInfo.BBHashes[BlockCount.first.BaseID];
120+
MachineBasicBlock *Block = Matcher.matchBlock(BlendedBlockHash(Hash));
121+
// When a basic block has clone copies, sum their counts.
122+
if (Block != nullptr)
123+
MatchWeight.BlockWeights[Block] += BlockCount.second;
124+
}
125+
}
126+
for (auto &PredItem : PathAndClusterInfo.EdgeCounts) {
127+
auto PredID = PredItem.first.BaseID;
128+
if (!PathAndClusterInfo.BBHashes.count(PredID))
129+
continue;
130+
auto PredHash = PathAndClusterInfo.BBHashes[PredID];
131+
MachineBasicBlock *PredBlock =
132+
Matcher.matchBlock(BlendedBlockHash(PredHash));
133+
if (PredBlock == nullptr)
134+
continue;
135+
for (auto &SuccItem : PredItem.second) {
136+
auto SuccID = SuccItem.first.BaseID;
137+
auto EdgeWeight = SuccItem.second;
138+
if (PathAndClusterInfo.BBHashes.count(SuccID)) {
139+
auto SuccHash = PathAndClusterInfo.BBHashes[SuccID];
140+
MachineBasicBlock *SuccBlock =
141+
Matcher.matchBlock(BlendedBlockHash(SuccHash));
142+
// When an edge has clone copies, sum their counts.
143+
if (SuccBlock != nullptr)
144+
MatchWeight.EdgeWeights[std::make_pair(PredBlock, SuccBlock)] +=
145+
EdgeWeight;
146+
}
147+
}
148+
}
149+
return MatchWeight;
150+
}
151+
152+
void BasicBlockMatchingAndInference::generateWeightInfoByInference(
153+
MachineFunction &MF,
154+
BasicBlockMatchingAndInference::WeightInfo &MatchWeight) {
155+
BlockEdgeMap Successors;
156+
for (auto &Block : MF) {
157+
for (auto *Succ : Block.successors())
158+
Successors[&Block].push_back(Succ);
159+
}
160+
SampleProfileInference<MachineFunction> SPI(
161+
MF, Successors, MatchWeight.BlockWeights, MatchWeight.EdgeWeights);
162+
BlockWeightMap BlockWeights;
163+
EdgeWeightMap EdgeWeights;
164+
SPI.apply(BlockWeights, EdgeWeights);
165+
ProgramWeightInfo.try_emplace(
166+
MF.getName(), BasicBlockMatchingAndInference::WeightInfo{
167+
std::move(BlockWeights), std::move(EdgeWeights)});
168+
}
169+
170+
bool BasicBlockMatchingAndInference::runOnMachineFunction(MachineFunction &MF) {
171+
if (MF.empty())
172+
return false;
173+
auto MatchWeight = initWeightInfoByMatching(MF);
174+
// If the ratio of the number of MBBs in matching to the total number of MBBs
175+
// in the function is less than the threshold value, the processing should be
176+
// abandoned.
177+
if (static_cast<float>(MatchWeight.BlockWeights.size()) / MF.size() <
178+
PropellerInferThreshold) {
179+
return false;
180+
}
181+
generateWeightInfoByInference(MF, MatchWeight);
182+
return false;
183+
}
184+
185+
MachineFunctionPass *llvm::createBasicBlockMatchingAndInferencePass() {
186+
return new BasicBlockMatchingAndInference();
187+
}

0 commit comments

Comments
 (0)