Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7111,10 +7111,19 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
ArrayRef<AutoDiffConfig>
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
prepareDerivativeFunctionConfigurations();

// Resolve derivative function configurations from `@differentiable`
// attributes by type-checking them.
for (auto *diffAttr : getAttrs().getAttributes<DifferentiableAttr>())
(void)diffAttr->getParameterIndices();
// For accessors: resolve derivative function configurations from storage
// `@differentiable` attributes by type-checking them.
if (auto *accessor = dyn_cast<AccessorDecl>(this)) {
auto *storage = accessor->getStorage();
for (auto *diffAttr : storage->getAttrs().getAttributes<DifferentiableAttr>())
(void)diffAttr->getParameterIndices();
}

// Load derivative configurations from imported modules.
auto &ctx = getASTContext();
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {
Expand Down
12 changes: 6 additions & 6 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,9 +656,9 @@ emitDerivativeFunctionReference(
auto loc = witnessMethod->getLoc();
auto requirementDeclRef = witnessMethod->getMember();
auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl();
// If requirement declaration does not have any `@differentiable`
// attributes, produce an error.
if (!requirementDecl->getAttrs().hasAttribute<DifferentiableAttr>()) {
// If requirement declaration does not have any derivative function
// configurations, produce an error.
if (requirementDecl->getDerivativeFunctionConfigurations().empty()) {
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_protocol_member_not_differentiable);
return None;
Expand Down Expand Up @@ -701,9 +701,9 @@ emitDerivativeFunctionReference(
auto loc = classMethod->getLoc();
auto methodDeclRef = classMethod->getMember();
auto *methodDecl = methodDeclRef.getAbstractFunctionDecl();
// If method declaration does not have any `@differentiable` attributes,
// produce an error.
if (!methodDecl->getAttrs().hasAttribute<DifferentiableAttr>()) {
// If method declaration does not have any derivative function
// configurations, produce an error.
if (methodDecl->getDerivativeFunctionConfigurations().empty()) {
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_class_member_not_differentiable);
return None;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ protocol Protocol: Differentiable {
// Test cross-file `@differentiable` attribute.
@differentiable(wrt: self)
func identityDifferentiableAttr() -> Self

// Test `@differentiable` propagation from storage declaration to accessors.
@differentiable
var property: Float { get set }

// Test `@differentiable` propagation from storage declaration to accessors.
@differentiable
subscript() -> Float { get set }
}

extension Protocol {
Expand All @@ -17,3 +25,19 @@ extension Protocol {
fatalError()
}
}

class Class: Differentiable {
// Test `@differentiable` propagation from storage declaration to accessors.
@differentiable
var property: Float {
get { 1 }
set {}
}

// Test `@differentiable` propagation from storage declaration to accessors.
@differentiable
subscript() -> Float {
get { 1 }
set {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,37 @@ func crossFileDerivativeAttr<T: Protocol>(
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
return input.identityDerivativeAttr()
}

// TF-1234: Test `@differentiable` propagation from protocol requirement storage
// declarations to their accessors in other file.

@differentiable
func protocolRequirementGetters<T: Protocol>(_ x: T) -> Float {
x.property + x[]
}

// TODO(TF-1184): Make `@differentiable` on storage declarations propagate to
// the setter in addition to the getter.
@differentiable
func protocolRequirementSetters<T: Protocol>(_ x: inout T, _ newValue: Float) {
// expected-error @+2 {{expression is not differentiable}}
// expected-note @+1 {{member is not differentiable because the corresponding protocol requirement is not '@differentiable'}}
x.property = newValue
// expected-error @+2 {{expression is not differentiable}}
// expected-note @+1 {{member is not differentiable because the corresponding protocol requirement is not '@differentiable'}}
x[] = newValue
}

// TF-1234: Test `@differentiable` propagation from class member storage
// declarations to their accessors in other file.

@differentiable
func classRequirementGetters(_ x: Class) -> Float {
x.property + x[]
}

@differentiable
func classRequirementSetters(_ x: inout Class, _ newValue: Float) {
x.property = newValue
x[] = newValue
}