Skip to content

Commit 37f6ddb

Browse files
committed
This patch contains part of the changes intended to resolve #68944.
The patch contains almost all needed C++ logic (while the closure specialization pass itself is implemented in Swift). Particularly, the patch contains branch tracing enum specialization logic and related pieces of logic (both in C++ and Swift).
1 parent ae2b46e commit 37f6ddb

File tree

13 files changed

+1126
-14
lines changed

13 files changed

+1126
-14
lines changed

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@
100100
/// ```
101101

102102
import AST
103+
import AutoDiffClosureSpecializationBridging
104+
import Cxx
105+
import CxxStdlib
103106
import SIL
104107
import SILBridging
105108

@@ -120,6 +123,25 @@ let generalClosureSpecialization = FunctionPass(
120123
print("NOT IMPLEMENTED")
121124
}
122125

126+
extension Type {
127+
func isBranchTracingEnumIn(vjp: Function) -> Bool {
128+
return self.bridged.isAutodiffBranchTracingEnumInVJP(vjp.bridged)
129+
}
130+
}
131+
132+
extension Collection {
133+
func getExactlyOneOrNil() -> Element? {
134+
assert(self.count <= 1)
135+
return self.first
136+
}
137+
}
138+
139+
extension BasicBlock {
140+
fileprivate func getBranchTracingEnumArg(vjp: Function) -> Argument? {
141+
return self.arguments.filter { $0.type.isBranchTracingEnumIn(vjp: vjp) }.getExactlyOneOrNil()
142+
}
143+
}
144+
123145
let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-specialization") {
124146
(function: Function, context: FunctionPassContext) in
125147

@@ -175,6 +197,80 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special
175197

176198
private let specializationLevelLimit = 2
177199

200+
private func getPartialApplyOfPullbackInExitVJPBB(vjp: Function) -> PartialApplyInst? {
201+
guard let exitBB = vjp.blocks.filter({ $0.terminator as? ReturnInst != nil }).getExactlyOneOrNil()
202+
else {
203+
return nil
204+
}
205+
206+
let ri = exitBB.terminator as! ReturnInst
207+
guard let retValDefiningInstr = ri.returnedValue.definingInstruction else {
208+
return nil
209+
}
210+
211+
func handleConvertFunctionOrPartialApply(inst: Instruction) -> PartialApplyInst? {
212+
if let pai = inst as? PartialApplyInst {
213+
return pai
214+
}
215+
if let cfi = inst as? ConvertFunctionInst {
216+
return cfi.fromFunction as? PartialApplyInst
217+
}
218+
return nil
219+
}
220+
221+
if let ti = retValDefiningInstr as? TupleInst {
222+
if ti.operands.count < 2 {
223+
return nil
224+
}
225+
guard let lastTupleElemDefiningInst = ti.operands.last!.value.definingInstruction else {
226+
return nil
227+
}
228+
return handleConvertFunctionOrPartialApply(inst: lastTupleElemDefiningInst)
229+
}
230+
231+
return handleConvertFunctionOrPartialApply(inst: retValDefiningInstr)
232+
}
233+
234+
typealias EnumTypeAndCase = (enumType: Type, caseIdx: Int)
235+
236+
typealias ClosureInfoMultiBB = (
237+
closure: SingleValueInstruction,
238+
capturedArgs: [Value],
239+
subsetThunk: PartialApplyInst?,
240+
payloadTuple: TupleInst,
241+
idxInPayload: Int,
242+
enumTypeAndCase: EnumTypeAndCase
243+
)
244+
245+
private func getPullbackClosureInfoMultiBB(in vjp: Function, _ context: FunctionPassContext)
246+
-> PullbackClosureInfo
247+
{
248+
let paiOfPbInExitVjpBB = getPartialApplyOfPullbackInExitVJPBB(vjp: vjp)!
249+
var pullbackClosureInfo = PullbackClosureInfo(paiOfPullback: paiOfPbInExitVjpBB)
250+
var subsetThunkArr = [SingleValueInstruction]()
251+
252+
for inst in vjp.instructions {
253+
if inst == paiOfPbInExitVjpBB {
254+
continue
255+
}
256+
if inst.asSupportedClosure == nil {
257+
continue
258+
}
259+
260+
let rootClosure = inst.asSupportedClosure!
261+
if subsetThunkArr.contains(rootClosure) {
262+
continue
263+
}
264+
265+
let closureInfoArr = handleNonAppliesMultiBB(for: rootClosure, context)
266+
pullbackClosureInfo.closureInfosMultiBB.append(contentsOf: closureInfoArr)
267+
subsetThunkArr.append(
268+
contentsOf: closureInfoArr.filter { $0.subsetThunk != nil }.map { $0.subsetThunk! })
269+
}
270+
271+
return pullbackClosureInfo
272+
}
273+
178274
private func getPullbackClosureInfo(in caller: Function, _ context: FunctionPassContext)
179275
-> PullbackClosureInfo?
180276
{
@@ -373,6 +469,127 @@ private func updatePullbackClosureInfo(
373469
intermediateClosureArgDescriptorData: intermediateClosureArgDescriptorData, context)
374470
}
375471

