100100/// ```
101101
102102import AST
103+ import AutoDiffClosureSpecializationBridging
104+ import Cxx
105+ import CxxStdlib
103106import SIL
104107import SILBridging
105108
@@ -120,6 +123,25 @@ let generalClosureSpecialization = FunctionPass(
120123 print ( " NOT IMPLEMENTED " )
121124}
122125
126+ extension Type {
127+ func isBranchTracingEnumIn( vjp: Function ) -> Bool {
128+ return self . bridged. isAutodiffBranchTracingEnumInVJP ( vjp. bridged)
129+ }
130+ }
131+
132+ extension Collection {
133+ func getExactlyOneOrNil( ) -> Element ? {
134+ assert ( self . count <= 1 )
135+ return self . first
136+ }
137+ }
138+
139+ extension BasicBlock {
140+ fileprivate func getBranchTracingEnumArg( vjp: Function ) -> Argument ? {
141+ return self . arguments. filter { $0. type. isBranchTracingEnumIn ( vjp: vjp) } . getExactlyOneOrNil ( )
142+ }
143+ }
144+
123145let autodiffClosureSpecialization = FunctionPass ( name: " autodiff-closure-specialization " ) {
124146 ( function: Function , context: FunctionPassContext ) in
125147
@@ -175,6 +197,80 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special
175197
176198private let specializationLevelLimit = 2
177199
200+ private func getPartialApplyOfPullbackInExitVJPBB( vjp: Function ) -> PartialApplyInst ? {
201+ guard let exitBB = vjp. blocks. filter ( { $0. terminator as? ReturnInst != nil } ) . getExactlyOneOrNil ( )
202+ else {
203+ return nil
204+ }
205+
206+ let ri = exitBB. terminator as! ReturnInst
207+ guard let retValDefiningInstr = ri. returnedValue. definingInstruction else {
208+ return nil
209+ }
210+
211+ func handleConvertFunctionOrPartialApply( inst: Instruction ) -> PartialApplyInst ? {
212+ if let pai = inst as? PartialApplyInst {
213+ return pai
214+ }
215+ if let cfi = inst as? ConvertFunctionInst {
216+ return cfi. fromFunction as? PartialApplyInst
217+ }
218+ return nil
219+ }
220+
221+ if let ti = retValDefiningInstr as? TupleInst {
222+ if ti. operands. count < 2 {
223+ return nil
224+ }
225+ guard let lastTupleElemDefiningInst = ti. operands. last!. value. definingInstruction else {
226+ return nil
227+ }
228+ return handleConvertFunctionOrPartialApply ( inst: lastTupleElemDefiningInst)
229+ }
230+
231+ return handleConvertFunctionOrPartialApply ( inst: retValDefiningInstr)
232+ }
233+
234+ typealias EnumTypeAndCase = ( enumType: Type , caseIdx: Int )
235+
236+ typealias ClosureInfoMultiBB = (
237+ closure: SingleValueInstruction ,
238+ capturedArgs: [ Value ] ,
239+ subsetThunk: PartialApplyInst ? ,
240+ payloadTuple: TupleInst ,
241+ idxInPayload: Int ,
242+ enumTypeAndCase: EnumTypeAndCase
243+ )
244+
245+ private func getPullbackClosureInfoMultiBB( in vjp: Function , _ context: FunctionPassContext )
246+ -> PullbackClosureInfo
247+ {
248+ let paiOfPbInExitVjpBB = getPartialApplyOfPullbackInExitVJPBB ( vjp: vjp) !
249+ var pullbackClosureInfo = PullbackClosureInfo ( paiOfPullback: paiOfPbInExitVjpBB)
250+ var subsetThunkArr = [ SingleValueInstruction] ( )
251+
252+ for inst in vjp. instructions {
253+ if inst == paiOfPbInExitVjpBB {
254+ continue
255+ }
256+ if inst. asSupportedClosure == nil {
257+ continue
258+ }
259+
260+ let rootClosure = inst. asSupportedClosure!
261+ if subsetThunkArr. contains ( rootClosure) {
262+ continue
263+ }
264+
265+ let closureInfoArr = handleNonAppliesMultiBB ( for: rootClosure, context)
266+ pullbackClosureInfo. closureInfosMultiBB. append ( contentsOf: closureInfoArr)
267+ subsetThunkArr. append (
268+ contentsOf: closureInfoArr. filter { $0. subsetThunk != nil } . map { $0. subsetThunk! } )
269+ }
270+
271+ return pullbackClosureInfo
272+ }
273+
178274private func getPullbackClosureInfo( in caller: Function , _ context: FunctionPassContext )
179275 -> PullbackClosureInfo ?
180276{
@@ -373,6 +469,128 @@ private func updatePullbackClosureInfo(
373469 intermediateClosureArgDescriptorData: intermediateClosureArgDescriptorData, context)
374470}
375471
472+ typealias BTEPayloadArgOfPbBBWithBTETypeAndCase = ( arg: Argument , enumTypeAndCase: EnumTypeAndCase )
473+
474+ // If the pullback's basic block has an argument which is a payload tuple of the
475+ // branch tracing enum corresponding to the given VJP, return this argument and any valid combination
476+ // of a branch tracing enum type and its case index having the same payload tuple type as the argument.
477+ // The function assumes that no more than one such argument is present.
478+ private func getBTEPayloadArgOfPbBBWithBTETypeAndCase( _ bb: BasicBlock , vjp: Function )
479+ -> BTEPayloadArgOfPbBBWithBTETypeAndCase ?
480+ {
481+ guard let predBB = bb. predecessors. first else {
482+ return nil
483+ }
484+
485+ for arg in bb. arguments {
486+ if !arg. type. isTuple {
487+ continue
488+ }
489+
490+ if let bi = predBB. terminator as? BranchInst {
491+ guard let uedi = bi. operands [ arg. index] . value. definingInstruction as? UncheckedEnumDataInst
492+ else {
493+ continue
494+ }
495+ let enumType = uedi. `enum`. type
496+ if !enumType. isBranchTracingEnumIn ( vjp: vjp) {
497+ continue
498+ }
499+
500+ return BTEPayloadArgOfPbBBWithBTETypeAndCase (
501+ arg: arg,
502+ enumTypeAndCase: (
503+ enumType: enumType,
504+ caseIdx: uedi. caseIndex
505+ )
506+ )
507+ }
508+
509+ if let sei = predBB. terminator as? SwitchEnumInst {
510+ let enumType = sei. enumOp. type
511+ if !enumType. isBranchTracingEnumIn ( vjp: vjp) {
512+ continue
513+ }
514+ return BTEPayloadArgOfPbBBWithBTETypeAndCase (
515+ arg: arg,
516+ enumTypeAndCase: (
517+ enumType: enumType,
518+ caseIdx: sei. getUniqueCase ( forSuccessor: bb) !
519+ )
520+ )
521+ }
522+ }
523+
524+ return nil
525+ }
526+
527+ extension PartialApplyInst {
528+ func isSubsetThunk( ) -> Bool {
529+ if self . argumentOperands. singleElement == nil {
530+ return false
531+ }
532+ guard let desc = self . referencedFunction? . description else {
533+ return false
534+ }
535+ // TODO: do not rely on description which is intended for debug purposes.
536+ return desc. starts (
537+ with: " // autodiff subset parameters thunk for " )
538+ }
539+ }
540+
541+ private func handleNonAppliesMultiBB(
542+ for rootClosure: SingleValueInstruction ,
543+ _ context: FunctionPassContext
544+ )
545+ -> [ ClosureInfoMultiBB ]
546+ {
547+ let vjp = rootClosure. parentFunction
548+ var closureInfoArr = [ ClosureInfoMultiBB] ( )
549+
550+ var closure = rootClosure
551+ var subsetThunk = PartialApplyInst ? ( nil )
552+ if rootClosure. uses. singleElement != nil {
553+ if let pai = closure. uses. singleElement!. instruction as? PartialApplyInst {
554+ if pai. isSubsetThunk ( ) {
555+ subsetThunk = pai
556+ closure = pai
557+ }
558+ }
559+ }
560+
561+ for use in closure. uses {
562+ guard let ti = use. instruction as? TupleInst else {
563+ // Unexpected use of closure, return nothing
564+ return [ ]
565+ }
566+ for tiUse in ti. uses {
567+ guard let ei = tiUse. instruction as? EnumInst else {
568+ // Unexpected use of payload tuple, return nothing
569+ return [ ]
570+ }
571+ if !ei. type. isBranchTracingEnumIn ( vjp: vjp) {
572+ // Unexpected use of payload tuple, return nothing
573+ return [ ]
574+ }
575+ var capturedArgs = [ Value] ( )
576+ if let pai = rootClosure as? PartialApplyInst {
577+ capturedArgs = pai. argumentOperands. map { $0. value }
578+ }
579+ let enumTypeAndCase = EnumTypeAndCase ( enumType: ei. type, caseIdx: ei. caseIndex)
580+ closureInfoArr. append (
581+ ClosureInfoMultiBB (
582+ closure: rootClosure,
583+ capturedArgs: capturedArgs,
584+ subsetThunk: subsetThunk,
585+ payloadTuple: ti,
586+ idxInPayload: use. index,
587+ enumTypeAndCase: enumTypeAndCase
588+ ) )
589+ }
590+ }
591+ return closureInfoArr
592+ }
593+
376594/// Handles all non-apply direct and transitive uses of `rootClosure`.
377595///
378596/// Returns:
@@ -1390,6 +1608,7 @@ private struct ClosureArgDescriptor {
13901608private struct PullbackClosureInfo {
13911609 let paiOfPullback : PartialApplyInst
13921610 var closureArgDescriptors : [ ClosureArgDescriptor ] = [ ]
1611+ var closureInfosMultiBB : [ ClosureInfoMultiBB ] = [ ]
13931612
13941613 init ( paiOfPullback: PartialApplyInst ) {
13951614 self . paiOfPullback = paiOfPullback
@@ -1475,3 +1694,101 @@ let rewrittenCallerBodyTest = FunctionTest("autodiff_closure_specialize_rewritte
14751694 print ( " Rewritten caller body for: \( function. name) : " )
14761695 print ( " \( function) \n " )
14771696}
1697+
1698+ let getPullbackClosureInfoMultiBBTest = FunctionTest (
1699+ " autodiff_closure_specialize_get_pullback_closure_info_multi_bb "
1700+ ) {
1701+ function, arguments, context in
1702+ let pullbackClosureInfo = getPullbackClosureInfoMultiBB ( in: function, context)
1703+ print ( " Run getPullbackClosureInfoMultiBB for VJP \( function. name) : pullbackClosureInfo = ( " )
1704+ print ( " pullbackFn = \( pullbackClosureInfo. pullbackFn. name) " )
1705+ print ( " closureInfosMultiBB = [ " )
1706+ for closureInfoMultiBB in pullbackClosureInfo. closureInfosMultiBB {
1707+ print ( " ClosureInfoMultiBB( " )
1708+ print ( " closure: \( closureInfoMultiBB. closure) " )
1709+ print ( " capturedArgs: [ " )
1710+ for capturedArg in closureInfoMultiBB. capturedArgs {
1711+ print ( " \( capturedArg) " )
1712+ }
1713+ print ( " ] " )
1714+ let subsetThunkStr =
1715+ ( closureInfoMultiBB. subsetThunk == nil ? " nil " : " \( closureInfoMultiBB. subsetThunk!) " )
1716+ print ( " subsetThunk: \( subsetThunkStr) " )
1717+ print ( " payloadTuple: \( closureInfoMultiBB. payloadTuple) " )
1718+ print ( " idxInPayload: \( closureInfoMultiBB. idxInPayload) " )
1719+ print ( " enumTypeAndCase: \( closureInfoMultiBB. enumTypeAndCase) " )
1720+ print ( " ) " )
1721+ }
1722+ print ( " ] \n ) \n " )
1723+ }
1724+
1725+ typealias SpecBTEDict = SpecializedBranchTracingEnumDict
1726+
1727+ func getSpecBTEDict( vjp: Function , context: FunctionPassContext ) -> SpecBTEDict {
1728+ let pullbackClosureInfo = getPullbackClosureInfoMultiBB ( in: vjp, context)
1729+ let pb = pullbackClosureInfo. pullbackFn
1730+ let enumTypeOfEntryBBArg = pb. entryBlock. getBranchTracingEnumArg ( vjp: vjp) !. type
1731+
1732+ let vectorOfClosureInfoMultiBB = VectorOfBranchTracingEnumAndClosureInfo (
1733+ pullbackClosureInfo. closureInfosMultiBB. map {
1734+ BranchTracingEnumAndClosureInfo (
1735+ enumType: $0. enumTypeAndCase. enumType. bridged,
1736+ enumCaseIdx: $0. enumTypeAndCase. caseIdx,
1737+ closure: $0. closure. bridged,
1738+ idxInPayload: $0. idxInPayload)
1739+ } )
1740+
1741+ let enumDict = autodiffSpecializeBranchTracingEnums (
1742+ vjp. bridged, enumTypeOfEntryBBArg. bridged, vectorOfClosureInfoMultiBB)
1743+
1744+ return enumDict
1745+ }
1746+
1747+ let specializeBranchTracingEnums = FunctionTest ( " autodiff_specialize_branch_tracing_enums " ) {
1748+ function, arguments, context in
1749+ let enumDict = getSpecBTEDict ( vjp: function, context: context)
1750+ print (
1751+ " Specialized branch tracing enum dict for VJP \( function. name) contains \( enumDict. size ( ) ) elements: "
1752+ )
1753+ print (
1754+ " \( String ( taking: getSpecializedBranchTracingEnumDictAsString ( enumDict) ) ) " ) //, function.bridged)))")
1755+ }
1756+
1757+ let specializeBTEArgInVjpBB = FunctionTest ( " autodiff_specialize_bte_arg_in_vjp_bb " ) {
1758+ function, arguments, context in
1759+ let enumDict = getSpecBTEDict ( vjp: function, context: context)
1760+ print ( " Specialized BTE arguments of basic blocks in VJP \( function. name) : " )
1761+ for bb in function. blocks {
1762+ guard let arg = bb. getBranchTracingEnumArg ( vjp: function) else {
1763+ continue
1764+ }
1765+ let newArg = specializeBranchTracingEnumBBArgInVJP ( arg. bridged, enumDict) . argument
1766+ print ( " \( newArg) " )
1767+ bb. eraseArgument ( at: newArg. index, context)
1768+ }
1769+ print ( " " )
1770+ }
1771+
1772+ let specializePayloadArgInPullbackBB = FunctionTest ( " autodiff_specialize_payload_arg_in_pb_bb " ) {
1773+ function, arguments, context in
1774+ let pullbackClosureInfo = getPullbackClosureInfoMultiBB ( in: function, context)
1775+ let pb = pullbackClosureInfo. pullbackFn
1776+ let enumDict = getSpecBTEDict ( vjp: function, context: context)
1777+
1778+ print ( " Specialized BTE payload arguments of basic blocks in pullback \( pb. name) : " )
1779+ for bb in pb. blocks {
1780+ guard
1781+ let ( arg, enumTypeAndCase) = getBTEPayloadArgOfPbBBWithBTETypeAndCase ( bb, vjp: function)
1782+ else {
1783+ continue
1784+ }
1785+
1786+ let enumType = enumDict [ enumTypeAndCase. enumType. bridged] !
1787+ let newArg = specializePayloadTupleBBArgInPullback (
1788+ arg. bridged, enumType, enumTypeAndCase. caseIdx
1789+ ) . argument
1790+ print ( " \( newArg) " )
1791+ bb. eraseArgument ( at: newArg. index, context)
1792+ }
1793+ print ( " " )
1794+ }
0 commit comments