diff --git a/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift b/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift index c2a52eb69713c..ff2458dfa702b 100644 --- a/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift +++ b/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift @@ -122,36 +122,29 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special (function: Function, context: FunctionPassContext) in guard !function.isDefinedExternally, - function.isAutodiffVJP, - function.blocks.singleElement != nil else { + function.isAutodiffVJP else { return } var remainingSpecializationRounds = 5 repeat { - // TODO: Names here are pretty misleading. We are looking for a place where - // the pullback closure is created (so for `partial_apply` instruction). - var callSites = gatherCallSites(in: function, context) - guard !callSites.isEmpty else { - return + guard let pullbackClosureInfo = getPullbackClosureInfo(in: function, context) else { + break } - for callSite in callSites { - var (specializedFunction, alreadyExists) = getOrCreateSpecializedFunction(basedOn: callSite, context) + var (specializedFunction, alreadyExists) = getOrCreateSpecializedFunction(basedOn: pullbackClosureInfo, context) - if !alreadyExists { - context.notifyNewFunction(function: specializedFunction, derivedFrom: callSite.applyCallee) - } - - rewriteApplyInstruction(using: specializedFunction, callSite: callSite, context) + if !alreadyExists { + context.notifyNewFunction(function: specializedFunction, derivedFrom: pullbackClosureInfo.pullbackFn) } - var deadClosures: InstructionWorklist = callSites.reduce(into: InstructionWorklist(context)) { deadClosures, callSite in - callSite.closureArgDescriptors + rewriteApplyInstruction(using: specializedFunction, pullbackClosureInfo: pullbackClosureInfo, context) + + var deadClosures = InstructionWorklist(context) + pullbackClosureInfo.closureArgDescriptors .map { $0.closure } .forEach { deadClosures.pushIfNotVisited($0) } - } defer { deadClosures.deinitialize() @@ -176,7 +169,7 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special private let specializationLevelLimit = 2 -private func gatherCallSites(in caller: Function, _ context: FunctionPassContext) -> [CallSite] { +private func getPullbackClosureInfo(in caller: Function, _ context: FunctionPassContext) -> PullbackClosureInfo? { /// __Root__ closures created via `partial_apply` or `thin_to_thick_function` may be converted and reabstracted /// before finally being used at an apply site. We do not want to handle these intermediate closures separately /// as they are handled and cloned into the specialized function as part of the root closures. Therefore, we keep @@ -208,59 +201,59 @@ private func gatherCallSites(in caller: Function, _ context: FunctionPassContext convertedAndReabstractedClosures.deinitialize() } - var callSiteMap = CallSiteMap() + var pullbackClosureInfoOpt = PullbackClosureInfo?(nil) for inst in caller.instructions { if !convertedAndReabstractedClosures.contains(inst), let rootClosure = inst.asSupportedClosure { - updateCallSites(for: rootClosure, in: &callSiteMap, - convertedAndReabstractedClosures: &convertedAndReabstractedClosures, context) + updatePullbackClosureInfo(for: rootClosure, in: &pullbackClosureInfoOpt, + convertedAndReabstractedClosures: &convertedAndReabstractedClosures, context) } } - return callSiteMap.callSites + return pullbackClosureInfoOpt } -private func getOrCreateSpecializedFunction(basedOn callSite: CallSite, _ context: FunctionPassContext) +private func getOrCreateSpecializedFunction(basedOn pullbackClosureInfo: PullbackClosureInfo, _ context: FunctionPassContext) -> (function: Function, alreadyExists: Bool) { - let specializedFunctionName = callSite.specializedCalleeName(context) + let specializedFunctionName = pullbackClosureInfo.specializedCalleeName(context) if let specializedFunction = context.lookupFunction(name: specializedFunctionName) { return (specializedFunction, true) } - let applySiteCallee = callSite.applyCallee - let specializedParameters = applySiteCallee.convention.getSpecializedParameters(basedOn: callSite) + let pullbackFn = pullbackClosureInfo.pullbackFn + let specializedParameters = pullbackFn.convention.getSpecializedParameters(basedOn: pullbackClosureInfo) let specializedFunction = - context.createSpecializedFunctionDeclaration(from: applySiteCallee, withName: specializedFunctionName, + context.createSpecializedFunctionDeclaration(from: pullbackFn, withName: specializedFunctionName, withParams: specializedParameters, makeThin: true, makeBare: true) context.buildSpecializedFunction(specializedFunction: specializedFunction, buildFn: { (emptySpecializedFunction, functionPassContext) in let closureSpecCloner = SpecializationCloner(emptySpecializedFunction: emptySpecializedFunction, functionPassContext) - closureSpecCloner.cloneAndSpecializeFunctionBody(using: callSite) + closureSpecCloner.cloneAndSpecializeFunctionBody(using: pullbackClosureInfo) }) return (specializedFunction, false) } -private func rewriteApplyInstruction(using specializedCallee: Function, callSite: CallSite, +private func rewriteApplyInstruction(using specializedCallee: Function, pullbackClosureInfo: PullbackClosureInfo, _ context: FunctionPassContext) { - let newApplyArgs = callSite.getArgumentsForSpecializedApply(of: specializedCallee) + let newApplyArgs = pullbackClosureInfo.getArgumentsForSpecializedApply(of: specializedCallee) for newApplyArg in newApplyArgs { if case let .PreviouslyCaptured(capturedArg, needsRetain, parentClosureArgIndex) = newApplyArg, needsRetain { - let closureArgDesc = callSite.closureArgDesc(at: parentClosureArgIndex)! + let closureArgDesc = pullbackClosureInfo.closureArgDesc(at: parentClosureArgIndex)! var builder = Builder(before: closureArgDesc.closure, context) // TODO: Support only OSSA instructions once the OSSA elimination pass is moved after all function optimization // passes. - if callSite.applySite.parentBlock != closureArgDesc.closure.parentBlock { + if pullbackClosureInfo.paiOfPullback.parentBlock != closureArgDesc.closure.parentBlock { // Emit the retain and release that keeps the argument live across the callee using the closure. builder.createRetainValue(operand: capturedArg) @@ -271,7 +264,7 @@ private func rewriteApplyInstruction(using specializedCallee: Function, callSite // Emit the retain that matches the captured argument by the partial_apply in the callee that is consumed by // the partial_apply. - builder = Builder(before: callSite.applySite, context) + builder = Builder(before: pullbackClosureInfo.paiOfPullback, context) builder.createRetainValue(operand: capturedArg) } else { builder.createRetainValue(operand: capturedArg) @@ -280,20 +273,20 @@ private func rewriteApplyInstruction(using specializedCallee: Function, callSite } // Rewrite apply instruction - var builder = Builder(before: callSite.applySite, context) - let oldApply = callSite.applySite as! PartialApplyInst + var builder = Builder(before: pullbackClosureInfo.paiOfPullback, context) + let oldPartialApply = pullbackClosureInfo.paiOfPullback let funcRef = builder.createFunctionRef(specializedCallee) let capturedArgs = Array(newApplyArgs.map { $0.value }) - let newApply = builder.createPartialApply(function: funcRef, substitutionMap: SubstitutionMap(), - capturedArguments: capturedArgs, calleeConvention: oldApply.calleeConvention, - hasUnknownResultIsolation: oldApply.hasUnknownResultIsolation, - isOnStack: oldApply.isOnStack) + let newPartialApply = builder.createPartialApply(function: funcRef, substitutionMap: SubstitutionMap(), + capturedArguments: capturedArgs, calleeConvention: oldPartialApply.calleeConvention, + hasUnknownResultIsolation: oldPartialApply.hasUnknownResultIsolation, + isOnStack: oldPartialApply.isOnStack) - builder = Builder(before: callSite.applySite.next!, context) + builder = Builder(before: pullbackClosureInfo.paiOfPullback.next!, context) // TODO: Support only OSSA instructions once the OSSA elimination pass is moved after all function optimization // passes. - for closureArgDesc in callSite.closureArgDescriptors { + for closureArgDesc in pullbackClosureInfo.closureArgDescriptors { if closureArgDesc.isClosureConsumed, !closureArgDesc.isPartialApplyOnStack, !closureArgDesc.parameterInfo.isTrivialNoescapeClosure @@ -302,13 +295,13 @@ private func rewriteApplyInstruction(using specializedCallee: Function, callSite } } - oldApply.replace(with: newApply, context) + oldPartialApply.replace(with: newPartialApply, context) } // ===================== Utility functions and extensions ===================== // -private func updateCallSites(for rootClosure: SingleValueInstruction, in callSiteMap: inout CallSiteMap, - convertedAndReabstractedClosures: inout InstructionSet, _ context: FunctionPassContext) { +private func updatePullbackClosureInfo(for rootClosure: SingleValueInstruction, in pullbackClosureInfoOpt: inout PullbackClosureInfo?, + convertedAndReabstractedClosures: inout InstructionSet, _ context: FunctionPassContext) { var rootClosurePossibleLiveRange = InstructionRange(begin: rootClosure, context) defer { rootClosurePossibleLiveRange.deinitialize() @@ -320,17 +313,17 @@ private func updateCallSites(for rootClosure: SingleValueInstruction, in callSit } // A "root" closure undergoing conversions and/or reabstractions has additional restrictions placed upon it, in order - // for a call site to be specialized against it. We handle conversion/reabstraction uses before we handle apply uses - // to gather the parameters required to evaluate these restrictions or to skip call site uses of "unsupported" + // for a pullback to be specialized against it. We handle conversion/reabstraction uses before we handle apply uses + // to gather the parameters required to evaluate these restrictions or to skip pullback's uses of "unsupported" // closures altogether. // - // There are currently 2 restrictions that are evaluated prior to specializing a callsite against a converted and/or + // There are currently 2 restrictions that are evaluated prior to specializing a pullback against a converted and/or // reabstracted closure - // 1. A reabstracted root closure can only be specialized against, if the reabstracted closure is ultimately passed - // trivially (as a noescape+thick function) into the call site. + // trivially (as a noescape+thick function) as captured argument of pullback's partial_apply. // // 2. A root closure may be a partial_apply [stack], in which case we need to make sure that all mark_dependence - // bases for it will be available in the specialized callee in case the call site is specialized against this root + // bases for it will be available in the specialized callee in case the pullback is specialized against this root // closure. let (foundUnexpectedUse, haveUsedReabstraction) = @@ -343,14 +336,18 @@ private func updateCallSites(for rootClosure: SingleValueInstruction, in callSit } let intermediateClosureArgDescriptorData = - handleApplies(for: rootClosure, callSiteMap: &callSiteMap, rootClosureApplies: &rootClosureApplies, + handleApplies(for: rootClosure, pullbackClosureInfoOpt: &pullbackClosureInfoOpt, rootClosureApplies: &rootClosureApplies, rootClosurePossibleLiveRange: &rootClosurePossibleLiveRange, convertedAndReabstractedClosures: &convertedAndReabstractedClosures, haveUsedReabstraction: haveUsedReabstraction, context) - finalizeCallSites(for: rootClosure, in: &callSiteMap, - rootClosurePossibleLiveRange: rootClosurePossibleLiveRange, - intermediateClosureArgDescriptorData: intermediateClosureArgDescriptorData, context) + if pullbackClosureInfoOpt == nil { + return + } + + finalizePullbackClosureInfo(for: rootClosure, in: &pullbackClosureInfoOpt, + rootClosurePossibleLiveRange: rootClosurePossibleLiveRange, + intermediateClosureArgDescriptorData: intermediateClosureArgDescriptorData, context) } /// Handles all non-apply direct and transitive uses of `rootClosure`. @@ -370,7 +367,7 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction, /// The root closure or an intermediate closure created by reabstracting the root closure may be a `partial_apply /// [stack]` and we need to make sure that all `mark_dependence` bases for this `onStack` closure will be available in - /// the specialized callee, in case the call site is specialized against this root closure. + /// the specialized callee, in case the pullback is specialized against this root closure. /// /// `possibleMarkDependenceBases` keeps track of all potential values that may be used as bases for creating /// `mark_dependence`s for our `onStack` root/reabstracted closures. For root closures these values are non-trivial @@ -389,7 +386,7 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction, /// ``` /// /// Any value outside of the aforementioned values is not going to be available in the specialized callee and a - /// `mark_dependence` of the root closure on such a value means that we cannot specialize the call site against it. + /// `mark_dependence` of the root closure on such a value means that we cannot specialize the pullback against it. var possibleMarkDependenceBases = ValueSet(context) defer { possibleMarkDependenceBases.deinitialize() @@ -479,7 +476,7 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction, private typealias IntermediateClosureArgDescriptorDatum = (applySite: SingleValueInstruction, closureArgIndex: Int, paramInfo: ParameterInfo) -private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap: inout CallSiteMap, +private func handleApplies(for rootClosure: SingleValueInstruction, pullbackClosureInfoOpt: inout PullbackClosureInfo?, rootClosureApplies: inout OperandWorklist, rootClosurePossibleLiveRange: inout InstructionRange, convertedAndReabstractedClosures: inout InstructionSet, haveUsedReabstraction: Bool, @@ -586,8 +583,10 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap: convertedAndReabstractedClosures: &convertedAndReabstractedClosures) } - if callSiteMap[pai] == nil { - callSiteMap.insert(key: pai, value: CallSite(applySite: pai)) + if pullbackClosureInfoOpt == nil { + pullbackClosureInfoOpt = PullbackClosureInfo(paiOfPullback: pai) + } else { + assert(pullbackClosureInfoOpt!.paiOfPullback == pai) } intermediateClosureArgDescriptorData @@ -597,23 +596,22 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap: return intermediateClosureArgDescriptorData } -/// Finalizes the call sites for a given root closure by adding a corresponding `ClosureArgDescriptor` -/// to all call sites where the closure is ultimately passed as an argument. -private func finalizeCallSites(for rootClosure: SingleValueInstruction, in callSiteMap: inout CallSiteMap, - rootClosurePossibleLiveRange: InstructionRange, - intermediateClosureArgDescriptorData: [IntermediateClosureArgDescriptorDatum], - _ context: FunctionPassContext) -{ +/// Finalizes the pullback closure info for a given root closure by adding a corresponding `ClosureArgDescriptor` +private func finalizePullbackClosureInfo(for rootClosure: SingleValueInstruction, in pullbackClosureInfoOpt: inout PullbackClosureInfo?, + rootClosurePossibleLiveRange: InstructionRange, + intermediateClosureArgDescriptorData: [IntermediateClosureArgDescriptorDatum], + _ context: FunctionPassContext) { + assert(pullbackClosureInfoOpt != nil) + let closureInfo = ClosureInfo(closure: rootClosure, lifetimeFrontier: Array(rootClosurePossibleLiveRange.ends)) for (applySite, closureArgumentIndex, parameterInfo) in intermediateClosureArgDescriptorData { - guard var callSite = callSiteMap[applySite] else { - fatalError("While finalizing call sites, call site descriptor not found for call site: \(applySite)!") + if pullbackClosureInfoOpt!.paiOfPullback != applySite { + fatalError("ClosureArgDescriptor's applySite field is not equal to pullback's partial_apply; got \(applySite)!") } let closureArgDesc = ClosureArgDescriptor(closureInfo: closureInfo, closureArgumentIndex: closureArgumentIndex, parameterInfo: parameterInfo) - callSite.appendClosureArgDescriptor(closureArgDesc) - callSiteMap.update(key: applySite, value: callSite) + pullbackClosureInfoOpt!.appendClosureArgDescriptor(closureArgDesc) } } @@ -679,7 +677,7 @@ private func markConvertedAndReabstractedClosuresAsUsed(rootClosure: Value, conv markConvertedAndReabstractedClosuresAsUsed(rootClosure: rootClosure, convertedAndReabstractedClosure: mdi.value, convertedAndReabstractedClosures: &convertedAndReabstractedClosures) default: - log("Parent function of callSite: \(rootClosure.parentFunction)") + log("Parent function of pullbackClosureInfo: \(rootClosure.parentFunction)") log("Root closure: \(rootClosure)") log("Converted/reabstracted closure: \(convertedAndReabstractedClosure)") fatalError("While marking converted/reabstracted closures as used, found unexpected instruction: \(convertedAndReabstractedClosure)") @@ -688,25 +686,25 @@ private func markConvertedAndReabstractedClosuresAsUsed(rootClosure: Value, conv } private extension SpecializationCloner { - func cloneAndSpecializeFunctionBody(using callSite: CallSite) { - self.cloneEntryBlockArgsWithoutOrigClosures(usingOrigCalleeAt: callSite) + func cloneAndSpecializeFunctionBody(using pullbackClosureInfo: PullbackClosureInfo) { + self.cloneEntryBlockArgsWithoutOrigClosures(usingOrigCalleeAt: pullbackClosureInfo) - let (allSpecializedEntryBlockArgs, closureArgIndexToAllClonedReleasableClosures) = cloneAllClosures(at: callSite) + let (allSpecializedEntryBlockArgs, closureArgIndexToAllClonedReleasableClosures) = cloneAllClosures(at: pullbackClosureInfo) - self.cloneFunctionBody(from: callSite.applyCallee, entryBlockArguments: allSpecializedEntryBlockArgs) + self.cloneFunctionBody(from: pullbackClosureInfo.pullbackFn, entryBlockArguments: allSpecializedEntryBlockArgs) self.insertCleanupCodeForClonedReleasableClosures( - from: callSite, closureArgIndexToAllClonedReleasableClosures: closureArgIndexToAllClonedReleasableClosures) + from: pullbackClosureInfo, closureArgIndexToAllClonedReleasableClosures: closureArgIndexToAllClonedReleasableClosures) } - private func cloneEntryBlockArgsWithoutOrigClosures(usingOrigCalleeAt callSite: CallSite) { - let originalEntryBlock = callSite.applyCallee.entryBlock + private func cloneEntryBlockArgsWithoutOrigClosures(usingOrigCalleeAt pullbackClosureInfo: PullbackClosureInfo) { + let originalEntryBlock = pullbackClosureInfo.pullbackFn.entryBlock let clonedFunction = self.cloned let clonedEntryBlock = self.entryBlock originalEntryBlock.arguments .enumerated() - .filter { index, _ in !callSite.hasClosureArg(at: index) } + .filter { index, _ in !pullbackClosureInfo.hasClosureArg(at: index) } .forEach { _, arg in let clonedEntryBlockArgType = arg.type.getLoweredType(in: clonedFunction) let clonedEntryBlockArg = clonedEntryBlock.addFunctionArgument(type: clonedEntryBlockArgType, self.context) @@ -714,31 +712,31 @@ private extension SpecializationCloner { } } - /// Clones all closures, originally passed to the callee at the given callSite, into the specialized function. + /// Clones all closures, originally passed to the callee at the given pullbackClosureInfo, into the specialized function. /// /// Returns the following - /// - allSpecializedEntryBlockArgs: Complete list of entry block arguments for the specialized function. This includes /// the original arguments to the function (minus the closure arguments) and the arguments representing the values /// originally captured by the skipped closure arguments. /// - /// - closureArgIndexToAllClonedReleasableClosures: Mapping from a closure's argument index at `callSite` to the list + /// - closureArgIndexToAllClonedReleasableClosures: Mapping from a closure's argument index at `pullbackClosureInfo` to the list /// of corresponding releasable closures cloned into the specialized function. We have a "list" because we clone /// "closure chains", which consist of a "root" closure and its conversions/reabstractions. This map is used to /// generate cleanup code for the cloned closures in the specialized function. - private func cloneAllClosures(at callSite: CallSite) + private func cloneAllClosures(at pullbackClosureInfo: PullbackClosureInfo) -> (allSpecializedEntryBlockArgs: [Value], closureArgIndexToAllClonedReleasableClosures: [Int: [SingleValueInstruction]]) { func entryBlockArgsWithOrigClosuresSkipped() -> [Value?] { var clonedNonClosureEntryBlockArgs = self.entryBlock.arguments.makeIterator() - return callSite.applyCallee + return pullbackClosureInfo.pullbackFn .entryBlock .arguments .enumerated() .reduce(into: []) { result, origArgTuple in let (index, _) = origArgTuple - if !callSite.hasClosureArg(at: index) { + if !pullbackClosureInfo.hasClosureArg(at: index) { result.append(clonedNonClosureEntryBlockArgs.next()) } else { result.append(Optional.none) @@ -749,9 +747,9 @@ private extension SpecializationCloner { var entryBlockArgs: [Value?] = entryBlockArgsWithOrigClosuresSkipped() var closureArgIndexToAllClonedReleasableClosures: [Int: [SingleValueInstruction]] = [:] - for closureArgDesc in callSite.closureArgDescriptors { + for closureArgDesc in pullbackClosureInfo.closureArgDescriptors { let (finalClonedReabstractedClosure, allClonedReleasableClosures) = - self.cloneClosureChain(representedBy: closureArgDesc, at: callSite) + self.cloneClosureChain(representedBy: closureArgDesc, at: pullbackClosureInfo) entryBlockArgs[closureArgDesc.closureArgIndex] = finalClonedReabstractedClosure closureArgIndexToAllClonedReleasableClosures[closureArgDesc.closureArgIndex] = allClonedReleasableClosures @@ -760,7 +758,7 @@ private extension SpecializationCloner { return (entryBlockArgs.map { $0! }, closureArgIndexToAllClonedReleasableClosures) } - private func cloneClosureChain(representedBy closureArgDesc: ClosureArgDescriptor, at callSite: CallSite) + private func cloneClosureChain(representedBy closureArgDesc: ClosureArgDescriptor, at pullbackClosureInfo: PullbackClosureInfo) -> (finalClonedReabstractedClosure: SingleValueInstruction, allClonedReleasableClosures: [SingleValueInstruction]) { let (origToClonedValueMap, capturedArgRange) = self.addEntryBlockArgs(forValuesCapturedBy: closureArgDesc) @@ -776,7 +774,7 @@ private extension SpecializationCloner { let finalClonedReabstractedClosure = builder.cloneRootClosureReabstractions(rootClosure: closureArgDesc.closure, clonedRootClosure: clonedRootClosure, - reabstractedClosure: callSite.appliedArgForClosure(at: closureArgDesc.closureArgIndex)!, + reabstractedClosure: pullbackClosureInfo.appliedArgForClosure(at: closureArgDesc.closureArgIndex)!, origToClonedValueMap: origToClonedValueMap, self.context) @@ -807,10 +805,10 @@ private extension SpecializationCloner { return (origToClonedValueMap, capturedArgRange) } - private func insertCleanupCodeForClonedReleasableClosures(from callSite: CallSite, + private func insertCleanupCodeForClonedReleasableClosures(from pullbackClosureInfo: PullbackClosureInfo, closureArgIndexToAllClonedReleasableClosures: [Int: [SingleValueInstruction]]) { - for closureArgDesc in callSite.closureArgDescriptors { + for closureArgDesc in pullbackClosureInfo.closureArgDescriptors { let allClonedReleasableClosures = closureArgIndexToAllClonedReleasableClosures[closureArgDesc.closureArgIndex]! // Insert a `destroy_value`, for all releasable closures, in all reachable exit BBs if the closure was passed as a @@ -819,7 +817,7 @@ private extension SpecializationCloner { if closureArgDesc.isClosureGuaranteed || closureArgDesc.parameterInfo.isTrivialNoescapeClosure, !allClonedReleasableClosures.isEmpty { - for exitBlock in callSite.reachableExitBBsInCallee { + for exitBlock in pullbackClosureInfo.reachableExitBBsInCallee { let clonedExitBlock = self.getClonedBlock(for: exitBlock) let terminator = clonedExitBlock.terminator is UnreachableInst @@ -854,7 +852,7 @@ private extension [HashableValue: Value] { } } -private extension CallSite { +private extension PullbackClosureInfo { enum NewApplyArg { case Original(Value) // TODO: This can be simplified in OSSA. We can just do a copy_value for everything - except for addresses??? @@ -876,8 +874,8 @@ private extension CallSite { var newApplyArgs: [NewApplyArg] = [] // Original arguments - for (applySiteIndex, arg) in self.applySite.arguments.enumerated() { - let calleeArgIndex = self.applySite.unappliedArgumentCount + applySiteIndex + for (applySiteIndex, arg) in self.paiOfPullback.arguments.enumerated() { + let calleeArgIndex = self.paiOfPullback.unappliedArgumentCount + applySiteIndex if !self.hasClosureArg(at: calleeArgIndex) { newApplyArgs.append(.Original(arg)) } @@ -960,7 +958,7 @@ private extension Builder { &origToClonedValueMap) guard let function = pai.referencedFunction else { - log("Parent function of callSite: \(rootClosure.parentFunction)") + log("Parent function of pullbackClosureInfo: \(rootClosure.parentFunction)") log("Root closure: \(rootClosure)") log("Unsupported reabstraction closure: \(pai)") fatalError("Encountered unsupported reabstraction (via partial_apply) of root closure!") @@ -983,7 +981,7 @@ private extension Builder { return reabstracted default: - log("Parent function of callSite: \(rootClosure.parentFunction)") + log("Parent function of pullbackClosureInfo: \(rootClosure.parentFunction)") log("Root closure: \(rootClosure)") log("Converted/reabstracted closure: \(reabstractedClosure)") fatalError("Encountered unsupported reabstraction of root closure: \(reabstractedClosure)") @@ -1027,15 +1025,15 @@ private extension Builder { } private extension FunctionConvention { - func getSpecializedParameters(basedOn callSite: CallSite) -> [ParameterInfo] { - let applySiteCallee = callSite.applyCallee + func getSpecializedParameters(basedOn pullbackClosureInfo: PullbackClosureInfo) -> [ParameterInfo] { + let pullbackFn = pullbackClosureInfo.pullbackFn var specializedParamInfoList: [ParameterInfo] = [] // Start by adding all original parameters except for the closure parameters. - let firstParamIndex = applySiteCallee.argumentConventions.firstParameterIndex - for (index, paramInfo) in applySiteCallee.convention.parameters.enumerated() { + let firstParamIndex = pullbackFn.argumentConventions.firstParameterIndex + for (index, paramInfo) in pullbackFn.convention.parameters.enumerated() { let argIndex = index + firstParamIndex - if !callSite.hasClosureArg(at: argIndex) { + if !pullbackClosureInfo.hasClosureArg(at: argIndex) { specializedParamInfoList.append(paramInfo) } } @@ -1048,7 +1046,7 @@ private extension FunctionConvention { // - direct and non-trivial, pass the new parameter as Direct_Owned. // - indirect, pass the new parameter using the same parameter convention as in // the original closure. - for closureArgDesc in callSite.closureArgDescriptors { + for closureArgDesc in pullbackClosureInfo.closureArgDescriptors { if let closure = closureArgDesc.closure as? PartialApplyInst { let closureCallee = closureArgDesc.callee let closureCalleeConvention = closureCallee.convention @@ -1202,18 +1200,10 @@ private struct OrderedDict { } } -private typealias CallSiteMap = OrderedDict - -private extension CallSiteMap { - var callSites: [CallSite] { - Array(self.values) - } -} - -/// Represents all the information required to represent a closure in isolation, i.e., outside of a callsite context -/// where the closure may be getting passed as an argument. +/// Represents all the information required to represent a closure in isolation, i.e., outside of a pullback's partial_apply context +/// where the closure may be getting captured as an argument. /// -/// Composed with other information inside a `ClosureArgDescriptor` to represent a closure as an argument at a callsite. +/// Composed with other information inside a `ClosureArgDescriptor` to represent a closure as a captured argument of a pullback's partial_apply. private struct ClosureInfo { let closure: SingleValueInstruction let lifetimeFrontier: [Instruction] @@ -1225,10 +1215,10 @@ private struct ClosureInfo { } -/// Represents a closure as an argument at a callsite. +/// Represents a closure as a captured argument of a pullback's partial_apply. private struct ClosureArgDescriptor { let closureInfo: ClosureInfo - /// The index of the closure in the callsite's argument list. + /// The index of the closure in the pullback's partial_apply argument list. let closureArgumentIndex: Int let parameterInfo: ParameterInfo @@ -1291,25 +1281,25 @@ private struct ClosureArgDescriptor { } } -/// Represents a callsite containing one or more closure arguments. -private struct CallSite { - let applySite: ApplySite +/// Represents a partial_apply of pullback capturing one or more closure arguments. +private struct PullbackClosureInfo { + let paiOfPullback: PartialApplyInst var closureArgDescriptors: [ClosureArgDescriptor] = [] - init(applySite: ApplySite) { - self.applySite = applySite + init(paiOfPullback: PartialApplyInst) { + self.paiOfPullback = paiOfPullback } mutating func appendClosureArgDescriptor(_ descriptor: ClosureArgDescriptor) { self.closureArgDescriptors.append(descriptor) } - var applyCallee: Function { - applySite.referencedFunction! + var pullbackFn: Function { + paiOfPullback.referencedFunction! } var reachableExitBBsInCallee: [BasicBlock] { - applyCallee.blocks.filter { $0.isReachableExitBlock } + pullbackFn.blocks.filter { $0.isReachableExitBlock } } func hasClosureArg(at index: Int) -> Bool { @@ -1322,7 +1312,7 @@ private struct CallSite { func appliedArgForClosure(at index: Int) -> Value? { if let closureArgDesc = closureArgDesc(at: index) { - return applySite.arguments[closureArgDesc.closureArgIndex - applySite.unappliedArgumentCount] + return paiOfPullback.arguments[closureArgDesc.closureArgIndex - paiOfPullback.unappliedArgumentCount] } return nil @@ -1333,48 +1323,41 @@ private struct CallSite { let closureIndices = Array(self.closureArgDescriptors.map { $0.closureArgIndex }) return context.mangle(withClosureArguments: closureArgs, closureArgIndices: closureIndices, - from: applyCallee) + from: pullbackFn) } } // ===================== Unit tests ===================== // -let gatherCallSitesTest = FunctionTest("closure_specialize_gather_call_sites") { function, arguments, context in +let getPullbackClosureInfoTest = FunctionTest("autodiff_closure_specialize_get_pullback_closure_info") { function, arguments, context in print("Specializing closures in function: \(function.name)") print("===============================================") - var callSites = gatherCallSites(in: function, context) - - callSites.forEach { callSite in - print("PartialApply call site: \(callSite.applySite)") - print("Passed in closures: ") - for index in callSite.closureArgDescriptors.indices { - var closureArgDescriptor = callSite.closureArgDescriptors[index] - print("\(index+1). \(closureArgDescriptor.closureInfo.closure)") - } + let pullbackClosureInfo = getPullbackClosureInfo(in: function, context)! + print("PartialApply of pullback: \(pullbackClosureInfo.paiOfPullback)") + print("Passed in closures: ") + for index in pullbackClosureInfo.closureArgDescriptors.indices { + var closureArgDescriptor = pullbackClosureInfo.closureArgDescriptors[index] + print("\(index+1). \(closureArgDescriptor.closureInfo.closure)") } print("\n") } let specializedFunctionSignatureAndBodyTest = FunctionTest( - "closure_specialize_specialized_function_signature_and_body") { function, arguments, context in + "autodiff_closure_specialize_specialized_function_signature_and_body") { function, arguments, context in - var callSites = gatherCallSites(in: function, context) + let pullbackClosureInfo = getPullbackClosureInfo(in: function, context)! - for callSite in callSites { - let (specializedFunction, _) = getOrCreateSpecializedFunction(basedOn: callSite, context) - print("Generated specialized function: \(specializedFunction.name)") - print("\(specializedFunction)\n") - } + let (specializedFunction, _) = getOrCreateSpecializedFunction(basedOn: pullbackClosureInfo, context) + print("Generated specialized function: \(specializedFunction.name)") + print("\(specializedFunction)\n") } -let rewrittenCallerBodyTest = FunctionTest("closure_specialize_rewritten_caller_body") { function, arguments, context in - var callSites = gatherCallSites(in: function, context) +let rewrittenCallerBodyTest = FunctionTest("autodiff_closure_specialize_rewritten_caller_body") { function, arguments, context in + let pullbackClosureInfo = getPullbackClosureInfo(in: function, context)! - for callSite in callSites { - let (specializedFunction, _) = getOrCreateSpecializedFunction(basedOn: callSite, context) - rewriteApplyInstruction(using: specializedFunction, callSite: callSite, context) + let (specializedFunction, _) = getOrCreateSpecializedFunction(basedOn: pullbackClosureInfo, context) + rewriteApplyInstruction(using: specializedFunction, pullbackClosureInfo: pullbackClosureInfo, context) - print("Rewritten caller body for: \(function.name):") - print("\(function)\n") - } + print("Rewritten caller body for: \(function.name):") + print("\(function)\n") } diff --git a/SwiftCompilerSources/Sources/Optimizer/Utilities/Test.swift b/SwiftCompilerSources/Sources/Optimizer/Utilities/Test.swift index 8e73223022a62..bba9f2d9ac4dc 100644 --- a/SwiftCompilerSources/Sources/Optimizer/Utilities/Test.swift +++ b/SwiftCompilerSources/Sources/Optimizer/Utilities/Test.swift @@ -160,7 +160,7 @@ public func registerOptimizerTests() { enclosingValuesTest, forwardingDefUseTest, forwardingUseDefTest, - gatherCallSitesTest, + getPullbackClosureInfoTest, interiorLivenessTest, lifetimeDependenceRootTest, lifetimeDependenceScopeTest, diff --git a/test/AutoDiff/SILOptimizer/closure_specialization.sil b/test/AutoDiff/SILOptimizer/closure_specialization.sil deleted file mode 100644 index 3d7c7ade772f2..0000000000000 --- a/test/AutoDiff/SILOptimizer/closure_specialization.sil +++ /dev/null @@ -1,509 +0,0 @@ -// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s - -// REQUIRES: swift_in_compiler - -sil_stage canonical - -import Builtin -import Swift -import SwiftShims - -import _Differentiation - -//////////////////////////////////////////////////////////////// -// Single closure call site where closure is passed as @owned // -//////////////////////////////////////////////////////////////// -sil @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) - -sil private @$pullback_f : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float { -bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> (Float, Float)): - %2 = apply %1(%0) : $@callee_guaranteed (Float) -> (Float, Float) - strong_release %1 : $@callee_guaranteed (Float) -> (Float, Float) // id: %3 - %4 = tuple_extract %2 : $(Float, Float), 0 - %5 = tuple_extract %2 : $(Float, Float), 1 - %6 = struct_extract %5 : $Float, #Float._value - %7 = struct_extract %4 : $Float, #Float._value - %8 = builtin "fadd_FPIEEE32"(%6 : $Builtin.FPIEEE32, %7 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - %9 = struct $Float (%8 : $Builtin.FPIEEE32) - debug_value %9 : $Float, let, name "x", argno 1 // id: %10 - return %9 : $Float // id: %11 -} - -// reverse-mode derivative of f(_:) -sil hidden @$s4test1fyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { -bb0(%0 : $Float): - //=========== Test callsite and closure gathering logic ===========// - specify_test "closure_specialize_gather_call_sites" - // CHECK-LABEL: Specializing closures in function: $s4test1fyS2fFTJrSpSr - // CHECK: PartialApply call site: %8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float - // CHECK: Passed in closures: - // CHECK: 1. %6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - - - //=========== Test specialized function signature and body ===========// - specify_test "closure_specialize_specialized_function_signature_and_body" - // CHECK-LABEL: Generated specialized function: $s11$pullback_f12$vjpMultiplyS2fTf1nc_n - // CHECK: sil private @$s11$pullback_f12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float { - // CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float): - // CHECK: %3 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %4 = partial_apply [callee_guaranteed] %3(%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %5 = apply %4(%0) : $@callee_guaranteed (Float) -> (Float, Float) - // CHECK: strong_release %4 : $@callee_guaranteed (Float) -> (Float, Float) // id: %6 - // CHECK: return - - //=========== Test rewritten body ===========// - specify_test "closure_specialize_rewritten_caller_body" - // CHECK-LABEL: Rewritten caller body for: $s4test1fyS2fFTJrSpSr - // CHECK: sil hidden @$s4test1fyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { - // CHECK: bb0(%0 : $Float): - // CHECK: %2 = struct_extract %0 : $Float, #Float._value - // CHECK: %3 = builtin "fmul_FPIEEE32"(%2 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - // CHECK: %4 = struct $Float (%3 : $Builtin.FPIEEE32) - // CHECK: %5 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %8 = function_ref @$s11$pullback_f12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float - // CHECK: %9 = partial_apply [callee_guaranteed] %8(%0, %0) : $@convention(thin) (Float, Float, Float) -> Float - // CHECK: release_value %6 : $@callee_guaranteed (Float) -> (Float, Float) // id: %10 - // CHECK: %11 = tuple (%4 : $Float, %9 : $@callee_guaranteed (Float) -> Float) - // CHECK: return %11 - - debug_value %0 : $Float, let, name "x", argno 1 // id: %1 - %2 = struct_extract %0 : $Float, #Float._value - %3 = builtin "fmul_FPIEEE32"(%2 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - %4 = struct $Float (%3 : $Builtin.FPIEEE32) - // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:) - %5 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) - %6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // function_ref pullback of f(_:) - %7 = function_ref @$pullback_f : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float - %8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float - %9 = tuple (%4 : $Float, %8 : $@callee_guaranteed (Float) -> Float) - return %9 : $(Float, @callee_guaranteed (Float) -> Float) // id: %10 -} - -///////////////////////////////////////////////////////////////////// -// Single closure call site where closure is passed as @guaranteed // -///////////////////////////////////////////////////////////////////// -sil private @$pullback_k : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float { -// %0 -// %1 -bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> (Float, Float)): - %2 = apply %1(%0) : $@callee_guaranteed (Float) -> (Float, Float) - %3 = tuple_extract %2 : $(Float, Float), 0 - %4 = tuple_extract %2 : $(Float, Float), 1 - %5 = struct_extract %4 : $Float, #Float._value - %6 = struct_extract %3 : $Float, #Float._value - %7 = builtin "fadd_FPIEEE32"(%5 : $Builtin.FPIEEE32, %6 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - %8 = struct $Float (%7 : $Builtin.FPIEEE32) - debug_value %8 : $Float, let, name "x", argno 1 // id: %9 - return %8 : $Float // id: %10 -} // end sil function '$pullback_k' - -// reverse-mode derivative of k(_:) -sil hidden @$s4test1kyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { -bb0(%0 : $Float): - //=========== Test callsite and closure gathering logic ===========// - specify_test "closure_specialize_gather_call_sites" - // CHECK-LABEL: Specializing closures in function: $s4test1kyS2fFTJrSpSr - // CHECK: PartialApply call site: %8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float - // CHECK: Passed in closures: - // CHECK: 1. %6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - - - //=========== Test specialized function signature and body ===========// - specify_test "closure_specialize_specialized_function_signature_and_body" - // CHECK-LABEL: Generated specialized function: $s11$pullback_k12$vjpMultiplyS2fTf1nc_n - // CHECK: sil private @$s11$pullback_k12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float { - // CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float): - // CHECK: %3 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %4 = partial_apply [callee_guaranteed] %3(%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %5 = apply %4(%0) : $@callee_guaranteed (Float) -> (Float, Float) - // CHECK: release_value %4 : $@callee_guaranteed (Float) -> (Float, Float) // id: %13 - // CHECK: return %11 - - //=========== Test rewritten body ===========// - specify_test "closure_specialize_rewritten_caller_body" - // CHECK-LABEL: Rewritten caller body for: $s4test1kyS2fFTJrSpSr - // CHECK: sil hidden @$s4test1kyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { - // CHECK: bb0(%0 : $Float): - // CHECK: %2 = struct_extract %0 : $Float, #Float._value - // CHECK: %3 = builtin "fmul_FPIEEE32"(%2 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - // CHECK: %4 = struct $Float (%3 : $Builtin.FPIEEE32) - // CHECK: %5 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %8 = function_ref @$s11$pullback_k12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float - // CHECK: %9 = partial_apply [callee_guaranteed] %8(%0, %0) : $@convention(thin) (Float, Float, Float) -> Float - // CHECK: strong_release %6 : $@callee_guaranteed (Float) -> (Float, Float) - // CHECK: %11 = tuple (%4 : $Float, %9 : $@callee_guaranteed (Float) -> Float) - // CHECK: return %11 - - debug_value %0 : $Float, let, name "x", argno 1 // id: %1 - %2 = struct_extract %0 : $Float, #Float._value - %3 = builtin "fmul_FPIEEE32"(%2 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - %4 = struct $Float (%3 : $Builtin.FPIEEE32) - // function_ref $vjpMultiply - %5 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) - %6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // function_ref $pullback_k - %7 = function_ref @$pullback_k : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float - %8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float - strong_release %6 : $@callee_guaranteed (Float) -> (Float, Float) // id: %9 - %10 = tuple (%4 : $Float, %8 : $@callee_guaranteed (Float) -> Float) - return %10 : $(Float, @callee_guaranteed (Float) -> Float) // id: %11 -} // end sil function '$s4test1kyS2fFTJrSpSr' - -/////////////////////////////// -// Multiple closure callsite // -/////////////////////////////// -sil @$vjpSin : $@convention(thin) (Float, Float) -> Float -sil @$vjpCos : $@convention(thin) (Float, Float) -> Float - -// pullback of g(_:) -sil private @$pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float { -bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float, %2 : $@callee_guaranteed (Float) -> Float, %3 : $@callee_guaranteed (Float) -> (Float, Float)): - %4 = apply %3(%0) : $@callee_guaranteed (Float) -> (Float, Float) - strong_release %3 : $@callee_guaranteed (Float) -> (Float, Float) // id: %5 - %6 = tuple_extract %4 : $(Float, Float), 0 - %7 = tuple_extract %4 : $(Float, Float), 1 - %8 = apply %2(%7) : $@callee_guaranteed (Float) -> Float - strong_release %2 : $@callee_guaranteed (Float) -> Float // id: %9 - %10 = apply %1(%6) : $@callee_guaranteed (Float) -> Float - strong_release %1 : $@callee_guaranteed (Float) -> Float // id: %11 - %12 = struct_extract %8 : $Float, #Float._value - %13 = struct_extract %10 : $Float, #Float._value - %14 = builtin "fadd_FPIEEE32"(%13 : $Builtin.FPIEEE32, %12 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - %15 = struct $Float (%14 : $Builtin.FPIEEE32) - debug_value %15 : $Float, let, name "x", argno 1 // id: %16 - return %15 : $Float // id: %17 -} - -// reverse-mode derivative of g(_:) -sil hidden @$s4test1gyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { -bb0(%0 : $Float): - //=========== Test callsite and closure gathering logic ===========// - specify_test "closure_specialize_gather_call_sites" - // CHECK-LABEL: Specializing closures in function: $s4test1gyS2fFTJrSpSr - // CHECK: PartialApply call site: %16 = partial_apply [callee_guaranteed] %15(%6, %10, %14) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float - // CHECK: Passed in closures: - // CHECK: 1. %6 = partial_apply [callee_guaranteed] %5(%0) : $@convention(thin) (Float, Float) -> Float - // CHECK: 2. %10 = partial_apply [callee_guaranteed] %9(%0) : $@convention(thin) (Float, Float) -> Float - // CHECK: 3. %14 = partial_apply [callee_guaranteed] %13(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - - //=========== Test specialized function signature and body ===========// - specify_test "closure_specialize_specialized_function_signature_and_body" - // CHECK-LABEL: Generated specialized function: $s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n - // CHECK: sil private @$s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n : $@convention(thin) (Float, Float, Float, Float, Float) -> Float { - // CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float, %3 : $Float, %4 : $Float): - // CHECK: %5 = function_ref @$vjpSin : $@convention(thin) (Float, Float) -> Float - // CHECK: %6 = partial_apply [callee_guaranteed] %5(%1) : $@convention(thin) (Float, Float) -> Float - // CHECK: %7 = function_ref @$vjpCos : $@convention(thin) (Float, Float) -> Float - // CHECK: %8 = partial_apply [callee_guaranteed] %7(%2) : $@convention(thin) (Float, Float) -> Float - // CHECK: %9 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %10 = partial_apply [callee_guaranteed] %9(%3, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %11 = apply %10(%0) : $@callee_guaranteed (Float) -> (Float, Float) - // CHECK: strong_release %10 : $@callee_guaranteed (Float) -> (Float, Float) // id: %12 - // CHECK: %15 = apply %8(%14) : $@callee_guaranteed (Float) -> Float - // CHECK: strong_release %8 : $@callee_guaranteed (Float) -> Float // id: %16 - // CHECK: %17 = apply %6(%13) : $@callee_guaranteed (Float) -> Float - // CHECK: strong_release %6 : $@callee_guaranteed (Float) -> Float // id: %18 - // CHECK: return - - //=========== Test rewritten body ===========// - specify_test "closure_specialize_rewritten_caller_body" - // CHECK-LABEL: Rewritten caller body for: $s4test1gyS2fFTJrSpSr - // CHECK: sil hidden @$s4test1gyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { - // CHECK: bb0(%0 : $Float): - // CHECK: %2 = struct_extract %0 : $Float, #Float._value - // CHECK: %3 = builtin "int_sin_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - // CHECK: %4 = struct $Float (%3 : $Builtin.FPIEEE32) - // CHECK: %5 = function_ref @$vjpSin : $@convention(thin) (Float, Float) -> Float - // CHECK: %6 = partial_apply [callee_guaranteed] %5(%0) : $@convention(thin) (Float, Float) -> Float - // CHECK: %7 = builtin "int_cos_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - // CHECK: %8 = struct $Float (%7 : $Builtin.FPIEEE32) - // CHECK: %9 = function_ref @$vjpCos : $@convention(thin) (Float, Float) -> Float - // CHECK: %10 = partial_apply [callee_guaranteed] %9(%0) : $@convention(thin) (Float, Float) -> Float - // CHECK: %11 = builtin "fmul_FPIEEE32"(%3 : $Builtin.FPIEEE32, %7 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - // CHECK: %12 = struct $Float (%11 : $Builtin.FPIEEE32) - // CHECK: %13 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %14 = partial_apply [callee_guaranteed] %13(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %16 = function_ref @$s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n : $@convention(thin) (Float, Float, Float, Float, Float) -> Float - // CHECK: %17 = partial_apply [callee_guaranteed] %16(%0, %0, %8, %4) : $@convention(thin) (Float, Float, Float, Float, Float) -> Float - // CHECK: release_value %6 : $@callee_guaranteed (Float) -> Float // id: %18 - // CHECK: release_value %10 : $@callee_guaranteed (Float) -> Float // id: %19 - // CHECK: release_value %14 : $@callee_guaranteed (Float) -> (Float, Float) // id: %20 - // CHECK: %21 = tuple (%12 : $Float, %17 : $@callee_guaranteed (Float) -> Float) - // CHECK: return %21 - - debug_value %0 : $Float, let, name "x", argno 1 // id: %1 - %2 = struct_extract %0 : $Float, #Float._value - %3 = builtin "int_sin_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - %4 = struct $Float (%3 : $Builtin.FPIEEE32) - // function_ref closure #1 in _vjpSin(_:) - %5 = function_ref @$vjpSin : $@convention(thin) (Float, Float) -> Float - %6 = partial_apply [callee_guaranteed] %5(%0) : $@convention(thin) (Float, Float) -> Float - %7 = builtin "int_cos_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - %8 = struct $Float (%7 : $Builtin.FPIEEE32) - // function_ref closure #1 in _vjpCos(_:) - %9 = function_ref @$vjpCos : $@convention(thin) (Float, Float) -> Float - %10 = partial_apply [callee_guaranteed] %9(%0) : $@convention(thin) (Float, Float) -> Float - %11 = builtin "fmul_FPIEEE32"(%3 : $Builtin.FPIEEE32, %7 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - %12 = struct $Float (%11 : $Builtin.FPIEEE32) - // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:) - %13 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) - %14 = partial_apply [callee_guaranteed] %13(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // function_ref pullback of g(_:) - %15 = function_ref @$pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float - %16 = partial_apply [callee_guaranteed] %15(%6, %10, %14) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float - %17 = tuple (%12 : $Float, %16 : $@callee_guaranteed (Float) -> Float) - return %17 : $(Float, @callee_guaranteed (Float) -> Float) // id: %18 -} - -/////////////////////////////// -/// Parameter subset thunks /// -/////////////////////////////// -struct X : Differentiable { - @_hasStorage var a: Float { get set } - @_hasStorage var b: Double { get set } - struct TangentVector : AdditiveArithmetic, Differentiable { - @_hasStorage var a: Float { get set } - @_hasStorage var b: Double { get set } - static func + (lhs: X.TangentVector, rhs: X.TangentVector) -> X.TangentVector - static func - (lhs: X.TangentVector, rhs: X.TangentVector) -> X.TangentVector - @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X.TangentVector, _ b: X.TangentVector) -> Bool - typealias TangentVector = X.TangentVector - init(a: Float, b: Double) - static var zero: X.TangentVector { get } - } - init(a: Float, b: Double) - mutating func move(by offset: X.TangentVector) -} - -sil [transparent] [thunk] @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector - -sil @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector - -sil shared @pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector { -bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> X.TangentVector): - %2 = apply %1(%0) : $@callee_guaranteed (Float) -> X.TangentVector - strong_release %1 : $@callee_guaranteed (Float) -> X.TangentVector - return %2 : $X.TangentVector -} - -sil hidden @$s5test21g1xSfAA1XV_tFTJrSpSr : $@convention(thin) (X) -> (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) { -bb0(%0 : $X): - //=========== Test callsite and closure gathering logic ===========// - specify_test "closure_specialize_gather_call_sites" - // CHECK-LABEL: Specializing closures in function: $s5test21g1xSfAA1XV_tFTJrSpSr - // CHECK: PartialApply call site: %7 = partial_apply [callee_guaranteed] %6(%5) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector - // CHECK: Passed in closures: - // CHECK: 1. %3 = thin_to_thick_function %2 : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector - - //=========== Test specialized function signature and body ===========// - specify_test "closure_specialize_specialized_function_signature_and_body" - // CHECK-LABEL: Generated specialized function: $s10pullback_g0A2_fTf1nc_n - // CHECK: sil shared @$s10pullback_g0A2_fTf1nc_n : $@convention(thin) (Float) -> X.TangentVector { - // CHECK: bb0(%0 : $Float): - // CHECK: %1 = function_ref @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector - // CHECK: %2 = thin_to_thick_function %1 : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector - // CHECK: %3 = function_ref @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector - // CHECK: %4 = partial_apply [callee_guaranteed] %3(%2) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector - // CHECK: %5 = apply %4(%0) : $@callee_guaranteed (Float) -> X.TangentVector - // CHECK: strong_release %4 : $@callee_guaranteed (Float) -> X.TangentVector // id: %6 - // CHECK: return %5 : $X.TangentVector // id: %7 - - //=========== Test rewritten body ===========// - specify_test "closure_specialize_rewritten_caller_body" - // CHECK-LABEL: Rewritten caller body for: $s5test21g1xSfAA1XV_tFTJrSpSr - // CHECK: sil hidden @$s5test21g1xSfAA1XV_tFTJrSpSr : $@convention(thin) (X) -> (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) { - // CHECK: bb0(%0 : $X): - // CHECK: %1 = struct_extract %0 : $X, #X.a - // CHECK: %2 = function_ref @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector - // CHECK: %3 = thin_to_thick_function %2 : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector - // CHECK: %4 = function_ref @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector - // CHECK: %5 = partial_apply [callee_guaranteed] %4(%3) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector - // CHECK: %7 = function_ref @$s10pullback_g0A2_fTf1nc_n : $@convention(thin) (Float) -> X.TangentVector - // CHECK: %8 = partial_apply [callee_guaranteed] %7() : $@convention(thin) (Float) -> X.TangentVector - // CHECK: release_value %3 : $@callee_guaranteed (Float, Double) -> X.TangentVector // id: %9 - // CHECK: %10 = tuple (%1 : $Float, %8 : $@callee_guaranteed (Float) -> X.TangentVector) - // CHECK: return %10 - - %1 = struct_extract %0 : $X, #X.a - // function_ref pullback_f - %2 = function_ref @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector - %3 = thin_to_thick_function %2 : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector - // function_ref subset_parameter_thunk - %4 = function_ref @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector - %5 = partial_apply [callee_guaranteed] %4(%3) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector - // function_ref pullback_g - %6 = function_ref @pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector - %7 = partial_apply [callee_guaranteed] %6(%5) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector - %8 = tuple (%1 : $Float, %7 : $@callee_guaranteed (Float) -> X.TangentVector) - return %8 : $(Float, @callee_guaranteed (Float) -> X.TangentVector) // id: %9 -} - -/////////////////////////////////////////////////////////////////////// -///////// Specialized generic closures - PartialApply Closure ///////// -/////////////////////////////////////////////////////////////////////// - -// closure #1 in static Float._vjpMultiply(lhs:rhs:) -sil @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) - -// thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float) -sil [transparent] [reabstraction_thunk] @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) - -// function_ref specialized pullback of f(a:) -sil [transparent] [thunk] @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for ) -> @out Float - -// thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float) -sil [transparent] [reabstraction_thunk] @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - -sil private [signature_optimized_thunk] [always_inline] @pullback_h : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float { -bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float): - %2 = apply %1(%0) : $@callee_guaranteed (Float) -> Float - strong_release %1 : $@callee_guaranteed (Float) -> Float - return %2 : $Float -} - -// reverse-mode derivative of h(x:) -sil hidden @$s5test21h1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { -bb0(%0 : $Float): - //=========== Test callsite and closure gathering logic ===========// - specify_test "closure_specialize_gather_call_sites" - // CHECK-LABEL: Specializing closures in function: $s5test21h1xS2f_tFTJrSpSr - // CHECK: PartialApply call site: %14 = partial_apply [callee_guaranteed] %13(%11) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float - // CHECK: Passed in closures: - // CHECK: 1. %4 = partial_apply [callee_guaranteed] %3(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - - //=========== Test specialized function signature and body ===========// - specify_test "closure_specialize_specialized_function_signature_and_body" - // CHECK-LABEL: Generated specialized function: $s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n - // CHECK: sil private [signature_optimized_thunk] [always_inline] @$s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float { - // CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float): - // CHECK: %3 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %4 = partial_apply [callee_guaranteed] %3(%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %5 = function_ref @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) - // CHECK: %6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) - // CHECK: %7 = convert_function %6 : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for - // CHECK: %8 = function_ref @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for ) -> @out Float - // CHECK: %9 = partial_apply [callee_guaranteed] %8(%7) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for ) -> @out Float - // CHECK: %10 = function_ref @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - // CHECK: %11 = partial_apply [callee_guaranteed] %10(%9) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - // CHECK: %12 = apply %11(%0) : $@callee_guaranteed (Float) -> Float - // CHECK: strong_release %11 : $@callee_guaranteed (Float) -> Float // id: %13 - // CHECK: return %12 : $Float - - //=========== Test rewritten body ===========// - specify_test "closure_specialize_rewritten_caller_body" - // CHECK-LABEL: Rewritten caller body for: $s5test21h1xS2f_tFTJrSpSr - // CHECK:sil hidden @$s5test21h1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { - // CHECK:bb0(%0 : $Float): - // CHECK: %1 = struct_extract %0 : $Float, #Float._value - // CHECK: %2 = builtin "fmul_FPIEEE32"(%1 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - // CHECK: %3 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %4 = partial_apply [callee_guaranteed] %3(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - // CHECK: %5 = function_ref @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) - // CHECK: %6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) - // CHECK: %7 = convert_function %6 : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for - // CHECK: %8 = function_ref @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for ) -> @out Float - // CHECK: %9 = partial_apply [callee_guaranteed] %8(%7) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for ) -> @out Float - // CHECK: %10 = function_ref @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - // CHECK: %11 = partial_apply [callee_guaranteed] %10(%9) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - // CHECK: %12 = struct $Float (%2 : $Builtin.FPIEEE32) - // CHECK: %14 = function_ref @$s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float - // CHECK: %15 = partial_apply [callee_guaranteed] %14(%0, %0) : $@convention(thin) (Float, Float, Float) -> Float - // CHECK: release_value %4 : $@callee_guaranteed (Float) -> (Float, Float) // id: %16 - // CHECK: %17 = tuple (%12 : $Float, %15 : $@callee_guaranteed (Float) -> Float) - // CHECK: return %17 - - %1 = struct_extract %0 : $Float, #Float._value - %2 = builtin "fmul_FPIEEE32"(%1 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 - - // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:) - %3 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) - %4 = partial_apply [callee_guaranteed] %3(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) - - // function_ref thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float) - %5 = function_ref @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) - %6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) - %7 = convert_function %6 : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for - - // function_ref pullback_f_specialized - %8 = function_ref @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for ) -> @out Float - %9 = partial_apply [callee_guaranteed] %8(%7) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for ) -> @out Float - - // function_ref thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float) - %10 = function_ref @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - %11 = partial_apply [callee_guaranteed] %10(%9) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - %12 = struct $Float (%2 : $Builtin.FPIEEE32) - - // function_ref pullback_h - %13 = function_ref @pullback_h : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float - %14 = partial_apply [callee_guaranteed] %13(%11) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float - %15 = tuple (%12 : $Float, %14 : $@callee_guaranteed (Float) -> Float) - return %15 : $(Float, @callee_guaranteed (Float) -> Float) // id: %16 -} - -////////////////////////////////////////////////////////////////////////////// -///////// Specialized generic closures - ThinToThickFunction closure ///////// -////////////////////////////////////////////////////////////////////////////// - -sil [transparent] [thunk] @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float - -sil [transparent] [reabstraction_thunk] @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - -sil private [signature_optimized_thunk] [always_inline] @pullback_z : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float { -bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float): - %2 = apply %1(%0) : $@callee_guaranteed (Float) -> Float - strong_release %1 : $@callee_guaranteed (Float) -> Float - return %2 : $Float -} - -sil hidden @$s5test21z1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { -bb0(%0 : $Float): - //=========== Test callsite and closure gathering logic ===========// - specify_test "closure_specialize_gather_call_sites" - // CHECK-LABEL: Specializing closures in function: $s5test21z1xS2f_tFTJrSpSr - // CHECK: PartialApply call site: %6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float - // CHECK: Passed in closures: - // CHECK: 1. %2 = thin_to_thick_function %1 : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float - - //=========== Test specialized function signature and body ===========// - specify_test "closure_specialize_specialized_function_signature_and_body" - // CHECK-LABEL: Generated specialized function: $s10pullback_z0A14_y_specializedTf1nc_n - // CHECK: sil private [signature_optimized_thunk] [always_inline] @$s10pullback_z0A14_y_specializedTf1nc_n : $@convention(thin) (Float) -> Float { - // CHECK: bb0(%0 : $Float): - // CHECK: %1 = function_ref @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float - // CHECK: %2 = thin_to_thick_function %1 : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float - // CHECK: %3 = function_ref @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - // CHECK: %4 = partial_apply [callee_guaranteed] %3(%2) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - // CHECK: %5 = apply %4(%0) : $@callee_guaranteed (Float) -> Float - // CHECK: strong_release %4 : $@callee_guaranteed (Float) -> Float // id: %6 - // CHECK: return %5 : $Float - - //=========== Test rewritten body ===========// - specify_test "closure_specialize_rewritten_caller_body" - // CHECK-LABEL: Rewritten caller body for: $s5test21z1xS2f_tFTJrSpSr - // CHECK: sil hidden @$s5test21z1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { - // CHECK: bb0(%0 : $Float): - // CHECK: %1 = function_ref @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float - // CHECK: %2 = thin_to_thick_function %1 : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float - // CHECK: %3 = function_ref @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - // CHECK: %4 = partial_apply [callee_guaranteed] %3(%2) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - // CHECK: %6 = function_ref @$s10pullback_z0A14_y_specializedTf1nc_n : $@convention(thin) (Float) -> Float - // CHECK: %7 = partial_apply [callee_guaranteed] %6() : $@convention(thin) (Float) -> Float - // CHECK: release_value %2 : $@callee_guaranteed (@in_guaranteed Float) -> @out Float // id: %8 - // CHECK: %9 = tuple (%0 : $Float, %7 : $@callee_guaranteed (Float) -> Float) - // CHECK: return %9 - - // function_ref pullback_y_specialized - %1 = function_ref @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float - %2 = thin_to_thick_function %1 : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float - // function_ref reabstraction_thunk - %3 = function_ref @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - %4 = partial_apply [callee_guaranteed] %3(%2) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float - // function_ref pullback_z - %5 = function_ref @pullback_z : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float - %6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float - %7 = tuple (%0 : $Float, %6 : $@callee_guaranteed (Float) -> Float) - return %7 : $(Float, @callee_guaranteed (Float) -> Float) // id: %8 -} diff --git a/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil b/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil new file mode 100644 index 0000000000000..1f69534a5c2b3 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil @@ -0,0 +1,689 @@ +/// Multi basic block VJP, pullback accepting branch tracing enum argument. + +// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s --check-prefixes=TRUNNER,CHECK +// RUN: %target-sil-opt -sil-print-types -autodiff-closure-specialization -sil-combine %s -o - | %FileCheck %s --check-prefixes=COMBINE,CHECK + +// REQUIRES: swift_in_compiler + +sil_stage canonical + +import Builtin +import Swift +import SwiftShims + +import _Differentiation + +/////////////////// +/// Test case 1 /// +/////////////////// + +/// This SIL corresponds to the following Swift: +/// +/// @differentiable(reverse) +/// func mul42(_ a: Float?) -> Float { +/// let b = 42 * a! +/// return b +/// } + +enum _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0 { + case bb0(()) +} + +sil @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) +sil [transparent] [thunk] @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + +// pullback of mul42(_:) +sil private [signature_optimized_thunk] [always_inline] @$s4test5mul42yS2fSgFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float) -> Optional.TangentVector { +bb0(%0 : $Float, %1 : $_AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, %2 : $@callee_guaranteed (Float) -> Float): + %4 = apply %2(%0) : $@callee_guaranteed (Float) -> Float + strong_release %2 + %6 = enum $Optional, #Optional.some!enumelt, %4 + %7 = struct $Optional.TangentVector (%6) + return %7 +} // end sil function '$s4test5mul42yS2fSgFTJpSpSr' + +// reverse-mode derivative of mul42(_:) +sil hidden @$s4test5mul42yS2fSgFTJrSpSr : $@convention(thin) (Optional) -> (Float, @owned @callee_guaranteed (Float) -> Optional.TangentVector) { +bb0(%0 : $Optional): + //=========== Test callsite and closure gathering logic ===========// + specify_test "autodiff_closure_specialize_get_pullback_closure_info" + // TRUNNER-LABEL: Specializing closures in function: $s4test5mul42yS2fSgFTJrSpSr + // TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float) -> Optional.TangentVector + // TRUNNER-NEXT: Passed in closures: + // TRUNNER-NEXT: 1. %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-EMPTY: + + //=========== Test specialized function signature and body ===========// + specify_test "autodiff_closure_specialize_specialized_function_signature_and_body" + // TRUNNER-LABEL: Generated specialized function: $s4test5mul42yS2fSgFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2fTf1nnc_n + // CHECK: sil private [signature_optimized_thunk] [always_inline] @$s4test5mul42yS2fSgFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2fTf1nnc_n : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional.TangentVector { + // CHECK: bb0(%0 : $Float, %1 : $_AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, %2 : $Float, %3 : $Float): + // CHECK: %[[#A4:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // CHECK: %[[#A5:]] = partial_apply [callee_guaranteed] %[[#A4]](%2, %3) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // CHECK: %[[#A6:]] = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER: %[[#A7:]] = partial_apply [callee_guaranteed] %[[#A6]](%[[#A5]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER: %[[#A8:]] = apply %[[#A7]](%0) : $@callee_guaranteed (Float) -> Float + // COMBINE-NOT: = partial_apply + // COMBINE: %[[#A8:]] = apply %[[#A6]](%0, %[[#A5]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER: strong_release %[[#A7]] : $@callee_guaranteed (Float) -> Float + // CHECK: %[[#A10:]] = enum $Optional, #Optional.some!enumelt, %[[#A8]] : $Float + // CHECK: %[[#A11:]] = struct $Optional.TangentVector (%[[#A10]] : $Optional) + // CHECK: return %[[#A11]] : $Optional.TangentVector + + //=========== Test rewritten body ===========// + specify_test "autodiff_closure_specialize_rewritten_caller_body" + // TRUNNER-LABEL: Rewritten caller body for: $s4test5mul42yS2fSgFTJrSpSr: + // CHECK: sil hidden @$s4test5mul42yS2fSgFTJrSpSr : $@convention(thin) (Optional) -> (Float, @owned @callee_guaranteed (Float) -> Optional.TangentVector) { + // CHECK: bb1(%2 : $Float): + // CHECK: %[[#B4:]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER: %[[#B10:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: %[[#B11:]] = partial_apply [callee_guaranteed] %[[#B10]](%2, %[[#B4]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: %[[#B12:]] = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER: %[[#B13:]] = partial_apply [callee_guaranteed] %[[#B12]](%[[#B11]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER: %[[#B14:]] = function_ref @$s4test5mul42yS2fSgFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float) -> Optional.TangentVector + // COMBINE-NOT: = partial_apply + // COMBINE-NOT: = function_ref + // CHECK: %[[#B15:]] = function_ref @$s4test5mul42yS2fSgFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2fTf1nnc_n : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional.TangentVector + // CHECK: %[[#B16:]] = partial_apply [callee_guaranteed] %[[#B15]](%[[#]], %2, %[[#B4]]) : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional.TangentVector + // TRUNNER: release_value %[[#B11]] : $@callee_guaranteed (Float) -> (Float, Float) + // CHECK: %[[#B18:]] = tuple (%[[#]] : $Float, %[[#B16]] : $@callee_guaranteed (Float) -> Optional.TangentVector) + // CHECK: return %[[#B18]] + + switch_enum %0, case #Optional.some!enumelt: bb1, case #Optional.none!enumelt: bb2 + +bb1(%3 : $Float): + %4 = float_literal $Builtin.FPIEEE32, 0x42280000 // 42 + %5 = struct $Float (%4) + %6 = tuple () + %7 = enum $_AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, #_AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %6 + %8 = struct_extract %3, #Float._value + %9 = builtin "fmul_FPIEEE32"(%4, %8) : $Builtin.FPIEEE32 + %10 = struct $Float (%9) + // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:) + %11 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %12 = partial_apply [callee_guaranteed] %11(%3, %5) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // function_ref autodiff subset parameters thunk for pullback from @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float) + %13 = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + %14 = partial_apply [callee_guaranteed] %13(%12) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // function_ref pullback of mul42(_:) + %16 = function_ref @$s4test5mul42yS2fSgFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float) -> Optional.TangentVector + %17 = partial_apply [callee_guaranteed] %16(%7, %14) : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float) -> Optional.TangentVector + %18 = tuple (%10, %17) + return %18 + +bb2: + unreachable +} // end sil function '$s4test5mul42yS2fSgFTJrSpSr' + +/////////////////// +/// Test case 2 /// +/////////////////// + +/// This SIL corresponds to the following Swift: +/// +/// struct Class: Differentiable { +/// var stored: Float +/// var optional: Float? +/// +/// init(stored: Float, optional: Float?) { +/// self.stored = stored +/// self.optional = optional +/// } +/// +/// @differentiable(reverse) +/// func method() -> Float { +/// let c: Class +/// do { +/// let tmp = Class(stored: 1 * stored, optional: optional) +/// let tuple = (tmp, tmp) +/// c = tuple.0 +/// } +/// var ret : Float = 0 +/// if let x = c.optional { +/// ret = x * c.stored +/// } else { +/// ret = 1 * c.stored +/// } +/// return 1 * ret * ret +/// } +/// } + +struct Class : Differentiable { + @_hasStorage var stored: Float { get set } + @_hasStorage @_hasInitialValue var optional: Float? { get set } + init(stored: Float, optional: Float?) + @differentiable(reverse, wrt: self) + func method() -> Float + struct TangentVector : AdditiveArithmetic, Differentiable { + @_hasStorage var stored: Float { get set } + @_hasStorage var optional: Optional.TangentVector { get set } + static func + (lhs: Class.TangentVector, rhs: Class.TangentVector) -> Class.TangentVector + static func - (lhs: Class.TangentVector, rhs: Class.TangentVector) -> Class.TangentVector + typealias TangentVector = Class.TangentVector + @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Class.TangentVector, _ b: Class.TangentVector) -> Bool + init(stored: Float, optional: Optional.TangentVector) + static var zero: Class.TangentVector { get } + } + mutating func move(by offset: Class.TangentVector) +} + +enum _AD__$s4test5ClassV6methodSfyF_bb0__Pred__src_0_wrt_0 { +} + +enum _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0 { + case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional.TangentVector))) +} + +enum _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0 { + case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional.TangentVector))) +} + +enum _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0 { + case bb2((predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, (Float) -> Float)) + case bb1((predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float))) +} + +enum _AD__$s4test13methodWrapperySfAA5ClassVF_bb0__Pred__src_0_wrt_0 { +} + +enum _AD__$s4test5ClassV6stored8optionalACSf_SfSgtcfC_bb0__Pred__src_0_wrt_0_1 { +} + +sil @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (Class.TangentVector) -> (Float, Optional.TangentVector) + +// pullback of Class.method() +sil private @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Class.TangentVector { +bb0(%0 : $Float, %1 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, %2 : $@callee_guaranteed (Float) -> Float, %3 : $@callee_guaranteed (Float) -> (Float, Float)): + %4 = float_literal $Builtin.FPIEEE32, 0x0 // 0 + %8 = apply %3(%0) : $@callee_guaranteed (Float) -> (Float, Float) + strong_release %3 + %10 = tuple_extract %8, 0 + %11 = tuple_extract %8, 1 + %12 = struct_extract %11, #Float._value + %13 = builtin "fadd_FPIEEE32"(%4, %12) : $Builtin.FPIEEE32 + %15 = apply %2(%10) : $@callee_guaranteed (Float) -> Float + strong_release %2 + %17 = struct_extract %15, #Float._value + %18 = builtin "fadd_FPIEEE32"(%13, %17) : $Builtin.FPIEEE32 + switch_enum %1, case #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb2 + +bb1(%37 : $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float)): + %38 = tuple_extract %37, 0 + %39 = tuple_extract %37, 1 + %40 = builtin "fadd_FPIEEE32"(%18, %4) : $Builtin.FPIEEE32 + %41 = struct $Float (%40) + %42 = apply %39(%41) : $@callee_guaranteed (Float) -> Float + strong_release %39 + %44 = struct_extract %42, #Float._value + %45 = builtin "fadd_FPIEEE32"(%44, %4) : $Builtin.FPIEEE32 + %46 = builtin "fadd_FPIEEE32"(%4, %45) : $Builtin.FPIEEE32 + %50 = unchecked_enum_data %38, #_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0.bb0!enumelt + %51 = tuple_extract %50, 1 + %52 = tuple_extract %50, 0 + br bb3(%4, %46, %52, %51) + +bb2(%54 : $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float))): + %55 = tuple_extract %54, 0 + %56 = tuple_extract %54, 1 + %57 = builtin "fadd_FPIEEE32"(%18, %4) : $Builtin.FPIEEE32 + %58 = struct $Float (%57) + %59 = apply %56(%58) : $@callee_guaranteed (Float) -> (Float, Float) + strong_release %56 + %61 = tuple_extract %59, 0 + %62 = tuple_extract %59, 1 + %63 = struct_extract %61, #Float._value + %64 = builtin "fadd_FPIEEE32"(%63, %4) : $Builtin.FPIEEE32 + %65 = struct_extract %62, #Float._value + %66 = builtin "fadd_FPIEEE32"(%65, %4) : $Builtin.FPIEEE32 + %67 = builtin "fadd_FPIEEE32"(%4, %66) : $Builtin.FPIEEE32 + %69 = builtin "fadd_FPIEEE32"(%64, %4) : $Builtin.FPIEEE32 + %70 = builtin "fadd_FPIEEE32"(%69, %4) : $Builtin.FPIEEE32 + %73 = unchecked_enum_data %55, #_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0.bb0!enumelt + %74 = tuple_extract %73, 1 + %75 = tuple_extract %73, 0 + br bb3(%70, %67, %75, %74) + +bb3(%77 : $Builtin.FPIEEE32, %78 : $Builtin.FPIEEE32, %79 : $@callee_guaranteed (Float) -> Float, %80 : $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector)): + %81 = builtin "fadd_FPIEEE32"(%4, %77) : $Builtin.FPIEEE32 + %85 = builtin "fadd_FPIEEE32"(%78, %4) : $Builtin.FPIEEE32 + %87 = builtin "fadd_FPIEEE32"(%81, %4) : $Builtin.FPIEEE32 + %89 = builtin "fadd_FPIEEE32"(%85, %4) : $Builtin.FPIEEE32 + %91 = builtin "fadd_FPIEEE32"(%87, %4) : $Builtin.FPIEEE32 + %93 = builtin "fadd_FPIEEE32"(%89, %4) : $Builtin.FPIEEE32 + %95 = builtin "fadd_FPIEEE32"(%91, %4) : $Builtin.FPIEEE32 + %97 = builtin "fadd_FPIEEE32"(%93, %4) : $Builtin.FPIEEE32 + %99 = builtin "fadd_FPIEEE32"(%95, %4) : $Builtin.FPIEEE32 + %100 = builtin "fadd_FPIEEE32"(%4, %97) : $Builtin.FPIEEE32 + %102 = builtin "fadd_FPIEEE32"(%4, %99) : $Builtin.FPIEEE32 + %104 = builtin "fadd_FPIEEE32"(%100, %4) : $Builtin.FPIEEE32 + %105 = struct $Float (%104) + %106 = builtin "fadd_FPIEEE32"(%102, %4) : $Builtin.FPIEEE32 + %107 = struct $Float (%106) + %108 = enum $Optional, #Optional.some!enumelt, %107 + %109 = struct $Optional.TangentVector (%108) + %110 = struct $Class.TangentVector (%105, %109) + %111 = apply %80(%110) : $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector) + strong_release %80 + %113 = tuple_extract %111, 1 + %114 = struct_extract %113, #Optional.TangentVector.value + switch_enum %114, case #Optional.none!enumelt: bb4, case #Optional.some!enumelt: bb5 + +bb4: + br bb6(%4) + +bb5(%117 : $Float): + %118 = unchecked_enum_data %114, #Optional.some!enumelt + %119 = struct_extract %118, #Float._value + %120 = builtin "fadd_FPIEEE32"(%119, %4) : $Builtin.FPIEEE32 + br bb6(%120) + +bb6(%122 : $Builtin.FPIEEE32): + %123 = tuple_extract %111, 0 + %124 = struct_extract %123, #Float._value + %125 = builtin "fadd_FPIEEE32"(%124, %4) : $Builtin.FPIEEE32 + %126 = struct $Float (%125) + %127 = apply %79(%126) : $@callee_guaranteed (Float) -> Float + strong_release %79 + %129 = struct_extract %127, #Float._value + %130 = builtin "fadd_FPIEEE32"(%129, %4) : $Builtin.FPIEEE32 + %131 = builtin "fadd_FPIEEE32"(%4, %130) : $Builtin.FPIEEE32 + %132 = builtin "fadd_FPIEEE32"(%4, %122) : $Builtin.FPIEEE32 + %133 = struct $Float (%132) + %134 = enum $Optional, #Optional.some!enumelt, %133 + %135 = struct $Float (%131) + %136 = struct $Optional.TangentVector (%134) + %137 = struct $Class.TangentVector (%135, %136) + return %137 +} // end sil function '$s4test5ClassV6methodSfyFTJpSpSr' + +// reverse-mode derivative of Class.method() +sil hidden @$s4test5ClassV6methodSfyFTJrSpSr : $@convention(method) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) { +bb0(%0 : $Class): + //=========== Test callsite and closure gathering logic ===========// + specify_test "autodiff_closure_specialize_get_pullback_closure_info" + // TRUNNER-LABEL: Specializing closures in function: $s4test5ClassV6methodSfyFTJrSpSr + // TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]], %[[#C42:]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Class.TangentVector + // TRUNNER-NEXT: Passed in closures: + // TRUNNER-NEXT: 1. %[[#]] = partial_apply [callee_guaranteed] %[[#C7:]](%[[#C34:]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-NEXT: 2. %[[#C42]] = partial_apply [callee_guaranteed] %[[#C7]](%[[#C34]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-EMPTY: + + //=========== Test specialized function signature and body ===========// + specify_test "autodiff_closure_specialize_specialized_function_signature_and_body" + // TRUNNER-LABEL: Generated specialized function: $s4test5ClassV6methodSfyFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktk1_l5FZSf_L6SfcfU_S2fAES2fTf1nncc_n + // CHECK: sil private @$s4test5ClassV6methodSfyFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktk1_l5FZSf_L6SfcfU_S2fAES2fTf1nncc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector { + // CHECK: bb0(%0 : $Float, %1 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, %2 : $Float, %3 : $Float, %4 : $Float, %5 : $Float): + // CHECK: %[[#D6:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // CHECK: %[[#D7:]] = partial_apply [callee_guaranteed] %[[#D6]](%2, %3) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // CHECK: %[[#D8:]] = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER: %[[#D9:]] = partial_apply [callee_guaranteed] %[[#D8]](%[[#D7]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // COMBINE-NOT: = partial_apply + // CHECK: %[[#D10:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: %[[#D11:]] = partial_apply [callee_guaranteed] %[[#D10]](%4, %5) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // COMBINE-NOT: = partial_apply + // TRUNNER: %[[#D13:]] = apply %[[#D11]](%0) : $@callee_guaranteed (Float) -> (Float, Float) + // COMBINE: %[[#D13:]] = apply %[[#D10]](%0, %4, %5) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: strong_release %[[#D11]] : $@callee_guaranteed (Float) -> (Float, Float) + // CHECK: %[[#D15:]] = tuple_extract %[[#D13]] : $(Float, Float), 0 + // TRUNNER: %[[#]] = apply %[[#D9]](%[[#D15]]) : $@callee_guaranteed (Float) -> Float + // COMBINE: %[[#]] = apply %[[#D8]](%[[#D15]], %[[#D7]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER: strong_release %[[#D9]] : $@callee_guaranteed (Float) -> Float + + //=========== Test rewritten body ===========// + specify_test "autodiff_closure_specialize_rewritten_caller_body" + // TRUNNER-LABEL: Rewritten caller body for: $s4test5ClassV6methodSfyFTJrSpSr: + // CHECK: sil hidden @$s4test5ClassV6methodSfyFTJrSpSr : $@convention(method) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) { + // CHECK: bb0(%0 : $Class): + // CHECK: %[[#E2:]] = struct $Float + // CHECK: %[[#E7:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // CHECK: %[[#E9:]] = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // CHECK: bb1(%[[#]] : $Float): + // CHECK: bb2: + // CHECK: bb3(%[[#E33:]] : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, %[[#E34:]] : $Float): + // CHECK: %[[#E37:]] = struct $Float + + // TRUNNER: %[[#E38:]] = partial_apply [callee_guaranteed] %[[#E7]](%[[#E34]], %[[#E2]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: %[[#]] = partial_apply [callee_guaranteed] %[[#E9]](%[[#E38]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER: %[[#E42:]] = partial_apply [callee_guaranteed] %[[#E7]](%[[#E34]], %[[#E37]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: %[[#]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Class.TangentVector + + // COMBINE-NOT: = partial_apply + // COMBINE-NOT: = function_ref @$s4test5ClassV6methodSfyFTJpSpSr + + // CHECK: %[[#E44:]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktk1_l5FZSf_L6SfcfU_S2fAES2fTf1nncc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector + // CHECK: %[[#E45:]] = partial_apply [callee_guaranteed] %[[#E44]](%[[#E33]], %[[#E34]], %[[#E2]], %[[#E34]], %[[#E37]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector + // TRUNNER: release_value %[[#E38]] : $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER: release_value %[[#E42]] : $@callee_guaranteed (Float) -> (Float, Float) + // CHECK: %[[#E48:]] = tuple (%[[#]] : $Float, %[[#E45]] : $@callee_guaranteed (Float) -> Class.TangentVector) + // CHECK: return %[[#E48]] + + %2 = float_literal $Builtin.FPIEEE32, 0x3F800000 // 1 + %3 = struct $Float (%2) + %4 = struct_extract %0, #Class.stored + %5 = struct_extract %4, #Float._value + %6 = builtin "fmul_FPIEEE32"(%2, %5) : $Builtin.FPIEEE32 + %7 = struct $Float (%6) + // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:) + %8 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %9 = partial_apply [callee_guaranteed] %8(%4, %3) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // function_ref autodiff subset parameters thunk for pullback from @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float) + %10 = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + %11 = partial_apply [callee_guaranteed] %10(%9) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + %12 = struct_extract %0, #Class.optional + // function_ref pullback of Class.init(stored:optional:) + %25 = function_ref @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (Class.TangentVector) -> (Float, Optional.TangentVector) + %26 = thin_to_thick_function %25 to $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector) + %27 = tuple (%11, %26) + switch_enum %12, case #Optional.some!enumelt: bb1, case #Optional.none!enumelt: bb2 + +bb1(%29 : $Float): + %30 = enum $_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %27 + %32 = struct_extract %29, #Float._value + %33 = builtin "fmul_FPIEEE32"(%32, %6) : $Builtin.FPIEEE32 + %34 = struct $Float (%33) + %35 = partial_apply [callee_guaranteed] %8(%7, %29) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %36 = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float)) (%30, %35) + %37 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %36 + br bb3(%37, %34) + +bb2: + %39 = enum $_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %27 + %40 = builtin "fmul_FPIEEE32"(%2, %6) : $Builtin.FPIEEE32 + %41 = struct $Float (%40) + %42 = partial_apply [callee_guaranteed] %8(%7, %3) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %43 = partial_apply [callee_guaranteed] %10(%42) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + %44 = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float) (%39, %43) + %45 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %44 + br bb3(%45, %41) + +bb3(%47 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, %48 : $Float): + %49 = struct_extract %48, #Float._value + %50 = builtin "fmul_FPIEEE32"(%2, %49) : $Builtin.FPIEEE32 + %51 = struct $Float (%50) + %52 = partial_apply [callee_guaranteed] %8(%48, %3) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %53 = partial_apply [callee_guaranteed] %10(%52) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + %54 = builtin "fmul_FPIEEE32"(%50, %49) : $Builtin.FPIEEE32 + %55 = struct $Float (%54) + %56 = partial_apply [callee_guaranteed] %8(%48, %51) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // function_ref pullback of Class.method() + %57 = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Class.TangentVector + %58 = partial_apply [callee_guaranteed] %57(%47, %53, %56) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Class.TangentVector + %59 = tuple (%55, %58) + return %59 +} // end sil function '$s4test5ClassV6methodSfyFTJrSpSr' + + +/////////////////// +/// Test case 3 /// +/////////////////// + +/// This SIL corresponds to the following Swift: +/// +/// @differentiable(reverse) +/// func cond_tuple_var(_ x: Float) -> Float { +/// // Convoluted function returning `x + x`. +/// var y: (Float, Float) = (x, x) +/// var z: (Float, Float) = (x + x, x - x) +/// if x > 0 { +/// let w = (x, x) +/// y.0 = w.1 +/// y.1 = w.0 +/// z.0 = z.0 - y.0 +/// z.1 = z.1 + y.0 +/// } else { +/// z = (1 * x, x) +/// } +/// return y.0 + y.1 - z.0 + z.1 +/// } + +enum _AD__$s4test14cond_tuple_varyS2fF_bb0__Pred__src_0_wrt_0 { +} + +enum _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0 { + case bb0(((Float) -> (Float, Float), (Float) -> (Float, Float))) +} + +enum _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0 { + case bb0(((Float) -> (Float, Float), (Float) -> (Float, Float))) +} + +enum _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0 { + case bb2((predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, (Float) -> Float)) + case bb1((predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float), (Float) -> (Float, Float))) +} + +sil @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) +sil @$sSf16_DifferentiationE12_vjpSubtract3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) + +// pullback of cond_tuple_var(_:) +sil private @$s4test14cond_tuple_varyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float { +bb0(%0 : $Float, %1 : $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, %2 : $@callee_guaranteed (Float) -> (Float, Float), %3 : $@callee_guaranteed (Float) -> (Float, Float), %4 : $@callee_guaranteed (Float) -> (Float, Float)): + %5 = float_literal $Builtin.FPIEEE32, 0x0 // 0 + %10 = apply %4(%0) : $@callee_guaranteed (Float) -> (Float, Float) + strong_release %4 + %12 = tuple_extract %10, 0 + %13 = tuple_extract %10, 1 + %14 = struct_extract %13, #Float._value + %15 = builtin "fadd_FPIEEE32"(%5, %14) : $Builtin.FPIEEE32 + %17 = apply %3(%12) : $@callee_guaranteed (Float) -> (Float, Float) + strong_release %3 + %19 = tuple_extract %17, 0 + %20 = tuple_extract %17, 1 + %21 = struct_extract %20, #Float._value + %22 = builtin "fadd_FPIEEE32"(%5, %21) : $Builtin.FPIEEE32 + %24 = apply %2(%19) : $@callee_guaranteed (Float) -> (Float, Float) + strong_release %2 + %26 = tuple_extract %24, 0 + %27 = tuple_extract %24, 1 + %28 = struct_extract %27, #Float._value + %29 = builtin "fadd_FPIEEE32"(%5, %28) : $Builtin.FPIEEE32 + %31 = struct_extract %26, #Float._value + %32 = builtin "fadd_FPIEEE32"(%5, %31) : $Builtin.FPIEEE32 + switch_enum %1, case #_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb2 + +bb1(%44 : $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float)): + %45 = tuple_extract %44, 0 + %46 = tuple_extract %44, 1 + %47 = builtin "fadd_FPIEEE32"(%15, %5) : $Builtin.FPIEEE32 + %49 = builtin "fadd_FPIEEE32"(%22, %5) : $Builtin.FPIEEE32 + %50 = struct $Float (%49) + %52 = apply %46(%50) : $@callee_guaranteed (Float) -> Float + strong_release %46 + %54 = struct_extract %52, #Float._value + %55 = builtin "fadd_FPIEEE32"(%54, %47) : $Builtin.FPIEEE32 + %61 = unchecked_enum_data %45, #_AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt + %62 = tuple_extract %61, 1 + %63 = tuple_extract %61, 0 + br bb3(%55, %32, %29, %5, %5, %63, %62) + +bb2(%65 : $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float), @callee_guaranteed (Float) -> (Float, Float))): + %66 = tuple_extract %65, 0 + %67 = tuple_extract %65, 1 + %68 = tuple_extract %65, 2 + %69 = builtin "fadd_FPIEEE32"(%15, %5) : $Builtin.FPIEEE32 + %70 = struct $Float (%69) + %72 = apply %68(%70) : $@callee_guaranteed (Float) -> (Float, Float) + strong_release %68 + %74 = tuple_extract %72, 0 + %75 = tuple_extract %72, 1 + %76 = struct_extract %74, #Float._value + %77 = builtin "fadd_FPIEEE32"(%76, %5) : $Builtin.FPIEEE32 + %78 = struct_extract %75, #Float._value + %79 = builtin "fadd_FPIEEE32"(%78, %5) : $Builtin.FPIEEE32 + %80 = builtin "fadd_FPIEEE32"(%32, %79) : $Builtin.FPIEEE32 + %82 = builtin "fadd_FPIEEE32"(%5, %77) : $Builtin.FPIEEE32 + %84 = builtin "fadd_FPIEEE32"(%22, %5) : $Builtin.FPIEEE32 + %85 = struct $Float (%84) + %87 = apply %67(%85) : $@callee_guaranteed (Float) -> (Float, Float) + strong_release %67 + %89 = tuple_extract %87, 0 + %90 = tuple_extract %87, 1 + %91 = struct_extract %89, #Float._value + %92 = builtin "fadd_FPIEEE32"(%91, %5) : $Builtin.FPIEEE32 + %93 = struct_extract %90, #Float._value + %94 = builtin "fadd_FPIEEE32"(%93, %5) : $Builtin.FPIEEE32 + %95 = builtin "fadd_FPIEEE32"(%80, %94) : $Builtin.FPIEEE32 + %97 = builtin "fadd_FPIEEE32"(%5, %92) : $Builtin.FPIEEE32 + %99 = builtin "fadd_FPIEEE32"(%29, %5) : $Builtin.FPIEEE32 + %101 = builtin "fadd_FPIEEE32"(%99, %5) : $Builtin.FPIEEE32 + %102 = builtin "fadd_FPIEEE32"(%95, %5) : $Builtin.FPIEEE32 + %104 = builtin "fadd_FPIEEE32"(%102, %5) : $Builtin.FPIEEE32 + %105 = builtin "fadd_FPIEEE32"(%101, %5) : $Builtin.FPIEEE32 + %106 = builtin "fadd_FPIEEE32"(%104, %5) : $Builtin.FPIEEE32 + %107 = builtin "fadd_FPIEEE32"(%105, %5) : $Builtin.FPIEEE32 + %108 = builtin "fadd_FPIEEE32"(%106, %107) : $Builtin.FPIEEE32 + %114 = unchecked_enum_data %66, #_AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt + %115 = tuple_extract %114, 1 + %116 = tuple_extract %114, 0 + br bb3(%108, %5, %5, %97, %82, %116, %115) + +bb3(%118 : $Builtin.FPIEEE32, %119 : $Builtin.FPIEEE32, %120 : $Builtin.FPIEEE32, %121 : $Builtin.FPIEEE32, %122 : $Builtin.FPIEEE32, %123 : $@callee_guaranteed (Float) -> (Float, Float), %124 : $@callee_guaranteed (Float) -> (Float, Float)): + %125 = builtin "fadd_FPIEEE32"(%122, %5) : $Builtin.FPIEEE32 + %126 = struct $Float (%125) + %127 = apply %124(%126) : $@callee_guaranteed (Float) -> (Float, Float) + strong_release %124 + %129 = tuple_extract %127, 0 + %130 = tuple_extract %127, 1 + %131 = struct_extract %129, #Float._value + %132 = builtin "fadd_FPIEEE32"(%131, %118) : $Builtin.FPIEEE32 + %133 = struct_extract %130, #Float._value + %134 = builtin "fadd_FPIEEE32"(%133, %132) : $Builtin.FPIEEE32 + %135 = builtin "fadd_FPIEEE32"(%121, %5) : $Builtin.FPIEEE32 + %136 = struct $Float (%135) + %137 = apply %123(%136) : $@callee_guaranteed (Float) -> (Float, Float) + strong_release %123 + %139 = tuple_extract %137, 0 + %140 = tuple_extract %137, 1 + %141 = struct_extract %139, #Float._value + %142 = builtin "fadd_FPIEEE32"(%141, %134) : $Builtin.FPIEEE32 + %143 = struct_extract %140, #Float._value + %144 = builtin "fadd_FPIEEE32"(%143, %142) : $Builtin.FPIEEE32 + %145 = builtin "fadd_FPIEEE32"(%120, %144) : $Builtin.FPIEEE32 + %146 = builtin "fadd_FPIEEE32"(%119, %145) : $Builtin.FPIEEE32 + %147 = struct $Float (%146) + return %147 +} // end sil function '$s4test14cond_tuple_varyS2fFTJpSpSr' + +// reverse-mode derivative of cond_tuple_var(_:) +sil hidden @$s4test14cond_tuple_varyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +[global: ] +bb0(%0 : $Float): + //=========== Test callsite and closure gathering logic ===========// + specify_test "autodiff_closure_specialize_get_pullback_closure_info" + // TRUNNER-LABEL: Specializing closures in function: $s4test14cond_tuple_varyS2fFTJrSpSr + // TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#F1:]], %[[#F2:]], %[[#F3:]]) : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER-NEXT: Passed in closures: + // TRUNNER-NEXT: 1. %[[#F1]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: 2. %[[#F2]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: 3. %[[#F3]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-EMPTY: + + //=========== Test specialized function signature and body ===========// + specify_test "autodiff_closure_specialize_specialized_function_signature_and_body" + // TRUNNER-LABEL: Generated specialized function: $s4test14cond_tuple_varyS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktl1_m5FZSf_M6SfcfU_0ef1_g4E12_i16Subtract3lhs3rhsk1_l1_mnl1_mo1_mP2U_ACTf1nnccc_n + // CHECK: sil private @$s4test14cond_tuple_varyS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktl1_m5FZSf_M6SfcfU_0ef1_g4E12_i16Subtract3lhs3rhsk1_l1_mnl1_mo1_mP2U_ACTf1nnccc_n : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0) -> Float { + // CHECK: bb0(%0 : $Float, %1 : $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0): + // CHECK: %[[#F2:]] = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) + // TRUNNER: %[[#F3:]] = thin_to_thick_function %[[#F2]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // COMBINE-NOT: = thin_to_thick_function + // CHECK: %[[#F4:]] = function_ref @$sSf16_DifferentiationE12_vjpSubtract3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) + // TRUNNER: %[[#F5:]] = thin_to_thick_function %[[#F4]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // COMBINE-NOT: = thin_to_thick_function + // CHECK: %[[#F6:]] = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) + // TRUNNER: %[[#F7:]] = thin_to_thick_function %[[#F6]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // COMBINE-NOT: = thin_to_thick_function + // CHECK: %[[#F8:]] = float_literal $Builtin.FPIEEE32, 0x0 // 0 + // TRUNNER: %[[#F9:]] = apply %[[#F7]](%0) : $@callee_guaranteed (Float) -> (Float, Float) + // COMBINE: %[[#F9:]] = apply %[[#F6]](%0) : $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER: strong_release %[[#F7]] : $@callee_guaranteed (Float) -> (Float, Float) + // CHECK: %[[#F11:]] = tuple_extract %[[#F9]] : $(Float, Float), 0 + // TRUNNER: %[[#F15:]] = apply %[[#F5]](%[[#F11]]) : $@callee_guaranteed (Float) -> (Float, Float) + // COMBINE: %[[#F15:]] = apply %[[#F4]](%[[#F11]]) : $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER: strong_release %[[#F5]] : $@callee_guaranteed (Float) -> (Float, Float) + // CHECK: %[[#F17:]] = tuple_extract %[[#F15]] : $(Float, Float), 0 + // TRUNNER: %[[#]] = apply %[[#F3]](%[[#F17]]) : $@callee_guaranteed (Float) -> (Float, Float) + // COMBINE: %[[#]] = apply %[[#F2]](%[[#F17]]) : $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER: strong_release %[[#F3]] : $@callee_guaranteed (Float) -> (Float, Float) + // CHECK: {{^}}bb1{{.*}}: + // CHECK: {{^}}bb2{{.*}}: + // CHECK: {{^}}bb3{{.*}}: + + //=========== Test rewritten body ===========// + specify_test "autodiff_closure_specialize_rewritten_caller_body" + // TRUNNER-LABEL: Rewritten caller body for: $s4test14cond_tuple_varyS2fFTJrSpSr: + // CHECK: sil hidden @$s4test14cond_tuple_varyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK: bb0(%0 : $Float): + // CHECK: %[[#G7:]] = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) + // CHECK: %[[#G11:]] = function_ref @$sSf16_DifferentiationE12_vjpSubtract3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) + // CHECK: bb1: + // CHECK: bb2: + // CHECK: bb3(%[[#G31:]] : $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, %[[#]] : $Builtin.FPIEEE32, %[[#]] : $Builtin.FPIEEE32): + + // COMBINE-NOT: = thin_to_thick_function + // COMBINE-NOT: = function_ref @$s4test14cond_tuple_varyS2fFTJpSpSr + + // TRUNNER: %[[#G33:]] = thin_to_thick_function %[[#G7]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER: %[[#G34:]] = thin_to_thick_function %[[#G11]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER: %[[#G35:]] = thin_to_thick_function %[[#G7]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER: %[[#]] = function_ref @$s4test14cond_tuple_varyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float + + // CHECK: %[[#G41:]] = function_ref @$s4test14cond_tuple_varyS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktl1_m5FZSf_M6SfcfU_0ef1_g4E12_i16Subtract3lhs3rhsk1_l1_mnl1_mo1_mP2U_ACTf1nnccc_n : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0) -> Float + // CHECK: %[[#G42:]] = partial_apply [callee_guaranteed] %[[#G41]](%[[#G31]]) : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0) -> Float + // TRUNNER: release_value %[[#G33]] : $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER: release_value %[[#G34]] : $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER: release_value %[[#G35]] : $@callee_guaranteed (Float) -> (Float, Float) + // CHECK: %[[#G46:]] = tuple (%[[#]] : $Float, %[[#G42]] : $@callee_guaranteed (Float) -> Float) + // CHECK: return %[[#G46]] + + %4 = struct_extract %0, #Float._value + %5 = builtin "fadd_FPIEEE32"(%4, %4) : $Builtin.FPIEEE32 + // function_ref closure #1 in static Float._vjpAdd(lhs:rhs:) + %7 = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) + %8 = thin_to_thick_function %7 to $@callee_guaranteed (Float) -> (Float, Float) + %9 = builtin "fsub_FPIEEE32"(%4, %4) : $Builtin.FPIEEE32 + // function_ref closure #1 in static Float._vjpSubtract(lhs:rhs:) + %11 = function_ref @$sSf16_DifferentiationE12_vjpSubtract3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) + %12 = thin_to_thick_function %11 to $@callee_guaranteed (Float) -> (Float, Float) + %13 = float_literal $Builtin.FPIEEE32, 0x0 // 0 + %14 = builtin "fcmp_olt_FPIEEE32"(%13, %4) : $Builtin.Int1 + %15 = tuple (%8, %12) + cond_br %14, bb1, bb2 + +bb1: + %17 = enum $_AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, #_AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %15 + %22 = builtin "fsub_FPIEEE32"(%5, %4) : $Builtin.FPIEEE32 + %23 = thin_to_thick_function %11 to $@callee_guaranteed (Float) -> (Float, Float) + %25 = builtin "fadd_FPIEEE32"(%9, %4) : $Builtin.FPIEEE32 + %27 = thin_to_thick_function %7 to $@callee_guaranteed (Float) -> (Float, Float) + %28 = tuple $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float), @callee_guaranteed (Float) -> (Float, Float)) (%17, %23, %27) + %29 = enum $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %28 + br bb3(%29, %25, %22) + +bb2: + %31 = enum $_AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, #_AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %15 + %32 = float_literal $Builtin.FPIEEE32, 0x3F800000 // 1 + %33 = struct $Float (%32) + %34 = builtin "fmul_FPIEEE32"(%32, %4) : $Builtin.FPIEEE32 + // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:) + %36 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %37 = partial_apply [callee_guaranteed] %36(%0, %33) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // function_ref autodiff subset parameters thunk for pullback from @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float) + %38 = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + %39 = partial_apply [callee_guaranteed] %38(%37) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + %41 = tuple $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float) (%31, %39) + %42 = enum $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %41 + br bb3(%42, %4, %34) + +bb3(%44 : $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, %45 : $Builtin.FPIEEE32, %46 : $Builtin.FPIEEE32): + %47 = thin_to_thick_function %7 to $@callee_guaranteed (Float) -> (Float, Float) + %48 = builtin "fsub_FPIEEE32"(%5, %46) : $Builtin.FPIEEE32 + %49 = thin_to_thick_function %11 to $@callee_guaranteed (Float) -> (Float, Float) + %50 = builtin "fadd_FPIEEE32"(%48, %45) : $Builtin.FPIEEE32 + %51 = struct $Float (%50) + %52 = thin_to_thick_function %7 to $@callee_guaranteed (Float) -> (Float, Float) + // function_ref pullback of cond_tuple_var(_:) + %53 = function_ref @$s4test14cond_tuple_varyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float + %54 = partial_apply [callee_guaranteed] %53(%44, %47, %49, %52) : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float + %55 = tuple (%51, %54) + return %55 +} // end sil function '$s4test14cond_tuple_varyS2fFTJrSpSr' diff --git a/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_no_bte1.sil b/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_no_bte1.sil new file mode 100644 index 0000000000000..0901c645843fd --- /dev/null +++ b/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_no_bte1.sil @@ -0,0 +1,181 @@ +/// Multi basic block VJP, pullback not accepting branch tracing enum argument. + +// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s --check-prefixes=TRUNNER,CHECK +// RUN: %target-sil-opt -sil-print-types -autodiff-closure-specialization -sil-combine %s -o - | %FileCheck %s --check-prefixes=COMBINE,CHECK + +// REQUIRES: swift_in_compiler + +sil_stage canonical + +import Builtin +import Swift +import SwiftShims + +import _Differentiation + +/// This SIL corresponds to the following Swift: +/// +/// struct Class: Differentiable { +/// var stored: Float +/// var optional: Float? +/// +/// init(stored: Float, optional: Float?) { +/// self.stored = stored +/// self.optional = optional +/// } +/// +/// @differentiable(reverse) +/// func method() -> Float { +/// let c: Class +/// do { +/// let tmp = Class(stored: 1 * stored, optional: optional) +/// let tuple = (tmp, tmp) +/// c = tuple.0 +/// } +/// if let x = c.optional { +/// return x * c.stored +/// } +/// return 1 * c.stored +/// } +/// } +/// +/// @differentiable(reverse) +/// func methodWrapper(_ x: Class) -> Float { +/// x.method() +/// } + +struct Class : Differentiable { + @_hasStorage var stored: Float { get set } + @_hasStorage @_hasInitialValue var optional: Float? { get set } + init(stored: Float, optional: Float?) + @differentiable(reverse, wrt: self) + func method() -> Float + struct TangentVector : AdditiveArithmetic, Differentiable { + @_hasStorage var stored: Float { get set } + @_hasStorage var optional: Optional.TangentVector { get set } + static func + (lhs: Class.TangentVector, rhs: Class.TangentVector) -> Class.TangentVector + static func - (lhs: Class.TangentVector, rhs: Class.TangentVector) -> Class.TangentVector + typealias TangentVector = Class.TangentVector + @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Class.TangentVector, _ b: Class.TangentVector) -> Bool + init(stored: Float, optional: Optional.TangentVector) + static var zero: Class.TangentVector { get } + } + mutating func move(by offset: Class.TangentVector) +} + +enum _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0 { + case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional.TangentVector))) +} + +enum _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0 { + case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional.TangentVector))) +} + +enum _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0 { + case bb2((predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, (Float) -> Float)) + case bb1((predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float))) +} + +sil @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) +sil [transparent] [thunk] @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float +sil @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (Class.TangentVector) -> (Float, Optional.TangentVector) +sil @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + +// pullback of methodWrapper(_:) +sil private [signature_optimized_thunk] [always_inline] @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector { +bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Class.TangentVector): + %2 = apply %1(%0) : $@callee_guaranteed (Float) -> Class.TangentVector + strong_release %1 + return %2 +} // end sil function '$s4test13methodWrapperySfAA5ClassVFTJpSpSr' + +// reverse-mode derivative of methodWrapper(_:) +sil hidden @$s4test13methodWrapperySfAA5ClassVFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) { +bb0(%0 : $Class): + //=========== Test callsite and closure gathering logic ===========// + specify_test "autodiff_closure_specialize_get_pullback_closure_info" + // TRUNNER-LABEL: Specializing closures in function: $s4test13methodWrapperySfAA5ClassVFTJrSpSr + // TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#A36:]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector + // TRUNNER-NEXT: Passed in closures: + // TRUNNER-NEXT: 1. %[[#A36]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + // TRUNNER-EMPTY: + + //=========== Test specialized function signature and body ===========// + specify_test "autodiff_closure_specialize_specialized_function_signature_and_body" + // TRUNNER-LABEL: Generated specialized function: $s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n + // CHECK: sil private [signature_optimized_thunk] [always_inline] @$s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector { + // CHECK: bb0(%0 : $Float, %1 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0): + // CHECK: %[[#B2:]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + // TRUNNER: %[[#B3:]] = partial_apply [callee_guaranteed] %[[#B2]](%1) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + // TRUNNER: %[[#B4:]] = apply %[[#B3]](%0) : $@callee_guaranteed (Float) -> Class.TangentVector + // COMBINE-NOT: partial_apply + // COMBINE: %[[#B4:]] = apply %[[#B2]](%0, %1) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + // TRUNNER: strong_release %[[#B3]] : $@callee_guaranteed (Float) -> Class.TangentVector + // CHECK: return %[[#B4]] + + //=========== Test rewritten body ===========// + specify_test "autodiff_closure_specialize_rewritten_caller_body" + // TRUNNER-LABEL: Rewritten caller body for: $s4test13methodWrapperySfAA5ClassVFTJrSpSr: + // CHECK: sil hidden @$s4test13methodWrapperySfAA5ClassVFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) { + // CHECK: bb3(%[[#C33:]] : $Float, %[[#C34:]] : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0): + // TRUNNER: %[[#C35:]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + // TRUNNER: %[[#C37:]] = partial_apply [callee_guaranteed] %[[#C35]](%[[#C34]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + // TRUNNER: %[[#C38:]] = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector + // COMBINE-NOT: function_ref @$s4test5ClassV6methodSfyFTJpSpSr + // COMBINE-NOT: partial_apply + // COMBINE-NOT: function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr + // CHECK: %[[#C39:]] = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr08$s4test5D19V6methodSfyFTJpSpSr4main05_AD__edfG24F_bb3__Pred__src_0_wrt_0OTf1nc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + // CHECK: %[[#C40:]] = partial_apply [callee_guaranteed] %[[#C39]](%[[#C34]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + // TRUNNER: release_value %[[#C37]] : $@callee_guaranteed (Float) -> Class.TangentVector + // CHECK: %[[#C42:]] = tuple (%[[#C33]] : $Float, %[[#C40]] : $@callee_guaranteed (Float) -> Class.TangentVector) + // CHECK: return %[[#C42]] + + %3 = float_literal $Builtin.FPIEEE32, 0x3F800000 // 1 + %4 = struct $Float (%3) + %5 = struct_extract %0, #Class.stored + %6 = struct_extract %5, #Float._value + %7 = builtin "fmul_FPIEEE32"(%3, %6) : $Builtin.FPIEEE32 + %8 = struct $Float (%7) + // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:) + %9 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %10 = partial_apply [callee_guaranteed] %9(%5, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // function_ref autodiff subset parameters thunk for pullback from @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float) + %11 = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + %12 = partial_apply [callee_guaranteed] %11(%10) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + %13 = struct_extract %0, #Class.optional + // function_ref pullback of Class.init(stored:optional:) + %26 = function_ref @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (Class.TangentVector) -> (Float, Optional.TangentVector) + %27 = thin_to_thick_function %26 to $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector) + %28 = tuple (%12, %27) + switch_enum %13, case #Optional.some!enumelt: bb1, case #Optional.none!enumelt: bb2 + +bb1(%30 : $Float): + %31 = enum $_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %28 + %33 = struct_extract %30, #Float._value + %34 = builtin "fmul_FPIEEE32"(%33, %7) : $Builtin.FPIEEE32 + %35 = struct $Float (%34) + %36 = partial_apply [callee_guaranteed] %9(%8, %30) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %37 = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float)) (%31, %36) + %38 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %37 + br bb3(%35, %38) + +bb2: + %40 = enum $_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %28 + %41 = builtin "fmul_FPIEEE32"(%3, %7) : $Builtin.FPIEEE32 + %42 = struct $Float (%41) + %43 = partial_apply [callee_guaranteed] %9(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %44 = partial_apply [callee_guaranteed] %11(%43) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + %45 = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float) (%40, %44) + %46 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %45 + br bb3(%42, %46) + +bb3(%48 : $Float, %49 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0): + // function_ref pullback of Class.method() + %50 = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + %51 = partial_apply [callee_guaranteed] %50(%49) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + // function_ref pullback of methodWrapper(_:) + %52 = function_ref @$s4test13methodWrapperySfAA5ClassVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector + %53 = partial_apply [callee_guaranteed] %52(%51) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) -> Class.TangentVector + %54 = tuple (%48, %53) + return %54 +} // end sil function '$s4test13methodWrapperySfAA5ClassVFTJrSpSr' diff --git a/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_no_bte2.sil b/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_no_bte2.sil new file mode 100644 index 0000000000000..03325afb05af1 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_no_bte2.sil @@ -0,0 +1,203 @@ +/// Multi basic block VJP, pullback not accepting branch tracing enum argument. + +// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s --check-prefixes=TRUNNER,CHECK +// RUN: %target-sil-opt -sil-print-types -autodiff-closure-specialization -sil-combine %s -o - | %FileCheck %s --check-prefixes=COMBINE,CHECK + +// REQUIRES: swift_in_compiler + +/// _ArrayBuffer is part of the ObjC runtime interop. +// REQUIRES: objc_interop + +sil_stage canonical + +import Builtin +import Swift +import SwiftShims + +import _Differentiation + +sil [_semantics "array.append_contentsOf"] @$sSa6append10contentsOfyqd__n_t7ElementQyd__RszSTRd__lFSf_SaySfGTg5 : $@convention(method) (@owned Array, @inout Array) -> () + +sil @$sSa16_DifferentiationAA14DifferentiableRzlE15_vjpConcatenateySayxG5value_SaA2aBRzlE0B4ViewVy13TangentVectorQz_G_AJtAJc8pullbacktAD_ADtFZAKL_yAJ_AJtAjaBRzlFSf_Tg5 : $@convention(thin) (@guaranteed Array.DifferentiableView, @guaranteed Array, @guaranteed Array) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + +sil [transparent] [reabstraction_thunk] @$sSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVy13TangentVectorAaBPQz_GA2HIeggoo_A3HIeggoo_AaBRzlTRSf_Tg5 : $@convention(thin) (@guaranteed Array.DifferentiableView, @guaranteed @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView)) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + +sil [_semantics "array.check_subscript"] @$sSa15_checkSubscript_20wasNativeTypeCheckeds16_DependenceTokenVSi_SbtFSf_Tg5 : $@convention(method) (Int, Bool, @guaranteed Array) -> _DependenceToken + +sil [transparent] [thunk] @$sSa16_DifferentiationAA14DifferentiableRzlE13_vjpSubscript5indexx5value_SaA2aBRzlE0B4ViewVy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5 : $@convention(thin) (@in_guaranteed Float, @guaranteed Array, Int) -> @owned Array.DifferentiableView + +sil [transparent] [thunk] @$s13TangentVector16_Differentiation14DifferentiablePQzSaA2bCRzlE0D4ViewVyAE_GIegno_AeHIegno_AbCRzlTRSf_TG5 : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> @owned Array.DifferentiableView + +// specialized pullback of sum1(_:_:) +sil private [signature_optimized_thunk] [always_inline] @$s4test4sum1ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyd8_GIegno_D10AEIegyo_TRSfSa01_F0AE0H0RzlE0hL0VySf_GIegno_Tf1nnc_n : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) { +bb0(%0 : $Float, %1 : $@callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), %2 : $@callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView): + %3 = float_literal $Builtin.FPIEEE32, 0x0 // 0 + %4 = struct_extract %0, #Float._value + %5 = builtin "fadd_FPIEEE32"(%3, %4) : $Builtin.FPIEEE32 + %6 = struct $Float (%5) + %7 = alloc_stack $Float + store %6 to %7 + %9 = apply %2(%7) : $@callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView + strong_release %2 + dealloc_stack %7 + %13 = apply %1(%9) : $@callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + release_value %9 + strong_release %1 + return %13 +} // end sil function '$s4test4sum1ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyd8_GIegno_D10AEIegyo_TRSfSa01_F0AE0H0RzlE0hL0VySf_GIegno_Tf1nnc_n' + +// reverse-mode derivative of sum1(_:_:) +sil hidden @$s4test4sum1ySfSaySfG_ACtFTJrSSpSr : $@convention(thin) (@guaranteed Array, @guaranteed Array) -> (Float, @owned @callee_guaranteed (Float) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView)) { +bb0(%0 : $Array, %1 : $Array): + //=========== Test callsite and closure gathering logic ===========// + specify_test "autodiff_closure_specialize_get_pullback_closure_info" + // TRUNNER-LABEL: Specializing closures in function: $s4test4sum1ySfSaySfG_ACtFTJrSSpSr + // TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + // TRUNNER-NEXT: Passed in closures: + // TRUNNER-NEXT: 1. %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed Array, Int) -> @owned Array.DifferentiableView + // TRUNNER-EMPTY: + + //=========== Test specialized function signature and body ===========// + specify_test "autodiff_closure_specialize_specialized_function_signature_and_body" + // TRUNNER-LABEL: Generated specialized function: $s4test4sum1ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyd8_GIegno_D10AEIegyo_TRSfSa01_F0AE0H0RzlE0hL0VySf_GIegno_Tf1nnc_n0ce1_fghi39E13_vjpSubscript5indexx5value_SaA2aBRzljkl48Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaV7FSf_TG5ACSiTf1nnc_n + // CHECK: sil private [signature_optimized_thunk] [always_inline] @$s4test4sum1ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyd8_GIegno_D10AEIegyo_TRSfSa01_F0AE0H0RzlE0hL0VySf_GIegno_Tf1nnc_n0ce1_fghi39E13_vjpSubscript5indexx5value_SaA2aBRzljkl48Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaV7FSf_TG5ACSiTf1nnc_n : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned Array, Int) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) { + // CHECK: bb0(%0 : $Float, %1 : $@callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), %2 : $Array, %3 : $Int): + // CHECK: %[[#D4:]] = function_ref @$sSa16_DifferentiationAA14DifferentiableRzlE13_vjpSubscript5indexx5value_SaA2aBRzlE0B4ViewVy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5 : $@convention(thin) (@in_guaranteed Float, @guaranteed Array, Int) -> @owned Array.DifferentiableView + // CHECK: %[[#D5:]] = partial_apply [callee_guaranteed] %[[#D4]](%2, %3) : $@convention(thin) (@in_guaranteed Float, @guaranteed Array, Int) -> @owned Array.DifferentiableView + // CHECK: %[[#D6:]] = function_ref @$s13TangentVector16_Differentiation14DifferentiablePQzSaA2bCRzlE0D4ViewVyAE_GIegno_AeHIegno_AbCRzlTRSf_TG5 : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> @owned Array.DifferentiableView + // TRUNNER: %[[#D7:]] = partial_apply [callee_guaranteed] %[[#D6]](%[[#D5]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> @owned Array.DifferentiableView + // CHECK: store %[[#D11:]] to %[[#D12:]] : $*Float + // TRUNNER: %[[#D14:]] = apply %[[#D7]](%[[#D12]]) : $@callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView + // COMBINE: %[[#D14:]] = apply %[[#D6]](%[[#D12]], %[[#D5]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> @owned Array.DifferentiableView + // TRUNNER: strong_release %[[#D7]] : $@callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView + // COMBINE: strong_release %[[#D5]] : $@callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView + // CHECK: %[[#D17:]] = apply %1(%[[#D14]]) : $@callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + // CHECK: return %[[#D17]] : $(Array.DifferentiableView, Array.DifferentiableView) + + //=========== Test rewritten body ===========// + specify_test "autodiff_closure_specialize_rewritten_caller_body" + // TRUNNER-LABEL: Rewritten caller body for: $s4test4sum1ySfSaySfG_ACtFTJrSSpSr: + // CHECK: sil hidden @$s4test4sum1ySfSaySfG_ACtFTJrSSpSr : $@convention(thin) (@guaranteed Array, @guaranteed Array) -> (Float, @owned @callee_guaranteed (Float) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView)) { + // CHECK: bb8: + // CHECK: %[[#E42:]] = tuple_extract %[[#]] : $(Array, @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView)), 1 + + // TRUNNER: %[[#E44:]] = function_ref @$sSa16_DifferentiationAA14DifferentiableRzlE13_vjpSubscript5indexx5value_SaA2aBRzlE0B4ViewVy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5 : $@convention(thin) (@in_guaranteed Float, @guaranteed Array, Int) -> @owned Array.DifferentiableView + // TRUNNER: %[[#E45:]] = partial_apply [callee_guaranteed] %[[#E44]](%[[#]], %[[#]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed Array, Int) -> @owned Array.DifferentiableView + // TRUNNER: %[[#E46:]] = function_ref @$s13TangentVector16_Differentiation14DifferentiablePQzSaA2bCRzlE0D4ViewVyAE_GIegno_AeHIegno_AbCRzlTRSf_TG5 : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> @owned Array.DifferentiableView + // TRUNNER: %[[#E47:]] = partial_apply [callee_guaranteed] %[[#E46]](%[[#E45]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> @owned Array.DifferentiableView + // TRUNNER: %[[#]] = function_ref @$s4test4sum1ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyd8_GIegno_D10AEIegyo_TRSfSa01_F0AE0H0RzlE0hL0VySf_GIegno_Tf1nnc_n : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + + // COMBINE-NOT: function_ref @$sSa16_DifferentiationAA14DifferentiableRzlE13_vjpSubscript5indexx5value_SaA2aBRzlE0B4ViewVy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5 + // COMBINE-NOT: partial_apply + // COMBINE-NOT: function_ref @$s13TangentVector16_Differentiation14DifferentiablePQzSaA2bCRzlE0D4ViewVyAE_GIegno_AeHIegno_AbCRzlTRSf_TG5 + // COMBINE-NOT: partial_apply + // COMBINE-NOT: function_ref @$s4test4sum1ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyd8_GIegno_D10AEIegyo_TRSfSa01_F0AE0H0RzlE0hL0VySf_GIegno_Tf1nnc_n + + // CHECK: %[[#E52:]] = function_ref @$s4test4sum1ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyd8_GIegno_D10AEIegyo_TRSfSa01_F0AE0H0RzlE0hL0VySf_GIegno_Tf1nnc_n0ce1_fghi39E13_vjpSubscript5indexx5value_SaA2aBRzljkl48Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaV7FSf_TG5ACSiTf1nnc_n : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned Array, Int) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + // CHECK: %[[#E53:]] = partial_apply [callee_guaranteed] %[[#E52]](%[[#E42]], %[[#]], %[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned Array, Int) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + // CHECK: %[[#E55:]] = tuple (%[[#]] : $Float, %[[#E53]] : $@callee_guaranteed (Float) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView)) + // CHECK: return %[[#E55]] + + %4 = alloc_stack [var_decl] $Array + store %0 to %4 + // function_ref specialized Array.append(contentsOf:) + %6 = function_ref @$sSa6append10contentsOfyqd__n_t7ElementQyd__RszSTRd__lFSf_SaySfGTg5 : $@convention(method) (@owned Array, @inout Array) -> () + retain_value %0 + retain_value %1 + %9 = apply %6(%1, %4) : $@convention(method) (@owned Array, @inout Array) -> () + %10 = load %4 + dealloc_stack %4 + // function_ref specialized pullback #1 (_:) in static Array._vjpConcatenate(_:_:) + %12 = function_ref @$sSa16_DifferentiationAA14DifferentiableRzlE15_vjpConcatenateySayxG5value_SaA2aBRzlE0B4ViewVy13TangentVectorQz_G_AJtAJc8pullbacktAD_ADtFZAKL_yAJ_AJtAjaBRzlFSf_Tg5 : $@convention(thin) (@guaranteed Array.DifferentiableView, @guaranteed Array, @guaranteed Array) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + %13 = partial_apply [callee_guaranteed] %12(%0, %1) : $@convention(thin) (@guaranteed Array.DifferentiableView, @guaranteed Array, @guaranteed Array) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + // function_ref specialized thunk for @escaping @callee_guaranteed (@guaranteed [A.Differentiable.TangentVector].DifferentiableView) -> (@owned [A.Differentiable.TangentVector].DifferentiableView, @owned [A.Differentiable.TangentVector].DifferentiableView) + %14 = function_ref @$sSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVy13TangentVectorAaBPQz_GA2HIeggoo_A3HIeggoo_AaBRzlTRSf_Tg5 : $@convention(thin) (@guaranteed Array.DifferentiableView, @guaranteed @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView)) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + %15 = partial_apply [callee_guaranteed] %14(%13) : $@convention(thin) (@guaranteed Array.DifferentiableView, @guaranteed @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView)) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + %16 = convert_function %15 to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@guaranteed τ_0_0) -> (@owned τ_0_1, @owned τ_0_2) for .DifferentiableView, Array.DifferentiableView, Array.DifferentiableView> + %17 = tuple (%10, %16) + %18 = unchecked_bitwise_cast %17 to $(Array, @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView)) + %19 = tuple_extract %18, 0 + %21 = integer_literal $Builtin.Int64, 0 + %22 = struct $Int (%21) + %23 = integer_literal $Builtin.Int1, -1 + %24 = struct $Bool (%23) + // function_ref specialized Array._checkSubscript(_:wasNativeTypeChecked:) + %25 = function_ref @$sSa15_checkSubscript_20wasNativeTypeCheckeds16_DependenceTokenVSi_SbtFSf_Tg5 : $@convention(method) (Int, Bool, @guaranteed Array) -> _DependenceToken + %26 = apply %25(%22, %24, %19) : $@convention(method) (Int, Bool, @guaranteed Array) -> _DependenceToken + %27 = struct_extract %19, #Array._buffer + %28 = struct_extract %27, #_ArrayBuffer._storage + %29 = string_literal utf8 "Swift/BridgeStorage.swift" + %30 = integer_literal $Builtin.Word, 25 + %31 = builtin "ptrtoint_Word"(%29) : $Builtin.Word + %32 = integer_literal $Builtin.Int8, 2 + %33 = struct $StaticString (%31, %30, %32) + %34 = string_literal utf8 "" + %35 = integer_literal $Builtin.Word, 0 + %36 = builtin "ptrtoint_Word"(%34) : $Builtin.Word + %37 = struct $StaticString (%36, %35, %32) + %38 = struct_extract %28, #_BridgeStorage.rawValue + %39 = classify_bridge_object %38 + %40 = tuple_extract %39, 0 + %41 = tuple_extract %39, 1 + %42 = builtin "or_Int1"(%40, %41) : $Builtin.Int1 + %43 = integer_literal $Builtin.Int1, 0 + %44 = builtin "int_expect_Int1"(%42, %43) : $Builtin.Int1 + cond_br %44, bb1, bb2 + +bb1: + unreachable + +bb2: + %57 = bridge_object_to_word %38 to $Builtin.Word + %58 = builtin "zextOrBitCast_Word_Int64"(%57) : $Builtin.Int64 + %59 = integer_literal $Builtin.Int64, 9151314442816847879 + %60 = builtin "and_Int64"(%58, %59) : $Builtin.Int64 + %61 = builtin "cmp_eq_Int64"(%60, %21) : $Builtin.Int1 + %62 = builtin "int_expect_Int1"(%61, %23) : $Builtin.Int1 + cond_br %62, bb4, bb3 + +bb3: + unreachable + +bb4: + %75 = unchecked_ref_cast %38 to $__ContiguousArrayStorageBase + %76 = integer_literal $Builtin.Word, 24 + %77 = ref_element_addr %75, #__ContiguousArrayStorageBase.countAndCapacity + %78 = struct_element_addr %77, #_ArrayBody._storage + %79 = struct_element_addr %78, #_SwiftArrayBodyStorage.count + %80 = struct_element_addr %79, #Int._value + %81 = load %80 + %82 = builtin "cmp_slt_Int64"(%81, %21) : $Builtin.Int1 + %83 = builtin "int_expect_Int1"(%82, %43) : $Builtin.Int1 + cond_br %83, bb6, bb7 + +bb5: + unreachable + +bb6: + unreachable + +bb7: + %117 = builtin "assumeNonNegative_Int64"(%81) : $Builtin.Int64 + %118 = builtin "cmp_slt_Int64"(%21, %117) : $Builtin.Int1 + %119 = builtin "int_expect_Int1"(%118, %23) : $Builtin.Int1 + cond_br %119, bb8, bb5 + +bb8: + %121 = tuple_extract %18, 1 + %122 = ref_tail_addr [immutable] %75, $Float + %123 = load %122 + // function_ref specialized pullback #1 (_:) in Array._vjpSubscript(index:) + %124 = function_ref @$sSa16_DifferentiationAA14DifferentiableRzlE13_vjpSubscript5indexx5value_SaA2aBRzlE0B4ViewVy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5 : $@convention(thin) (@in_guaranteed Float, @guaranteed Array, Int) -> @owned Array.DifferentiableView + %125 = partial_apply [callee_guaranteed] %124(%19, %22) : $@convention(thin) (@in_guaranteed Float, @guaranteed Array, Int) -> @owned Array.DifferentiableView + // function_ref specialized thunk for @escaping @callee_guaranteed (@in_guaranteed A.Differentiable.TangentVector) -> (@owned [A.Differentiable.TangentVector].DifferentiableView) + %126 = function_ref @$s13TangentVector16_Differentiation14DifferentiablePQzSaA2bCRzlE0D4ViewVyAE_GIegno_AeHIegno_AbCRzlTRSf_TG5 : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> @owned Array.DifferentiableView + %127 = partial_apply [callee_guaranteed] %126(%125) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> @owned Array.DifferentiableView + // function_ref specialized pullback of sum1(_:_:) + %128 = function_ref @$s4test4sum1ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyd8_GIegno_D10AEIegyo_TRSfSa01_F0AE0H0RzlE0hL0VySf_GIegno_Tf1nnc_n : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + %129 = partial_apply [callee_guaranteed] %128(%121, %127) : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned @callee_guaranteed (@in_guaranteed Float) -> @owned Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + %130 = tuple (%123, %129) + retain_value %0 + retain_value %1 + return %130 +} // end sil function '$s4test4sum1ySfSaySfG_ACtFTJrSSpSr' diff --git a/test/AutoDiff/SILOptimizer/closure_specialization/single_bb.sil b/test/AutoDiff/SILOptimizer/closure_specialization/single_bb.sil new file mode 100644 index 0000000000000..176a4c9593a57 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/closure_specialization/single_bb.sil @@ -0,0 +1,490 @@ +// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s --check-prefixes=TRUNNER,CHECK +// RUN: %target-sil-opt -sil-print-types -autodiff-closure-specialization -sil-combine %s -o - | %FileCheck %s --check-prefixes=COMBINE,CHECK + +// REQUIRES: swift_in_compiler + +sil_stage canonical + +import Builtin +import Swift +import SwiftShims + +import _Differentiation + +//////////////////////////////////////////////////////////////// +// Single closure call site where closure is passed as @owned // +//////////////////////////////////////////////////////////////// +sil @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) + +sil private @$pullback_f : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float { +bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> (Float, Float)): + %2 = apply %1(%0) : $@callee_guaranteed (Float) -> (Float, Float) + strong_release %1 : $@callee_guaranteed (Float) -> (Float, Float) + %4 = tuple_extract %2 : $(Float, Float), 0 + %5 = tuple_extract %2 : $(Float, Float), 1 + %6 = struct_extract %5 : $Float, #Float._value + %7 = struct_extract %4 : $Float, #Float._value + %8 = builtin "fadd_FPIEEE32"(%6 : $Builtin.FPIEEE32, %7 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + %9 = struct $Float (%8 : $Builtin.FPIEEE32) + return %9 : $Float +} + +// reverse-mode derivative of f(_:) +sil hidden @$s4test1fyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +bb0(%0 : $Float): + //=========== Test callsite and closure gathering logic ===========// + specify_test "autodiff_closure_specialize_get_pullback_closure_info" + // TRUNNER-LABEL: Specializing closures in function: $s4test1fyS2fFTJrSpSr + // TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#A1:]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER-NEXT: Passed in closures: + // TRUNNER-NEXT: 1. %[[#A1]] = partial_apply [callee_guaranteed] %[[#]](%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-EMPTY: + + //=========== Test specialized function signature and body ===========// + specify_test "autodiff_closure_specialize_specialized_function_signature_and_body" + // TRUNNER-LABEL: Generated specialized function: $s11$pullback_f12$vjpMultiplyS2fTf1nc_n + // CHECK: sil private @$s11$pullback_f12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float { + // CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float): + // CHECK: %[[#A2:]] = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: %[[#A3:]] = partial_apply [callee_guaranteed] %[[#A2]](%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: %[[#]] = apply %[[#A3]](%0) : $@callee_guaranteed (Float) -> (Float, Float) + // COMBINE: %[[#]] = apply %[[#A2]](%0, %1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: strong_release %[[#A3]] : $@callee_guaranteed (Float) -> (Float, Float) + + //=========== Test rewritten body ===========// + specify_test "autodiff_closure_specialize_rewritten_caller_body" + // TRUNNER-LABEL: Rewritten caller body for: $s4test1fyS2fFTJrSpSr + // CHECK: sil hidden @$s4test1fyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK: bb0(%0 : $Float): + // CHECK: %[[#A4:]] = struct_extract %0 : $Float, #Float._value + // CHECK: %[[#A5:]] = builtin "fmul_FPIEEE32"(%[[#A4]] : $Builtin.FPIEEE32, %[[#A4]] : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + // CHECK: %[[#A6:]] = struct $Float (%[[#A5]] : $Builtin.FPIEEE32) + // COMBINE-NOT: function_ref @$vjpMultiply + // CHECK: %[[#A7:]] = function_ref @$s11$pullback_f12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float + // CHECK: %[[#A8:]] = partial_apply [callee_guaranteed] %[[#A7]](%0, %0) : $@convention(thin) (Float, Float, Float) -> Float + // CHECK: %[[#A9:]] = tuple (%[[#A6]] : $Float, %[[#A8]] : $@callee_guaranteed (Float) -> Float) + // CHECK: return %[[#A9]] + + %2 = struct_extract %0 : $Float, #Float._value + %3 = builtin "fmul_FPIEEE32"(%2 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + %4 = struct $Float (%3 : $Builtin.FPIEEE32) + // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:) + %5 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // function_ref pullback of f(_:) + %7 = function_ref @$pullback_f : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float + %8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float + %9 = tuple (%4 : $Float, %8 : $@callee_guaranteed (Float) -> Float) + return %9 : $(Float, @callee_guaranteed (Float) -> Float) +} + +///////////////////////////////////////////////////////////////////// +// Single closure call site where closure is passed as @guaranteed // +///////////////////////////////////////////////////////////////////// +sil private @$pullback_k : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float { +bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> (Float, Float)): + %2 = apply %1(%0) : $@callee_guaranteed (Float) -> (Float, Float) + %3 = tuple_extract %2 : $(Float, Float), 0 + %4 = tuple_extract %2 : $(Float, Float), 1 + %5 = struct_extract %4 : $Float, #Float._value + %6 = struct_extract %3 : $Float, #Float._value + %7 = builtin "fadd_FPIEEE32"(%5 : $Builtin.FPIEEE32, %6 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + %8 = struct $Float (%7 : $Builtin.FPIEEE32) + return %8 : $Float +} // end sil function '$pullback_k' + +// reverse-mode derivative of k(_:) +sil hidden @$s4test1kyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +bb0(%0 : $Float): + //=========== Test callsite and closure gathering logic ===========// + specify_test "autodiff_closure_specialize_get_pullback_closure_info" + // TRUNNER-LABEL: Specializing closures in function: $s4test1kyS2fFTJrSpSr + // TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#B1:]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER: Passed in closures: + // TRUNNER: 1. %[[#B1]] = partial_apply [callee_guaranteed] %[[#]](%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-EMPTY: + + //=========== Test specialized function signature and body ===========// + specify_test "autodiff_closure_specialize_specialized_function_signature_and_body" + // TRUNNER-LABEL: Generated specialized function: $s11$pullback_k12$vjpMultiplyS2fTf1nc_n + // CHECK: sil private @$s11$pullback_k12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float { + // CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float): + // CHECK: %[[#B2:]] = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: %[[#B3:]] = partial_apply [callee_guaranteed] %[[#B2]](%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: %[[#]] = apply %[[#B3]](%0) : $@callee_guaranteed (Float) -> (Float, Float) + // COMBINE: %[[#]] = apply %[[#B2]](%0, %1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: release_value %[[#B3]] : $@callee_guaranteed (Float) -> (Float, Float) + + //=========== Test rewritten body ===========// + specify_test "autodiff_closure_specialize_rewritten_caller_body" + // TRUNNER-LABEL: Rewritten caller body for: $s4test1kyS2fFTJrSpSr + // CHECK: sil hidden @$s4test1kyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK: bb0(%0 : $Float): + // CHECK: %[[#B4:]] = struct_extract %0 : $Float, #Float._value + // CHECK: %[[#B5:]] = builtin "fmul_FPIEEE32"(%[[#B4]] : $Builtin.FPIEEE32, %[[#B4]] : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + // CHECK: %[[#B6:]] = struct $Float (%[[#B5]] : $Builtin.FPIEEE32) + // COMBINE-NOT: function_ref @$vjpMultiply + // CHECK: %[[#B7:]] = function_ref @$s11$pullback_k12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float + // CHECK: %[[#B8:]] = partial_apply [callee_guaranteed] %[[#B7]](%0, %0) : $@convention(thin) (Float, Float, Float) -> Float + // CHECK: %[[#B9:]] = tuple (%[[#B6]] : $Float, %[[#B8]] : $@callee_guaranteed (Float) -> Float) + // CHECK: return %[[#B9]] + + %2 = struct_extract %0 : $Float, #Float._value + %3 = builtin "fmul_FPIEEE32"(%2 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + %4 = struct $Float (%3 : $Builtin.FPIEEE32) + // function_ref $vjpMultiply + %5 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // function_ref $pullback_k + %7 = function_ref @$pullback_k : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + %8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + strong_release %6 : $@callee_guaranteed (Float) -> (Float, Float) + %10 = tuple (%4 : $Float, %8 : $@callee_guaranteed (Float) -> Float) + return %10 : $(Float, @callee_guaranteed (Float) -> Float) +} // end sil function '$s4test1kyS2fFTJrSpSr' + +/////////////////////////////// +// Multiple closure callsite // +/////////////////////////////// +sil @$vjpSin : $@convention(thin) (Float, Float) -> Float +sil @$vjpCos : $@convention(thin) (Float, Float) -> Float + +// pullback of g(_:) +sil private @$pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float { +bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float, %2 : $@callee_guaranteed (Float) -> Float, %3 : $@callee_guaranteed (Float) -> (Float, Float)): + %4 = apply %3(%0) : $@callee_guaranteed (Float) -> (Float, Float) + strong_release %3 : $@callee_guaranteed (Float) -> (Float, Float) + %6 = tuple_extract %4 : $(Float, Float), 0 + %7 = tuple_extract %4 : $(Float, Float), 1 + %8 = apply %2(%7) : $@callee_guaranteed (Float) -> Float + strong_release %2 : $@callee_guaranteed (Float) -> Float + %10 = apply %1(%6) : $@callee_guaranteed (Float) -> Float + strong_release %1 : $@callee_guaranteed (Float) -> Float + %12 = struct_extract %8 : $Float, #Float._value + %13 = struct_extract %10 : $Float, #Float._value + %14 = builtin "fadd_FPIEEE32"(%13 : $Builtin.FPIEEE32, %12 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + %15 = struct $Float (%14 : $Builtin.FPIEEE32) + return %15 : $Float +} + +// reverse-mode derivative of g(_:) +sil hidden @$s4test1gyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +bb0(%0 : $Float): + //=========== Test callsite and closure gathering logic ===========// + specify_test "autodiff_closure_specialize_get_pullback_closure_info" + // TRUNNER-LABEL: Specializing closures in function: $s4test1gyS2fFTJrSpSr + // TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#C1:]], %[[#C2:]], %[[#C3:]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER: Passed in closures: + // TRUNNER: 1. %[[#C1]] = partial_apply [callee_guaranteed] %[[#]](%0) : $@convention(thin) (Float, Float) -> Float + // TRUNNER: 2. %[[#C2]] = partial_apply [callee_guaranteed] %[[#]](%0) : $@convention(thin) (Float, Float) -> Float + // TRUNNER: 3. %[[#C3]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-EMPTY: + + //=========== Test specialized function signature and body ===========// + specify_test "autodiff_closure_specialize_specialized_function_signature_and_body" + // TRUNNER-LABEL: Generated specialized function: $s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n + // CHECK: sil private @$s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n : $@convention(thin) (Float, Float, Float, Float, Float) -> Float { + // CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float, %3 : $Float, %4 : $Float): + // CHECK: %[[#C4:]] = function_ref @$vjpSin : $@convention(thin) (Float, Float) -> Float + // TRUNNER: %[[#C5:]] = partial_apply [callee_guaranteed] %[[#C4]](%1) : $@convention(thin) (Float, Float) -> Float + // CHECK: %[[#C6:]] = function_ref @$vjpCos : $@convention(thin) (Float, Float) -> Float + // TRUNNER: %[[#C7:]] = partial_apply [callee_guaranteed] %[[#C6]](%2) : $@convention(thin) (Float, Float) -> Float + // CHECK: %[[#C8:]] = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: %[[#C9:]] = partial_apply [callee_guaranteed] %[[#C8]](%3, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: %[[#]] = apply %[[#C9]](%0) : $@callee_guaranteed (Float) -> (Float, Float) + // COMBINE: %[[#]] = apply %[[#C8]](%0, %3, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER: strong_release %[[#C9]] : $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER: %[[#]] = apply %[[#C7]](%[[#]]) : $@callee_guaranteed (Float) -> Float + // COMBINE: %[[#]] = apply %[[#C6]](%[[#]], %2) : $@convention(thin) (Float, Float) -> Float + // TRUNNER: strong_release %[[#C7]] : $@callee_guaranteed (Float) -> Float + // TRUNNER: %[[#]] = apply %[[#C5]](%[[#]]) : $@callee_guaranteed (Float) -> Float + // COMBINE: %[[#]] = apply %[[#C4]](%[[#]], %1) : $@convention(thin) (Float, Float) -> Float + // TRUNNER: strong_release %[[#C5]] : $@callee_guaranteed (Float) -> Float + + //=========== Test rewritten body ===========// + specify_test "autodiff_closure_specialize_rewritten_caller_body" + // TRUNNER-LABEL: Rewritten caller body for: $s4test1gyS2fFTJrSpSr + // CHECK: sil hidden @$s4test1gyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK: bb0(%0 : $Float): + // CHECK: %[[#C10:]] = struct_extract %0 : $Float, #Float._value + // CHECK: %[[#C11:]] = builtin "int_sin_FPIEEE32"(%[[#C10]] : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + // CHECK: %[[#C12:]] = struct $Float (%[[#C11]] : $Builtin.FPIEEE32) + // COMBINE-NOT: function_ref @$vjpSin + // CHECK: %[[#C13:]] = builtin "int_cos_FPIEEE32"(%[[#C10]] : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + // COMBINE-NOT: function_ref @$vjpCos + // COMBINE-NOT: function_ref @$vjpMultiply + // CHECK: %[[#C14:]] = struct $Float (%[[#C13]] : $Builtin.FPIEEE32) + // CHECK: %[[#C15:]] = function_ref @$s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n : $@convention(thin) (Float, Float, Float, Float, Float) -> Float + // CHECK: %[[#C16:]] = partial_apply [callee_guaranteed] %[[#C15]](%0, %0, %[[#C14]], %[[#C12]]) : $@convention(thin) (Float, Float, Float, Float, Float) -> Float + // CHECK: %[[#C17:]] = tuple (%[[#]] : $Float, %[[#C16]] : $@callee_guaranteed (Float) -> Float) + // CHECK: return %[[#C17]] + + %2 = struct_extract %0 : $Float, #Float._value + %3 = builtin "int_sin_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + %4 = struct $Float (%3 : $Builtin.FPIEEE32) + // function_ref closure #1 in _vjpSin(_:) + %5 = function_ref @$vjpSin : $@convention(thin) (Float, Float) -> Float + %6 = partial_apply [callee_guaranteed] %5(%0) : $@convention(thin) (Float, Float) -> Float + %7 = builtin "int_cos_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + %8 = struct $Float (%7 : $Builtin.FPIEEE32) + // function_ref closure #1 in _vjpCos(_:) + %9 = function_ref @$vjpCos : $@convention(thin) (Float, Float) -> Float + %10 = partial_apply [callee_guaranteed] %9(%0) : $@convention(thin) (Float, Float) -> Float + %11 = builtin "fmul_FPIEEE32"(%3 : $Builtin.FPIEEE32, %7 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + %12 = struct $Float (%11 : $Builtin.FPIEEE32) + // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:) + %13 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %14 = partial_apply [callee_guaranteed] %13(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // function_ref pullback of g(_:) + %15 = function_ref @$pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float + %16 = partial_apply [callee_guaranteed] %15(%6, %10, %14) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float + %17 = tuple (%12 : $Float, %16 : $@callee_guaranteed (Float) -> Float) + return %17 : $(Float, @callee_guaranteed (Float) -> Float) +} + +/////////////////////////////// +/// Parameter subset thunks /// +/////////////////////////////// +struct X : Differentiable { + @_hasStorage var a: Float { get set } + @_hasStorage var b: Double { get set } + struct TangentVector : AdditiveArithmetic, Differentiable { + @_hasStorage var a: Float { get set } + @_hasStorage var b: Double { get set } + static func + (lhs: X.TangentVector, rhs: X.TangentVector) -> X.TangentVector + static func - (lhs: X.TangentVector, rhs: X.TangentVector) -> X.TangentVector + @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X.TangentVector, _ b: X.TangentVector) -> Bool + typealias TangentVector = X.TangentVector + init(a: Float, b: Double) + static var zero: X.TangentVector { get } + } + init(a: Float, b: Double) + mutating func move(by offset: X.TangentVector) +} + +sil [transparent] [thunk] @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector + +sil @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector + +sil shared @pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector { +bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> X.TangentVector): + %2 = apply %1(%0) : $@callee_guaranteed (Float) -> X.TangentVector + strong_release %1 : $@callee_guaranteed (Float) -> X.TangentVector + return %2 : $X.TangentVector +} + +sil hidden @$s5test21g1xSfAA1XV_tFTJrSpSr : $@convention(thin) (X) -> (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) { +bb0(%0 : $X): + //=========== Test callsite and closure gathering logic ===========// + specify_test "autodiff_closure_specialize_get_pullback_closure_info" + // TRUNNER-LABEL: Specializing closures in function: $s5test21g1xSfAA1XV_tFTJrSpSr + // TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector + // TRUNNER: Passed in closures: + // TRUNNER: 1. %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector + // TRUNNER-EMPTY: + + //=========== Test specialized function signature and body ===========// + specify_test "autodiff_closure_specialize_specialized_function_signature_and_body" + // TRUNNER-LABEL: Generated specialized function: $s10pullback_g0A2_fTf1nc_n + // CHECK: sil shared @$s10pullback_g0A2_fTf1nc_n : $@convention(thin) (Float) -> X.TangentVector { + // CHECK: bb0(%0 : $Float): + // CHECK: %[[#D1:]] = function_ref @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector + // CHECK: %[[#D2:]] = thin_to_thick_function %[[#D1]] : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector + // CHECK: %[[#D3:]] = function_ref @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector + // TRUNNER: %[[#D4:]] = partial_apply [callee_guaranteed] %[[#D3]](%[[#D2]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector + // TRUNNER: %[[#D5:]] = apply %[[#D4]](%0) : $@callee_guaranteed (Float) -> X.TangentVector + // COMBINE: %[[#D5:]] = apply %[[#D3]](%0, %[[#D2]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector + // TRUNNER: strong_release %[[#D4]] : $@callee_guaranteed (Float) -> X.TangentVector + // CHECK: return %[[#D5]] + + //=========== Test rewritten body ===========// + specify_test "autodiff_closure_specialize_rewritten_caller_body" + // TRUNNER-LABEL: Rewritten caller body for: $s5test21g1xSfAA1XV_tFTJrSpSr + // CHECK: sil hidden @$s5test21g1xSfAA1XV_tFTJrSpSr : $@convention(thin) (X) -> (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) { + // CHECK: bb0(%0 : $X): + // CHECK: %[[#D6:]] = struct_extract %0 : $X, #X.a + // COMBINE-NOT: function_ref @pullback_f + // COMBINE-NOT: function_ref @subset_parameter_thunk + // CHECK: %[[#D7:]] = function_ref @$s10pullback_g0A2_fTf1nc_n : $@convention(thin) (Float) -> X.TangentVector + // TRUNNER: %[[#D8:]] = partial_apply [callee_guaranteed] %[[#D7]]() : $@convention(thin) (Float) -> X.TangentVector + // COMBINE: %[[#D8:]] = thin_to_thick_function %[[#D7]] : $@convention(thin) (Float) -> X.TangentVector to $@callee_guaranteed (Float) -> X.TangentVector + // CHECK: %[[#D9:]] = tuple (%[[#D6]] : $Float, %[[#D8]] : $@callee_guaranteed (Float) -> X.TangentVector) + // CHECK: return %[[#D9]] + + %1 = struct_extract %0 : $X, #X.a + // function_ref pullback_f + %2 = function_ref @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector + %3 = thin_to_thick_function %2 : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector + // function_ref subset_parameter_thunk + %4 = function_ref @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector + %5 = partial_apply [callee_guaranteed] %4(%3) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector + // function_ref pullback_g + %6 = function_ref @pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector + %7 = partial_apply [callee_guaranteed] %6(%5) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector + %8 = tuple (%1 : $Float, %7 : $@callee_guaranteed (Float) -> X.TangentVector) + return %8 : $(Float, @callee_guaranteed (Float) -> X.TangentVector) +} + +/////////////////////////////////////////////////////////////////////// +///////// Specialized generic closures - PartialApply Closure ///////// +/////////////////////////////////////////////////////////////////////// + +// closure #1 in static Float._vjpMultiply(lhs:rhs:) +sil @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) + +// thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float) +sil [transparent] [reabstraction_thunk] @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) + +// function_ref specialized pullback of f(a:) +sil [transparent] [thunk] @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for ) -> @out Float + +// thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float) +sil [transparent] [reabstraction_thunk] @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float + +sil private [signature_optimized_thunk] [always_inline] @pullback_h : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float { +bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float): + %2 = apply %1(%0) : $@callee_guaranteed (Float) -> Float + strong_release %1 : $@callee_guaranteed (Float) -> Float + return %2 : $Float +} + +// reverse-mode derivative of h(x:) +sil hidden @$s5test21h1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +bb0(%0 : $Float): + //=========== Test callsite and closure gathering logic ===========// + specify_test "autodiff_closure_specialize_get_pullback_closure_info" + // TRUNNER-LABEL: Specializing closures in function: $s5test21h1xS2f_tFTJrSpSr + // TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float + // TRUNNER: Passed in closures: + // TRUNNER: 1. %[[#]] = partial_apply [callee_guaranteed] %[[#]](%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-EMPTY: + + //=========== Test specialized function signature and body ===========// + specify_test "autodiff_closure_specialize_specialized_function_signature_and_body" + // TRUNNER-LABEL: Generated specialized function: $s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n + // CHECK: sil private [signature_optimized_thunk] [always_inline] @$s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float { + // CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float): + // CHECK: %[[#E1:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // CHECK: %[[#E2:]] = partial_apply [callee_guaranteed] %[[#E1]](%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // CHECK: %[[#E3:]] = function_ref @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) + // CHECK: %[[#E4:]] = partial_apply [callee_guaranteed] %[[#E3]](%[[#E2]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) + // CHECK: %[[#E5:]] = convert_function %[[#E4]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for + // CHECK: %[[#E6:]] = function_ref @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for ) -> @out Float + // CHECK: %[[#E7:]] = partial_apply [callee_guaranteed] %[[#E6]](%[[#E5]]) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for ) -> @out Float + // CHECK: %[[#E8:]] = function_ref @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float + // TRUNNER: %[[#E9:]] = partial_apply [callee_guaranteed] %[[#E8]](%[[#E7]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float + // TRUNNER: %[[#E10:]] = apply %[[#E9]](%0) : $@callee_guaranteed (Float) -> Float + // COMBINE: %[[#E10:]] = apply %[[#E8]](%0, %[[#E7]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float + // TRUNNER: strong_release %[[#E9]] : $@callee_guaranteed (Float) -> Float + // COMBINE: strong_release %[[#E7]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float + // CHECK: return %[[#E10]] + + //=========== Test rewritten body ===========// + specify_test "autodiff_closure_specialize_rewritten_caller_body" + // TRUNNER-LABEL: Rewritten caller body for: $s5test21h1xS2f_tFTJrSpSr + // CHECK: sil hidden @$s5test21h1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK: bb0(%0 : $Float): + // CHECK: %[[#E11:]] = struct_extract %0 : $Float, #Float._value + // CHECK: %[[#E12:]] = builtin "fmul_FPIEEE32"(%[[#E11]] : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + // COMBINE-NOT: function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ + // COMBINE-NOT: function_ref @$sS3fIegydd_S3fIegnrr_TR + // COMBINE-NOT: function_ref @pullback_f_specialized + // COMBINE-NOT: function_ref @$sS2fIegnr_S2fIegyd_TR + // CHECK: %[[#E13:]] = struct $Float (%[[#E12]] : $Builtin.FPIEEE32) + // CHECK: %[[#E14:]] = function_ref @$s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float + // CHECK: %[[#E15:]] = partial_apply [callee_guaranteed] %[[#E14]](%0, %0) : $@convention(thin) (Float, Float, Float) -> Float + // CHECK: %[[#E16:]] = tuple (%[[#E13]] : $Float, %[[#E15]] : $@callee_guaranteed (Float) -> Float) + // CHECK: return %[[#E16]] + + %1 = struct_extract %0 : $Float, #Float._value + %2 = builtin "fmul_FPIEEE32"(%1 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 + + // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:) + %3 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %4 = partial_apply [callee_guaranteed] %3(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + + // function_ref thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float) + %5 = function_ref @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) + %6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) + %7 = convert_function %6 : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for + + // function_ref pullback_f_specialized + %8 = function_ref @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for ) -> @out Float + %9 = partial_apply [callee_guaranteed] %8(%7) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for ) -> @out Float + + // function_ref thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float) + %10 = function_ref @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float + %11 = partial_apply [callee_guaranteed] %10(%9) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float + %12 = struct $Float (%2 : $Builtin.FPIEEE32) + + // function_ref pullback_h + %13 = function_ref @pullback_h : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float + %14 = partial_apply [callee_guaranteed] %13(%11) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float + %15 = tuple (%12 : $Float, %14 : $@callee_guaranteed (Float) -> Float) + return %15 : $(Float, @callee_guaranteed (Float) -> Float) +} + +////////////////////////////////////////////////////////////////////////////// +///////// Specialized generic closures - ThinToThickFunction closure ///////// +////////////////////////////////////////////////////////////////////////////// + +sil [transparent] [thunk] @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float + +sil [transparent] [reabstraction_thunk] @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float + +sil private [signature_optimized_thunk] [always_inline] @pullback_z : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float { +bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float): + %2 = apply %1(%0) : $@callee_guaranteed (Float) -> Float + strong_release %1 : $@callee_guaranteed (Float) -> Float + return %2 : $Float +} + +sil hidden @$s5test21z1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +bb0(%0 : $Float): + //=========== Test callsite and closure gathering logic ===========// + specify_test "autodiff_closure_specialize_get_pullback_closure_info" + // TRUNNER-LABEL: Specializing closures in function: $s5test21z1xS2f_tFTJrSpSr + // TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float + // TRUNNER: Passed in closures: + // TRUNNER: 1. %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float + // TRUNNER-EMPTY: + + //=========== Test specialized function signature and body ===========// + specify_test "autodiff_closure_specialize_specialized_function_signature_and_body" + // TRUNNER-LABEL: Generated specialized function: $s10pullback_z0A14_y_specializedTf1nc_n + // CHECK: sil private [signature_optimized_thunk] [always_inline] @$s10pullback_z0A14_y_specializedTf1nc_n : $@convention(thin) (Float) -> Float { + // CHECK: bb0(%0 : $Float): + // CHECK: %[[#F1:]] = function_ref @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float + // CHECK: %[[#F2:]] = thin_to_thick_function %[[#F1]] : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float + // CHECK: %[[#F3:]] = function_ref @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float + // TRUNNER: %[[#F4:]] = partial_apply [callee_guaranteed] %[[#F3]](%[[#F2]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float + // TRUNNER: %[[#F5:]] = apply %[[#F4]](%0) : $@callee_guaranteed (Float) -> Float + // COMBINE: %[[#F5:]] = apply %[[#F3]](%0, %[[#F2]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float + // TRUNNER: strong_release %[[#F4]] : $@callee_guaranteed (Float) -> Float + // CHECK: return %[[#F5]] + + //=========== Test rewritten body ===========// + specify_test "autodiff_closure_specialize_rewritten_caller_body" + // TRUNNER-LABEL: Rewritten caller body for: $s5test21z1xS2f_tFTJrSpSr + // CHECK: sil hidden @$s5test21z1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK: bb0(%0 : $Float): + // COMBINE-NOT: function_ref @pullback_y_specialized + // COMBINE-NOT: function_ref @reabstraction_thunk + // CHECK: %[[#F6:]] = function_ref @$s10pullback_z0A14_y_specializedTf1nc_n : $@convention(thin) (Float) -> Float + // TRUNNER: %[[#F7:]] = partial_apply [callee_guaranteed] %[[#F6]]() : $@convention(thin) (Float) -> Float + // COMBINE: %[[#F7:]] = thin_to_thick_function %[[#F6]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float + // CHECK: %[[#F8:]] = tuple (%0 : $Float, %[[#F7]] : $@callee_guaranteed (Float) -> Float) + // CHECK: return %[[#F8]] + + // function_ref pullback_y_specialized + %1 = function_ref @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float + %2 = thin_to_thick_function %1 : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float + // function_ref reabstraction_thunk + %3 = function_ref @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float + %4 = partial_apply [callee_guaranteed] %3(%2) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float + // function_ref pullback_z + %5 = function_ref @pullback_z : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float + %6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float + %7 = tuple (%0 : $Float, %6 : $@callee_guaranteed (Float) -> Float) + return %7 : $(Float, @callee_guaranteed (Float) -> Float) +} diff --git a/test/AutoDiff/validation-test/closure_specialization/multi_bb_bte.swift b/test/AutoDiff/validation-test/closure_specialization/multi_bb_bte.swift new file mode 100644 index 0000000000000..a33732aa293e1 --- /dev/null +++ b/test/AutoDiff/validation-test/closure_specialization/multi_bb_bte.swift @@ -0,0 +1,133 @@ +/// Multi basic block VJP, pullback accepting branch tracing enum argument. + +// REQUIRES: executable_test + +// RUN: %empty-directory(%t) +// RUN: %target-build-swift %s -o %t/none.out -Onone +// RUN: %target-build-swift %s -o %t/opt.out -O +// RUN: %target-run %t/none.out +// RUN: %target-run %t/opt.out + +// RUN: %target-swift-frontend -emit-sil %s -O -o %t/out.sil +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK1 +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK2 +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK3 + +import DifferentiationUnittest +import StdlibUnittest + +var AutoDiffClosureSpecMultiBBBTETests = TestSuite("AutoDiffClosureSpecMultiBBBTE") + +AutoDiffClosureSpecMultiBBBTETests.testWithLeakChecking("Test1") { + // CHECK1-LABEL: {{^}}// reverse-mode derivative of mul42 #1 (_:) + // CHECK1-NEXT: sil private @$s3outyycfU_5mul42L_yS2fSgFTJrSpSr : $@convention(thin) (Optional) -> (Float, @owned @callee_guaranteed (Float) -> Optional.TangentVector) { + // CHECK1: %[[#A12:]] = function_ref @$s3outyycfU_5mul42L_yS2fSgFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2fTf1nnc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU_5mul42L_yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional.TangentVector + // CHECK1: %[[#A13:]] = partial_apply [callee_guaranteed] %[[#A12]](%[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, @owned _AD__$s3outyycfU_5mul42L_yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional.TangentVector + // CHECK1: %[[#A14:]] = tuple (%[[#]], %[[#A13]]) + // CHECK1: return %[[#A14]] + // CHECK1: } // end sil function '$s3outyycfU_5mul42L_yS2fSgFTJrSpSr' + + // CHECK1-NONE: {{^}}// pullback of mul42 + // CHECK1: {{^}}// specialized pullback of mul42 + // CHECK1: sil private @$s3outyycfU_5mul42L_yS2fSgFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2fTf1nnc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU_5mul42L_yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional.TangentVector { + + @differentiable(reverse) + func mul42(_ a: Float?) -> Float { + let b = 42 * a! + return b + } + + expectEqual((-84, 42), valueWithGradient(at: -2, of: mul42)) + expectEqual((0, 42), valueWithGradient(at: 0, of: mul42)) + expectEqual((42, 42), valueWithGradient(at: 1, of: mul42)) + expectEqual((210, 42), valueWithGradient(at: 5, of: mul42)) +} + +AutoDiffClosureSpecMultiBBBTETests.testWithLeakChecking("Test2") { + // CHECK2-LABEL: {{^}}// reverse-mode derivative of cond_tuple_var #1 (_:) + // CHECK2-NEXT: sil private @$s3outyycfU0_14cond_tuple_varL_yS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK2: %[[#E41:]] = function_ref @$s3outyycfU0_14cond_tuple_varL_yS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktl1_m5FZSf_M6SfcfU_0ef1_g4E12_i16Subtract3lhs3rhsk1_l1_mnl1_mo1_mP2U_ACTf1nnccc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU0_14cond_tuple_varL_yS2fF_bb3__Pred__src_0_wrt_0) -> Float + // CHECK2: %[[#E42:]] = partial_apply [callee_guaranteed] %[[#E41]](%[[#]]) : $@convention(thin) (Float, @owned _AD__$s3outyycfU0_14cond_tuple_varL_yS2fF_bb3__Pred__src_0_wrt_0) -> Float + // CHECK2: %[[#E46:]] = tuple (%[[#]], %[[#E42]]) + // CHECK2: return %[[#E46]] + // CHECK2: } // end sil function '$s3outyycfU0_14cond_tuple_varL_yS2fFTJrSpSr' + + // CHECK2-NONE: {{^}}// pullback of cond_tuple_var + // CHECK2: {{^}}// specialized pullback of cond_tuple_var + // CHECK2: sil private @$s3outyycfU0_14cond_tuple_varL_yS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktl1_m5FZSf_M6SfcfU_0ef1_g4E12_i16Subtract3lhs3rhsk1_l1_mnl1_mo1_mP2U_ACTf1nnccc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU0_14cond_tuple_varL_yS2fF_bb3__Pred__src_0_wrt_0) -> Float { + + func cond_tuple_var(_ x: Float) -> Float { + // Convoluted function returning `x + x`. + var y: (Float, Float) = (x, x) + var z: (Float, Float) = (x + x, x - x) + if x > 0 { + var w = (x, x) + y.0 = w.1 + y.1 = w.0 + z.0 = z.0 - y.0 + z.1 = z.1 + y.0 + } else { + z = (1 * x, x) + } + return y.0 + y.1 - z.0 + z.1 + } + + expectEqual((8, 2), valueWithGradient(at: 4, of: cond_tuple_var)) + expectEqual((-20, 2), valueWithGradient(at: -10, of: cond_tuple_var)) + expectEqual((-2674, 2), valueWithGradient(at: -1337, of: cond_tuple_var)) +} + +AutoDiffClosureSpecMultiBBBTETests.testWithLeakChecking("Test3") { + struct Class: Differentiable { + var stored: Float + var optional: Float? + + init(stored: Float, optional: Float?) { + self.stored = stored + self.optional = optional + } + + // CHECK3-LABEL: {{^}}// reverse-mode derivative of method() + // CHECK3-NEXT: sil private @$s3outyycfU1_5ClassL_V6methodSfyFTJrSpSr : $@convention(method) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) { + // CHECK3: %[[#C44:]] = function_ref @$s3outyycfU1_5ClassL_V6methodSfyFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktk1_l5FZSf_L6SfcfU_S2fAES2fTf1nncc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU1_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector + // CHECK3: %[[#C45:]] = partial_apply [callee_guaranteed] %[[#C44]](%[[#]], %[[#]], %[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, @owned _AD__$s3outyycfU1_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector + // CHECK3: %[[#C48:]] = tuple (%[[#]], %[[#C45]]) + // CHECK3: return %[[#C48]] + // CHECK3: } // end sil function '$s3outyycfU1_5ClassL_V6methodSfyFTJrSpSr' + + // CHECK3-NONE: {{^}}// pullback of method + // CHECK3: {{^}}// specialized pullback of method() + // CHECK3: sil private @$s3outyycfU1_5ClassL_V6methodSfyFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktk1_l5FZSf_L6SfcfU_S2fAES2fTf1nncc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU1_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector { + + @differentiable(reverse) + func method() -> Float { + let c: Class + do { + let tmp = Class(stored: 1 * stored, optional: optional) + let tuple = (tmp, tmp) + c = tuple.0 + } + var ret : Float = 0 + if let x = c.optional { + ret = x * c.stored + } else { + ret = 1 * c.stored + } + return 1 * ret * ret + } + } + + @differentiable(reverse) + func methodWrapper(_ x: Class) -> Float { + x.method() + } + + expectEqual( + valueWithGradient(at: Class(stored: 3, optional: 4), of: methodWrapper), + (144, .init(stored: 96, optional: .init(72)))) + expectEqual( + valueWithGradient(at: Class(stored: 3, optional: nil), of: methodWrapper), + (9, .init(stored: 6, optional: .init(0)))) +} + +runAllTests() diff --git a/test/AutoDiff/validation-test/closure_specialization/multi_bb_no_bte.swift b/test/AutoDiff/validation-test/closure_specialization/multi_bb_no_bte.swift new file mode 100644 index 0000000000000..0aa0642aa4ff2 --- /dev/null +++ b/test/AutoDiff/validation-test/closure_specialization/multi_bb_no_bte.swift @@ -0,0 +1,190 @@ +/// Multi basic block VJP, pullback not accepting branch tracing enum argument. + +// REQUIRES: executable_test + +// RUN: %empty-directory(%t) +// RUN: %target-build-swift %s -o %t/none.out -Onone +// RUN: %target-build-swift %s -o %t/opt.out -O +// RUN: %target-run %t/none.out +// RUN: %target-run %t/opt.out + +// RUN: %target-swift-frontend -emit-sil %s -O -o %t/out.sil +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK1 +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK2 +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK3 +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK4 + +import DifferentiationUnittest +import StdlibUnittest + +var AutoDiffClosureSpecMultiBBNoBTETests = TestSuite("AutoDiffClosureSpecMultiBBNoBTE") + +typealias FloatArrayTan = Array.TangentVector + +AutoDiffClosureSpecMultiBBNoBTETests.testWithLeakChecking("Test1") { + // CHECK1-LABEL: {{^}}// reverse-mode derivative of sumFirstThreeConcatenating1 #1 (_:_:) + // CHECK1-NEXT: sil private @$s3outyycfU_27sumFirstThreeConcatenating1L_ySfSaySfG_ACtFTJrSSpSr : $@convention(thin) (@guaranteed Array, @guaranteed Array) -> (Float, @owned @callee_guaranteed (Float) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView)) { + // CHECK1: %[[#E52:]] = function_ref @$s3outyycfU_27sumFirstThreeConcatenating1L_ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyg8_GIegno_G10AEIegyo_TRSfSa01_I0AE0K0RzlE0kO0VySf_GIegno_ADSfAIIegno_0f5Sf16_i26E7_vjpAdd3lhs3rhsSf5value_g17_SftSfc8pullbacktg1_y5FZSf_Y6SfcfU_ADSfAIIegno_AJTf1nnccccc_n0fh1_ijkl4E13_v32Subscript5indexx5value_SaA2aBRzlmnO59Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5ACSiAkCSiAkCSiTf1nnccc_n : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned Array, Int, @owned Array, Int, @owned Array, Int) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + // CHECK1: %[[#E53:]] = partial_apply [callee_guaranteed] %[[#E52]](%[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned Array, Int, @owned Array, Int, @owned Array, Int) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + // CHECK1: %[[#E55:]] = tuple (%[[#]], %[[#E53]]) + // CHECK1: return %[[#E55]] + // CHECK1: } // end sil function '$s3outyycfU_27sumFirstThreeConcatenating1L_ySfSaySfG_ACtFTJrSSpSr' + + // CHECK1-NONE: {{^}}// pullback of sumFirstThreeConcatenating1 + // CHECK1: {{^}}// specialized pullback of sumFirstThreeConcatenating1 + // CHECK1: sil private @$s3outyycfU_27sumFirstThreeConcatenating1L_ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyg8_GIegno_G10AEIegyo_TRSfSa01_I0AE0K0RzlE0kO0VySf_GIegno_ADSfAIIegno_0f5Sf16_i26E7_vjpAdd3lhs3rhsSf5value_g17_SftSfc8pullbacktg1_y5FZSf_Y6SfcfU_ADSfAIIegno_AJTf1nnccccc_n0fh1_ijkl4E13_v32Subscript5indexx5value_SaA2aBRzlmnO59Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5ACSiAkCSiAkCSiTf1nnccc_nTf4ngnnnnnn_n : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned Array, Int, @owned Array, Int, @owned Array, Int) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) { + func sumFirstThreeConcatenating1(_ a: [Float], _ b: [Float]) -> Float { + let c = a + b + return c[0] + c[1] + c[2] + } + + expectEqual( + (.init([1, 1]), .init([1, 0])), + gradient(at: [0, 0], [0, 0], of: sumFirstThreeConcatenating1)) + expectEqual( + (.init([1, 1, 1, 0]), .init([0, 0])), + gradient(at: [0, 0, 0, 0], [0, 0], of: sumFirstThreeConcatenating1)) + expectEqual( + (.init([]), .init([1, 1, 1, 0])), + gradient(at: [], [0, 0, 0, 0], of: sumFirstThreeConcatenating1)) +} + +AutoDiffClosureSpecMultiBBNoBTETests.testWithLeakChecking("Test2") { + // CHECK2-LABEL: {{^}}// reverse-mode derivative of sumFirstThreeConcatenating2 #1 (_:_:) + // CHECK2-NEXT: sil private @$s3outyycfU0_27sumFirstThreeConcatenating2L_ySfSaySfG_ACtFTJrSSpSr : $@convention(thin) (@guaranteed Array, @guaranteed Array) -> (Float, @owned @callee_guaranteed (Float) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView)) { + // CHECK2: %[[#E52:]] = function_ref @$s3outyycfU0_27sumFirstThreeConcatenating2L_ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyg8_GIegno_G10AEIegyo_TRSfSa01_I0AE0K0RzlE0kO0VySf_GIegno_ADSfAIIegno_0f5Sf16_i26E7_vjpAdd3lhs3rhsSf5value_g17_SftSfc8pullbacktg1_y5FZSf_Y6SfcfU_ADSfAIIegno_AJTf1nnccccc_n0fh1_ijkl4E10_v25Appendyyt5value_SaA2aBRzlmno55Vy13TangentVectorQz_GAIzc8pullbacktSayxGz_AKtFZA2IzcfU_G4_Tg5Si0fh1_ijkl4E13_v32Subscript5indexx5value_SaA2aBRzlmnO59Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5ACSiAlCSiAlCSiTf1ncccc_n : $@convention(thin) (Float, Int, @owned Array, Int, @owned Array, Int, @owned Array, Int) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + // CHECK2: %[[#E53:]] = partial_apply [callee_guaranteed] %[[#E52]](%[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, Int, @owned Array, Int, @owned Array, Int, @owned Array, Int) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) + // CHECK2: %[[#E55:]] = tuple (%[[#]], %[[#E53]]) + // CHECK2: return %[[#E55]] + // CHECK2: } // end sil function '$s3outyycfU0_27sumFirstThreeConcatenating2L_ySfSaySfG_ACtFTJrSSpSr' + + // CHECK2-NONE: {{^}}// pullback of sumFirstThreeConcatenating2 + // CHECK2: {{^}}// specialized pullback of sumFirstThreeConcatenating2 + // CHECK2: sil private @$s3outyycfU0_27sumFirstThreeConcatenating2L_ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyg8_GIegno_G10AEIegyo_TRSfSa01_I0AE0K0RzlE0kO0VySf_GIegno_ADSfAIIegno_0f5Sf16_i26E7_vjpAdd3lhs3rhsSf5value_g17_SftSfc8pullbacktg1_y5FZSf_Y6SfcfU_ADSfAIIegno_AJTf1nnccccc_n0fh1_ijkl4E10_v25Appendyyt5value_SaA2aBRzlmno55Vy13TangentVectorQz_GAIzc8pullbacktSayxGz_AKtFZA2IzcfU_G4_Tg5Si0fh1_ijkl4E13_v32Subscript5indexx5value_SaA2aBRzlmnO59Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5ACSiAlCSiAlCSiTf1ncccc_n : $@convention(thin) (Float, Int, @owned Array, Int, @owned Array, Int, @owned Array, Int) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView) { + + func sumFirstThreeConcatenating2(_ a: [Float], _ b: [Float]) -> Float { + var c = a + c += b + return c[0] + c[1] + c[2] + } + + expectEqual( + (.init([1, 1]), .init([1, 0])), + gradient(at: [0, 0], [0, 0], of: sumFirstThreeConcatenating2)) + expectEqual( + (.init([1, 1, 1, 0]), .init([0, 0])), + gradient(at: [0, 0, 0, 0], [0, 0], of: sumFirstThreeConcatenating2)) + expectEqual( + (.init([]), .init([1, 1, 1, 0])), + gradient(at: [], [0, 0, 0, 0], of: sumFirstThreeConcatenating2)) +} + +AutoDiffClosureSpecMultiBBNoBTETests.testWithLeakChecking("Test3") { + @propertyWrapper + enum Wrapper { + case case1(Float) + case case2(Float) + + init(wrappedValue: Float) { + self = .case1(wrappedValue) + } + + var wrappedValue: Float { + get { + switch self { + case .case1(let val): + return val + case .case2(let val): + return val * 2 + } + } + set { + self = .case2(wrappedValue) + } + } + } + + struct RealPropertyWrappers: Differentiable { + @Wrapper var x: Float = 3 + var y: Float = 4 + } + + // CHECK3: {{^}}// reverse-mode derivative of multiply #1 (_:) + // CHECK3-NEXT: sil private @$s3outyycfU1_8multiplyL_ySfAAyycfU1_20RealPropertyWrappersL_VFTJrSpSr : $@convention(thin) (RealPropertyWrappers) -> (Float, @owned @callee_guaranteed (Float) -> RealPropertyWrappers.TangentVector) { + // CHECK3: %[[#A22:]] = function_ref @$s3outyycfU1_8multiplyL_ySfAAyycfU1_20RealPropertyWrappersL_VFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktm1_n5FZSf_N6SfcfU_S2fTf1nnc_n015$s3outyycfU1_20cdE16L_V1xSfvgTJpSpSrTf1ncnn_n : $@convention(thin) (Float, Float, Float) -> RealPropertyWrappers.TangentVector + // CHECK3: %[[#A23:]] = partial_apply [callee_guaranteed] %[[#A22]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> RealPropertyWrappers.TangentVector + // CHECK3: %[[#A24:]] = tuple (%[[#]], %[[#A23]]) + // CHECK3: return %[[#A24]] + // CHECK3: } // end sil function '$s3outyycfU1_8multiplyL_ySfAAyycfU1_20RealPropertyWrappersL_VFTJrSpSr' + + // CHECK3-NONE: {{^}}// pullback of multiply + // CHECK3: {{^}}// specialized pullback of multiply + // CHECK3: sil private @$s3outyycfU1_8multiplyL_ySfAAyycfU1_20RealPropertyWrappersL_VFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktm1_n5FZSf_N6SfcfU_S2fTf1nnc_n015$s3outyycfU1_20cdE16L_V1xSfvgTJpSpSrTf1ncnn_n : $@convention(thin) (Float, Float, Float) -> RealPropertyWrappers.TangentVector { + + @differentiable(reverse) + func multiply(_ s: RealPropertyWrappers) -> Float { + return s.x * s.y + } + + expectEqual( + .init(x: 4, y: 3), + gradient(at: RealPropertyWrappers(x: 3, y: 4), of: multiply)) +} + +AutoDiffClosureSpecMultiBBNoBTETests.testWithLeakChecking("Test4") { + struct Class: Differentiable { + var stored: Float + var optional: Float? + + init(stored: Float, optional: Float?) { + self.stored = stored + self.optional = optional + } + + @differentiable(reverse) + func method() -> Float { + let c: Class + do { + let tmp = Class(stored: 1 * stored, optional: optional) + let tuple = (tmp, tmp) + c = tuple.0 + } + if let x = c.optional { + return x * c.stored + } + return 1 * c.stored + } + } + + // CHECK4-LABEL: {{^}}// reverse-mode derivative of methodWrapper #1 (_:) + // CHECK4-NEXT: sil private @$s3outyycfU2_13methodWrapperL_ySfAAyycfU2_5ClassL_VFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) { + // CHECK4: %[[#C39:]] = function_ref @$s3outyycfU2_13methodWrapperL_ySfAAyycfU2_5ClassL_VFTJpSpSr014$s3outyycfU2_5D21L_V6methodSfyFTJpSpSrAA05_AD__ef2_5d2L_gH24F_bb3__Pred__src_0_wrt_033_E588B908471A5F020CF23EC392ADD7D3LLOTf1nc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU2_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + // CHECK4: %[[#C40:]] = partial_apply [callee_guaranteed] %[[#C39]](%[[#]]) : $@convention(thin) (Float, @owned _AD__$s3outyycfU2_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector + // CHECK4: %[[#C42:]] = tuple (%[[#]], %[[#C40]]) + // CHECK4: return %[[#C42]] + // CHECK4: } // end sil function '$s3outyycfU2_13methodWrapperL_ySfAAyycfU2_5ClassL_VFTJrSpSr' + + /// TODO: even though branch tracing enum is not passed to top-level pullback + /// directly, it is captured by one of the closures which was specialized. + /// Because of that, this enum argument is now an argument of specialized top-level + /// pullback. Specializing closures passed as payload tuple elements of the enum + /// is currently not supported. + + // CHECK4-NONE: {{^}}// pullback of methodWrapper + // CHECK4: {{^}}// specialized pullback of methodWrapper + // CHECK4: sil private @$s3outyycfU2_13methodWrapperL_ySfAAyycfU2_5ClassL_VFTJpSpSr014$s3outyycfU2_5D21L_V6methodSfyFTJpSpSrAA05_AD__ef2_5d2L_gH24F_bb3__Pred__src_0_wrt_033_E588B908471A5F020CF23EC392ADD7D3LLOTf1nc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU2_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector { + + @differentiable(reverse) + func methodWrapper(_ x: Class) -> Float { + x.method() + } + + expectEqual( + valueWithGradient(at: Class(stored: 3, optional: 4), of: methodWrapper), + (12, .init(stored: 4, optional: .init(3)))) + expectEqual( + valueWithGradient(at: Class(stored: 3, optional: nil), of: methodWrapper), + (3, .init(stored: 1, optional: .init(0)))) +} + +runAllTests() diff --git a/test/AutoDiff/validation-test/closure_specialization/single_bb1.swift b/test/AutoDiff/validation-test/closure_specialization/single_bb1.swift new file mode 100644 index 0000000000000..3c15e18ff1a35 --- /dev/null +++ b/test/AutoDiff/validation-test/closure_specialization/single_bb1.swift @@ -0,0 +1,115 @@ +/// Single basic block VJP. + +// REQUIRES: executable_test +// UNSUPPORTED: OS=windows-msvc + +// RUN: %empty-directory(%t) +// RUN: %target-build-swift %s -o %t/none.out -Onone +// RUN: %target-build-swift %s -o %t/opt.out -O +// RUN: %target-run %t/none.out +// RUN: %target-run %t/opt.out + +// RUN: %target-swift-frontend -emit-sil %s -O -o %t/out.sil +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK1 +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK2 +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK3 + +import DifferentiationUnittest +import StdlibUnittest + +#if canImport(Glibc) +import Glibc +#elseif canImport(Android) +import Android +#else +import Foundation +#endif + +var AutoDiffClosureSpecSingleBBTests = TestSuite("AutoDiffClosureSpecSingleBB") + +AutoDiffClosureSpecSingleBBTests.testWithLeakChecking("Test1") { + // CHECK1-LABEL: {{^}}// reverse-mode derivative of test1 #1 (_:) + // CHECK1-NEXT: sil private @$s3outyycfU_5test1L_yS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK1: %[[#A10:]] = function_ref @$s3outyycfU_5test1L_yS2fFTJpSpSr62$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_Sf0c1_d1_e4Cosyg1_hiJ2U_Sf026$sSf16_DifferentiationE12_e16Multiply3lhs3rhsg1_i17_SftSfc8pullbackti1_q5FZSf_Q6SfcfU_S2fTf1nccc_n : $@convention(thin) (Float, Float, Float, Float, Float) -> Float + // CHECK1: %[[#A11:]] = partial_apply [callee_guaranteed] %[[#A10]](%[[#]], %[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float, Float, Float) -> Float + // CHECK1: %[[#A12:]] = tuple (%[[#]], %[[#A11]]) + // CHECK1: return %[[#A12]] + // CHECK1: } // end sil function '$s3outyycfU_5test1L_yS2fFTJrSpSr' + + // CHECK1-NONE: {{^}}// pullback of test1 #1 (_:) + // CHECK1: {{^}}// specialized pullback of test1 #1 (_:) + // CHECK1: sil private @$s3outyycfU_5test1L_yS2fFTJpSpSr62$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_Sf0c1_d1_e4Cosyg1_hiJ2U_Sf026$sSf16_DifferentiationE12_e16Multiply3lhs3rhsg1_i17_SftSfc8pullbackti1_q5FZSf_Q6SfcfU_S2fTf1nccc_n : $@convention(thin) (Float, Float, Float, Float, Float) -> Float { + + @differentiable(reverse) + func test1(_ x: Float) -> Float { + return sin(x) * cos(x) + } + + func test1Derivative(_ x: Float) -> Float { + return cos(x) * cos(x) - sin(x) * sin(x) + } + + for x in -100...100 { + expectEqual((1000 * gradient(at: Float(x), of: test1)).rounded(), (1000 * test1Derivative(Float(x))).rounded()) + } +} + +AutoDiffClosureSpecSingleBBTests.testWithLeakChecking("Test2") { + // CHECK2-LABEL: {{^}}// reverse-mode derivative of test2 #1 (_:) + // CHECK2-NEXT: sil private @$s3outyycfU0_5test2L_yS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK2: %[[#B19:]] = function_ref @$s3outyycfU0_5test2L_yS2fFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2f022$s16_Differentiation7_g4Sinyi15_S2fc8pullbacktJ8FS2fcfU_SfADSf0o1_p1_g4Cosyi1_rjS2U_SfACS2fAESf0cd1_e3E7_g11Add3lhs3rhsi1_j1_klj1_km1_kN2U_Tf1nccccccc_n : $@convention(thin) (Float, Float, Float, Float, Float, Float, Float, Float, Float) -> Float + // CHECK2: %[[#B20:]] = partial_apply [callee_guaranteed] %[[#B19]](%[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float, Float, Float, Float, Float, Float, Float) -> Float + // CHECK2: %[[#B21:]] = tuple (%[[#]], %[[#B20]]) + // CHECK2: return %[[#B21]] + // CHECK2: } // end sil function '$s3outyycfU0_5test2L_yS2fFTJrSpSr' + + // CHECK2-NONE: {{^}}// pullback of test2 #1 (_:) + // CHECK2: {{^}}// specialized pullback of test2 #1 (_:) + // CHECK2: sil private @$s3outyycfU0_5test2L_yS2fFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2f022$s16_Differentiation7_g4Sinyi15_S2fc8pullbacktJ8FS2fcfU_SfADSf0o1_p1_g4Cosyi1_rjS2U_SfACS2fAESf0cd1_e3E7_g11Add3lhs3rhsi1_j1_klj1_km1_kN2U_Tf1nccccccc_n : $@convention(thin) (Float, Float, Float, Float, Float, Float, Float, Float, Float) -> Float { + + @differentiable(reverse) + func test2(_ x: Float) -> Float { + return sin(37 * x) * cos(sin(x)) + cos(x) + } + + func test2Derivative(_ x: Float) -> Float { + return -cos(x)*sin(37*x)*sin(sin(x)) + 37*cos(37*x)*cos(sin(x)) - sin(x) + } + + for x in -100...100 { + expectEqual((1000 * gradient(at: Float(x), of: test2)).rounded(), (1000 * test2Derivative(Float(x))).rounded()) + } +} + +AutoDiffClosureSpecSingleBBTests.testWithLeakChecking("Test3") { + // CHECK3-LABEL: {{^}}// reverse-mode derivative of test3 #1 (_:_:_:) + // CHECK3-NEXT: sil private @$s3outyycfU1_5test3L_yS2f_S2ftFTJrSSSpSr : $@convention(thin) (Float, Float, Float) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float, Float)) { + // CHECK3: %[[#C18:]] = function_ref @$s3outyycfU1_5test3L_yS2f_S2ftFTJpSSSpSr62$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_Sf0c1_d1_e4Cosyg1_hiJ2U_Sf025$sSf16_DifferentiationE7_e11Add3lhs3rhsg1_i17_SftSfc8pullbackti1_q5FZSf_Q6SfcfU_0c1_d1_e4Tanyg1_hiJ2U_Sf0lm1_n4E12_e16Subtract3lhs3rhsg1_i1_qri1_qs1_qT2U_Tf1nccccc_n : $@convention(thin) (Float, Float, Float, Float) -> (Float, Float, Float) + // CHECK3: %[[#C19:]] = partial_apply [callee_guaranteed] %[[#C18]](%[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float, Float) -> (Float, Float, Float) + // CHECK3: %[[#C20:]] = tuple (%[[#]], %[[#C19]]) + // CHECK3: return %[[#C20]] + // CHECK3: } // end sil function '$s3outyycfU1_5test3L_yS2f_S2ftFTJrSSSpSr' + + // CHECK3-NONE: {{^}}// pullback of test3 #1 (_:_:_:) + // CHECK3: {{^}}// specialized pullback of test3 #1 (_:_:_:) + // CHECK3: sil private @$s3outyycfU1_5test3L_yS2f_S2ftFTJpSSSpSr62$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_Sf0c1_d1_e4Cosyg1_hiJ2U_Sf025$sSf16_DifferentiationE7_e11Add3lhs3rhsg1_i17_SftSfc8pullbackti1_q5FZSf_Q6SfcfU_0c1_d1_e4Tanyg1_hiJ2U_Sf0lm1_n4E12_e16Subtract3lhs3rhsg1_i1_qri1_qs1_qT2U_Tf1nccccc_n : $@convention(thin) (Float, Float, Float, Float) -> (Float, Float, Float) { + + @differentiable(reverse) + func test3(_ x: Float, _ y: Float, _ z: Float) -> Float { + return sin(x) + cos(y) - tan(z) + } + + for x in -5...5 { + for y in -5...5 { + for z in -5...5 { + let pb = pullback(at: Float(x), Float(y), Float(z), of: test3) + let (der1, der2, der3) = pb(1) + expectEqual((10000 * der1).rounded(), (10000 * cos(Float(x))).rounded()) + expectEqual((10000 * der2).rounded(), (10000 * (-sin(Float(y)))).rounded()) + expectEqual((10000 * der3).rounded(), (10000 * (-(tan(Float(z)) * tan(Float(z)) + 1))).rounded()) + } + } + } +} + +runAllTests() diff --git a/test/AutoDiff/validation-test/closure_specialization/single_bb2.swift b/test/AutoDiff/validation-test/closure_specialization/single_bb2.swift new file mode 100644 index 0000000000000..d5dae828a5d0c --- /dev/null +++ b/test/AutoDiff/validation-test/closure_specialization/single_bb2.swift @@ -0,0 +1,85 @@ +/// Single basic block VJP. + +// REQUIRES: executable_test + +// RUN: %empty-directory(%t) +// RUN: %target-build-swift %s -o %t/none.out -Onone +// RUN: %target-build-swift %s -o %t/opt.out -O +// RUN: %target-run %t/none.out +// RUN: %target-run %t/opt.out + +// RUN: %target-swift-frontend -emit-sil %s -O -o %t/out.sil +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK4 +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK5 + +import DifferentiationUnittest +import StdlibUnittest + +var AutoDiffClosureSpecSingleBBTests = TestSuite("AutoDiffClosureSpecSingleBB") + +AutoDiffClosureSpecSingleBBTests.testWithLeakChecking("Test4") { + func square(_ x: Float) -> Float { + return x * x + } + + func double(_ x: Float) -> Float { + return x + x + } + + // CHECK4-LABEL: {{^}}// reverse-mode derivative of test4 #1 (_:) + // CHECK4-NEXT: sil private @$s3outyycfU_5test4L_yS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK4: %[[#D9:]] = function_ref @$s3outyycfU_5test4L_yS2fFTJpSpSr128$s3outyycfU_6doubleL_yS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_Tf1nc_n0c12U_6squareL_yefg7Sr073$si1_j4E12_l16Multiply3lhs3rhsn1_o1_pq1_rs1_tu2U_eV2_nS2fTf1ncc_n : $@convention(thin) (Float, Float, Float) -> Float + // CHECK4: %[[#D10:]] = partial_apply [callee_guaranteed] %[[#D9]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> Float + // CHECK4: %[[#D11:]] = tuple (%[[#]], %[[#D10]]) + // CHECK4: return %[[#D11]] + // CHECK4: } // end sil function '$s3outyycfU_5test4L_yS2fFTJrSpSr' + + // CHECK4-NONE: {{^}}// pullback of test4 #1 (_:) + // CHECK4: {{^}}// specialized pullback of test4 #1 (_:) + // CHECK4: sil private @$s3outyycfU_5test4L_yS2fFTJpSpSr128$s3outyycfU_6doubleL_yS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_Tf1nc_n0c12U_6squareL_yefg7Sr073$si1_j4E12_l16Multiply3lhs3rhsn1_o1_pq1_rs1_tu2U_eV2_nS2fTf1ncc_n : $@convention(thin) (Float, Float, Float) -> Float { + + @differentiable(reverse) + func test4(_ x: Float) -> Float { + return square(double(x)) + } + + func test4Derivative(_ x: Float) -> Float { + return 8 * x + } + + for x in -100...100 { + expectEqual(gradient(at: Float(x), of: test4), test4Derivative(Float(x))) + } +} + +AutoDiffClosureSpecSingleBBTests.testWithLeakChecking("Test5") { + // CHECK5-LABEL: {{^}}// reverse-mode derivative of test5 #1 (_:_:) + // CHECK5-NEXT: sil private @$s3outyycfU0_5test5L_ySaySfGAC_SftFTJrSSpSr : $@convention(thin) (@guaranteed Array, Float) -> (@owned Array, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, Float)) { + // CHECK5: %[[#E79:]] = function_ref @$s3outyycfU0_5test5L_ySaySfGAC_SftFTJpSSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2f0cd1_ef1_g16Subtract3lhs3rhsi1_j1_klj1_km1_kN2U_Tf1ncncn_n : $@convention(thin) (@guaranteed Array.DifferentiableView, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), Float, Float) -> (@owned Array.DifferentiableView, Float) + // CHECK5: %[[#E80:]] = partial_apply [callee_guaranteed] %[[#E79]](%[[#]], %[[#]], %[[#]], %[[#]]) : $@convention(thin) (@guaranteed Array.DifferentiableView, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), Float, Float) -> (@owned Array.DifferentiableView, Float) + // CHECK5: %[[#E81:]] = tuple (%[[#]], %[[#E80]]) + // CHECK5: return %[[#E81]] + // CHECK5: } // end sil function '$s3outyycfU0_5test5L_ySaySfGAC_SftFTJrSSpSr' + + // CHECK5-NONE: {{^}}// pullback of test5 #1 (_:_:) + // CHECK5: {{^}}// specialized pullback of test5 #1 (_:_:) + // CHECK5: sil private @$s3outyycfU0_5test5L_ySaySfGAC_SftFTJpSSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2f0cd1_ef1_g16Subtract3lhs3rhsi1_j1_klj1_km1_kN2U_Tf1ncncn_n : $@convention(thin) (@guaranteed Array.DifferentiableView, @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), @owned @callee_guaranteed (@guaranteed Array.DifferentiableView) -> (@owned Array.DifferentiableView, @owned Array.DifferentiableView), Float, Float) -> (@owned Array.DifferentiableView, Float) { + + @differentiable(reverse) + func test5(_ x: [Float], _ y: Float) -> [Float] { + return [42 * y] + x + [37 - y] + } + + let pb = pullback(at: [Float(1), Float(1)], Float(1), of: test5) + for a in -10...10 { + for b in -10...10 { + for c in -10...10 { + for d in -10...10 { + expectEqual(pb([Float(a), Float(b), Float(c), Float(d)]), ([Float(b), Float(c)], Float(42 * a - d))) + } + } + } + } +} + +runAllTests()