Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/swift/SIL/SILArgument.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class SILArgument : public ValueBase {
/// Note: this peeks through any projections or cast implied by the
/// terminator. e.g. the incoming value for a switch_enum payload argument is
/// the enum itself (the operand of the switch_enum).
bool getSingleTerminatorOperands(
[[nodiscard]] bool getSingleTerminatorOperands(
SmallVectorImpl<SILValue> &returnedSingleTermOperands) const;

/// Returns true if we were able to find single terminator operand values for
Expand All @@ -188,7 +188,7 @@ class SILArgument : public ValueBase {
/// Note: this peeks through any projections or cast implied by the
/// terminator. e.g. the incoming value for a switch_enum payload argument is
/// the enum itself (the operand of the switch_enum).
bool getSingleTerminatorOperands(
[[nodiscard]] bool getSingleTerminatorOperands(
SmallVectorImpl<std::pair<SILBasicBlock *, SILValue>>
&returnedSingleTermOperands) const;

Expand Down Expand Up @@ -303,7 +303,7 @@ class SILPhiArgument : public SILArgument {
/// Note: this peeks through any projections or cast implied by the
/// terminator. e.g. the incoming value for a switch_enum payload argument is
/// the enum itself (the operand of the switch_enum).
bool getSingleTerminatorOperands(
[[nodiscard]] bool getSingleTerminatorOperands(
SmallVectorImpl<SILValue> &returnedSingleTermOperands) const;

/// Returns true if we were able to find single terminator operand values for
Expand All @@ -313,7 +313,7 @@ class SILPhiArgument : public SILArgument {
/// Note: this peeks through any projections or cast implied by the
/// terminator. e.g. the incoming value for a switch_enum payload argument is
/// the enum itself (the operand of the switch_enum).
bool getSingleTerminatorOperands(
[[nodiscard]] bool getSingleTerminatorOperands(
SmallVectorImpl<std::pair<SILBasicBlock *, SILValue>>
&returnedSingleTermOperands) const;

Expand Down
18 changes: 14 additions & 4 deletions lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,25 @@ void DifferentiableActivityInfo::setUsefulAndPropagateToOperands(
return;
}
setUseful(value, dependentVariableIndex);

// If the given value is a basic block argument, propagate usefulness to
// incoming values.
if (auto *bbArg = dyn_cast<SILPhiArgument>(value)) {
SmallVector<SILValue, 4> incomingValues;
bbArg->getSingleTerminatorOperands(incomingValues);
for (auto incomingValue : incomingValues)
setUsefulAndPropagateToOperands(incomingValue, dependentVariableIndex);
return;
if (bbArg->getSingleTerminatorOperands(incomingValues)) {
for (auto incomingValue : incomingValues)
setUsefulAndPropagateToOperands(incomingValue, dependentVariableIndex);
return;
} else if (bbArg->isTerminatorResult()) {
if (TryApplyInst *tai = dyn_cast<TryApplyInst>(bbArg->getTerminatorForResult())) {
propagateUseful(tai, dependentVariableIndex);
return;
} else
llvm::report_fatal_error("unknown terminator with result");
} else
llvm::report_fatal_error("do not know how to handle this incoming bb argument");
}

auto *inst = value->getDefiningInstruction();
if (!inst)
return;
Expand Down
90 changes: 46 additions & 44 deletions lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2515,12 +2515,11 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {

// Get predecessor terminator operands.
SmallVector<std::pair<SILBasicBlock *, SILValue>, 4> incomingValues;
bbArg->getSingleTerminatorOperands(incomingValues);

// Returns true if the given terminator instruction is a `switch_enum` on
// an `Optional`-typed value. `switch_enum` instructions require
// special-case adjoint value propagation for the operand.
auto isSwitchEnumInstOnOptional =
if (bbArg->getSingleTerminatorOperands(incomingValues)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do an early return here to reduce nesting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we'd need to eventually handle else case since it at least covers try_apply instruction among the others. The fix for activity calculation was pretty straightforward, but I will need to think how to proceed with adjoint propagation. I do not have a testcase for this right now, but hopefully @BradLarson will be able to distill some from their codebase.

For now we just do not silently generate incorrect code :)

// Returns true if the given terminator instruction is a `switch_enum` on
// an `Optional`-typed value. `switch_enum` instructions require
// special-case adjoint value propagation for the operand.
auto isSwitchEnumInstOnOptional =
[&ctx = getASTContext()](TermInst *termInst) {
if (!termInst)
return false;
Expand All @@ -2531,49 +2530,52 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
return false;
};

// Check the tangent value category of the active basic block argument.
switch (getTangentValueCategory(bbArg)) {
// If argument has a loadable tangent value category: materialize adjoint
// value of the argument, create a copy, and set the copy as the adjoint
// value of incoming values.
case SILValueCategory::Object: {
auto bbArgAdj = getAdjointValue(bb, bbArg);
auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc);
auto concreteBBArgAdjCopy =
// Check the tangent value category of the active basic block argument.
switch (getTangentValueCategory(bbArg)) {
// If argument has a loadable tangent value category: materialize adjoint
// value of the argument, create a copy, and set the copy as the adjoint
// value of incoming values.
case SILValueCategory::Object: {
auto bbArgAdj = getAdjointValue(bb, bbArg);
auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc);
auto concreteBBArgAdjCopy =
builder.emitCopyValueOperation(pbLoc, concreteBBArgAdj);
for (auto pair : incomingValues) {
auto *predBB = std::get<0>(pair);
auto incomingValue = std::get<1>(pair);
// Handle `switch_enum` on `Optional`.
auto termInst = bbArg->getSingleTerminator();
if (isSwitchEnumInstOnOptional(termInst)) {
accumulateAdjointForOptional(bb, incomingValue, concreteBBArgAdjCopy);
} else {
blockTemporaries[getPullbackBlock(predBB)].insert(
for (auto pair : incomingValues) {
pair.second->dump();
auto *predBB = std::get<0>(pair);
auto incomingValue = std::get<1>(pair);
// Handle `switch_enum` on `Optional`.
auto termInst = bbArg->getSingleTerminator();
if (isSwitchEnumInstOnOptional(termInst)) {
accumulateAdjointForOptional(bb, incomingValue, concreteBBArgAdjCopy);
} else {
blockTemporaries[getPullbackBlock(predBB)].insert(
concreteBBArgAdjCopy);
setAdjointValue(predBB, incomingValue,
makeConcreteAdjointValue(concreteBBArgAdjCopy));
setAdjointValue(predBB, incomingValue,
makeConcreteAdjointValue(concreteBBArgAdjCopy));
}
}
break;
}
break;
}
// If argument has an address tangent value category: materialize adjoint
// value of the argument, create a copy, and set the copy as the adjoint
// value of incoming values.
case SILValueCategory::Address: {
auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg);
for (auto pair : incomingValues) {
auto incomingValue = std::get<1>(pair);
// Handle `switch_enum` on `Optional`.
auto termInst = bbArg->getSingleTerminator();
if (isSwitchEnumInstOnOptional(termInst))
accumulateAdjointForOptional(bb, incomingValue, bbArgAdjBuf);
else
addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc);
// If argument has an address tangent value category: materialize adjoint
// value of the argument, create a copy, and set the copy as the adjoint
// value of incoming values.
case SILValueCategory::Address: {
auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg);
for (auto pair : incomingValues) {
auto incomingValue = std::get<1>(pair);
// Handle `switch_enum` on `Optional`.
auto termInst = bbArg->getSingleTerminator();
if (isSwitchEnumInstOnOptional(termInst))
accumulateAdjointForOptional(bb, incomingValue, bbArgAdjBuf);
else
addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc);
}
break;
}
break;
}
}
}
} else
llvm::report_fatal_error("do not know how to handle this incoming bb argument");
}

// 3. Build the pullback successor cases for the `switch_enum`
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// RUN: not %target-swift-frontend -emit-sil -verify %s

// The testcase from https://github.com/apple/swift/issues/63728 is not valid
// (the function is not differentiable), however, it should not cause verifier errors
// Here the root case is lack of activity analysis for `try_apply` terminators

import _Differentiation

func a() throws {
let keyPaths = (readable: [String: KeyPath<T, Double>](), writable: [String: WritableKeyPath<T, Double>]())
@differentiable(reverse)
func f(p: PAndT) -> Double {
var mutableP = p
let s = p.p.e
var sArray: [[Double]] = []
sArray.append((s["a"]!.asArray()).map {$0.value})
mutableP.s = w(mutableP.s, at: keyPaths.writable["a"]!, with: sArray[0][0])
return mutableP.s[keyPath: keyPaths.writable["a"]!]
}
}

public struct S<I: SProtocol, D> {
public func asArray() -> [(index: I, value: D)] {
return [(index: I, value: D)]()
}
}
struct T: Differentiable {}
struct P: Differentiable {
public var e: F<Double, Double>
}

struct PAndT: Differentiable{
@differentiable(reverse) public var p: P
@differentiable(reverse) public var s: T
}

public struct F<I: SProtocol, D>{
public func asArray() -> [(index: I, value: D)] {return [(index: I, value: D)]()}
public subscript(_ name: String) -> S<I, D>? {get { return self.sGet(name) }}
func sGet(_ name: String) -> S<I, D>? { fatalError("") }
}

public protocol ZProtocol: Differentiable {var z: () -> TangentVector { get }}
public protocol SProtocol: Hashable {}

//extension S: Differentiable where D: ZProtocol, D.TangentVector == D {}

extension F: Differentiable where D: ZProtocol, D.TangentVector == D {}
public extension ZProtocol {var z: () -> TangentVector {{ Self.TangentVector.zero }}}

extension F: Equatable where I: Equatable, D: Equatable {}
//extension S: Equatable where I: Equatable, D: Equatable {}
extension Double: SProtocol, ZProtocol {}

@differentiable(reverse where O: Differentiable, M: ZProtocol)
func w<O, M>(_ o: O, at m: WritableKeyPath<O, M>, with v: M) -> O {return o}

@derivative(of: w)
func vjpw<O, M>(_ o: O, at m: WritableKeyPath<O, M>, with v: M) -> (value: O, pullback: (O.TangentVector) -> (O.TangentVector, M.TangentVector)) where O: Differentiable, M: ZProtocol{fatalError("")}