472+
typealias BTEPayloadArgOfPbBBWithBTETypeAndCase = (arg: Argument, enumTypeAndCase: EnumTypeAndCase)
473+
474+
// If the pullback's basic block has a single argument which is a payload tuple of the
475+
// branch tracing enum corresponding to the given VJP, return this argument and any valid combination
476+
// of a branch tracing enum type and its case index having the same payload tuple type as the argument.
477+
private func getBTEPayloadArgOfPbBBWithBTETypeAndCase(_ bb: BasicBlock, vjp: Function)
478+
-> BTEPayloadArgOfPbBBWithBTETypeAndCase?
479+
{
480+
guard let predBB = bb.predecessors.first else {
481+
return nil
482+
}
483+
guard let arg = bb.arguments.singleElement else {
484+
return nil
485+
}
486+
if !arg.type.isTuple {
487+
return nil
488+
}
489+
490+
if let bi = predBB.terminator as? BranchInst {
491+
guard let uedi = bi.operands[arg.index].value.definingInstruction as? UncheckedEnumDataInst
492+
else {
493+
return nil
494+
}
495+
let enumType = uedi.`enum`.type
496+
if !enumType.isBranchTracingEnumIn(vjp: vjp) {
497+
return nil
498+
}
499+
500+
return BTEPayloadArgOfPbBBWithBTETypeAndCase(
501+
arg: arg,
502+
enumTypeAndCase: (
503+
enumType: enumType,
504+
caseIdx: uedi.caseIndex
505+
)
506+
)
507+
}
508+
509+
if let sei = predBB.terminator as? SwitchEnumInst {
510+
let enumType = sei.enumOp.type
511+
if !enumType.isBranchTracingEnumIn(vjp: vjp) {
512+
return nil
513+
}
514+
return BTEPayloadArgOfPbBBWithBTETypeAndCase(
515+
arg: arg,
516+
enumTypeAndCase: (
517+
enumType: enumType,
518+
caseIdx: sei.getUniqueCase(forSuccessor: bb)!
519+
)
520+
)
521+
}
522+
523+
return nil
524+
}
525+
526+
extension PartialApplyInst {
527+
func isSubsetThunk() -> Bool {
528+
if self.argumentOperands.singleElement == nil {
529+
return false
530+
}
531+
guard let desc = self.referencedFunction?.description else {
532+
return false
533+
}
534+
// TODO: do not rely on description which is intended for debug purposes.
535+
return desc.starts(
536+
with: "// autodiff subset parameters thunk for")
537+
}
538+
}
539+
540+
private func handleNonAppliesMultiBB(
541+
for rootClosure: SingleValueInstruction,
542+
_ context: FunctionPassContext
543+
)
544+
-> [ClosureInfoMultiBB]
545+
{
546+
let vjp = rootClosure.parentFunction
547+
var closureInfoArr = [ClosureInfoMultiBB]()
548+
549+
var closure = rootClosure
550+
var subsetThunk = PartialApplyInst?(nil)
551+
if rootClosure.uses.singleElement != nil {
552+
if let pai = closure.uses.singleElement!.instruction as? PartialApplyInst {
553+
if pai.isSubsetThunk() {
554+
subsetThunk = pai
555+
closure = pai
556+
}
557+
}
558+
}
559+
560+
for use in closure.uses {
561+
guard let ti = use.instruction as? TupleInst else {
562+
// Unexpected use of closure, return nothing
563+
return []
564+
}
565+
for tiUse in ti.uses {
566+
guard let ei = tiUse.instruction as? EnumInst else {
567+
// Unexpected use of payload tuple, return nothing
568+
return []
569+
}
570+
if !ei.type.isBranchTracingEnumIn(vjp: vjp) {
571+
// Unexpected use of payload tuple, return nothing
572+
return []
573+
}
574+
var capturedArgs = [Value]()
575+
if let pai = rootClosure as? PartialApplyInst {
576+
capturedArgs = pai.argumentOperands.map { $0.value }
577+
}
578+
let enumTypeAndCase = EnumTypeAndCase(enumType: ei.type, caseIdx: ei.caseIndex)
579+
closureInfoArr.append(
580+
ClosureInfoMultiBB(
581+
closure: rootClosure,
582+
capturedArgs: capturedArgs,
583+
subsetThunk: subsetThunk,
584+
payloadTuple: ti,
585+
idxInPayload: use.index,
586+
enumTypeAndCase: enumTypeAndCase
587+
))
588+
}
589+
}
590+
return closureInfoArr
591+
}
592+
376593
/// Handles all non-apply direct and transitive uses of `rootClosure`.
377594
///
378595
/// Returns:
@@ -1390,6 +1607,7 @@ private struct ClosureArgDescriptor {
13901607
private struct PullbackClosureInfo {
13911608
let paiOfPullback: PartialApplyInst
13921609
var closureArgDescriptors: [ClosureArgDescriptor] = []
1610+
var closureInfosMultiBB: [ClosureInfoMultiBB] = []
13931611

13941612
init(paiOfPullback: PartialApplyInst) {
13951613
self.paiOfPullback = paiOfPullback
@@ -1475,3 +1693,101 @@ let rewrittenCallerBodyTest = FunctionTest("autodiff_closure_specialize_rewritte
14751693
print("Rewritten caller body for: \(function.name):")
14761694
print("\(function)\n")
14771695
}
1696+
1697+
let getPullbackClosureInfoMultiBBTest = FunctionTest(
1698+
"autodiff_closure_specialize_get_pullback_closure_info_multi_bb"
1699+
) {
1700+
function, arguments, context in
1701+
let pullbackClosureInfo = getPullbackClosureInfoMultiBB(in: function, context)
1702+
print("Run getPullbackClosureInfoMultiBB for VJP \(function.name): pullbackClosureInfo = (")
1703+
print(" pullbackFn = \(pullbackClosureInfo.pullbackFn.name)")
1704+
print(" closureInfosMultiBB = [")
1705+
for closureInfoMultiBB in pullbackClosureInfo.closureInfosMultiBB {
1706+
print(" ClosureInfoMultiBB(")
1707+
print(" closure: \(closureInfoMultiBB.closure)")
1708+
print(" capturedArgs: [")
1709+
for capturedArg in closureInfoMultiBB.capturedArgs {
1710+
print(" \(capturedArg)")
1711+
}
1712+
print(" ]")
1713+
let subsetThunkStr =
1714+
(closureInfoMultiBB.subsetThunk == nil ? "nil" : "\(closureInfoMultiBB.subsetThunk!)")
1715+
print(" subsetThunk: \(subsetThunkStr)")
1716+
print(" payloadTuple: \(closureInfoMultiBB.payloadTuple)")
1717+
print(" idxInPayload: \(closureInfoMultiBB.idxInPayload)")
1718+
print(" enumTypeAndCase: \(closureInfoMultiBB.enumTypeAndCase)")
1719+
print(" )")
1720+
}
1721+
print(" ]\n)\n")
1722+
}
1723+
1724+
typealias SpecBTEDict = SpecializedBranchTracingEnumDict
1725+
1726+
func getSpecBTEDict(vjp: Function, context: FunctionPassContext) -> SpecBTEDict {
1727+
let pullbackClosureInfo = getPullbackClosureInfoMultiBB(in: vjp, context)
1728+
let pb = pullbackClosureInfo.pullbackFn
1729+
let enumTypeOfEntryBBArg = pb.entryBlock.getBranchTracingEnumArg(vjp: vjp)!.type
1730+
1731+
let vectorOfClosureInfoMultiBB = VectorOfBranchTracingEnumAndClosureInfo(
1732+
pullbackClosureInfo.closureInfosMultiBB.map {
1733+
BranchTracingEnumAndClosureInfo(
1734+
enumType: $0.enumTypeAndCase.enumType.bridged,
1735+
enumCaseIdx: $0.enumTypeAndCase.caseIdx,
1736+
closure: $0.closure.bridged,
1737+
idxInPayload: $0.idxInPayload)
1738+
})
1739+
1740+
let enumDict = autodiffSpecializeBranchTracingEnums(
1741+
vjp.bridged, enumTypeOfEntryBBArg.bridged, vectorOfClosureInfoMultiBB)
1742+
1743+
return enumDict
1744+
}
1745+
1746+
let specializeBranchTracingEnums = FunctionTest("autodiff_specialize_branch_tracing_enums") {
1747+
function, arguments, context in
1748+
let enumDict = getSpecBTEDict(vjp: function, context: context)
1749+
print(
1750+
"Specialized branch tracing enum dict for VJP \(function.name) contains \(enumDict.size()) elements:"
1751+
)
1752+
print(
1753+
"\(String(taking: getSpecializedBranchTracingEnumDictAsString(enumDict)))") //, function.bridged)))")
1754+
}
1755+
1756+
let specializeBTEArgInVjpBB = FunctionTest("autodiff_specialize_bte_arg_in_vjp_bb") {
1757+
function, arguments, context in
1758+
let enumDict = getSpecBTEDict(vjp: function, context: context)
1759+
print("Specialized BTE arguments of basic blocks in VJP \(function.name):")
1760+
for bb in function.blocks {
1761+
guard let arg = bb.getBranchTracingEnumArg(vjp: function) else {
1762+
continue
1763+
}
1764+
let newArg = specializeBranchTracingEnumBBArgInVJP(arg.bridged, enumDict).argument
1765+
print("\(newArg)")
1766+
bb.eraseArgument(at: newArg.index, context)
1767+
}
1768+
print("")
1769+
}
1770+
1771+
let specializePayloadArgInPullbackBB = FunctionTest("autodiff_specialize_payload_arg_in_pb_bb") {
1772+
function, arguments, context in
1773+
let pullbackClosureInfo = getPullbackClosureInfoMultiBB(in: function, context)
1774+
let pb = pullbackClosureInfo.pullbackFn
1775+
let enumDict = getSpecBTEDict(vjp: function, context: context)
1776+
1777+
print("Specialized BTE payload arguments of basic blocks in pullback \(pb.name):")
1778+
for bb in pb.blocks {
1779+
guard
1780+
let (arg, enumTypeAndCase) = getBTEPayloadArgOfPbBBWithBTETypeAndCase(bb, vjp: function)
1781+
else {
1782+
continue
1783+
}
1784+
1785+
let enumType = enumDict[enumTypeAndCase.enumType.bridged]!
1786+
let newArg = specializePayloadTupleBBArgInPullback(
1787+
arg.bridged, enumType, enumTypeAndCase.caseIdx
1788+
).argument
1789+
print("\(newArg)")
1790+
bb.eraseArgument(at: newArg.index, context)
1791+
}
1792+
print("")
1793+
}

SwiftCompilerSources/Sources/Optimizer/Utilities/FunctionTest.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public func registerOptimizerTests() {
4444
addressOwnershipLiveRangeTest,
4545
argumentConventionsTest,
4646
getPullbackClosureInfoTest,
47+
getPullbackClosureInfoMultiBBTest,
4748
interiorLivenessTest,
4849
lifetimeDependenceRootTest,
4950
lifetimeDependenceScopeTest,
@@ -53,6 +54,9 @@ public func registerOptimizerTests() {
5354
localVariableReachingAssignmentsTest,
5455
rangeOverlapsPathTest,
5556
rewrittenCallerBodyTest,
57+
specializeBranchTracingEnums,
58+
specializeBTEArgInVjpBB,
59+
specializePayloadArgInPullbackBB,
5660
specializedFunctionSignatureAndBodyTest,
5761
variableIntroducerTest
5862
)

include/module.modulemap

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,9 @@ module OptimizerBridging {
2626
header "swift/SILOptimizer/OptimizerBridging.h"
2727
export *
2828
}
29+
30+
module AutoDiffClosureSpecializationBridging {
31+
header "swift/SILOptimizer/AutoDiffClosureSpecializationBridging.h"
32+
requires cplusplus
33+
export *
34+
}

include/swift/SIL/SILBridging.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ struct BridgedType {
263263
BRIDGED_INLINE bool isExactSuperclassOf(BridgedType t) const;
264264
BRIDGED_INLINE bool isMarkedAsImmortal() const;
265265
BRIDGED_INLINE bool isAddressableForDeps(BridgedFunction f) const;
266+
SWIFT_IMPORT_UNSAFE bool
267+
isAutodiffBranchTracingEnumInVJP(BridgedFunction vjp) const;
266268
BRIDGED_INLINE SWIFT_IMPORT_UNSAFE BridgedASTType getRawLayoutSubstitutedLikeType() const;
267269
BRIDGED_INLINE SWIFT_IMPORT_UNSAFE BridgedASTType getRawLayoutSubstitutedCountType() const;
268270
BRIDGED_INLINE SwiftInt getCaseIdxOfEnumType(BridgedStringRef name) const;

0 commit comments

Comments
 (0)