Skip to content

Commit f568182

Browse files
committed
[AutoDiff] Closure specialization: specialize branch tracing enums
This patch contains part of the changes intended to resolve #68944. The patch contains almost all needed C++ logic (while the closure specialization pass itself is implemented in Swift). Particularly, the patch contains branch tracing enum specialization logic and related pieces of logic (both in C++ and Swift).
1 parent e419d30 commit f568182

File tree

13 files changed

+1169
-15
lines changed

13 files changed

+1169
-15
lines changed

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift

Lines changed: 347 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,17 @@
100100
/// ```
101101

102102
import AST
103+
import AutoDiffClosureSpecializationBridging
104+
import Cxx
105+
import CxxStdlib
103106
import SIL
104107
import SILBridging
105108

106109
private let verbose = false
107110

108111
private func log(prefix: Bool = true, _ message: @autoclosure () -> String) {
109112
if verbose {
110-
debugLog(prefix: prefix, message())
113+
debugLog(prefix: prefix, "[ADCS] " + message())
111114
}
112115
}
113116

@@ -120,6 +123,25 @@ let generalClosureSpecialization = FunctionPass(
120123
print("NOT IMPLEMENTED")
121124
}
122125

126+
extension Type {
127+
func isBranchTracingEnumIn(vjp: Function) -> Bool {
128+
return self.bridged.isAutodiffBranchTracingEnumInVJP(vjp.bridged)
129+
}
130+
}
131+
132+
extension Collection {
133+
func getExactlyOneOrNil() -> Element? {
134+
assert(self.count <= 1)
135+
return self.first
136+
}
137+
}
138+
139+
extension BasicBlock {
140+
fileprivate func getBranchTracingEnumArg(vjp: Function) -> Argument? {
141+
return self.arguments.filter { $0.type.isBranchTracingEnumIn(vjp: vjp) }.getExactlyOneOrNil()
142+
}
143+
}
144+
123145
let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-specialization") {
124146
(function: Function, context: FunctionPassContext) in
125147

@@ -175,6 +197,93 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special
175197

176198
private let specializationLevelLimit = 2
177199

200+
private func getPartialApplyOfPullbackInExitVJPBB(vjp: Function) -> PartialApplyInst? {
201+
log("getPartialApplyOfPullbackInExitVJPBB: running for VJP \(vjp.name)")
202+
guard let exitBB = vjp.blocks.filter({ $0.terminator as? ReturnInst != nil }).getExactlyOneOrNil()
203+
else {
204+
log("getPartialApplyOfPullbackInExitVJPBB: exit BB not found, aborting")
205+
return nil
206+
}
207+
208+
let ri = exitBB.terminator as! ReturnInst
209+
guard let retValDefiningInstr = ri.returnedValue.definingInstruction else {
210+
log("getPartialApplyOfPullbackInExitVJPBB: return value is not defined by an instruction, aborting")
211+
return nil
212+
}
213+
214+
func handleConvertFunctionOrPartialApply(inst: Instruction) -> PartialApplyInst? {
215+
if let pai = inst as? PartialApplyInst {
216+
log("getPartialApplyOfPullbackInExitVJPBB: success")
217+
return pai
218+
}
219+
if let cfi = inst as? ConvertFunctionInst {
220+
if let pai = cfi.fromFunction as? PartialApplyInst {
221+
log("getPartialApplyOfPullbackInExitVJPBB: success")
222+
return pai
223+
}
224+
log("getPartialApplyOfPullbackInExitVJPBB: fromFunction operand of convert_function instruction is not defined by partial_apply instruction, aborting")
225+
return nil
226+
}
227+
log("getPartialApplyOfPullbackInExitVJPBB: unexpected instruction type, aborting")
228+
return nil
229+
}
230+
231+
if let ti = retValDefiningInstr as? TupleInst {
232+
log("getPartialApplyOfPullbackInExitVJPBB: return value is defined by tuple instruction")
233+
if ti.operands.count < 2 {
234+
log("getPartialApplyOfPullbackInExitVJPBB: tuple instruction has \(ti.operands.count) operands, but at least 2 expected, aborting")
235+
return nil
236+
}
237+
guard let lastTupleElemDefiningInst = ti.operands.last!.value.definingInstruction else {
238+
log("getPartialApplyOfPullbackInExitVJPBB: last tuple element is not defined by an instruction, aborting")
239+
return nil
240+
}
241+
return handleConvertFunctionOrPartialApply(inst: lastTupleElemDefiningInst)
242+
}
243+
244+
return handleConvertFunctionOrPartialApply(inst: retValDefiningInstr)
245+
}
246+
247+
typealias EnumTypeAndCase = (enumType: Type, caseIdx: Int)
248+
249+
typealias ClosureInfoMultiBB = (
250+
closure: SingleValueInstruction,
251+
capturedArgs: [Value],
252+
subsetThunk: PartialApplyInst?,
253+
payloadTuple: TupleInst,
254+
idxInPayload: Int,
255+
enumTypeAndCase: EnumTypeAndCase
256+
)
257+
258+
private func getPullbackClosureInfoMultiBB(in vjp: Function, _ context: FunctionPassContext)
259+
-> PullbackClosureInfo
260+
{
261+
let paiOfPbInExitVjpBB = getPartialApplyOfPullbackInExitVJPBB(vjp: vjp)!
262+
var pullbackClosureInfo = PullbackClosureInfo(paiOfPullback: paiOfPbInExitVjpBB)
263+
var subsetThunkArr = [SingleValueInstruction]()
264+
265+
for inst in vjp.instructions {
266+
if inst == paiOfPbInExitVjpBB {
267+
continue
268+
}
269+
if inst.asSupportedClosure == nil {
270+
continue
271+
}
272+
273+
let rootClosure = inst.asSupportedClosure!
274+
if subsetThunkArr.contains(rootClosure) {
275+
continue
276+
}
277+
278+
let closureInfoArr = handleNonAppliesMultiBB(for: rootClosure, context)
279+
pullbackClosureInfo.closureInfosMultiBB.append(contentsOf: closureInfoArr)
280+
subsetThunkArr.append(
281+
contentsOf: closureInfoArr.filter { $0.subsetThunk != nil }.map { $0.subsetThunk! })
282+
}
283+
284+
return pullbackClosureInfo
285+
}
286+
178287
private func getPullbackClosureInfo(in caller: Function, _ context: FunctionPassContext)
179288
-> PullbackClosureInfo?
180289
{
@@ -374,6 +483,144 @@ private func updatePullbackClosureInfo(
374483
intermediateClosureArgDescriptorData: intermediateClosureArgDescriptorData, context)
375484
}
376485

