From 811ff03a885aadaf1267a29ffcf1db5b64045547 Mon Sep 17 00:00:00 2001 From: Daniil Kovalev Date: Wed, 20 Aug 2025 13:29:41 +0300 Subject: [PATCH] [AutoDiff][gardening] Auto-format ClosureSpecialization.swift Auto-formatting is done as a prerequisite for the ongoing series of patches resolving #68944. --- .../ClosureSpecialization.swift | 804 ++++++++++-------- 1 file changed, 459 insertions(+), 345 deletions(-) diff --git a/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift b/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift index 8559e8365680a..fce5c35365d19 100644 --- a/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift +++ b/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift @@ -25,43 +25,43 @@ /// The compiler performs reverse-mode differentiation on functions marked with `@differentiable(reverse)`. In doing so, /// it generates corresponding VJP and Pullback functions, which perform the forward and reverse pass respectively. You /// can think of VJPs as functions that "differentiate" an original function and Pullbacks as the calculated -/// "derivative" of the original function. -/// -/// VJPs always return a tuple of 2 values -- the original result and the Pullback. Pullbacks are essentially a chain +/// "derivative" of the original function. +/// +/// VJPs always return a tuple of 2 values -- the original result and the Pullback. Pullbacks are essentially a chain /// of closures, where the closure-contexts are implicitly used as the so-called "tape" during the reverse /// differentiation process. It is this chain of closures contained within the Pullbacks that this optimization aims /// to optimize via closure specialization. /// /// The code patterns that this optimization targets, look similar to the one below: /// ``` swift -/// +/// /// // Since `foo` is marked with the `differentiable(reverse)` attribute the compiler /// // will generate corresponding VJP and Pullback functions in SIL. Let's assume that /// // these functions are called `vjp_foo` and `pb_foo` respectively. -/// @differentiable(reverse) -/// func foo(_ x: Float) -> Float { +/// @differentiable(reverse) +/// func foo(_ x: Float) -> Float { /// return sin(x) /// } /// -/// //============== Before closure specialization ==============// +/// //============== Before closure specialization ==============// /// // VJP of `foo`. Returns the original result and the Pullback of `foo`. -/// sil @vjp_foo: $(Float) -> (originalResult: Float, pullback: (Float) -> Float) { -/// bb0(%0: $Float): -/// // __Inlined__ `vjp_sin`: It is important for all intermediate VJPs to have +/// sil @vjp_foo: $(Float) -> (originalResult: Float, pullback: (Float) -> Float) { +/// bb0(%0: $Float): +/// // __Inlined__ `vjp_sin`: It is important for all intermediate VJPs to have /// // been inlined in `vjp_foo`, otherwise `vjp_foo` will not be able to determine /// // that `pb_foo` is closing over other closures and no specialization will happen. -/// \ +/// \ /// %originalResult = apply @sin(%0): $(Float) -> Float \__ Inlined `vjp_sin` /// %partially_applied_pb_sin = partial_apply pb_sin(%0): $(Float) -> Float / -/// / +/// / /// /// %pb_foo = function_ref @pb_foo: $@convention(thin) (Float, (Float) -> Float) -> Float /// %partially_applied_pb_foo = partial_apply %pb_foo(%partially_applied_pb_sin): $(Float, (Float) -> Float) -> Float -/// +/// /// return (%originalResult, %partially_applied_pb_foo) /// } /// -/// // Pullback of `foo`. +/// // Pullback of `foo`. /// // /// // It receives what are called as intermediate closures that represent /// // the calculations that the Pullback needs to perform to calculate a function's @@ -70,31 +70,31 @@ /// // The intermediate closures may themselves contain intermediate closures and /// // that is why the Pullback for a function differentiated at the "top" level /// // may end up being a "chain" of closures. -/// sil @pb_foo: $(Float, (Float) -> Float) -> Float { -/// bb0(%0: $Float, %pb_sin: $(Float) -> Float): -/// %derivative_of_sin = apply %pb_sin(%0): $(Float) -> Float +/// sil @pb_foo: $(Float, (Float) -> Float) -> Float { +/// bb0(%0: $Float, %pb_sin: $(Float) -> Float): +/// %derivative_of_sin = apply %pb_sin(%0): $(Float) -> Float /// return %derivative_of_sin: Float /// } /// -/// //============== After closure specialization ==============// -/// sil @vjp_foo: $(Float) -> (originalResult: Float, pullback: (Float) -> Float) { -/// bb0(%0: $Float): -/// %originalResult = apply @sin(%0): $(Float) -> Float -/// +/// //============== After closure specialization ==============// +/// sil @vjp_foo: $(Float) -> (originalResult: Float, pullback: (Float) -> Float) { +/// bb0(%0: $Float): +/// %originalResult = apply @sin(%0): $(Float) -> Float +/// /// // Before the optimization, pullback of `foo` used to take a closure for computing /// // pullback of `sin`. Now, the specialized pullback of `foo` takes the arguments that /// // pullback of `sin` used to close over and pullback of `sin` is instead copied over /// // inside pullback of `foo`. /// %specialized_pb_foo = function_ref @specialized_pb_foo: $@convention(thin) (Float, Float) -> Float -/// %partially_applied_pb_foo = partial_apply %specialized_pb_foo(%0): $(Float, Float) -> Float -/// +/// %partially_applied_pb_foo = partial_apply %specialized_pb_foo(%0): $(Float, Float) -> Float +/// /// return (%originalResult, %partially_applied_pb_foo) /// } -/// -/// sil @specialized_pb_foo: $(Float, Float) -> Float { -/// bb0(%0: $Float, %1: $Float): -/// %2 = partial_apply @pb_sin(%1): $(Float) -> Float -/// %3 = apply %2(): $() -> Float +/// +/// sil @specialized_pb_foo: $(Float, Float) -> Float { +/// bb0(%0: $Float, %1: $Float): +/// %2 = partial_apply @pb_sin(%1): $(Float) -> Float +/// %3 = apply %2(): $() -> Float /// return %3: $Float /// } /// ``` @@ -112,7 +112,9 @@ private func log(prefix: Bool = true, _ message: @autoclosure () -> String) { } // =========== Entry point =========== // -let generalClosureSpecialization = FunctionPass(name: "experimental-swift-based-closure-specialization") { +let generalClosureSpecialization = FunctionPass( + name: "experimental-swift-based-closure-specialization" +) { (function: Function, context: FunctionPassContext) in // TODO: Implement general closure specialization optimization print("NOT IMPLEMENTED") @@ -122,10 +124,11 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special (function: Function, context: FunctionPassContext) in guard !function.isDefinedExternally, - function.isAutodiffVJP else { + function.isAutodiffVJP + else { return } - + var remainingSpecializationRounds = 5 repeat { @@ -133,13 +136,16 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special break } - var (specializedFunction, alreadyExists) = getOrCreateSpecializedFunction(basedOn: pullbackClosureInfo, context) + var (specializedFunction, alreadyExists) = getOrCreateSpecializedFunction( + basedOn: pullbackClosureInfo, context) if !alreadyExists { - context.notifyNewFunction(function: specializedFunction, derivedFrom: pullbackClosureInfo.pullbackFn) + context.notifyNewFunction( + function: specializedFunction, derivedFrom: pullbackClosureInfo.pullbackFn) } - rewriteApplyInstruction(using: specializedFunction, pullbackClosureInfo: pullbackClosureInfo, context) + rewriteApplyInstruction( + using: specializedFunction, pullbackClosureInfo: pullbackClosureInfo, context) var deadClosures = InstructionWorklist(context) pullbackClosureInfo.closureArgDescriptors @@ -169,28 +175,30 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special private let specializationLevelLimit = 2 -private func getPullbackClosureInfo(in caller: Function, _ context: FunctionPassContext) -> PullbackClosureInfo? { +private func getPullbackClosureInfo(in caller: Function, _ context: FunctionPassContext) + -> PullbackClosureInfo? +{ /// __Root__ closures created via `partial_apply` or `thin_to_thick_function` may be converted and reabstracted /// before finally being used at an apply site. We do not want to handle these intermediate closures separately - /// as they are handled and cloned into the specialized function as part of the root closures. Therefore, we keep - /// track of these intermediate closures in a set. - /// + /// as they are handled and cloned into the specialized function as part of the root closures. Therefore, we keep + /// track of these intermediate closures in a set. + /// /// This set is populated via the `markConvertedAndReabstractedClosuresAsUsed` function which is called when we're /// handling the different uses of our root closures. /// /// Below SIL example illustrates the above point. - /// ``` + /// ``` /// // The below set of a "root" closure and its reabstractions/conversions /// // will be handled as a unit and the entire set will be copied over - /// // in the specialized version of `takesClosure` if we determine that we + /// // in the specialized version of `takesClosure` if we determine that we /// // can specialize `takesClosure` against its closure argument. - /// __ - /// %someFunction = function_ref @someFunction: $@convention(thin) (Int, Int) -> Int \ + /// __ + /// %someFunction = function_ref @someFunction: $@convention(thin) (Int, Int) -> Int \ /// %rootClosure = partial_apply [callee_guaranteed] %someFunction (%someInt): $(Int, Int) -> Int \ - /// %thunk = function_ref @reabstractionThunk : $@convention(thin) (@callee_guaranteed (Int) -> Int) -> @out Int / - /// %reabstractedClosure = partial_apply [callee_guaranteed] %thunk(%rootClosure) : / - /// $@convention(thin) (@callee_guaranteed (Int) -> Int) -> @out Int __/ - /// + /// %thunk = function_ref @reabstractionThunk : $@convention(thin) (@callee_guaranteed (Int) -> Int) -> @out Int / + /// %reabstractedClosure = partial_apply [callee_guaranteed] %thunk(%rootClosure) : / + /// $@convention(thin) (@callee_guaranteed (Int) -> Int) -> @out Int __/ + /// /// %takesClosure = function_ref @takesClosure : $@convention(thin) (@owned @callee_guaranteed (Int) -> @out Int) -> Int /// %result = partial_apply %takesClosure(%reabstractedClosure) : $@convention(thin) (@owned @callee_guaranteed () -> @out Int) -> Int /// ret %result @@ -205,17 +213,20 @@ private func getPullbackClosureInfo(in caller: Function, _ context: FunctionPass for inst in caller.instructions { if !convertedAndReabstractedClosures.contains(inst), - let rootClosure = inst.asSupportedClosure + let rootClosure = inst.asSupportedClosure { - updatePullbackClosureInfo(for: rootClosure, in: &pullbackClosureInfoOpt, - convertedAndReabstractedClosures: &convertedAndReabstractedClosures, context) + updatePullbackClosureInfo( + for: rootClosure, in: &pullbackClosureInfoOpt, + convertedAndReabstractedClosures: &convertedAndReabstractedClosures, context) } } return pullbackClosureInfoOpt } -private func getOrCreateSpecializedFunction(basedOn pullbackClosureInfo: PullbackClosureInfo, _ context: FunctionPassContext) +private func getOrCreateSpecializedFunction( + basedOn pullbackClosureInfo: PullbackClosureInfo, _ context: FunctionPassContext +) -> (function: Function, alreadyExists: Bool) { let specializedFunctionName = pullbackClosureInfo.specializedCalleeName(context) @@ -224,34 +235,40 @@ private func getOrCreateSpecializedFunction(basedOn pullbackClosureInfo: Pullbac } let pullbackFn = pullbackClosureInfo.pullbackFn - let specializedParameters = pullbackFn.convention.getSpecializedParameters(basedOn: pullbackClosureInfo) - - let specializedFunction = - context.createSpecializedFunctionDeclaration(from: pullbackFn, withName: specializedFunctionName, - withParams: specializedParameters, - makeThin: true, makeBare: true) - - context.buildSpecializedFunction(specializedFunction: specializedFunction, - buildFn: { (emptySpecializedFunction, functionPassContext) in - let closureSpecCloner = SpecializationCloner(emptySpecializedFunction: emptySpecializedFunction, functionPassContext) - closureSpecCloner.cloneAndSpecializeFunctionBody(using: pullbackClosureInfo) - }) + let specializedParameters = pullbackFn.convention.getSpecializedParameters( + basedOn: pullbackClosureInfo) + + let specializedFunction = + context.createSpecializedFunctionDeclaration( + from: pullbackFn, withName: specializedFunctionName, + withParams: specializedParameters, + makeThin: true, makeBare: true) + + context.buildSpecializedFunction( + specializedFunction: specializedFunction, + buildFn: { (emptySpecializedFunction, functionPassContext) in + let closureSpecCloner = SpecializationCloner( + emptySpecializedFunction: emptySpecializedFunction, functionPassContext) + closureSpecCloner.cloneAndSpecializeFunctionBody(using: pullbackClosureInfo) + }) return (specializedFunction, false) } -private func rewriteApplyInstruction(using specializedCallee: Function, pullbackClosureInfo: PullbackClosureInfo, - _ context: FunctionPassContext) { +private func rewriteApplyInstruction( + using specializedCallee: Function, pullbackClosureInfo: PullbackClosureInfo, + _ context: FunctionPassContext +) { let newApplyArgs = pullbackClosureInfo.getArgumentsForSpecializedApply(of: specializedCallee) for newApplyArg in newApplyArgs { if case let .PreviouslyCaptured(capturedArg, needsRetain, parentClosureArgIndex) = newApplyArg, - needsRetain + needsRetain { let closureArgDesc = pullbackClosureInfo.closureArgDesc(at: parentClosureArgIndex)! var builder = Builder(before: closureArgDesc.closure, context) - // TODO: Support only OSSA instructions once the OSSA elimination pass is moved after all function optimization + // TODO: Support only OSSA instructions once the OSSA elimination pass is moved after all function optimization // passes. if pullbackClosureInfo.paiOfPullback.parentBlock != closureArgDesc.closure.parentBlock { // Emit the retain and release that keeps the argument live across the callee using the closure. @@ -278,18 +295,19 @@ private func rewriteApplyInstruction(using specializedCallee: Function, pullback let funcRef = builder.createFunctionRef(specializedCallee) let capturedArgs = Array(newApplyArgs.map { $0.value }) - let newPartialApply = builder.createPartialApply(function: funcRef, substitutionMap: SubstitutionMap(), - capturedArguments: capturedArgs, calleeConvention: oldPartialApply.calleeConvention, - hasUnknownResultIsolation: oldPartialApply.hasUnknownResultIsolation, - isOnStack: oldPartialApply.isOnStack) + let newPartialApply = builder.createPartialApply( + function: funcRef, substitutionMap: SubstitutionMap(), + capturedArguments: capturedArgs, calleeConvention: oldPartialApply.calleeConvention, + hasUnknownResultIsolation: oldPartialApply.hasUnknownResultIsolation, + isOnStack: oldPartialApply.isOnStack) builder = Builder(before: pullbackClosureInfo.paiOfPullback.next!, context) - // TODO: Support only OSSA instructions once the OSSA elimination pass is moved after all function optimization + // TODO: Support only OSSA instructions once the OSSA elimination pass is moved after all function optimization // passes. for closureArgDesc in pullbackClosureInfo.closureArgDescriptors { if closureArgDesc.isClosureConsumed, - !closureArgDesc.isPartialApplyOnStack, - !closureArgDesc.parameterInfo.isTrivialNoescapeClosure + !closureArgDesc.isPartialApplyOnStack, + !closureArgDesc.parameterInfo.isTrivialNoescapeClosure { builder.createReleaseValue(operand: closureArgDesc.closure) } @@ -300,66 +318,73 @@ private func rewriteApplyInstruction(using specializedCallee: Function, pullback // ===================== Utility functions and extensions ===================== // -private func updatePullbackClosureInfo(for rootClosure: SingleValueInstruction, in pullbackClosureInfoOpt: inout PullbackClosureInfo?, - convertedAndReabstractedClosures: inout InstructionSet, _ context: FunctionPassContext) { +private func updatePullbackClosureInfo( + for rootClosure: SingleValueInstruction, in pullbackClosureInfoOpt: inout PullbackClosureInfo?, + convertedAndReabstractedClosures: inout InstructionSet, _ context: FunctionPassContext +) { var rootClosurePossibleLiveRange = InstructionRange(begin: rootClosure, context) defer { rootClosurePossibleLiveRange.deinitialize() } - var rootClosureApplies = OperandWorklist(context) + var rootClosureApplies = OperandWorklist(context) defer { rootClosureApplies.deinitialize() } // A "root" closure undergoing conversions and/or reabstractions has additional restrictions placed upon it, in order // for a pullback to be specialized against it. We handle conversion/reabstraction uses before we handle apply uses - // to gather the parameters required to evaluate these restrictions or to skip pullback's uses of "unsupported" + // to gather the parameters required to evaluate these restrictions or to skip pullback's uses of "unsupported" // closures altogether. // - // There are currently 2 restrictions that are evaluated prior to specializing a pullback against a converted and/or + // There are currently 2 restrictions that are evaluated prior to specializing a pullback against a converted and/or // reabstracted closure - // 1. A reabstracted root closure can only be specialized against, if the reabstracted closure is ultimately passed // trivially (as a noescape+thick function) as captured argument of pullback's partial_apply. // - // 2. A root closure may be a partial_apply [stack], in which case we need to make sure that all mark_dependence + // 2. A root closure may be a partial_apply [stack], in which case we need to make sure that all mark_dependence // bases for it will be available in the specialized callee in case the pullback is specialized against this root // closure. - let (foundUnexpectedUse, haveUsedReabstraction) = - handleNonApplies(for: rootClosure, rootClosureApplies: &rootClosureApplies, - rootClosurePossibleLiveRange: &rootClosurePossibleLiveRange, context); - + let (foundUnexpectedUse, haveUsedReabstraction) = + handleNonApplies( + for: rootClosure, rootClosureApplies: &rootClosureApplies, + rootClosurePossibleLiveRange: &rootClosurePossibleLiveRange, context) if foundUnexpectedUse { return } - let intermediateClosureArgDescriptorData = - handleApplies(for: rootClosure, pullbackClosureInfoOpt: &pullbackClosureInfoOpt, rootClosureApplies: &rootClosureApplies, - rootClosurePossibleLiveRange: &rootClosurePossibleLiveRange, - convertedAndReabstractedClosures: &convertedAndReabstractedClosures, - haveUsedReabstraction: haveUsedReabstraction, context) + let intermediateClosureArgDescriptorData = + handleApplies( + for: rootClosure, pullbackClosureInfoOpt: &pullbackClosureInfoOpt, + rootClosureApplies: &rootClosureApplies, + rootClosurePossibleLiveRange: &rootClosurePossibleLiveRange, + convertedAndReabstractedClosures: &convertedAndReabstractedClosures, + haveUsedReabstraction: haveUsedReabstraction, context) if pullbackClosureInfoOpt == nil { return } - finalizePullbackClosureInfo(for: rootClosure, in: &pullbackClosureInfoOpt, - rootClosurePossibleLiveRange: rootClosurePossibleLiveRange, - intermediateClosureArgDescriptorData: intermediateClosureArgDescriptorData, context) + finalizePullbackClosureInfo( + for: rootClosure, in: &pullbackClosureInfoOpt, + rootClosurePossibleLiveRange: rootClosurePossibleLiveRange, + intermediateClosureArgDescriptorData: intermediateClosureArgDescriptorData, context) } /// Handles all non-apply direct and transitive uses of `rootClosure`. /// -/// Returns: -/// haveUsedReabstraction - whether the root closure is reabstracted via a thunk +/// Returns: +/// haveUsedReabstraction - whether the root closure is reabstracted via a thunk /// foundUnexpectedUse - whether the root closure is directly or transitively used in an instruction that we don't know /// how to handle. If true, then `rootClosure` should not be specialized against. -private func handleNonApplies(for rootClosure: SingleValueInstruction, - rootClosureApplies: inout OperandWorklist, - rootClosurePossibleLiveRange: inout InstructionRange, - _ context: FunctionPassContext) +private func handleNonApplies( + for rootClosure: SingleValueInstruction, + rootClosureApplies: inout OperandWorklist, + rootClosurePossibleLiveRange: inout InstructionRange, + _ context: FunctionPassContext +) -> (foundUnexpectedUse: Bool, haveUsedReabstraction: Bool) { var foundUnexpectedUse = false @@ -392,7 +417,7 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction, possibleMarkDependenceBases.deinitialize() } - var rootClosureConversionsAndReabstractions = OperandWorklist(context) + var rootClosureConversionsAndReabstractions = OperandWorklist(context) rootClosureConversionsAndReabstractions.pushIfNotVisited(contentsOf: rootClosure.uses) defer { rootClosureConversionsAndReabstractions.deinitialize() @@ -403,7 +428,7 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction, possibleMarkDependenceBases.insert(arg) } } - + while let use = rootClosureConversionsAndReabstractions.pop() { switch use.instruction { case let cfi as ConvertFunctionInst: @@ -418,10 +443,10 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction, case let pai as PartialApplyInst: if !pai.isPullbackInResultOfAutodiffVJP, - pai.isSupportedClosure, - pai.isPartialApplyOfThunk, - // Argument must be a closure - pai.arguments[0].type.isThickFunction + pai.isSupportedClosure, + pai.isPartialApplyOfThunk, + // Argument must be a closure + pai.arguments[0].type.isThickFunction { rootClosureConversionsAndReabstractions.pushIfNotVisited(contentsOf: pai.uses) possibleMarkDependenceBases.insert(pai) @@ -437,27 +462,27 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction, rootClosurePossibleLiveRange.insert(use.instruction) case let mdi as MarkDependenceInst: - if possibleMarkDependenceBases.contains(mdi.base), - mdi.value == use.value, - mdi.value.type.isNoEscapeFunction, - mdi.value.type.isThickFunction + if possibleMarkDependenceBases.contains(mdi.base), + mdi.value == use.value, + mdi.value.type.isNoEscapeFunction, + mdi.value.type.isThickFunction { rootClosureConversionsAndReabstractions.pushIfNotVisited(contentsOf: mdi.uses) rootClosurePossibleLiveRange.insert(use.instruction) } - + case is CopyValueInst, - is DestroyValueInst, - is RetainValueInst, - is ReleaseValueInst, - is StrongRetainInst, - is StrongReleaseInst: + is DestroyValueInst, + is RetainValueInst, + is ReleaseValueInst, + is StrongRetainInst, + is StrongReleaseInst: rootClosurePossibleLiveRange.insert(use.instruction) case let ti as TupleInst: if ti.parentFunction.isAutodiffVJP, - let returnInst = ti.parentFunction.returnInstruction, - ti == returnInst.returnedValue + let returnInst = ti.parentFunction.returnInstruction, + ti == returnInst.returnedValue { // This is the pullback closure returned from an Autodiff VJP and we don't need to handle it. } else { @@ -467,23 +492,26 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction, default: foundUnexpectedUse = true log("Found unexpected direct or transitive user of root closure: \(use.instruction)") - return (foundUnexpectedUse, haveUsedReabstraction) + return (foundUnexpectedUse, haveUsedReabstraction) } } return (foundUnexpectedUse, haveUsedReabstraction) } -private typealias IntermediateClosureArgDescriptorDatum = (applySite: SingleValueInstruction, closureArgIndex: Int, paramInfo: ParameterInfo) - -private func handleApplies(for rootClosure: SingleValueInstruction, pullbackClosureInfoOpt: inout PullbackClosureInfo?, - rootClosureApplies: inout OperandWorklist, - rootClosurePossibleLiveRange: inout InstructionRange, - convertedAndReabstractedClosures: inout InstructionSet, haveUsedReabstraction: Bool, - _ context: FunctionPassContext) -> [IntermediateClosureArgDescriptorDatum] -{ +private typealias IntermediateClosureArgDescriptorDatum = ( + applySite: SingleValueInstruction, closureArgIndex: Int, paramInfo: ParameterInfo +) + +private func handleApplies( + for rootClosure: SingleValueInstruction, pullbackClosureInfoOpt: inout PullbackClosureInfo?, + rootClosureApplies: inout OperandWorklist, + rootClosurePossibleLiveRange: inout InstructionRange, + convertedAndReabstractedClosures: inout InstructionSet, haveUsedReabstraction: Bool, + _ context: FunctionPassContext +) -> [IntermediateClosureArgDescriptorDatum] { var intermediateClosureArgDescriptorData: [IntermediateClosureArgDescriptorDatum] = [] - + while let use = rootClosureApplies.pop() { rootClosurePossibleLiveRange.insert(use.instruction) @@ -493,7 +521,9 @@ private func handleApplies(for rootClosure: SingleValueInstruction, pullbackClos } // TODO: Handling generic closures may be possible but is not yet implemented - if pai.hasSubstitutions || !pai.calleeIsDynamicFunctionRef || !pai.isPullbackInResultOfAutodiffVJP { + if pai.hasSubstitutions || !pai.calleeIsDynamicFunctionRef + || !pai.isPullbackInResultOfAutodiffVJP + { continue } @@ -534,18 +564,19 @@ private func handleApplies(for rootClosure: SingleValueInstruction, pullbackClos continue } - let onlyHaveThinToThickClosure = rootClosure is ThinToThickFunctionInst && !haveUsedReabstraction + let onlyHaveThinToThickClosure = + rootClosure is ThinToThickFunctionInst && !haveUsedReabstraction guard let closureParamInfo = pai.operandConventions[parameter: use.index] else { fatalError("While handling apply uses, parameter info not found for operand: \(use)!") } // If we are going to need to release the copied over closure, we must make sure that we understand all the exit - // blocks, i.e., they terminate with an instruction that clearly indicates whether to release the copied over + // blocks, i.e., they terminate with an instruction that clearly indicates whether to release the copied over // closure or leak it. if closureParamInfo.convention.isGuaranteed, - !onlyHaveThinToThickClosure, - !callee.blocks.allSatisfy({ $0.isReachableExitBlock || $0.terminator is UnreachableInst }) + !onlyHaveThinToThickClosure, + !callee.blocks.allSatisfy({ $0.isReachableExitBlock || $0.terminator is UnreachableInst }) { continue } @@ -565,24 +596,26 @@ private func handleApplies(for rootClosure: SingleValueInstruction, pullbackClos // again by the ClosureSpecializer and so on. This happens if a closure argument is called _and_ referenced in // another closure, which is passed to a recursive call. E.g. // - // func foo(_ c: @escaping () -> ()) { + // func foo(_ c: @escaping () -> ()) { // c() foo({ c() }) // } // // A limit of 2 is good enough and will not be exceed in "regular" optimization scenarios. - let closureCallee = rootClosure is PartialApplyInst - ? (rootClosure as! PartialApplyInst).referencedFunction! - : (rootClosure as! ThinToThickFunctionInst).referencedFunction! + let closureCallee = + rootClosure is PartialApplyInst + ? (rootClosure as! PartialApplyInst).referencedFunction! + : (rootClosure as! ThinToThickFunctionInst).referencedFunction! if closureCallee.specializationLevel > specializationLevelLimit { continue } if haveUsedReabstraction { - markConvertedAndReabstractedClosuresAsUsed(rootClosure: rootClosure, convertedAndReabstractedClosure: use.value, - convertedAndReabstractedClosures: &convertedAndReabstractedClosures) + markConvertedAndReabstractedClosuresAsUsed( + rootClosure: rootClosure, convertedAndReabstractedClosure: use.value, + convertedAndReabstractedClosures: &convertedAndReabstractedClosures) } - + if pullbackClosureInfoOpt == nil { pullbackClosureInfoOpt = PullbackClosureInfo(paiOfPullback: pai) } else { @@ -597,20 +630,26 @@ private func handleApplies(for rootClosure: SingleValueInstruction, pullbackClos } /// Finalizes the pullback closure info for a given root closure by adding a corresponding `ClosureArgDescriptor` -private func finalizePullbackClosureInfo(for rootClosure: SingleValueInstruction, in pullbackClosureInfoOpt: inout PullbackClosureInfo?, - rootClosurePossibleLiveRange: InstructionRange, - intermediateClosureArgDescriptorData: [IntermediateClosureArgDescriptorDatum], - _ context: FunctionPassContext) { +private func finalizePullbackClosureInfo( + for rootClosure: SingleValueInstruction, in pullbackClosureInfoOpt: inout PullbackClosureInfo?, + rootClosurePossibleLiveRange: InstructionRange, + intermediateClosureArgDescriptorData: [IntermediateClosureArgDescriptorDatum], + _ context: FunctionPassContext +) { assert(pullbackClosureInfoOpt != nil) - let closureInfo = ClosureInfo(closure: rootClosure, lifetimeFrontier: Array(rootClosurePossibleLiveRange.ends)) + let closureInfo = ClosureInfo( + closure: rootClosure, lifetimeFrontier: Array(rootClosurePossibleLiveRange.ends)) for (applySite, closureArgumentIndex, parameterInfo) in intermediateClosureArgDescriptorData { if pullbackClosureInfoOpt!.paiOfPullback != applySite { - fatalError("ClosureArgDescriptor's applySite field is not equal to pullback's partial_apply; got \(applySite)!") + fatalError( + "ClosureArgDescriptor's applySite field is not equal to pullback's partial_apply; got \(applySite)!" + ) } - let closureArgDesc = ClosureArgDescriptor(closureInfo: closureInfo, closureArgumentIndex: closureArgumentIndex, - parameterInfo: parameterInfo) + let closureArgDesc = ClosureArgDescriptor( + closureInfo: closureInfo, closureArgumentIndex: closureArgumentIndex, + parameterInfo: parameterInfo) pullbackClosureInfoOpt!.appendClosureArgDescriptor(closureArgDesc) } } @@ -626,9 +665,9 @@ private func isClosureApplied(in callee: Function, closureArgIndex index: Int) - } if let faiCallee = fai.referencedFunction, - !faiCallee.blocks.isEmpty, - handledFuncs.insert(faiCallee).inserted, - handledFuncs.count <= recursionBudget + !faiCallee.blocks.isEmpty, + handledFuncs.insert(faiCallee).inserted, + handledFuncs.count <= recursionBudget { if inner(faiCallee, fai.calleeArgumentIndex(of: use)!, &handledFuncs) { return true @@ -646,58 +685,70 @@ private func isClosureApplied(in callee: Function, closureArgIndex index: Int) - return inner(callee, index, &handledFuncs) } -/// Marks any converted/reabstracted closures, corresponding to a given root closure as used. We do not want to -/// look at such closures separately as during function specialization they will be handled as part of the root closure. -private func markConvertedAndReabstractedClosuresAsUsed(rootClosure: Value, convertedAndReabstractedClosure: Value, - convertedAndReabstractedClosures: inout InstructionSet) -{ +/// Marks any converted/reabstracted closures, corresponding to a given root closure as used. We do not want to +/// look at such closures separately as during function specialization they will be handled as part of the root closure. +private func markConvertedAndReabstractedClosuresAsUsed( + rootClosure: Value, convertedAndReabstractedClosure: Value, + convertedAndReabstractedClosures: inout InstructionSet +) { if convertedAndReabstractedClosure != rootClosure { switch convertedAndReabstractedClosure { case let pai as PartialApplyInst: convertedAndReabstractedClosures.insert(pai) - return - markConvertedAndReabstractedClosuresAsUsed(rootClosure: rootClosure, - convertedAndReabstractedClosure: pai.arguments[0], - convertedAndReabstractedClosures: &convertedAndReabstractedClosures) + return + markConvertedAndReabstractedClosuresAsUsed( + rootClosure: rootClosure, + convertedAndReabstractedClosure: pai.arguments[0], + convertedAndReabstractedClosures: &convertedAndReabstractedClosures) case let cvt as ConvertFunctionInst: convertedAndReabstractedClosures.insert(cvt) - return - markConvertedAndReabstractedClosuresAsUsed(rootClosure: rootClosure, - convertedAndReabstractedClosure: cvt.fromFunction, - convertedAndReabstractedClosures: &convertedAndReabstractedClosures) + return + markConvertedAndReabstractedClosuresAsUsed( + rootClosure: rootClosure, + convertedAndReabstractedClosure: cvt.fromFunction, + convertedAndReabstractedClosures: &convertedAndReabstractedClosures) case let cvt as ConvertEscapeToNoEscapeInst: convertedAndReabstractedClosures.insert(cvt) - return - markConvertedAndReabstractedClosuresAsUsed(rootClosure: rootClosure, - convertedAndReabstractedClosure: cvt.fromFunction, - convertedAndReabstractedClosures: &convertedAndReabstractedClosures) + return + markConvertedAndReabstractedClosuresAsUsed( + rootClosure: rootClosure, + convertedAndReabstractedClosure: cvt.fromFunction, + convertedAndReabstractedClosures: &convertedAndReabstractedClosures) case let mdi as MarkDependenceInst: convertedAndReabstractedClosures.insert(mdi) - return - markConvertedAndReabstractedClosuresAsUsed(rootClosure: rootClosure, convertedAndReabstractedClosure: mdi.value, - convertedAndReabstractedClosures: &convertedAndReabstractedClosures) + return + markConvertedAndReabstractedClosuresAsUsed( + rootClosure: rootClosure, convertedAndReabstractedClosure: mdi.value, + convertedAndReabstractedClosures: &convertedAndReabstractedClosures) default: log("Parent function of pullbackClosureInfo: \(rootClosure.parentFunction)") log("Root closure: \(rootClosure)") log("Converted/reabstracted closure: \(convertedAndReabstractedClosure)") - fatalError("While marking converted/reabstracted closures as used, found unexpected instruction: \(convertedAndReabstractedClosure)") + fatalError( + "While marking converted/reabstracted closures as used, found unexpected instruction: \(convertedAndReabstractedClosure)" + ) } } } -private extension SpecializationCloner { - func cloneAndSpecializeFunctionBody(using pullbackClosureInfo: PullbackClosureInfo) { +extension SpecializationCloner { + fileprivate func cloneAndSpecializeFunctionBody(using pullbackClosureInfo: PullbackClosureInfo) { self.cloneEntryBlockArgsWithoutOrigClosures(usingOrigCalleeAt: pullbackClosureInfo) - let (allSpecializedEntryBlockArgs, closureArgIndexToAllClonedReleasableClosures) = cloneAllClosures(at: pullbackClosureInfo) + let (allSpecializedEntryBlockArgs, closureArgIndexToAllClonedReleasableClosures) = + cloneAllClosures(at: pullbackClosureInfo) - self.cloneFunctionBody(from: pullbackClosureInfo.pullbackFn, entryBlockArguments: allSpecializedEntryBlockArgs) + self.cloneFunctionBody( + from: pullbackClosureInfo.pullbackFn, entryBlockArguments: allSpecializedEntryBlockArgs) self.insertCleanupCodeForClonedReleasableClosures( - from: pullbackClosureInfo, closureArgIndexToAllClonedReleasableClosures: closureArgIndexToAllClonedReleasableClosures) + from: pullbackClosureInfo, + closureArgIndexToAllClonedReleasableClosures: closureArgIndexToAllClonedReleasableClosures) } - private func cloneEntryBlockArgsWithoutOrigClosures(usingOrigCalleeAt pullbackClosureInfo: PullbackClosureInfo) { + private func cloneEntryBlockArgsWithoutOrigClosures( + usingOrigCalleeAt pullbackClosureInfo: PullbackClosureInfo + ) { let originalEntryBlock = pullbackClosureInfo.pullbackFn.entryBlock let clonedFunction = self.cloned let clonedEntryBlock = self.entryBlock @@ -707,7 +758,8 @@ private extension SpecializationCloner { .filter { index, _ in !pullbackClosureInfo.hasClosureArg(at: index) } .forEach { _, arg in let clonedEntryBlockArgType = arg.type.getLoweredType(in: clonedFunction) - let clonedEntryBlockArg = clonedEntryBlock.addFunctionArgument(type: clonedEntryBlockArgType, self.context) + let clonedEntryBlockArg = clonedEntryBlock.addFunctionArgument( + type: clonedEntryBlockArgType, self.context) clonedEntryBlockArg.copyFlags(from: arg as! FunctionArgument, self.context) } } @@ -723,9 +775,11 @@ private extension SpecializationCloner { /// of corresponding releasable closures cloned into the specialized function. We have a "list" because we clone /// "closure chains", which consist of a "root" closure and its conversions/reabstractions. This map is used to /// generate cleanup code for the cloned closures in the specialized function. - private func cloneAllClosures(at pullbackClosureInfo: PullbackClosureInfo) - -> (allSpecializedEntryBlockArgs: [Value], - closureArgIndexToAllClonedReleasableClosures: [Int: [SingleValueInstruction]]) + private func cloneAllClosures(at pullbackClosureInfo: PullbackClosureInfo) + -> ( + allSpecializedEntryBlockArgs: [Value], + closureArgIndexToAllClonedReleasableClosures: [Int: [SingleValueInstruction]] + ) { func entryBlockArgsWithOrigClosuresSkipped() -> [Value?] { var clonedNonClosureEntryBlockArgs = self.entryBlock.arguments.makeIterator() @@ -752,97 +806,116 @@ private extension SpecializationCloner { self.cloneClosureChain(representedBy: closureArgDesc, at: pullbackClosureInfo) entryBlockArgs[closureArgDesc.closureArgIndex] = finalClonedReabstractedClosure - closureArgIndexToAllClonedReleasableClosures[closureArgDesc.closureArgIndex] = allClonedReleasableClosures + closureArgIndexToAllClonedReleasableClosures[closureArgDesc.closureArgIndex] = + allClonedReleasableClosures } return (entryBlockArgs.map { $0! }, closureArgIndexToAllClonedReleasableClosures) } - private func cloneClosureChain(representedBy closureArgDesc: ClosureArgDescriptor, at pullbackClosureInfo: PullbackClosureInfo) - -> (finalClonedReabstractedClosure: SingleValueInstruction, allClonedReleasableClosures: [SingleValueInstruction]) + private func cloneClosureChain( + representedBy closureArgDesc: ClosureArgDescriptor, at pullbackClosureInfo: PullbackClosureInfo + ) + -> ( + finalClonedReabstractedClosure: SingleValueInstruction, + allClonedReleasableClosures: [SingleValueInstruction] + ) { - let (origToClonedValueMap, capturedArgRange) = self.addEntryBlockArgs(forValuesCapturedBy: closureArgDesc) + let (origToClonedValueMap, capturedArgRange) = self.addEntryBlockArgs( + forValuesCapturedBy: closureArgDesc) let clonedFunction = self.cloned let clonedEntryBlock = self.entryBlock let clonedClosureArgs = Array(clonedEntryBlock.arguments[capturedArgRange]) - let builder = clonedEntryBlock.instructions.isEmpty - ? Builder(atStartOf: clonedFunction, self.context) - : Builder(atEndOf: clonedEntryBlock, location: clonedEntryBlock.instructions.last!.location, self.context) + let builder = + clonedEntryBlock.instructions.isEmpty + ? Builder(atStartOf: clonedFunction, self.context) + : Builder( + atEndOf: clonedEntryBlock, location: clonedEntryBlock.instructions.last!.location, + self.context) - let clonedRootClosure = builder.cloneRootClosure(representedBy: closureArgDesc, capturedArguments: clonedClosureArgs) + let clonedRootClosure = builder.cloneRootClosure( + representedBy: closureArgDesc, capturedArguments: clonedClosureArgs) let finalClonedReabstractedClosure = - builder.cloneRootClosureReabstractions(rootClosure: closureArgDesc.closure, clonedRootClosure: clonedRootClosure, - reabstractedClosure: pullbackClosureInfo.appliedArgForClosure(at: closureArgDesc.closureArgIndex)!, - origToClonedValueMap: origToClonedValueMap, - self.context) - - let allClonedReleasableClosures = [ finalClonedReabstractedClosure ]; + builder.cloneRootClosureReabstractions( + rootClosure: closureArgDesc.closure, clonedRootClosure: clonedRootClosure, + reabstractedClosure: pullbackClosureInfo.appliedArgForClosure( + at: closureArgDesc.closureArgIndex)!, + origToClonedValueMap: origToClonedValueMap, + self.context) + + let allClonedReleasableClosures = [finalClonedReabstractedClosure] return (finalClonedReabstractedClosure, allClonedReleasableClosures) } - private func addEntryBlockArgs(forValuesCapturedBy closureArgDesc: ClosureArgDescriptor) - -> (origToClonedValueMap: [HashableValue: Value], capturedArgRange: Range) + private func addEntryBlockArgs(forValuesCapturedBy closureArgDesc: ClosureArgDescriptor) + -> (origToClonedValueMap: [HashableValue: Value], capturedArgRange: Range) { var origToClonedValueMap: [HashableValue: Value] = [:] let clonedFunction = self.cloned let clonedEntryBlock = self.entryBlock let capturedArgRangeStart = clonedEntryBlock.arguments.count - + for arg in closureArgDesc.arguments { - let capturedArg = clonedEntryBlock.addFunctionArgument(type: arg.type.getLoweredType(in: clonedFunction), - self.context) + let capturedArg = clonedEntryBlock.addFunctionArgument( + type: arg.type.getLoweredType(in: clonedFunction), + self.context) origToClonedValueMap[arg] = capturedArg } let capturedArgRangeEnd = clonedEntryBlock.arguments.count - let capturedArgRange = capturedArgRangeStart == capturedArgRangeEnd - ? 0..<0 - : capturedArgRangeStart.. Value? { +extension [HashableValue: Value] { + fileprivate subscript(key: Value) -> Value? { get { self[key.hashable] } @@ -852,8 +925,8 @@ private extension [HashableValue: Value] { } } -private extension PullbackClosureInfo { - enum NewApplyArg { +extension PullbackClosureInfo { + fileprivate enum NewApplyArg { case Original(Value) // TODO: This can be simplified in OSSA. We can just do a copy_value for everything - except for addresses??? case PreviouslyCaptured( @@ -869,7 +942,7 @@ private extension PullbackClosureInfo { } } - func getArgumentsForSpecializedApply(of specializedCallee: Function) -> [NewApplyArg] + fileprivate func getArgumentsForSpecializedApply(of specializedCallee: Function) -> [NewApplyArg] { var newApplyArgs: [NewApplyArg] = [] @@ -884,11 +957,14 @@ private extension PullbackClosureInfo { // Previously captured arguments for closureArgDesc in self.closureArgDescriptors { for (applySiteIndex, capturedArg) in closureArgDesc.arguments.enumerated() { - let needsRetain = closureArgDesc.isCapturedArgNonTrivialObjectType(applySiteIndex: applySiteIndex, - specializedCallee: specializedCallee) - - newApplyArgs.append(.PreviouslyCaptured(value: capturedArg, needsRetain: needsRetain, - parentClosureArgIndex: closureArgDesc.closureArgIndex)) + let needsRetain = closureArgDesc.isCapturedArgNonTrivialObjectType( + applySiteIndex: applySiteIndex, + specializedCallee: specializedCallee) + + newApplyArgs.append( + .PreviouslyCaptured( + value: capturedArg, needsRetain: needsRetain, + parentClosureArgIndex: closureArgDesc.closureArgIndex)) } } @@ -896,106 +972,126 @@ private extension PullbackClosureInfo { } } -private extension ClosureArgDescriptor { - func isCapturedArgNonTrivialObjectType(applySiteIndex: Int, specializedCallee: Function) -> Bool { - precondition(self.closure is PartialApplyInst, "ClosureArgDescriptor is not for a partial_apply closure!") +extension ClosureArgDescriptor { + fileprivate func isCapturedArgNonTrivialObjectType( + applySiteIndex: Int, specializedCallee: Function + ) -> Bool { + precondition( + self.closure is PartialApplyInst, "ClosureArgDescriptor is not for a partial_apply closure!") let capturedArg = self.arguments[applySiteIndex] let pai = self.closure as! PartialApplyInst let capturedArgIndexInCallee = applySiteIndex + pai.unappliedArgumentCount let capturedArgConvention = self.callee.argumentConventions[capturedArgIndexInCallee] - return !capturedArg.type.isTrivial(in: specializedCallee) && - !capturedArgConvention.isAllowedIndirectConvForClosureSpec + return !capturedArg.type.isTrivial(in: specializedCallee) + && !capturedArgConvention.isAllowedIndirectConvForClosureSpec } } -private extension Builder { - func cloneRootClosure(representedBy closureArgDesc: ClosureArgDescriptor, capturedArguments: [Value]) - -> SingleValueInstruction +extension Builder { + fileprivate func cloneRootClosure( + representedBy closureArgDesc: ClosureArgDescriptor, capturedArguments: [Value] + ) + -> SingleValueInstruction { let function = self.createFunctionRef(closureArgDesc.callee) if let pai = closureArgDesc.closure as? PartialApplyInst { - return self.createPartialApply(function: function, substitutionMap: SubstitutionMap(), - capturedArguments: capturedArguments, calleeConvention: pai.calleeConvention, - hasUnknownResultIsolation: pai.hasUnknownResultIsolation, - isOnStack: pai.isOnStack) + return self.createPartialApply( + function: function, substitutionMap: SubstitutionMap(), + capturedArguments: capturedArguments, calleeConvention: pai.calleeConvention, + hasUnknownResultIsolation: pai.hasUnknownResultIsolation, + isOnStack: pai.isOnStack) } else { - return self.createThinToThickFunction(thinFunction: function, resultType: closureArgDesc.closure.type) + return self.createThinToThickFunction( + thinFunction: function, resultType: closureArgDesc.closure.type) } } - func cloneRootClosureReabstractions(rootClosure: Value, clonedRootClosure: Value, reabstractedClosure: Value, - origToClonedValueMap: [HashableValue: Value], _ context: FunctionPassContext) + fileprivate func cloneRootClosureReabstractions( + rootClosure: Value, clonedRootClosure: Value, reabstractedClosure: Value, + origToClonedValueMap: [HashableValue: Value], _ context: FunctionPassContext + ) -> SingleValueInstruction { - func inner(_ rootClosure: Value, _ clonedRootClosure: Value, _ reabstractedClosure: Value, - _ origToClonedValueMap: inout [HashableValue: Value]) -> Value { + func inner( + _ rootClosure: Value, _ clonedRootClosure: Value, _ reabstractedClosure: Value, + _ origToClonedValueMap: inout [HashableValue: Value] + ) -> Value { switch reabstractedClosure { - case let reabstractedClosure where reabstractedClosure == rootClosure: - origToClonedValueMap[reabstractedClosure] = clonedRootClosure - return clonedRootClosure - - case let cvt as ConvertFunctionInst: - let toBeReabstracted = inner(rootClosure, clonedRootClosure, cvt.fromFunction, - &origToClonedValueMap) - let reabstracted = self.createConvertFunction(originalFunction: toBeReabstracted, resultType: cvt.type, - withoutActuallyEscaping: cvt.withoutActuallyEscaping) - origToClonedValueMap[cvt] = reabstracted - return reabstracted - - case let cvt as ConvertEscapeToNoEscapeInst: - let toBeReabstracted = inner(rootClosure, clonedRootClosure, cvt.fromFunction, - &origToClonedValueMap) - let reabstracted = self.createConvertEscapeToNoEscape(originalFunction: toBeReabstracted, resultType: cvt.type, - isLifetimeGuaranteed: true) - origToClonedValueMap[cvt] = reabstracted - return reabstracted - - case let pai as PartialApplyInst: - let toBeReabstracted = inner(rootClosure, clonedRootClosure, pai.arguments[0], - &origToClonedValueMap) - - guard let function = pai.referencedFunction else { - log("Parent function of pullbackClosureInfo: \(rootClosure.parentFunction)") - log("Root closure: \(rootClosure)") - log("Unsupported reabstraction closure: \(pai)") - fatalError("Encountered unsupported reabstraction (via partial_apply) of root closure!") - } - - let fri = self.createFunctionRef(function) - let reabstracted = self.createPartialApply(function: fri, substitutionMap: SubstitutionMap(), - capturedArguments: [toBeReabstracted], - calleeConvention: pai.calleeConvention, - hasUnknownResultIsolation: pai.hasUnknownResultIsolation, - isOnStack: pai.isOnStack) - origToClonedValueMap[pai] = reabstracted - return reabstracted - - case let mdi as MarkDependenceInst: - let toBeReabstracted = inner(rootClosure, clonedRootClosure, mdi.value, &origToClonedValueMap) - let base = origToClonedValueMap[mdi.base]! - let reabstracted = self.createMarkDependence(value: toBeReabstracted, base: base, kind: .Escaping) - origToClonedValueMap[mdi] = reabstracted - return reabstracted - - default: + case let reabstractedClosure where reabstractedClosure == rootClosure: + origToClonedValueMap[reabstractedClosure] = clonedRootClosure + return clonedRootClosure + + case let cvt as ConvertFunctionInst: + let toBeReabstracted = inner( + rootClosure, clonedRootClosure, cvt.fromFunction, + &origToClonedValueMap) + let reabstracted = self.createConvertFunction( + originalFunction: toBeReabstracted, resultType: cvt.type, + withoutActuallyEscaping: cvt.withoutActuallyEscaping) + origToClonedValueMap[cvt] = reabstracted + return reabstracted + + case let cvt as ConvertEscapeToNoEscapeInst: + let toBeReabstracted = inner( + rootClosure, clonedRootClosure, cvt.fromFunction, + &origToClonedValueMap) + let reabstracted = self.createConvertEscapeToNoEscape( + originalFunction: toBeReabstracted, resultType: cvt.type, + isLifetimeGuaranteed: true) + origToClonedValueMap[cvt] = reabstracted + return reabstracted + + case let pai as PartialApplyInst: + let toBeReabstracted = inner( + rootClosure, clonedRootClosure, pai.arguments[0], + &origToClonedValueMap) + + guard let function = pai.referencedFunction else { log("Parent function of pullbackClosureInfo: \(rootClosure.parentFunction)") log("Root closure: \(rootClosure)") - log("Converted/reabstracted closure: \(reabstractedClosure)") - fatalError("Encountered unsupported reabstraction of root closure: \(reabstractedClosure)") + log("Unsupported reabstraction closure: \(pai)") + fatalError("Encountered unsupported reabstraction (via partial_apply) of root closure!") + } + + let fri = self.createFunctionRef(function) + let reabstracted = self.createPartialApply( + function: fri, substitutionMap: SubstitutionMap(), + capturedArguments: [toBeReabstracted], + calleeConvention: pai.calleeConvention, + hasUnknownResultIsolation: pai.hasUnknownResultIsolation, + isOnStack: pai.isOnStack) + origToClonedValueMap[pai] = reabstracted + return reabstracted + + case let mdi as MarkDependenceInst: + let toBeReabstracted = inner( + rootClosure, clonedRootClosure, mdi.value, &origToClonedValueMap) + let base = origToClonedValueMap[mdi.base]! + let reabstracted = self.createMarkDependence( + value: toBeReabstracted, base: base, kind: .Escaping) + origToClonedValueMap[mdi] = reabstracted + return reabstracted + + default: + log("Parent function of pullbackClosureInfo: \(rootClosure.parentFunction)") + log("Root closure: \(rootClosure)") + log("Converted/reabstracted closure: \(reabstractedClosure)") + fatalError("Encountered unsupported reabstraction of root closure: \(reabstractedClosure)") } } var origToClonedValueMap = origToClonedValueMap - let finalClonedReabstractedClosure = inner(rootClosure, clonedRootClosure, reabstractedClosure, - &origToClonedValueMap) + let finalClonedReabstractedClosure = inner( + rootClosure, clonedRootClosure, reabstractedClosure, + &origToClonedValueMap) return (finalClonedReabstractedClosure as! SingleValueInstruction) } - func destroyPartialApply(pai: PartialApplyInst, _ context: FunctionPassContext){ - // TODO: Support only OSSA instructions once the OSSA elimination pass is moved after all function optimization + fileprivate func destroyPartialApply(pai: PartialApplyInst, _ context: FunctionPassContext) { + // TODO: Support only OSSA instructions once the OSSA elimination pass is moved after all function optimization // passes. if pai.isOnStack { @@ -1005,8 +1101,8 @@ private extension Builder { // self.createDestroyValue(operand: pai) if pai.parentFunction.hasOwnership { - // Under OSSA, the closure acts as an owned value whose lifetime is a borrow scope for the captures, so we need to - // end the borrow scope before ending the lifetimes of the captures themselves. + // Under OSSA, the closure acts as an owned value whose lifetime is a borrow scope for the captures, so we need to + // end the borrow scope before ending the lifetimes of the captures themselves. self.createDestroyValue(operand: pai) self.destroyCapturedArgs(for: pai) } else { @@ -1024,8 +1120,10 @@ private extension Builder { } } -private extension FunctionConvention { - func getSpecializedParameters(basedOn pullbackClosureInfo: PullbackClosureInfo) -> [ParameterInfo] { +extension FunctionConvention { + fileprivate func getSpecializedParameters(basedOn pullbackClosureInfo: PullbackClosureInfo) + -> [ParameterInfo] + { let pullbackFn = pullbackClosureInfo.pullbackFn var specializedParamInfoList: [ParameterInfo] = [] @@ -1040,7 +1138,7 @@ private extension FunctionConvention { // Now, append parameters captured by each of the original closure parameter. // - // Captured parameters are always appended to the function signature. If the argument type of the captured + // Captured parameters are always appended to the function signature. If the argument type of the captured // parameter in the callee is: // - direct and trivial, pass the new parameter as Direct_Unowned. // - direct and non-trivial, pass the new parameter as Direct_Owned. @@ -1050,16 +1148,19 @@ private extension FunctionConvention { if let closure = closureArgDesc.closure as? PartialApplyInst { let closureCallee = closureArgDesc.callee let closureCalleeConvention = closureCallee.convention - let unappliedArgumentCount = closure.unappliedArgumentCount - closureCalleeConvention.indirectSILResultCount + let unappliedArgumentCount = + closure.unappliedArgumentCount - closureCalleeConvention.indirectSILResultCount let prevCapturedParameters = closureCalleeConvention .parameters[unappliedArgumentCount...] .enumerated() .map { index, paramInfo in - let argIndexOfParam = closureCallee.argumentConventions.firstParameterIndex + unappliedArgumentCount + index + let argIndexOfParam = + closureCallee.argumentConventions.firstParameterIndex + unappliedArgumentCount + index let argType = closureCallee.argumentTypes[argIndexOfParam] - return paramInfo.withSpecializedConvention(isArgTypeTrivial: argType.isTrivial(in: closureCallee)) + return paramInfo.withSpecializedConvention( + isArgTypeTrivial: argType.isTrivial(in: closureCallee)) } specializedParamInfoList.append(contentsOf: prevCapturedParameters) @@ -1070,23 +1171,25 @@ private extension FunctionConvention { } } -private extension ParameterInfo { - func withSpecializedConvention(isArgTypeTrivial: Bool) -> Self { - let specializedParamConvention = self.convention.isAllowedIndirectConvForClosureSpec +extension ParameterInfo { + fileprivate func withSpecializedConvention(isArgTypeTrivial: Bool) -> Self { + let specializedParamConvention = + self.convention.isAllowedIndirectConvForClosureSpec ? self.convention : isArgTypeTrivial ? ArgumentConvention.directUnowned : ArgumentConvention.directOwned - return ParameterInfo(type: self.type, convention: specializedParamConvention, options: self.options, - hasLoweredAddresses: self.hasLoweredAddresses) + return ParameterInfo( + type: self.type, convention: specializedParamConvention, options: self.options, + hasLoweredAddresses: self.hasLoweredAddresses) } - var isTrivialNoescapeClosure: Bool { + fileprivate var isTrivialNoescapeClosure: Bool { SILFunctionType_isTrivialNoescape(type.bridged) } } -private extension ArgumentConvention { - var isAllowedIndirectConvForClosureSpec: Bool { +extension ArgumentConvention { + fileprivate var isAllowedIndirectConvForClosureSpec: Bool { switch self { case .indirectInout, .indirectInoutAliasable: return true @@ -1096,15 +1199,15 @@ private extension ArgumentConvention { } } -private extension PartialApplyInst { +extension PartialApplyInst { /// True, if the closure obtained from this partial_apply is the /// pullback returned from an autodiff VJP - var isPullbackInResultOfAutodiffVJP: Bool { + fileprivate var isPullbackInResultOfAutodiffVJP: Bool { if self.parentFunction.isAutodiffVJP, - let use = self.uses.singleUse, - let tupleInst = use.instruction as? TupleInst, - let returnInst = self.parentFunction.returnInstruction, - tupleInst == returnInst.returnedValue + let use = self.uses.singleUse, + let tupleInst = use.instruction as? TupleInst, + let returnInst = self.parentFunction.returnInstruction, + tupleInst == returnInst.returnedValue { return true } @@ -1112,53 +1215,55 @@ private extension PartialApplyInst { return false } - var isPartialApplyOfThunk: Bool { - if self.numArguments == 1, - let fun = self.referencedFunction, - fun.thunkKind == .reabstractionThunk || fun.thunkKind == .thunk, - self.arguments[0].type.isLoweredFunction, - self.arguments[0].type.isReferenceCounted(in: self.parentFunction) || self.callee.type.isThickFunction + fileprivate var isPartialApplyOfThunk: Bool { + if self.numArguments == 1, + let fun = self.referencedFunction, + fun.thunkKind == .reabstractionThunk || fun.thunkKind == .thunk, + self.arguments[0].type.isLoweredFunction, + self.arguments[0].type.isReferenceCounted(in: self.parentFunction) + || self.callee.type.isThickFunction { return true } - + return false } - var hasOnlyInoutIndirectArguments: Bool { + fileprivate var hasOnlyInoutIndirectArguments: Bool { self.argumentOperands .filter { !$0.value.type.isObject } - .allSatisfy { self.convention(of: $0)!.isInout } + .allSatisfy { self.convention(of: $0)!.isInout } } } -private extension Instruction { - var asSupportedClosure: SingleValueInstruction? { +extension Instruction { + fileprivate var asSupportedClosure: SingleValueInstruction? { switch self { case let tttf as ThinToThickFunctionInst where tttf.callee is FunctionRefInst: return tttf // TODO: figure out what to do with non-inout indirect arguments // https://forums.swift.org/t/non-inout-indirect-types-not-supported-in-closure-specialization-optimization/70826 - case let pai as PartialApplyInst where pai.callee is FunctionRefInst && pai.hasOnlyInoutIndirectArguments: + case let pai as PartialApplyInst + where pai.callee is FunctionRefInst && pai.hasOnlyInoutIndirectArguments: return pai default: return nil } } - var isSupportedClosure: Bool { + fileprivate var isSupportedClosure: Bool { asSupportedClosure != nil } } -private extension ApplySite { - var calleeIsDynamicFunctionRef: Bool { +extension ApplySite { + fileprivate var calleeIsDynamicFunctionRef: Bool { return !(callee is DynamicFunctionRefInst || callee is PreviousDynamicFunctionRefInst) } } -private extension Function { - var effectAllowsSpecialization: Bool { +extension Function { + fileprivate var effectAllowsSpecialization: Bool { switch self.effectAttribute { case .readNone, .readOnly, .releaseNone: return false default: return true @@ -1191,11 +1296,11 @@ private struct OrderedDict { } } - var keys: LazyMapSequence, Key> { + var keys: LazyMapSequence<[(Key, Value)], Key> { entryList.lazy.map { $0.0 } } - var values: LazyMapSequence, Value> { + var values: LazyMapSequence<[(Key, Value)], Value> { entryList.lazy.map { $0.1 } } } @@ -1312,7 +1417,8 @@ private struct PullbackClosureInfo { func appliedArgForClosure(at index: Int) -> Value? { if let closureArgDesc = closureArgDesc(at: index) { - return paiOfPullback.arguments[closureArgDesc.closureArgIndex - paiOfPullback.unappliedArgumentCount] + return paiOfPullback.arguments[ + closureArgDesc.closureArgIndex - paiOfPullback.unappliedArgumentCount] } return nil @@ -1322,14 +1428,17 @@ private struct PullbackClosureInfo { let closureArgs = Array(self.closureArgDescriptors.map { $0.closure }) let closureIndices = Array(self.closureArgDescriptors.map { $0.closureArgIndex }) - return context.mangle(withClosureArguments: closureArgs, closureArgIndices: closureIndices, - from: pullbackFn) + return context.mangle( + withClosureArguments: closureArgs, closureArgIndices: closureIndices, + from: pullbackFn) } } // ===================== Unit tests ===================== // -let getPullbackClosureInfoTest = FunctionTest("autodiff_closure_specialize_get_pullback_closure_info") { function, arguments, context in +let getPullbackClosureInfoTest = FunctionTest( + "autodiff_closure_specialize_get_pullback_closure_info" +) { function, arguments, context in print("Specializing closures in function: \(function.name)") print("===============================================") let pullbackClosureInfo = getPullbackClosureInfo(in: function, context)! @@ -1343,20 +1452,25 @@ let getPullbackClosureInfoTest = FunctionTest("autodiff_closure_specialize_get_p } let specializedFunctionSignatureAndBodyTest = FunctionTest( - "autodiff_closure_specialize_specialized_function_signature_and_body") { function, arguments, context in + "autodiff_closure_specialize_specialized_function_signature_and_body" +) { function, arguments, context in let pullbackClosureInfo = getPullbackClosureInfo(in: function, context)! - let (specializedFunction, _) = getOrCreateSpecializedFunction(basedOn: pullbackClosureInfo, context) + let (specializedFunction, _) = getOrCreateSpecializedFunction( + basedOn: pullbackClosureInfo, context) print("Generated specialized function: \(specializedFunction.name)") print("\(specializedFunction)\n") } -let rewrittenCallerBodyTest = FunctionTest("autodiff_closure_specialize_rewritten_caller_body") { function, arguments, context in +let rewrittenCallerBodyTest = FunctionTest("autodiff_closure_specialize_rewritten_caller_body") { + function, arguments, context in let pullbackClosureInfo = getPullbackClosureInfo(in: function, context)! - let (specializedFunction, _) = getOrCreateSpecializedFunction(basedOn: pullbackClosureInfo, context) - rewriteApplyInstruction(using: specializedFunction, pullbackClosureInfo: pullbackClosureInfo, context) + let (specializedFunction, _) = getOrCreateSpecializedFunction( + basedOn: pullbackClosureInfo, context) + rewriteApplyInstruction( + using: specializedFunction, pullbackClosureInfo: pullbackClosureInfo, context) print("Rewritten caller body for: \(function.name):") print("\(function)\n")