From f10356a5b44a026366f5cc0f3d5e265ef97b1790 Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Tue, 7 Mar 2023 18:31:30 -0800 Subject: [PATCH 1/2] Ensure autodiff code does not ignore `getSingleTerminatorOperands` return value. Extend activity analysis to handle `try_apply` normal result properly. --- include/swift/SIL/SILArgument.h | 8 +- .../DifferentiableActivityAnalysis.cpp | 18 +++- .../Differentiation/PullbackCloner.cpp | 90 ++++++++++--------- 3 files changed, 64 insertions(+), 52 deletions(-) diff --git a/include/swift/SIL/SILArgument.h b/include/swift/SIL/SILArgument.h index 9d48d7f388ada..d65cc238ce8f5 100644 --- a/include/swift/SIL/SILArgument.h +++ b/include/swift/SIL/SILArgument.h @@ -178,7 +178,7 @@ class SILArgument : public ValueBase { /// Note: this peeks through any projections or cast implied by the /// terminator. e.g. the incoming value for a switch_enum payload argument is /// the enum itself (the operand of the switch_enum). - bool getSingleTerminatorOperands( + [[nodiscard]] bool getSingleTerminatorOperands( SmallVectorImpl &returnedSingleTermOperands) const; /// Returns true if we were able to find single terminator operand values for @@ -188,7 +188,7 @@ class SILArgument : public ValueBase { /// Note: this peeks through any projections or cast implied by the /// terminator. e.g. the incoming value for a switch_enum payload argument is /// the enum itself (the operand of the switch_enum). - bool getSingleTerminatorOperands( + [[nodiscard]] bool getSingleTerminatorOperands( SmallVectorImpl> &returnedSingleTermOperands) const; @@ -303,7 +303,7 @@ class SILPhiArgument : public SILArgument { /// Note: this peeks through any projections or cast implied by the /// terminator. e.g. the incoming value for a switch_enum payload argument is /// the enum itself (the operand of the switch_enum). - bool getSingleTerminatorOperands( + [[nodiscard]] bool getSingleTerminatorOperands( SmallVectorImpl &returnedSingleTermOperands) const; /// Returns true if we were able to find single terminator operand values for @@ -313,7 +313,7 @@ class SILPhiArgument : public SILArgument { /// Note: this peeks through any projections or cast implied by the /// terminator. e.g. the incoming value for a switch_enum payload argument is /// the enum itself (the operand of the switch_enum). - bool getSingleTerminatorOperands( + [[nodiscard]] bool getSingleTerminatorOperands( SmallVectorImpl> &returnedSingleTermOperands) const; diff --git a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp index ed41076de53ac..9e2d24a4b3b0e 100644 --- a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp +++ b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp @@ -303,15 +303,25 @@ void DifferentiableActivityInfo::setUsefulAndPropagateToOperands( return; } setUseful(value, dependentVariableIndex); + // If the given value is a basic block argument, propagate usefulness to // incoming values. if (auto *bbArg = dyn_cast(value)) { SmallVector incomingValues; - bbArg->getSingleTerminatorOperands(incomingValues); - for (auto incomingValue : incomingValues) - setUsefulAndPropagateToOperands(incomingValue, dependentVariableIndex); - return; + if (bbArg->getSingleTerminatorOperands(incomingValues)) { + for (auto incomingValue : incomingValues) + setUsefulAndPropagateToOperands(incomingValue, dependentVariableIndex); + return; + } else if (bbArg->isTerminatorResult()) { + if (TryApplyInst *tai = dyn_cast(bbArg->getTerminatorForResult())) { + propagateUseful(tai, dependentVariableIndex); + return; + } else + llvm::report_fatal_error("unknown terminator with result"); + } else + llvm::report_fatal_error("do not know how to handle this incoming bb argument"); } + auto *inst = value->getDefiningInstruction(); if (!inst) return; diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index 9bcb296c0231f..fc80726fc7489 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -2515,12 +2515,11 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) { // Get predecessor terminator operands. SmallVector, 4> incomingValues; - bbArg->getSingleTerminatorOperands(incomingValues); - - // Returns true if the given terminator instruction is a `switch_enum` on - // an `Optional`-typed value. `switch_enum` instructions require - // special-case adjoint value propagation for the operand. - auto isSwitchEnumInstOnOptional = + if (bbArg->getSingleTerminatorOperands(incomingValues)) { + // Returns true if the given terminator instruction is a `switch_enum` on + // an `Optional`-typed value. `switch_enum` instructions require + // special-case adjoint value propagation for the operand. + auto isSwitchEnumInstOnOptional = [&ctx = getASTContext()](TermInst *termInst) { if (!termInst) return false; @@ -2531,49 +2530,52 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) { return false; }; - // Check the tangent value category of the active basic block argument. - switch (getTangentValueCategory(bbArg)) { - // If argument has a loadable tangent value category: materialize adjoint - // value of the argument, create a copy, and set the copy as the adjoint - // value of incoming values. - case SILValueCategory::Object: { - auto bbArgAdj = getAdjointValue(bb, bbArg); - auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc); - auto concreteBBArgAdjCopy = + // Check the tangent value category of the active basic block argument. + switch (getTangentValueCategory(bbArg)) { + // If argument has a loadable tangent value category: materialize adjoint + // value of the argument, create a copy, and set the copy as the adjoint + // value of incoming values. + case SILValueCategory::Object: { + auto bbArgAdj = getAdjointValue(bb, bbArg); + auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc); + auto concreteBBArgAdjCopy = builder.emitCopyValueOperation(pbLoc, concreteBBArgAdj); - for (auto pair : incomingValues) { - auto *predBB = std::get<0>(pair); - auto incomingValue = std::get<1>(pair); - // Handle `switch_enum` on `Optional`. - auto termInst = bbArg->getSingleTerminator(); - if (isSwitchEnumInstOnOptional(termInst)) { - accumulateAdjointForOptional(bb, incomingValue, concreteBBArgAdjCopy); - } else { - blockTemporaries[getPullbackBlock(predBB)].insert( + for (auto pair : incomingValues) { + pair.second->dump(); + auto *predBB = std::get<0>(pair); + auto incomingValue = std::get<1>(pair); + // Handle `switch_enum` on `Optional`. + auto termInst = bbArg->getSingleTerminator(); + if (isSwitchEnumInstOnOptional(termInst)) { + accumulateAdjointForOptional(bb, incomingValue, concreteBBArgAdjCopy); + } else { + blockTemporaries[getPullbackBlock(predBB)].insert( concreteBBArgAdjCopy); - setAdjointValue(predBB, incomingValue, - makeConcreteAdjointValue(concreteBBArgAdjCopy)); + setAdjointValue(predBB, incomingValue, + makeConcreteAdjointValue(concreteBBArgAdjCopy)); + } } + break; } - break; - } - // If argument has an address tangent value category: materialize adjoint - // value of the argument, create a copy, and set the copy as the adjoint - // value of incoming values. - case SILValueCategory::Address: { - auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg); - for (auto pair : incomingValues) { - auto incomingValue = std::get<1>(pair); - // Handle `switch_enum` on `Optional`. - auto termInst = bbArg->getSingleTerminator(); - if (isSwitchEnumInstOnOptional(termInst)) - accumulateAdjointForOptional(bb, incomingValue, bbArgAdjBuf); - else - addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc); + // If argument has an address tangent value category: materialize adjoint + // value of the argument, create a copy, and set the copy as the adjoint + // value of incoming values. + case SILValueCategory::Address: { + auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg); + for (auto pair : incomingValues) { + auto incomingValue = std::get<1>(pair); + // Handle `switch_enum` on `Optional`. + auto termInst = bbArg->getSingleTerminator(); + if (isSwitchEnumInstOnOptional(termInst)) + accumulateAdjointForOptional(bb, incomingValue, bbArgAdjBuf); + else + addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc); + } + break; } - break; - } - } + } + } else + llvm::report_fatal_error("do not know how to handle this incoming bb argument"); } // 3. Build the pullback successor cases for the `switch_enum` From 9c80f9c017521534a565744bacc679f7158413df Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Tue, 7 Mar 2023 22:01:34 -0800 Subject: [PATCH 2/2] Add testcase from #63728 --- .../issue-63728-try-apply-activity.swift | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 test/AutoDiff/compiler_crashers_fixed/issue-63728-try-apply-activity.swift diff --git a/test/AutoDiff/compiler_crashers_fixed/issue-63728-try-apply-activity.swift b/test/AutoDiff/compiler_crashers_fixed/issue-63728-try-apply-activity.swift new file mode 100644 index 0000000000000..1c1b4df393560 --- /dev/null +++ b/test/AutoDiff/compiler_crashers_fixed/issue-63728-try-apply-activity.swift @@ -0,0 +1,59 @@ +// RUN: not %target-swift-frontend -emit-sil -verify %s + +// The testcase from https://github.com/apple/swift/issues/63728 is not valid +// (the function is not differentiable), however, it should not cause verifier errors +// Here the root case is lack of activity analysis for `try_apply` terminators + +import _Differentiation + +func a() throws { + let keyPaths = (readable: [String: KeyPath](), writable: [String: WritableKeyPath]()) + @differentiable(reverse) + func f(p: PAndT) -> Double { + var mutableP = p + let s = p.p.e + var sArray: [[Double]] = [] + sArray.append((s["a"]!.asArray()).map {$0.value}) + mutableP.s = w(mutableP.s, at: keyPaths.writable["a"]!, with: sArray[0][0]) + return mutableP.s[keyPath: keyPaths.writable["a"]!] + } +} + +public struct S { + public func asArray() -> [(index: I, value: D)] { + return [(index: I, value: D)]() + } +} +struct T: Differentiable {} +struct P: Differentiable { + public var e: F +} + +struct PAndT: Differentiable{ + @differentiable(reverse) public var p: P + @differentiable(reverse) public var s: T +} + +public struct F{ + public func asArray() -> [(index: I, value: D)] {return [(index: I, value: D)]()} + public subscript(_ name: String) -> S? {get { return self.sGet(name) }} + func sGet(_ name: String) -> S? { fatalError("") } +} + +public protocol ZProtocol: Differentiable {var z: () -> TangentVector { get }} +public protocol SProtocol: Hashable {} + +//extension S: Differentiable where D: ZProtocol, D.TangentVector == D {} + +extension F: Differentiable where D: ZProtocol, D.TangentVector == D {} +public extension ZProtocol {var z: () -> TangentVector {{ Self.TangentVector.zero }}} + +extension F: Equatable where I: Equatable, D: Equatable {} +//extension S: Equatable where I: Equatable, D: Equatable {} +extension Double: SProtocol, ZProtocol {} + +@differentiable(reverse where O: Differentiable, M: ZProtocol) +func w(_ o: O, at m: WritableKeyPath, with v: M) -> O {return o} + +@derivative(of: w) +func vjpw(_ o: O, at m: WritableKeyPath, with v: M) -> (value: O, pullback: (O.TangentVector) -> (O.TangentVector, M.TangentVector)) where O: Differentiable, M: ZProtocol{fatalError("")}