486+
typealias BTEPayloadArgOfPbBBWithBTETypeAndCase = (arg: Argument, enumTypeAndCase: EnumTypeAndCase)
487+
488+
// If the pullback's basic block has an argument which is a payload tuple of the
489+
// branch tracing enum corresponding to the given VJP, return this argument and any valid combination
490+
// of a branch tracing enum type and its case index having the same payload tuple type as the argument.
491+
// The function assumes that no more than one such argument is present.
492+
private func getBTEPayloadArgOfPbBBWithBTETypeAndCase(_ bb: BasicBlock, vjp: Function)
493+
-> BTEPayloadArgOfPbBBWithBTETypeAndCase?
494+
{
495+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: basic block \(bb.shortDescription) in pullback \(bb.parentFunction.name)")
496+
guard let predBB = bb.predecessors.first else {
497+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: the bb has no predecessors, aborting")
498+
return nil
499+
}
500+
501+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: start iterating over bb args")
502+
for arg in bb.arguments {
503+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: \(arg)")
504+
if !arg.type.isTuple {
505+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: arg is not a tuple, skipping")
506+
continue
507+
}
508+
509+
if let bi = predBB.terminator as? BranchInst {
510+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: terminator of pred bb is branch instruction")
511+
guard let uedi = bi.operands[arg.index].value.definingInstruction as? UncheckedEnumDataInst
512+
else {
513+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: operand corresponding to the argument is not defined by unchecked_enum_data instruction")
514+
continue
515+
}
516+
let enumType = uedi.`enum`.type
517+
if !enumType.isBranchTracingEnumIn(vjp: vjp) {
518+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: enum type \(enumType) is not a branch tracing enum in VJP \(vjp.name)")
519+
continue
520+
}
521+
522+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: success")
523+
return BTEPayloadArgOfPbBBWithBTETypeAndCase(
524+
arg: arg,
525+
enumTypeAndCase: (
526+
enumType: enumType,
527+
caseIdx: uedi.caseIndex
528+
)
529+
)
530+
}
531+
532+
if let sei = predBB.terminator as? SwitchEnumInst {
533+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: terminator of pred bb is switch_enum instruction")
534+
let enumType = sei.enumOp.type
535+
if !enumType.isBranchTracingEnumIn(vjp: vjp) {
536+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: enum type \(enumType) is not a branch tracing enum in VJP \(vjp.name)")
537+
continue
538+
}
539+
540+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: success")
541+
return BTEPayloadArgOfPbBBWithBTETypeAndCase(
542+
arg: arg,
543+
enumTypeAndCase: (
544+
enumType: enumType,
545+
caseIdx: sei.getUniqueCase(forSuccessor: bb)!
546+
)
547+
)
548+
}
549+
}
550+
551+
log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: finish iterating over bb args; branch tracing enum arg not found")
552+
return nil
553+
}
554+
555+
extension PartialApplyInst {
556+
func isSubsetThunk() -> Bool {
557+
if self.argumentOperands.singleElement == nil {
558+
return false
559+
}
560+
guard let function = self.referencedFunction else {
561+
return false
562+
}
563+
return function.bridged.isAutodiffSubsetParametersThunk()
564+
}
565+
}
566+
567+
private func handleNonAppliesMultiBB(
568+
for rootClosure: SingleValueInstruction,
569+
_ context: FunctionPassContext
570+
)
571+
-> [ClosureInfoMultiBB]
572+
{
573+
log("handleNonAppliesMultiBB: running for \(rootClosure)")
574+
let vjp = rootClosure.parentFunction
575+
var closureInfoArr = [ClosureInfoMultiBB]()
576+
577+
var closure = rootClosure
578+
var subsetThunk = PartialApplyInst?(nil)
579+
if rootClosure.uses.singleElement != nil {
580+
if let pai = closure.uses.singleElement!.instruction as? PartialApplyInst {
581+
if pai.isSubsetThunk() {
582+
log("handleNonAppliesMultiBB: found subset thunk \(pai)")
583+
subsetThunk = pai
584+
closure = pai
585+
}
586+
}
587+
}
588+
589+
for use in closure.uses {
590+
guard let ti = use.instruction as? TupleInst else {
591+
log("handleNonAppliesMultiBB: unexpected use of closure, aborting: \(use)")
592+
return []
593+
}
594+
for tiUse in ti.uses {
595+
guard let ei = tiUse.instruction as? EnumInst else {
596+
log("handleNonAppliesMultiBB: unexpected use of payload tuple, aborting: \(tiUse)")
597+
return []
598+
}
599+
if !ei.type.isBranchTracingEnumIn(vjp: vjp) {
600+
log("handleNonAppliesMultiBB: enum type \(ei.type) is not a branch tracing enum in VJP \(vjp.name), aborting")
601+
return []
602+
}
603+
var capturedArgs = [Value]()
604+
if let pai = rootClosure as? PartialApplyInst {
605+
capturedArgs = pai.argumentOperands.map { $0.value }
606+
}
607+
log("handleNonAppliesMultiBB: creating closure info with enum type \(ei.type), case index \(ei.caseIndex), index in payload tuple \(use.index) and payload tuple \(ti)")
608+
let enumTypeAndCase = EnumTypeAndCase(enumType: ei.type, caseIdx: ei.caseIndex)
609+
closureInfoArr.append(
610+
ClosureInfoMultiBB(
611+
closure: rootClosure,
612+
capturedArgs: capturedArgs,
613+
subsetThunk: subsetThunk,
614+
payloadTuple: ti,
615+
idxInPayload: use.index,
616+
enumTypeAndCase: enumTypeAndCase
617+
))
618+
}
619+
}
620+
log("handleNonAppliesMultiBB: created \(closureInfoArr.count) closure info entries for \(rootClosure)")
621+
return closureInfoArr
622+
}
623+
377624
/// Handles all non-apply direct and transitive uses of `rootClosure`.
378625
///
379626
/// Returns:
@@ -1391,6 +1638,7 @@ private struct ClosureArgDescriptor {
13911638
private struct PullbackClosureInfo {
13921639
let paiOfPullback: PartialApplyInst
13931640
var closureArgDescriptors: [ClosureArgDescriptor] = []
1641+
var closureInfosMultiBB: [ClosureInfoMultiBB] = []
13941642

13951643
init(paiOfPullback: PartialApplyInst) {
13961644
self.paiOfPullback = paiOfPullback
@@ -1474,3 +1722,101 @@ let rewrittenCallerBodyTest = FunctionTest("autodiff_closure_specialize_rewritte
14741722
print("Rewritten caller body for: \(function.name):")
14751723
print("\(function)\n")
14761724
}
1725+
1726+
let getPullbackClosureInfoMultiBBTest = FunctionTest(
1727+
"autodiff_closure_specialize_get_pullback_closure_info_multi_bb"
1728+
) {
1729+
function, arguments, context in
1730+
let pullbackClosureInfo = getPullbackClosureInfoMultiBB(in: function, context)
1731+
print("Run getPullbackClosureInfoMultiBB for VJP \(function.name): pullbackClosureInfo = (")
1732+
print(" pullbackFn = \(pullbackClosureInfo.pullbackFn.name)")
1733+
print(" closureInfosMultiBB = [")
1734+
for closureInfoMultiBB in pullbackClosureInfo.closureInfosMultiBB {
1735+
print(" ClosureInfoMultiBB(")
1736+
print(" closure: \(closureInfoMultiBB.closure)")
1737+
print(" capturedArgs: [")
1738+
for capturedArg in closureInfoMultiBB.capturedArgs {
1739+
print(" \(capturedArg)")
1740+
}
1741+
print(" ]")
1742+
let subsetThunkStr =
1743+
(closureInfoMultiBB.subsetThunk == nil ? "nil" : "\(closureInfoMultiBB.subsetThunk!)")
1744+
print(" subsetThunk: \(subsetThunkStr)")
1745+
print(" payloadTuple: \(closureInfoMultiBB.payloadTuple)")
1746+
print(" idxInPayload: \(closureInfoMultiBB.idxInPayload)")
1747+
print(" enumTypeAndCase: \(closureInfoMultiBB.enumTypeAndCase)")
1748+
print(" )")
1749+
}
1750+
print(" ]\n)\n")
1751+
}
1752+
1753+
typealias SpecBTEDict = SpecializedBranchTracingEnumDict
1754+
1755+
func getSpecBTEDict(vjp: Function, context: FunctionPassContext) -> SpecBTEDict {
1756+
let pullbackClosureInfo = getPullbackClosureInfoMultiBB(in: vjp, context)
1757+
let pb = pullbackClosureInfo.pullbackFn
1758+
let enumTypeOfEntryBBArg = pb.entryBlock.getBranchTracingEnumArg(vjp: vjp)!.type
1759+
1760+
let vectorOfClosureInfoMultiBB = VectorOfBranchTracingEnumAndClosureInfo(
1761+
pullbackClosureInfo.closureInfosMultiBB.map {
1762+
BranchTracingEnumAndClosureInfo(
1763+
enumType: $0.enumTypeAndCase.enumType.bridged,
1764+
enumCaseIdx: $0.enumTypeAndCase.caseIdx,
1765+
closure: $0.closure.bridged,
1766+
idxInPayload: $0.idxInPayload)
1767+
})
1768+
1769+
let enumDict = autodiffSpecializeBranchTracingEnums(
1770+
vjp.bridged, enumTypeOfEntryBBArg.bridged, vectorOfClosureInfoMultiBB)
1771+
1772+
return enumDict
1773+
}
1774+
1775+
let specializeBranchTracingEnums = FunctionTest("autodiff_specialize_branch_tracing_enums") {
1776+
function, arguments, context in
1777+
let enumDict = getSpecBTEDict(vjp: function, context: context)
1778+
print(
1779+
"Specialized branch tracing enum dict for VJP \(function.name) contains \(enumDict.size()) elements:"
1780+
)
1781+
print(
1782+
"\(String(taking: getSpecializedBranchTracingEnumDictAsString(enumDict)))") //, function.bridged)))")
1783+
}
1784+
1785+
let specializeBTEArgInVjpBB = FunctionTest("autodiff_specialize_bte_arg_in_vjp_bb") {
1786+
function, arguments, context in
1787+
let enumDict = getSpecBTEDict(vjp: function, context: context)
1788+
print("Specialized BTE arguments of basic blocks in VJP \(function.name):")
1789+
for bb in function.blocks {
1790+
guard let arg = bb.getBranchTracingEnumArg(vjp: function) else {
1791+
continue
1792+
}
1793+
let newArg = specializeBranchTracingEnumBBArgInVJP(arg.bridged, enumDict).argument
1794+
print("\(newArg)")
1795+
bb.eraseArgument(at: newArg.index, context)
1796+
}
1797+
print("")
1798+
}
1799+
1800+
let specializePayloadArgInPullbackBB = FunctionTest("autodiff_specialize_payload_arg_in_pb_bb") {
1801+
function, arguments, context in
1802+
let pullbackClosureInfo = getPullbackClosureInfoMultiBB(in: function, context)
1803+
let pb = pullbackClosureInfo.pullbackFn
1804+
let enumDict = getSpecBTEDict(vjp: function, context: context)
1805+
1806+
print("Specialized BTE payload arguments of basic blocks in pullback \(pb.name):")
1807+
for bb in pb.blocks {
1808+
guard
1809+
let (arg, enumTypeAndCase) = getBTEPayloadArgOfPbBBWithBTETypeAndCase(bb, vjp: function)
1810+
else {
1811+
continue
1812+
}
1813+
1814+
let enumType = enumDict[enumTypeAndCase.enumType.bridged]!
1815+
let newArg = specializePayloadTupleBBArgInPullback(
1816+
arg.bridged, enumType, enumTypeAndCase.caseIdx
1817+
).argument
1818+
print("\(newArg)")
1819+
bb.eraseArgument(at: newArg.index, context)
1820+
}
1821+
print("")
1822+
}

0 commit comments

Comments
 (0)