100100/// ```
101101
102102import AST
103+ import AutoDiffClosureSpecializationBridging
104+ import Cxx
105+ import CxxStdlib
103106import SIL
104107import SILBridging
105108
@@ -120,6 +123,25 @@ let generalClosureSpecialization = FunctionPass(
120123 print ( " NOT IMPLEMENTED " )
121124}
122125
126+ extension Type {
127+ func isBranchTracingEnumIn( vjp: Function ) -> Bool {
128+ return self . bridged. isAutodiffBranchTracingEnumInVJP ( vjp. bridged)
129+ }
130+ }
131+
132+ extension Collection {
133+ func getExactlyOneOrNil( ) -> Element ? {
134+ assert ( self . count <= 1 )
135+ return self . first
136+ }
137+ }
138+
139+ extension BasicBlock {
140+ fileprivate func getBranchTracingEnumArg( vjp: Function ) -> Argument ? {
141+ return self . arguments. filter { $0. type. isBranchTracingEnumIn ( vjp: vjp) } . getExactlyOneOrNil ( )
142+ }
143+ }
144+
123145let autodiffClosureSpecialization = FunctionPass ( name: " autodiff-closure-specialization " ) {
124146 ( function: Function , context: FunctionPassContext ) in
125147
@@ -175,6 +197,80 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special
175197
176198private let specializationLevelLimit = 2
177199
200+ private func getPartialApplyOfPullbackInExitVJPBB( vjp: Function ) -> PartialApplyInst ? {
201+ guard let exitBB = vjp. blocks. filter ( { $0. terminator as? ReturnInst != nil } ) . getExactlyOneOrNil ( )
202+ else {
203+ return nil
204+ }
205+
206+ let ri = exitBB. terminator as! ReturnInst
207+ guard let retValDefiningInstr = ri. returnedValue. definingInstruction else {
208+ return nil
209+ }
210+
211+ func handleConvertFunctionOrPartialApply( inst: Instruction ) -> PartialApplyInst ? {
212+ if let pai = inst as? PartialApplyInst {
213+ return pai
214+ }
215+ if let cfi = inst as? ConvertFunctionInst {
216+ return cfi. fromFunction as? PartialApplyInst
217+ }
218+ return nil
219+ }
220+
221+ if let ti = retValDefiningInstr as? TupleInst {
222+ if ti. operands. count < 2 {
223+ return nil
224+ }
225+ guard let lastTupleElemDefiningInst = ti. operands. last!. value. definingInstruction else {
226+ return nil
227+ }
228+ return handleConvertFunctionOrPartialApply ( inst: lastTupleElemDefiningInst)
229+ }
230+
231+ return handleConvertFunctionOrPartialApply ( inst: retValDefiningInstr)
232+ }
233+
234+ typealias EnumTypeAndCase = ( enumType: Type , caseIdx: Int )
235+
236+ typealias ClosureInfoMultiBB = (
237+ closure: SingleValueInstruction ,
238+ capturedArgs: [ Value ] ,
239+ subsetThunk: PartialApplyInst ? ,
240+ payloadTuple: TupleInst ,
241+ idxInPayload: Int ,
242+ enumTypeAndCase: EnumTypeAndCase
243+ )
244+
245+ private func getPullbackClosureInfoMultiBB( in vjp: Function , _ context: FunctionPassContext )
246+ -> PullbackClosureInfo
247+ {
248+ let paiOfPbInExitVjpBB = getPartialApplyOfPullbackInExitVJPBB ( vjp: vjp) !
249+ var pullbackClosureInfo = PullbackClosureInfo ( paiOfPullback: paiOfPbInExitVjpBB)
250+ var subsetThunkArr = [ SingleValueInstruction] ( )
251+
252+ for inst in vjp. instructions {
253+ if inst == paiOfPbInExitVjpBB {
254+ continue
255+ }
256+ if inst. asSupportedClosure == nil {
257+ continue
258+ }
259+
260+ let rootClosure = inst. asSupportedClosure!
261+ if subsetThunkArr. contains ( rootClosure) {
262+ continue
263+ }
264+
265+ let closureInfoArr = handleNonAppliesMultiBB ( for: rootClosure, context)
266+ pullbackClosureInfo. closureInfosMultiBB. append ( contentsOf: closureInfoArr)
267+ subsetThunkArr. append (
268+ contentsOf: closureInfoArr. filter { $0. subsetThunk != nil } . map { $0. subsetThunk! } )
269+ }
270+
271+ return pullbackClosureInfo
272+ }
273+
178274private func getPullbackClosureInfo( in caller: Function , _ context: FunctionPassContext )
179275 -> PullbackClosureInfo ?
180276{
@@ -373,6 +469,127 @@ private func updatePullbackClosureInfo(
373469 intermediateClosureArgDescriptorData: intermediateClosureArgDescriptorData, context)
374470}
375471
472+ typealias BTEPayloadArgOfPbBBWithBTETypeAndCase = ( arg: Argument , enumTypeAndCase: EnumTypeAndCase )
473+
474+ // If the pullback's basic block has a single argument which is a payload tuple of the
475+ // branch tracing enum corresponding to the given VJP, return this argument and any valid combination
476+ // of a branch tracing enum type and its case index having the same payload tuple type as the argument.
477+ private func getBTEPayloadArgOfPbBBWithBTETypeAndCase( _ bb: BasicBlock , vjp: Function )
478+ -> BTEPayloadArgOfPbBBWithBTETypeAndCase ?
479+ {
480+ guard let predBB = bb. predecessors. first else {
481+ return nil
482+ }
483+ guard let arg = bb. arguments. singleElement else {
484+ return nil
485+ }
486+ if !arg. type. isTuple {
487+ return nil
488+ }
489+
490+ if let bi = predBB. terminator as? BranchInst {
491+ guard let uedi = bi. operands [ arg. index] . value. definingInstruction as? UncheckedEnumDataInst
492+ else {
493+ return nil
494+ }
495+ let enumType = uedi. `enum`. type
496+ if !enumType. isBranchTracingEnumIn ( vjp: vjp) {
497+ return nil
498+ }
499+
500+ return BTEPayloadArgOfPbBBWithBTETypeAndCase (
501+ arg: arg,
502+ enumTypeAndCase: (
503+ enumType: enumType,
504+ caseIdx: uedi. caseIndex
505+ )
506+ )
507+ }
508+
509+ if let sei = predBB. terminator as? SwitchEnumInst {
510+ let enumType = sei. enumOp. type
511+ if !enumType. isBranchTracingEnumIn ( vjp: vjp) {
512+ return nil
513+ }
514+ return BTEPayloadArgOfPbBBWithBTETypeAndCase (
515+ arg: arg,
516+ enumTypeAndCase: (
517+ enumType: enumType,
518+ caseIdx: sei. getUniqueCase ( forSuccessor: bb) !
519+ )
520+ )
521+ }
522+
523+ return nil
524+ }
525+
526+ extension PartialApplyInst {
527+ func isSubsetThunk( ) -> Bool {
528+ if self . argumentOperands. singleElement == nil {
529+ return false
530+ }
531+ guard let desc = self . referencedFunction? . description else {
532+ return false
533+ }
534+ // TODO: do not rely on description which is intended for debug purposes.
535+ return desc. starts (
536+ with: " // autodiff subset parameters thunk for " )
537+ }
538+ }
539+
540+ private func handleNonAppliesMultiBB(
541+ for rootClosure: SingleValueInstruction ,
542+ _ context: FunctionPassContext
543+ )
544+ -> [ ClosureInfoMultiBB ]
545+ {
546+ let vjp = rootClosure. parentFunction
547+ var closureInfoArr = [ ClosureInfoMultiBB] ( )
548+
549+ var closure = rootClosure
550+ var subsetThunk = PartialApplyInst ? ( nil )
551+ if rootClosure. uses. singleElement != nil {
552+ if let pai = closure. uses. singleElement!. instruction as? PartialApplyInst {
553+ if pai. isSubsetThunk ( ) {
554+ subsetThunk = pai
555+ closure = pai
556+ }
557+ }
558+ }
559+
560+ for use in closure. uses {
561+ guard let ti = use. instruction as? TupleInst else {
562+ // Unexpected use of closure, return nothing
563+ return [ ]
564+ }
565+ for tiUse in ti. uses {
566+ guard let ei = tiUse. instruction as? EnumInst else {
567+ // Unexpected use of payload tuple, return nothing
568+ return [ ]
569+ }
570+ if !ei. type. isBranchTracingEnumIn ( vjp: vjp) {
571+ // Unexpected use of payload tuple, return nothing
572+ return [ ]
573+ }
574+ var capturedArgs = [ Value] ( )
575+ if let pai = rootClosure as? PartialApplyInst {
576+ capturedArgs = pai. argumentOperands. map { $0. value }
577+ }
578+ let enumTypeAndCase = EnumTypeAndCase ( enumType: ei. type, caseIdx: ei. caseIndex)
579+ closureInfoArr. append (
580+ ClosureInfoMultiBB (
581+ closure: rootClosure,
582+ capturedArgs: capturedArgs,
583+ subsetThunk: subsetThunk,
584+ payloadTuple: ti,
585+ idxInPayload: use. index,
586+ enumTypeAndCase: enumTypeAndCase
587+ ) )
588+ }
589+ }
590+ return closureInfoArr
591+ }
592+
376593/// Handles all non-apply direct and transitive uses of `rootClosure`.
377594///
378595/// Returns:
@@ -1390,6 +1607,7 @@ private struct ClosureArgDescriptor {
13901607private struct PullbackClosureInfo {
13911608 let paiOfPullback : PartialApplyInst
13921609 var closureArgDescriptors : [ ClosureArgDescriptor ] = [ ]
1610+ var closureInfosMultiBB : [ ClosureInfoMultiBB ] = [ ]
13931611
13941612 init ( paiOfPullback: PartialApplyInst ) {
13951613 self . paiOfPullback = paiOfPullback
@@ -1475,3 +1693,101 @@ let rewrittenCallerBodyTest = FunctionTest("autodiff_closure_specialize_rewritte
14751693 print ( " Rewritten caller body for: \( function. name) : " )
14761694 print ( " \( function) \n " )
14771695}
1696+
1697+ let getPullbackClosureInfoMultiBBTest = FunctionTest (
1698+ " autodiff_closure_specialize_get_pullback_closure_info_multi_bb "
1699+ ) {
1700+ function, arguments, context in
1701+ let pullbackClosureInfo = getPullbackClosureInfoMultiBB ( in: function, context)
1702+ print ( " Run getPullbackClosureInfoMultiBB for VJP \( function. name) : pullbackClosureInfo = ( " )
1703+ print ( " pullbackFn = \( pullbackClosureInfo. pullbackFn. name) " )
1704+ print ( " closureInfosMultiBB = [ " )
1705+ for closureInfoMultiBB in pullbackClosureInfo. closureInfosMultiBB {
1706+ print ( " ClosureInfoMultiBB( " )
1707+ print ( " closure: \( closureInfoMultiBB. closure) " )
1708+ print ( " capturedArgs: [ " )
1709+ for capturedArg in closureInfoMultiBB. capturedArgs {
1710+ print ( " \( capturedArg) " )
1711+ }
1712+ print ( " ] " )
1713+ let subsetThunkStr =
1714+ ( closureInfoMultiBB. subsetThunk == nil ? " nil " : " \( closureInfoMultiBB. subsetThunk!) " )
1715+ print ( " subsetThunk: \( subsetThunkStr) " )
1716+ print ( " payloadTuple: \( closureInfoMultiBB. payloadTuple) " )
1717+ print ( " idxInPayload: \( closureInfoMultiBB. idxInPayload) " )
1718+ print ( " enumTypeAndCase: \( closureInfoMultiBB. enumTypeAndCase) " )
1719+ print ( " ) " )
1720+ }
1721+ print ( " ] \n ) \n " )
1722+ }
1723+
1724+ typealias SpecBTEDict = SpecializedBranchTracingEnumDict
1725+
1726+ func getSpecBTEDict( vjp: Function , context: FunctionPassContext ) -> SpecBTEDict {
1727+ let pullbackClosureInfo = getPullbackClosureInfoMultiBB ( in: vjp, context)
1728+ let pb = pullbackClosureInfo. pullbackFn
1729+ let enumTypeOfEntryBBArg = pb. entryBlock. getBranchTracingEnumArg ( vjp: vjp) !. type
1730+
1731+ let vectorOfClosureInfoMultiBB = VectorOfBranchTracingEnumAndClosureInfo (
1732+ pullbackClosureInfo. closureInfosMultiBB. map {
1733+ BranchTracingEnumAndClosureInfo (
1734+ enumType: $0. enumTypeAndCase. enumType. bridged,
1735+ enumCaseIdx: $0. enumTypeAndCase. caseIdx,
1736+ closure: $0. closure. bridged,
1737+ idxInPayload: $0. idxInPayload)
1738+ } )
1739+
1740+ let enumDict = autodiffSpecializeBranchTracingEnums (
1741+ vjp. bridged, enumTypeOfEntryBBArg. bridged, vectorOfClosureInfoMultiBB)
1742+
1743+ return enumDict
1744+ }
1745+
1746+ let specializeBranchTracingEnums = FunctionTest ( " autodiff_specialize_branch_tracing_enums " ) {
1747+ function, arguments, context in
1748+ let enumDict = getSpecBTEDict ( vjp: function, context: context)
1749+ print (
1750+ " Specialized branch tracing enum dict for VJP \( function. name) contains \( enumDict. size ( ) ) elements: "
1751+ )
1752+ print (
1753+ " \( String ( taking: getSpecializedBranchTracingEnumDictAsString ( enumDict) ) ) " ) //, function.bridged)))")
1754+ }
1755+
1756+ let specializeBTEArgInVjpBB = FunctionTest ( " autodiff_specialize_bte_arg_in_vjp_bb " ) {
1757+ function, arguments, context in
1758+ let enumDict = getSpecBTEDict ( vjp: function, context: context)
1759+ print ( " Specialized BTE arguments of basic blocks in VJP \( function. name) : " )
1760+ for bb in function. blocks {
1761+ guard let arg = bb. getBranchTracingEnumArg ( vjp: function) else {
1762+ continue
1763+ }
1764+ let newArg = specializeBranchTracingEnumBBArgInVJP ( arg. bridged, enumDict) . argument
1765+ print ( " \( newArg) " )
1766+ bb. eraseArgument ( at: newArg. index, context)
1767+ }
1768+ print ( " " )
1769+ }
1770+
1771+ let specializePayloadArgInPullbackBB = FunctionTest ( " autodiff_specialize_payload_arg_in_pb_bb " ) {
1772+ function, arguments, context in
1773+ let pullbackClosureInfo = getPullbackClosureInfoMultiBB ( in: function, context)
1774+ let pb = pullbackClosureInfo. pullbackFn
1775+ let enumDict = getSpecBTEDict ( vjp: function, context: context)
1776+
1777+ print ( " Specialized BTE payload arguments of basic blocks in pullback \( pb. name) : " )
1778+ for bb in pb. blocks {
1779+ guard
1780+ let ( arg, enumTypeAndCase) = getBTEPayloadArgOfPbBBWithBTETypeAndCase ( bb, vjp: function)
1781+ else {
1782+ continue
1783+ }
1784+
1785+ let enumType = enumDict [ enumTypeAndCase. enumType. bridged] !
1786+ let newArg = specializePayloadTupleBBArgInPullback (
1787+ arg. bridged, enumType, enumTypeAndCase. caseIdx
1788+ ) . argument
1789+ print ( " \( newArg) " )
1790+ bb. eraseArgument ( at: newArg. index, context)
1791+ }
1792+ print ( " " )
1793+ }
0 commit comments