Skip to content

Commit aa485e8

Browse files
authored
Merge pull request #4457 from eeckstein/partial-apply-opt
Add an optimization to eliminate a partial_apply if all applied arguments are dead in the applied function.
2 parents fa36206 + 959e19d commit aa485e8

File tree

13 files changed

+669
-91
lines changed

13 files changed

+669
-91
lines changed

include/swift/SILOptimizer/Analysis/CallerAnalysis.h

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,11 @@
1919
#include "swift/SIL/SILModule.h"
2020
#include "llvm/ADT/ArrayRef.h"
2121
#include "llvm/ADT/DenseMap.h"
22-
#include "llvm/ADT/SmallVector.h"
22+
#include "llvm/ADT/SmallSet.h"
2323
#include "llvm/ADT/TinyPtrVector.h"
2424

2525
namespace swift {
2626

27-
/// NOTE: this can be extended to contain the callsites of the function.
28-
struct CallerAnalysisFunctionInfo {
29-
/// A list of all the functions this function calls.
30-
llvm::SetVector<SILFunction *> Callees;
31-
/// A list of all the callers this function has.
32-
llvm::SetVector<SILFunction *> Callers;
33-
};
34-
3527
/// CallerAnalysis relies on keeping the Caller/Callee relation up-to-date
3628
/// lazily. i.e. when a function is invalidated, instead of recomputing the
3729
/// function it calls right away, its kept in a recompute list and
@@ -44,13 +36,57 @@ struct CallerAnalysisFunctionInfo {
4436
/// to run every function through the a sequence of function passes which might
4537
/// invalidate the functions and make the computed list incomplete. So
4638
/// O(n) * O(n) = O(n^2).
39+
///
40+
/// In addition of caller information this analysis also provides information
41+
/// about partial applies of a function.
4742
class CallerAnalysis : public SILAnalysis {
4843

44+
public:
45+
46+
/// NOTE: this can be extended to contain the callsites of the function.
47+
class FunctionInfo {
48+
friend class CallerAnalysis;
49+
50+
/// A list of all the functions this function calls or partially applies.
51+
llvm::SetVector<SILFunction *> Callees;
52+
/// A list of all the callers this function has.
53+
llvm::SmallSet<SILFunction *, 4> Callers;
54+
55+
/// The number of partial applied arguments of this function.
56+
/// Specifically, it stores the minimum number of partial applied arguments
57+
/// of each function which contain one or multiple partial_applys of this
58+
/// function.
59+
/// This is a little bit off-topic because a partial_apply is not really
60+
/// a "call" of this function.
61+
llvm::DenseMap<SILFunction *, int> PartialAppliers;
62+
63+
public:
64+
/// Returns true if this function has at least one caller.
65+
bool hasCaller() const {
66+
return !Callers.empty();
67+
}
68+
69+
/// Returns non zero if this function is partially applied anywhere.
70+
/// The return value is the minimum number of partially applied arguments.
71+
/// Usually all partial applies of a function partially apply the same
72+
/// number of arguments anyway.
73+
int getMinPartialAppliedArgs() const {
74+
int minArgs = 0;
75+
for (auto Iter : PartialAppliers) {
76+
int numArgs = Iter.second;
77+
if (minArgs == 0 || numArgs < minArgs)
78+
minArgs = numArgs;
79+
}
80+
return minArgs;
81+
}
82+
};
83+
84+
private:
4985
/// Current module we are analyzing.
5086
SILModule &Mod;
5187

5288
/// A map between all the functions and their callsites in the module.
53-
llvm::DenseMap<SILFunction *, CallerAnalysisFunctionInfo> CallInfo;
89+
llvm::DenseMap<SILFunction *, FunctionInfo> FuncInfos;
5490

5591
/// A list of functions that needs to be recomputed.
5692
llvm::SetVector<SILFunction *> RecomputeFunctionList;
@@ -74,7 +110,7 @@ class CallerAnalysis : public SILAnalysis {
74110
CallerAnalysis(SILModule *M) : SILAnalysis(AnalysisKind::Caller), Mod(*M) {
75111
// Make sure we compute everything first time called.
76112
for (auto &F : Mod) {
77-
CallInfo.FindAndConstruct(&F);
113+
FuncInfos.FindAndConstruct(&F);
78114
RecomputeFunctionList.insert(&F);
79115
}
80116
}
@@ -89,7 +125,7 @@ class CallerAnalysis : public SILAnalysis {
89125

90126
virtual void invalidate(SILFunction *F, InvalidationKind K) {
91127
// Should we invalidate based on the invalidation kind.
92-
bool shouldInvalidate = K & InvalidationKind::Calls;
128+
bool shouldInvalidate = K & InvalidationKind::CallsAndInstructions;
93129
if (!shouldInvalidate)
94130
return;
95131

@@ -111,20 +147,18 @@ class CallerAnalysis : public SILAnalysis {
111147
if (!shouldInvalidate)
112148
return;
113149

114-
CallInfo.clear();
150+
FuncInfos.clear();
115151
RecomputeFunctionList.clear();
116152
for (auto &F : Mod) {
117153
RecomputeFunctionList.insert(&F);
118154
}
119155
}
120156

121-
/// Return true if the function has a caller inside current module.
122-
bool hasCaller(SILFunction *F) {
157+
const FunctionInfo &getCallerInfo(SILFunction *F) {
123158
// Recompute every function in the invalidated function list and empty the
124159
// list.
125160
processRecomputeFunctionList();
126-
auto Iter = CallInfo.FindAndConstruct(F);
127-
return !Iter.second.Callers.empty();
161+
return FuncInfos[F];
128162
}
129163
};
130164

include/swift/SILOptimizer/PassManager/Passes.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ PASS(RedundantOverflowCheckRemoval, "remove-redundant-overflow-checks",
8181
"Removes redundant overflow checks")
8282
PASS(DCE, "dce",
8383
"Eliminate dead code")
84+
PASS(DeadArgSignatureOpt, "dead-arg-signature-opt",
85+
"Create function with removed dead arguments")
8486
PASS(DeadFunctionElimination, "sil-deadfuncelim",
8587
"Remove unused functions")
8688
PASS(DeadObjectElimination, "deadobject-elim",

lib/SILOptimizer/Analysis/CallerAnalysis.cpp

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,42 @@ void CallerAnalysis::processFunctionCallSites(SILFunction *F) {
2828
continue;
2929

3030
// Update the callee information for this function.
31-
CallerAnalysisFunctionInfo &CallerInfo
32-
= CallInfo.FindAndConstruct(F).second;
31+
FunctionInfo &CallerInfo = FuncInfos[F];
3332
CallerInfo.Callees.insert(CalleeFn);
3433

3534
// Update the callsite information for the callee.
36-
CallerAnalysisFunctionInfo &CalleeInfo
37-
= CallInfo.FindAndConstruct(CalleeFn).second;
35+
FunctionInfo &CalleeInfo = FuncInfos[CalleeFn];
3836
CalleeInfo.Callers.insert(F);
39-
}
37+
continue;
38+
}
39+
if (auto *PAI = dyn_cast<PartialApplyInst>(&II)) {
40+
SILFunction *CalleeFn = PAI->getCalleeFunction();
41+
if (!CalleeFn)
42+
continue;
43+
44+
// Update the callee information for this function.
45+
FunctionInfo &CallerInfo = FuncInfos[F];
46+
CallerInfo.Callees.insert(CalleeFn);
47+
48+
// Update the partial-apply information for the callee.
49+
FunctionInfo &CalleeInfo = FuncInfos[CalleeFn];
50+
int &minAppliedArgs = CalleeInfo.PartialAppliers[F];
51+
int numArgs = (int)PAI->getNumArguments();
52+
if (minAppliedArgs == 0 || numArgs < minAppliedArgs) {
53+
minAppliedArgs = numArgs;
54+
}
55+
continue;
56+
}
4057
}
4158
}
4259
}
4360

4461
void CallerAnalysis::invalidateExistingCalleeRelation(SILFunction *F) {
45-
CallerAnalysisFunctionInfo &CallerInfo = CallInfo.FindAndConstruct(F).second;
62+
FunctionInfo &CallerInfo = FuncInfos[F];
4663
for (auto Callee : CallerInfo.Callees) {
47-
CallerAnalysisFunctionInfo &CalleeInfo
48-
= CallInfo.FindAndConstruct(Callee).second;
49-
CalleeInfo.Callers.remove(F);
64+
FunctionInfo &CalleeInfo = FuncInfos[Callee];
65+
CalleeInfo.Callers.erase(F);
66+
CalleeInfo.PartialAppliers.erase(F);
5067
}
5168
}
5269

lib/SILOptimizer/IPO/CapturePropagation.cpp

Lines changed: 104 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ STATISTIC(NumCapturesPropagated, "Number of constant captures propagated");
3131
namespace {
3232
/// Propagate constants through closure captures by specializing the partially
3333
/// applied function.
34+
/// Also optimize away partial_apply instructions where all partially applied
35+
/// arguments are dead.
3436
class CapturePropagation : public SILFunctionTransform
3537
{
3638
public:
@@ -258,15 +260,11 @@ void CapturePropagation::rewritePartialApply(PartialApplyInst *OrigPAI,
258260
SILFunction *SpecialF) {
259261
SILBuilderWithScope Builder(OrigPAI);
260262
auto FuncRef = Builder.createFunctionRef(OrigPAI->getLoc(), SpecialF);
261-
auto NewPAI = Builder.createPartialApply(OrigPAI->getLoc(),
262-
FuncRef,
263-
SpecialF->getLoweredType(),
264-
ArrayRef<Substitution>(),
265-
ArrayRef<SILValue>(),
266-
OrigPAI->getType());
267-
OrigPAI->replaceAllUsesWith(NewPAI);
263+
auto *T2TF = Builder.createThinToThickFunction(OrigPAI->getLoc(),
264+
FuncRef, OrigPAI->getType());
265+
OrigPAI->replaceAllUsesWith(T2TF);
268266
recursivelyDeleteTriviallyDeadInstructions(OrigPAI, true);
269-
DEBUG(llvm::dbgs() << " Rewrote caller:\n" << *NewPAI);
267+
DEBUG(llvm::dbgs() << " Rewrote caller:\n" << *T2TF);
270268
}
271269

272270
/// For now, we conservative only specialize if doing so can eliminate dynamic
@@ -286,6 +284,92 @@ static bool isProfitable(SILFunction *Callee) {
286284
return false;
287285
}
288286

287+
/// Returns true if block \p BB only contains a return or throw of the first
288+
/// block argument and side-effect-free instructions.
289+
static bool isArgReturnOrThrow(SILBasicBlock *BB) {
290+
for (SILInstruction &I : *BB) {
291+
if (isa<ReturnInst>(&I) || isa<ThrowInst>(&I)) {
292+
SILValue RetVal = I.getOperand(0);
293+
if (BB->getNumBBArg() == 1 && RetVal == BB->getBBArg(0))
294+
return true;
295+
return false;
296+
}
297+
if (I.mayHaveSideEffects() || isa<TermInst>(&I))
298+
return false;
299+
}
300+
llvm_unreachable("should have seen a terminator instruction");
301+
}
302+
303+
/// Checks if \p Orig is a thunk which calls another function but without
304+
/// passing the trailing \p numDeadParams dead parameters.
305+
static SILFunction *getSpecializedWithDeadParams(SILFunction *Orig,
306+
int numDeadParams) {
307+
SILBasicBlock &EntryBB = *Orig->begin();
308+
unsigned NumArgs = EntryBB.getNumBBArg();
309+
SILModule &M = Orig->getModule();
310+
311+
// Check if all dead parameters have trivial types. We don't support non-
312+
// trivial types because it's very hard to find places where we can release
313+
// those parameters (as a replacement for the removed partial_apply).
314+
// TODO: maybe we can skip this restrication when we have semantic ARC.
315+
for (unsigned Idx = NumArgs - numDeadParams; Idx < NumArgs; ++Idx) {
316+
SILType ArgTy = EntryBB.getBBArg(Idx)->getType();
317+
if (!ArgTy.isTrivial(M))
318+
return nullptr;
319+
}
320+
SILFunction *Specialized = nullptr;
321+
SILValue RetValue;
322+
323+
// Check all instruction of the entry block.
324+
for (SILInstruction &I : EntryBB) {
325+
if (auto FAS = FullApplySite::isa(&I)) {
326+
327+
// Check if this is the call of the specialized function.
328+
// As the original function is not generic, also the specialized function
329+
// must be not generic.
330+
if (FAS.hasSubstitutions())
331+
return nullptr;
332+
// Is it the only call?
333+
if (Specialized)
334+
return nullptr;
335+
336+
Specialized = FAS.getReferencedFunction();
337+
if (!Specialized)
338+
return nullptr;
339+
340+
// Check if parameters are passes 1-to-1
341+
unsigned NumArgs = FAS.getNumArguments();
342+
if (EntryBB.getNumBBArg() - numDeadParams != NumArgs)
343+
return nullptr;
344+
345+
for (unsigned Idx = 0; Idx < NumArgs; ++Idx) {
346+
if (FAS.getArgument(Idx) != (ValueBase *)EntryBB.getBBArg(Idx))
347+
return nullptr;
348+
}
349+
350+
if (TryApplyInst *TAI = dyn_cast<TryApplyInst>(&I)) {
351+
// Check the normal and throw blocks of the try_apply.
352+
if (isArgReturnOrThrow(TAI->getNormalBB()) &&
353+
isArgReturnOrThrow(TAI->getErrorBB()))
354+
return Specialized;
355+
return nullptr;
356+
}
357+
assert(isa<ApplyInst>(&I) && "unknown FullApplySite instruction");
358+
RetValue = &I;
359+
continue;
360+
}
361+
if (auto *RI = dyn_cast<ReturnInst>(&I)) {
362+
// Check if we return the result of the apply.
363+
if (RI->getOperand() != RetValue)
364+
return nullptr;
365+
continue;
366+
}
367+
if (I.mayHaveSideEffects() || isa<TermInst>(&I))
368+
return nullptr;
369+
}
370+
return Specialized;
371+
}
372+
289373
bool CapturePropagation::optimizePartialApply(PartialApplyInst *PAI) {
290374
// Check if the partial_apply has generic substitutions.
291375
// FIXME: We could handle generic thunks if it's worthwhile.
@@ -295,15 +379,26 @@ bool CapturePropagation::optimizePartialApply(PartialApplyInst *PAI) {
295379
SILFunction *SubstF = PAI->getReferencedFunction();
296380
if (!SubstF)
297381
return false;
382+
if (SubstF->isExternalDeclaration())
383+
return false;
298384

299385
assert(!SubstF->getLoweredFunctionType()->isPolymorphic() &&
300386
"cannot specialize generic partial apply");
301387

388+
// First possibility: Is it a partial_apply where all partially applied
389+
// arguments are dead?
390+
if (SILFunction *NewFunc = getSpecializedWithDeadParams(SubstF,
391+
PAI->getNumArguments())) {
392+
rewritePartialApply(PAI, NewFunc);
393+
return true;
394+
}
395+
396+
// Second possibility: Are all partially applied arguments constant?
302397
for (auto Arg : PAI->getArguments()) {
303398
if (!isConstant(Arg))
304399
return false;
305400
}
306-
if (SubstF->isExternalDeclaration() || !isProfitable(SubstF))
401+
if (!isProfitable(SubstF))
307402
return false;
308403

309404
DEBUG(llvm::dbgs() << "Specializing closure for constant arguments:\n"

lib/SILOptimizer/PassManager/Passes.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ void swift::runSILOptimizationPasses(SILModule &Module) {
307307
// Run an iteration of the mid-level SSA passes.
308308
PM.setStageName("MidLevel");
309309
AddSSAPasses(PM, OptimizationLevelKind::MidLevel);
310+
311+
// Specialy partially applied functions with dead arguments as a preparation
312+
// for CapturePropagation.
313+
PM.addDeadArgSignatureOpt();
314+
310315
PM.runOneIteration();
311316
PM.resetAndRemoveTransformations();
312317

0 commit comments

Comments
 (0)