File tree Expand file tree Collapse file tree 3 files changed +49
-0
lines changed
test/AutoDiff/SILOptimizer Expand file tree Collapse file tree 3 files changed +49
-0
lines changed Original file line number Diff line number Diff line change @@ -7106,6 +7106,11 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
71067106ArrayRef<AutoDiffConfig>
71077107AbstractFunctionDecl::getDerivativeFunctionConfigurations () {
71087108 prepareDerivativeFunctionConfigurations ();
7109+ // Resolve derivative function configurations from `@differentiable`
7110+ // attributes by type-checking them.
7111+ for (auto *diffAttr : getAttrs ().getAttributes <DifferentiableAttr>())
7112+ (void )diffAttr->getParameterIndices ();
7113+ // Load derivative configurations from imported modules.
71097114 auto &ctx = getASTContext ();
71107115 if (ctx.getCurrentGeneration () > DerivativeFunctionConfigGeneration) {
71117116 unsigned previousGeneration = DerivativeFunctionConfigGeneration;
Original file line number Diff line number Diff line change 1+ import _Differentiation
2+
3+ protocol Protocol : Differentiable {
4+ // Test cross-file `@differentiable` attribute.
5+ @differentiable ( wrt: self )
6+ func identityDifferentiableAttr( ) -> Self
7+ }
8+
9+ extension Protocol {
10+ func identityDerivativeAttr( ) -> Self { self }
11+
12+ // Test cross-file `@derivative` attribute.
13+ @derivative ( of: identityDerivativeAttr)
14+ func vjpIdentityDerivativeAttr( ) -> (
15+ value: Self , pullback: ( TangentVector ) -> TangentVector
16+ ) {
17+ fatalError ( )
18+ }
19+ }
Original file line number Diff line number Diff line change 1+ // RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/differentiation_diagnostics_other_file.swift -module-name main -o /dev/null
2+
3+ // Test differentiation transform cross-file diagnostics.
4+
5+ import _Differentiation
6+
7+ // TF-1271: Test `@differentiable` original function in other file.
8+ @differentiable
9+ func crossFileDifferentiableAttr< T: Protocol > (
10+ _ input: T
11+ ) -> T {
12+ return input. identityDifferentiableAttr ( )
13+ }
14+
15+ // TF-1272: Test original function with registered derivatives in other files.
16+ // FIXME(TF-1272): Find a way to type-check `@derivative` attributes in other
17+ // files.
18+ @differentiable
19+ func crossFileDerivativeAttr< T: Protocol > (
20+ _ input: T
21+ ) -> T {
22+ // expected-error @+2 {{expression is not differentiable}}
23+ // expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
24+ return input. identityDerivativeAttr ( )
25+ }
You can’t perform that action at this time.
0 commit comments