diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index b28ab04ce3fd3..ea072f6537180 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -7106,6 +7106,11 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() { ArrayRef AbstractFunctionDecl::getDerivativeFunctionConfigurations() { prepareDerivativeFunctionConfigurations(); + // Resolve derivative function configurations from `@differentiable` + // attributes by type-checking them. + for (auto *diffAttr : getAttrs().getAttributes()) + (void)diffAttr->getParameterIndices(); + // Load derivative configurations from imported modules. auto &ctx = getASTContext(); if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) { unsigned previousGeneration = DerivativeFunctionConfigGeneration; diff --git a/test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift b/test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift new file mode 100644 index 0000000000000..818d07a5a446a --- /dev/null +++ b/test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift @@ -0,0 +1,19 @@ +import _Differentiation + +protocol Protocol: Differentiable { + // Test cross-file `@differentiable` attribute. + @differentiable(wrt: self) + func identityDifferentiableAttr() -> Self +} + +extension Protocol { + func identityDerivativeAttr() -> Self { self } + + // Test cross-file `@derivative` attribute. + @derivative(of: identityDerivativeAttr) + func vjpIdentityDerivativeAttr() -> ( + value: Self, pullback: (TangentVector) -> TangentVector + ) { + fatalError() + } +} diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift new file mode 100644 index 0000000000000..95e048e1864f4 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift @@ -0,0 +1,25 @@ +// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/differentiation_diagnostics_other_file.swift -module-name main -o /dev/null + +// Test differentiation transform cross-file diagnostics. + +import _Differentiation + +// TF-1271: Test `@differentiable` original function in other file. +@differentiable +func crossFileDifferentiableAttr( + _ input: T +) -> T { + return input.identityDifferentiableAttr() +} + +// TF-1272: Test original function with registered derivatives in other files. +// FIXME(TF-1272): Find a way to type-check `@derivative` attributes in other +// files. +@differentiable +func crossFileDerivativeAttr( + _ input: T +) -> T { + // expected-error @+2 {{expression is not differentiable}} + // expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}} + return input.identityDerivativeAttr() +}