100100/// ```
101101
102102import AST
103+ import AutoDiffClosureSpecializationBridging
104+ import Cxx
105+ import CxxStdlib
103106import SIL
104107import SILBridging
105108
106109private let verbose = false
107110
108111private func log( prefix: Bool = true , _ message: @autoclosure ( ) -> String ) {
109112 if verbose {
110- debugLog ( prefix: prefix, message ( ) )
113+ debugLog ( prefix: prefix, " [ADCS] " + message( ) )
111114 }
112115}
113116
@@ -120,6 +123,25 @@ let generalClosureSpecialization = FunctionPass(
120123 print ( " NOT IMPLEMENTED " )
121124}
122125
126+ extension Type {
127+ func isBranchTracingEnumIn( vjp: Function ) -> Bool {
128+ return self . bridged. isAutodiffBranchTracingEnumInVJP ( vjp. bridged)
129+ }
130+ }
131+
132+ extension Collection {
133+ func getExactlyOneOrNil( ) -> Element ? {
134+ assert ( self . count <= 1 )
135+ return self . first
136+ }
137+ }
138+
139+ extension BasicBlock {
140+ fileprivate func getBranchTracingEnumArg( vjp: Function ) -> Argument ? {
141+ return self . arguments. filter { $0. type. isBranchTracingEnumIn ( vjp: vjp) } . getExactlyOneOrNil ( )
142+ }
143+ }
144+
123145let autodiffClosureSpecialization = FunctionPass ( name: " autodiff-closure-specialization " ) {
124146 ( function: Function , context: FunctionPassContext ) in
125147
@@ -175,6 +197,93 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special
175197
176198private let specializationLevelLimit = 2
177199
200+ private func getPartialApplyOfPullbackInExitVJPBB( vjp: Function ) -> PartialApplyInst ? {
201+ log ( " getPartialApplyOfPullbackInExitVJPBB: running for VJP \( vjp. name) " )
202+ guard let exitBB = vjp. blocks. filter ( { $0. terminator as? ReturnInst != nil } ) . getExactlyOneOrNil ( )
203+ else {
204+ log ( " getPartialApplyOfPullbackInExitVJPBB: exit BB not found, aborting " )
205+ return nil
206+ }
207+
208+ let ri = exitBB. terminator as! ReturnInst
209+ guard let retValDefiningInstr = ri. returnedValue. definingInstruction else {
210+ log ( " getPartialApplyOfPullbackInExitVJPBB: return value is not defined by an instruction, aborting " )
211+ return nil
212+ }
213+
214+ func handleConvertFunctionOrPartialApply( inst: Instruction ) -> PartialApplyInst ? {
215+ if let pai = inst as? PartialApplyInst {
216+ log ( " getPartialApplyOfPullbackInExitVJPBB: success " )
217+ return pai
218+ }
219+ if let cfi = inst as? ConvertFunctionInst {
220+ if let pai = cfi. fromFunction as? PartialApplyInst {
221+ log ( " getPartialApplyOfPullbackInExitVJPBB: success " )
222+ return pai
223+ }
224+ log ( " getPartialApplyOfPullbackInExitVJPBB: fromFunction operand of convert_function instruction is not defined by partial_apply instruction, aborting " )
225+ return nil
226+ }
227+ log ( " getPartialApplyOfPullbackInExitVJPBB: unexpected instruction type, aborting " )
228+ return nil
229+ }
230+
231+ if let ti = retValDefiningInstr as? TupleInst {
232+ log ( " getPartialApplyOfPullbackInExitVJPBB: return value is defined by tuple instruction " )
233+ if ti. operands. count < 2 {
234+ log ( " getPartialApplyOfPullbackInExitVJPBB: tuple instruction has \( ti. operands. count) operands, but at least 2 expected, aborting " )
235+ return nil
236+ }
237+ guard let lastTupleElemDefiningInst = ti. operands. last!. value. definingInstruction else {
238+ log ( " getPartialApplyOfPullbackInExitVJPBB: last tuple element is not defined by an instruction, aborting " )
239+ return nil
240+ }
241+ return handleConvertFunctionOrPartialApply ( inst: lastTupleElemDefiningInst)
242+ }
243+
244+ return handleConvertFunctionOrPartialApply ( inst: retValDefiningInstr)
245+ }
246+
247+ typealias EnumTypeAndCase = ( enumType: Type , caseIdx: Int )
248+
249+ typealias ClosureInfoMultiBB = (
250+ closure: SingleValueInstruction ,
251+ capturedArgs: [ Value ] ,
252+ subsetThunk: PartialApplyInst ? ,
253+ payloadTuple: TupleInst ,
254+ idxInPayload: Int ,
255+ enumTypeAndCase: EnumTypeAndCase
256+ )
257+
258+ private func getPullbackClosureInfoMultiBB( in vjp: Function , _ context: FunctionPassContext )
259+ -> PullbackClosureInfo
260+ {
261+ let paiOfPbInExitVjpBB = getPartialApplyOfPullbackInExitVJPBB ( vjp: vjp) !
262+ var pullbackClosureInfo = PullbackClosureInfo ( paiOfPullback: paiOfPbInExitVjpBB)
263+ var subsetThunkArr = [ SingleValueInstruction] ( )
264+
265+ for inst in vjp. instructions {
266+ if inst == paiOfPbInExitVjpBB {
267+ continue
268+ }
269+ if inst. asSupportedClosure == nil {
270+ continue
271+ }
272+
273+ let rootClosure = inst. asSupportedClosure!
274+ if subsetThunkArr. contains ( rootClosure) {
275+ continue
276+ }
277+
278+ let closureInfoArr = handleNonAppliesMultiBB ( for: rootClosure, context)
279+ pullbackClosureInfo. closureInfosMultiBB. append ( contentsOf: closureInfoArr)
280+ subsetThunkArr. append (
281+ contentsOf: closureInfoArr. filter { $0. subsetThunk != nil } . map { $0. subsetThunk! } )
282+ }
283+
284+ return pullbackClosureInfo
285+ }
286+
178287private func getPullbackClosureInfo( in caller: Function , _ context: FunctionPassContext )
179288 -> PullbackClosureInfo ?
180289{
@@ -374,6 +483,144 @@ private func updatePullbackClosureInfo(
374483 intermediateClosureArgDescriptorData: intermediateClosureArgDescriptorData, context)
375484}
376485
486+ typealias BTEPayloadArgOfPbBBWithBTETypeAndCase = ( arg: Argument , enumTypeAndCase: EnumTypeAndCase )
487+
488+ // If the pullback's basic block has an argument which is a payload tuple of the
489+ // branch tracing enum corresponding to the given VJP, return this argument and any valid combination
490+ // of a branch tracing enum type and its case index having the same payload tuple type as the argument.
491+ // The function assumes that no more than one such argument is present.
492+ private func getBTEPayloadArgOfPbBBWithBTETypeAndCase( _ bb: BasicBlock , vjp: Function )
493+ -> BTEPayloadArgOfPbBBWithBTETypeAndCase ?
494+ {
495+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: basic block \( bb. shortDescription) in pullback \( bb. parentFunction. name) " )
496+ guard let predBB = bb. predecessors. first else {
497+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: the bb has no predecessors, aborting " )
498+ return nil
499+ }
500+
501+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: start iterating over bb args " )
502+ for arg in bb. arguments {
503+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: \( arg) " )
504+ if !arg. type. isTuple {
505+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: arg is not a tuple, skipping " )
506+ continue
507+ }
508+
509+ if let bi = predBB. terminator as? BranchInst {
510+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: terminator of pred bb is branch instruction " )
511+ guard let uedi = bi. operands [ arg. index] . value. definingInstruction as? UncheckedEnumDataInst
512+ else {
513+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: operand corresponding to the argument is not defined by unchecked_enum_data instruction " )
514+ continue
515+ }
516+ let enumType = uedi. `enum`. type
517+ if !enumType. isBranchTracingEnumIn ( vjp: vjp) {
518+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: enum type \( enumType) is not a branch tracing enum in VJP \( vjp. name) " )
519+ continue
520+ }
521+
522+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: success " )
523+ return BTEPayloadArgOfPbBBWithBTETypeAndCase (
524+ arg: arg,
525+ enumTypeAndCase: (
526+ enumType: enumType,
527+ caseIdx: uedi. caseIndex
528+ )
529+ )
530+ }
531+
532+ if let sei = predBB. terminator as? SwitchEnumInst {
533+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: terminator of pred bb is switch_enum instruction " )
534+ let enumType = sei. enumOp. type
535+ if !enumType. isBranchTracingEnumIn ( vjp: vjp) {
536+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: enum type \( enumType) is not a branch tracing enum in VJP \( vjp. name) " )
537+ continue
538+ }
539+
540+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: success " )
541+ return BTEPayloadArgOfPbBBWithBTETypeAndCase (
542+ arg: arg,
543+ enumTypeAndCase: (
544+ enumType: enumType,
545+ caseIdx: sei. getUniqueCase ( forSuccessor: bb) !
546+ )
547+ )
548+ }
549+ }
550+
551+ log ( " getBTEPayloadArgOfPbBBWithBTETypeAndCase: finish iterating over bb args; branch tracing enum arg not found " )
552+ return nil
553+ }
554+
555+ extension PartialApplyInst {
556+ func isSubsetThunk( ) -> Bool {
557+ if self . argumentOperands. singleElement == nil {
558+ return false
559+ }
560+ guard let function = self . referencedFunction else {
561+ return false
562+ }
563+ return function. bridged. isAutodiffSubsetParametersThunk ( )
564+ }
565+ }
566+
567+ private func handleNonAppliesMultiBB(
568+ for rootClosure: SingleValueInstruction ,
569+ _ context: FunctionPassContext
570+ )
571+ -> [ ClosureInfoMultiBB ]
572+ {
573+ log ( " handleNonAppliesMultiBB: running for \( rootClosure) " )
574+ let vjp = rootClosure. parentFunction
575+ var closureInfoArr = [ ClosureInfoMultiBB] ( )
576+
577+ var closure = rootClosure
578+ var subsetThunk = PartialApplyInst ? ( nil )
579+ if rootClosure. uses. singleElement != nil {
580+ if let pai = closure. uses. singleElement!. instruction as? PartialApplyInst {
581+ if pai. isSubsetThunk ( ) {
582+ log ( " handleNonAppliesMultiBB: found subset thunk \( pai) " )
583+ subsetThunk = pai
584+ closure = pai
585+ }
586+ }
587+ }
588+
589+ for use in closure. uses {
590+ guard let ti = use. instruction as? TupleInst else {
591+ log ( " handleNonAppliesMultiBB: unexpected use of closure, aborting: \( use) " )
592+ return [ ]
593+ }
594+ for tiUse in ti. uses {
595+ guard let ei = tiUse. instruction as? EnumInst else {
596+ log ( " handleNonAppliesMultiBB: unexpected use of payload tuple, aborting: \( tiUse) " )
597+ return [ ]
598+ }
599+ if !ei. type. isBranchTracingEnumIn ( vjp: vjp) {
600+ log ( " handleNonAppliesMultiBB: enum type \( ei. type) is not a branch tracing enum in VJP \( vjp. name) , aborting " )
601+ return [ ]
602+ }
603+ var capturedArgs = [ Value] ( )
604+ if let pai = rootClosure as? PartialApplyInst {
605+ capturedArgs = pai. argumentOperands. map { $0. value }
606+ }
607+ log ( " handleNonAppliesMultiBB: creating closure info with enum type \( ei. type) , case index \( ei. caseIndex) , index in payload tuple \( use. index) and payload tuple \( ti) " )
608+ let enumTypeAndCase = EnumTypeAndCase ( enumType: ei. type, caseIdx: ei. caseIndex)
609+ closureInfoArr. append (
610+ ClosureInfoMultiBB (
611+ closure: rootClosure,
612+ capturedArgs: capturedArgs,
613+ subsetThunk: subsetThunk,
614+ payloadTuple: ti,
615+ idxInPayload: use. index,
616+ enumTypeAndCase: enumTypeAndCase
617+ ) )
618+ }
619+ }
620+ log ( " handleNonAppliesMultiBB: created \( closureInfoArr. count) closure info entries for \( rootClosure) " )
621+ return closureInfoArr
622+ }
623+
377624/// Handles all non-apply direct and transitive uses of `rootClosure`.
378625///
379626/// Returns:
@@ -1391,6 +1638,7 @@ private struct ClosureArgDescriptor {
13911638private struct PullbackClosureInfo {
13921639 let paiOfPullback : PartialApplyInst
13931640 var closureArgDescriptors : [ ClosureArgDescriptor ] = [ ]
1641+ var closureInfosMultiBB : [ ClosureInfoMultiBB ] = [ ]
13941642
13951643 init ( paiOfPullback: PartialApplyInst ) {
13961644 self . paiOfPullback = paiOfPullback
@@ -1474,3 +1722,101 @@ let rewrittenCallerBodyTest = FunctionTest("autodiff_closure_specialize_rewritte
14741722 print ( " Rewritten caller body for: \( function. name) : " )
14751723 print ( " \( function) \n " )
14761724}
1725+
1726+ let getPullbackClosureInfoMultiBBTest = FunctionTest (
1727+ " autodiff_closure_specialize_get_pullback_closure_info_multi_bb "
1728+ ) {
1729+ function, arguments, context in
1730+ let pullbackClosureInfo = getPullbackClosureInfoMultiBB ( in: function, context)
1731+ print ( " Run getPullbackClosureInfoMultiBB for VJP \( function. name) : pullbackClosureInfo = ( " )
1732+ print ( " pullbackFn = \( pullbackClosureInfo. pullbackFn. name) " )
1733+ print ( " closureInfosMultiBB = [ " )
1734+ for closureInfoMultiBB in pullbackClosureInfo. closureInfosMultiBB {
1735+ print ( " ClosureInfoMultiBB( " )
1736+ print ( " closure: \( closureInfoMultiBB. closure) " )
1737+ print ( " capturedArgs: [ " )
1738+ for capturedArg in closureInfoMultiBB. capturedArgs {
1739+ print ( " \( capturedArg) " )
1740+ }
1741+ print ( " ] " )
1742+ let subsetThunkStr =
1743+ ( closureInfoMultiBB. subsetThunk == nil ? " nil " : " \( closureInfoMultiBB. subsetThunk!) " )
1744+ print ( " subsetThunk: \( subsetThunkStr) " )
1745+ print ( " payloadTuple: \( closureInfoMultiBB. payloadTuple) " )
1746+ print ( " idxInPayload: \( closureInfoMultiBB. idxInPayload) " )
1747+ print ( " enumTypeAndCase: \( closureInfoMultiBB. enumTypeAndCase) " )
1748+ print ( " ) " )
1749+ }
1750+ print ( " ] \n ) \n " )
1751+ }
1752+
1753+ typealias SpecBTEDict = SpecializedBranchTracingEnumDict
1754+
1755+ func getSpecBTEDict( vjp: Function , context: FunctionPassContext ) -> SpecBTEDict {
1756+ let pullbackClosureInfo = getPullbackClosureInfoMultiBB ( in: vjp, context)
1757+ let pb = pullbackClosureInfo. pullbackFn
1758+ let enumTypeOfEntryBBArg = pb. entryBlock. getBranchTracingEnumArg ( vjp: vjp) !. type
1759+
1760+ let vectorOfClosureInfoMultiBB = VectorOfBranchTracingEnumAndClosureInfo (
1761+ pullbackClosureInfo. closureInfosMultiBB. map {
1762+ BranchTracingEnumAndClosureInfo (
1763+ enumType: $0. enumTypeAndCase. enumType. bridged,
1764+ enumCaseIdx: $0. enumTypeAndCase. caseIdx,
1765+ closure: $0. closure. bridged,
1766+ idxInPayload: $0. idxInPayload)
1767+ } )
1768+
1769+ let enumDict = autodiffSpecializeBranchTracingEnums (
1770+ vjp. bridged, enumTypeOfEntryBBArg. bridged, vectorOfClosureInfoMultiBB)
1771+
1772+ return enumDict
1773+ }
1774+
1775+ let specializeBranchTracingEnums = FunctionTest ( " autodiff_specialize_branch_tracing_enums " ) {
1776+ function, arguments, context in
1777+ let enumDict = getSpecBTEDict ( vjp: function, context: context)
1778+ print (
1779+ " Specialized branch tracing enum dict for VJP \( function. name) contains \( enumDict. size ( ) ) elements: "
1780+ )
1781+ print (
1782+ " \( String ( taking: getSpecializedBranchTracingEnumDictAsString ( enumDict) ) ) " ) //, function.bridged)))")
1783+ }
1784+
1785+ let specializeBTEArgInVjpBB = FunctionTest ( " autodiff_specialize_bte_arg_in_vjp_bb " ) {
1786+ function, arguments, context in
1787+ let enumDict = getSpecBTEDict ( vjp: function, context: context)
1788+ print ( " Specialized BTE arguments of basic blocks in VJP \( function. name) : " )
1789+ for bb in function. blocks {
1790+ guard let arg = bb. getBranchTracingEnumArg ( vjp: function) else {
1791+ continue
1792+ }
1793+ let newArg = specializeBranchTracingEnumBBArgInVJP ( arg. bridged, enumDict) . argument
1794+ print ( " \( newArg) " )
1795+ bb. eraseArgument ( at: newArg. index, context)
1796+ }
1797+ print ( " " )
1798+ }
1799+
1800+ let specializePayloadArgInPullbackBB = FunctionTest ( " autodiff_specialize_payload_arg_in_pb_bb " ) {
1801+ function, arguments, context in
1802+ let pullbackClosureInfo = getPullbackClosureInfoMultiBB ( in: function, context)
1803+ let pb = pullbackClosureInfo. pullbackFn
1804+ let enumDict = getSpecBTEDict ( vjp: function, context: context)
1805+
1806+ print ( " Specialized BTE payload arguments of basic blocks in pullback \( pb. name) : " )
1807+ for bb in pb. blocks {
1808+ guard
1809+ let ( arg, enumTypeAndCase) = getBTEPayloadArgOfPbBBWithBTETypeAndCase ( bb, vjp: function)
1810+ else {
1811+ continue
1812+ }
1813+
1814+ let enumType = enumDict [ enumTypeAndCase. enumType. bridged] !
1815+ let newArg = specializePayloadTupleBBArgInPullback (
1816+ arg. bridged, enumType, enumTypeAndCase. caseIdx
1817+ ) . argument
1818+ print ( " \( newArg) " )
1819+ bb. eraseArgument ( at: newArg. index, context)
1820+ }
1821+ print ( " " )
1822+ }
0 commit comments