diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index 7f2b690297c55..bc93cdf2387a3 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -935,6 +935,43 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, emitDifferentiabilityWitnessesForFunction(constant, F); } +/// Returns the SIL differentiability witness generic signature given the +/// original declaration's generic signature and the derivative generic +/// signature. +/// +/// In general, the differentiability witness generic signature is equal to the +/// derivative generic signature. +/// +/// Edge case, if two conditions are satisfied: +/// 1. The derivative generic signature is equal to the original generic +/// signature. +/// 2. The derivative generic signature has *all concrete* generic parameters +/// (i.e. all generic parameters are bound to concrete types via same-type +/// requirements). +/// +/// Then the differentiability witness generic signature is `nullptr`. +/// +/// Both the original and derivative declarations are lowered to SIL functions +/// with a fully concrete type and no generic signature, so the +/// differentiability witness should similarly have no generic signature. +static GenericSignature +getDifferentiabilityWitnessGenericSignature(GenericSignature origGenSig, + GenericSignature derivativeGenSig) { + // If there is no derivative generic signature, return the original generic + // signature. + if (!derivativeGenSig) + return origGenSig; + // If derivative generic signature has all concrete generic parameters and is + // equal to the original generic signature, return `nullptr`. + auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature(); + auto origCanGenSig = origGenSig.getCanonicalSignature(); + if (origCanGenSig == derivativeCanGenSig && + derivativeCanGenSig->areAllParamsConcrete()) + return GenericSignature(); + // Otherwise, return the derivative generic signature. + return derivativeGenSig; +} + void SILGenModule::emitDifferentiabilityWitnessesForFunction( SILDeclRef constant, SILFunction *F) { // Visit `@derivative` attributes and generate SIL differentiability @@ -955,8 +992,11 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( diffAttr->getDerivativeGenericSignature()) && "Type-checking should resolve derivative generic signatures for " "all original SIL functions with generic signatures"); + auto witnessGenSig = getDifferentiabilityWitnessGenericSignature( + AFD->getGenericSignature(), + diffAttr->getDerivativeGenericSignature()); AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices, - diffAttr->getDerivativeGenericSignature()); + witnessGenSig); emitDifferentiabilityWitness(AFD, F, config, /*jvp*/ nullptr, /*vjp*/ nullptr, diffAttr); } @@ -975,10 +1015,11 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( auto origDeclRef = SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD)); auto *origFn = getFunction(origDeclRef, NotForDefinition); - auto derivativeGenSig = AFD->getGenericSignature(); + auto witnessGenSig = getDifferentiabilityWitnessGenericSignature( + origAFD->getGenericSignature(), AFD->getGenericSignature()); auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices, - derivativeGenSig); + witnessGenSig); emitDifferentiabilityWitness(origAFD, origFn, config, jvp, vjp, derivAttr); } diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 147affbcbc40e..9fde8797c9cb2 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -4188,24 +4188,6 @@ bool resolveDifferentiableAttrDerivativeGenericSignature( attr->getLocation(), /*allowConcreteGenericParams=*/true); } - // Set the resolved derivative generic signature in the attribute. - // Do not set the derivative generic signature if the original function's - // generic signature is equal to `derivativeGenSig` and all generic parameters - // are concrete. In that case, the original function and derivative functions - // are all lowered as SIL functions with no generic signature (specialized - // with concrete types from same-type requirements), so the derivative generic - // signature should not be set. - auto skipDerivativeGenericSignature = [&] { - auto origCanGenSig = - original->getGenericSignature().getCanonicalSignature(); - auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature(); - if (!derivativeCanGenSig) - return false; - return origCanGenSig == derivativeCanGenSig && - derivativeCanGenSig->areAllParamsConcrete(); - }; - if (skipDerivativeGenericSignature()) - derivativeGenSig = GenericSignature(); attr->setDerivativeGenericSignature(derivativeGenSig); return false; } diff --git a/test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift b/test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift new file mode 100644 index 0000000000000..ff198e430e33c --- /dev/null +++ b/test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift @@ -0,0 +1,177 @@ +// RUN: %target-swift-emit-silgen -verify -module-name main %s | %FileCheck %s +// RUN: %target-swift-emit-sil -verify -module-name main %s + +// NOTE(SR-11950): SILParser crashes for SILGen round-trip. + +// This file tests: +// - The "derivative generic signature" of `@differentiable` and `@derivative` +// attributes. +// - The generic signature of lowered SIL differentiability witnesses. + +// Context: +// - For `@differentiable` attributes: the derivative generic signature is +// resolved from the original declaration's generic signature and additional +// `where` clause requirements. +// - For `@derivative` attributes: the derivative generic signature is the +// attributed declaration's generic signature. + +import _Differentiation + +//===----------------------------------------------------------------------===// +// Same-type requirements +//===----------------------------------------------------------------------===// + +// Test original declaration with a generic signature and derivative generic +// signature where all generic parameters are concrete (i.e. bound to concrete +// types via same-type requirements). + +struct AllConcrete: Differentiable {} + +extension AllConcrete { + // Original generic signature: `` + // Derivative generic signature: `` + // Witness generic signature: `` + @_silgen_name("allconcrete_where_gensig_constrained") + @differentiable(where T == Float) + func whereClauseGenericSignatureConstrained() -> AllConcrete { + return self + } +} +extension AllConcrete where T == Float { + @derivative(of: whereClauseGenericSignatureConstrained) + func jvpWhereClauseGenericSignatureConstrained() -> ( + value: AllConcrete, differential: (TangentVector) -> TangentVector + ) { + (whereClauseGenericSignatureConstrained(), { $0 }) + } +} + +// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig_constrained +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_where_gensig_constrained : $@convention(method) (AllConcrete) -> AllConcrete { +// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszl : $@convention(method) (AllConcrete) -> (AllConcrete, @owned @callee_guaranteed (AllConcrete.TangentVector) -> AllConcrete.TangentVector) +// CHECK-NEXT: } + +// If a `@differentiable` or `@derivative` attribute satisfies two conditions: +// 1. The derivative generic signature is equal to the original generic signature. +// 2. The derivative generic signature has *all concrete* generic parameters. +// +// Then the attribute should be lowered to a SIL differentiability witness with +// *no* derivative generic signature. + +extension AllConcrete where T == Float { + // Original generic signature: `` + // Derivative generic signature: `` + // Witness generic signature: none + @_silgen_name("allconcrete_original_gensig") + @differentiable + func originalGenericSignature() -> AllConcrete { + return self + } + + @derivative(of: originalGenericSignature) + func jvpOriginalGenericSignature() -> ( + value: AllConcrete, differential: (TangentVector) -> TangentVector + ) { + (originalGenericSignature(), { $0 }) + } + +// CHECK-LABEL: // differentiability witness for allconcrete_original_gensig +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_original_gensig : $@convention(method) (AllConcrete) -> AllConcrete { +// CHECK-NEXT: jvp: @AD__allconcrete_original_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete) -> (AllConcrete, @owned @callee_guaranteed (AllConcrete.TangentVector) -> AllConcrete.TangentVector) +// CHECK-NEXT: } + + // Original generic signature: `` + // Derivative generic signature: `` (explicit `where` clause) + // Witness generic signature: none + @_silgen_name("allconcrete_where_gensig") + @differentiable(where T == Float) + func whereClauseGenericSignature() -> AllConcrete { + return self + } + + @derivative(of: whereClauseGenericSignature) + func jvpWhereClauseGenericSignature() -> ( + value: AllConcrete, differential: (TangentVector) -> TangentVector + ) { + (whereClauseGenericSignature(), { $0 }) + } + +// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_where_gensig : $@convention(method) (AllConcrete) -> AllConcrete { +// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete) -> (AllConcrete, @owned @callee_guaranteed (AllConcrete.TangentVector) -> AllConcrete.TangentVector) +// CHECK-NEXT: } +} + +// Test original declaration with a generic signature and derivative generic +// signature where *not* all generic parameters are concrete. +// types via same-type requirements). + +struct NotAllConcrete: Differentiable {} + +extension NotAllConcrete { + // Original generic signature: `` + // Derivative generic signature: `` + // Witness generic signature: `` (not all concrete) + @_silgen_name("notallconcrete_where_gensig_constrained") + @differentiable(where T == Float) + func whereClauseGenericSignatureConstrained() -> NotAllConcrete { + return self + } +} +extension NotAllConcrete where T == Float { + @derivative(of: whereClauseGenericSignatureConstrained) + func jvpWhereClauseGenericSignatureConstrained() -> ( + value: NotAllConcrete, differential: (TangentVector) -> TangentVector + ) { + (whereClauseGenericSignatureConstrained(), { $0 }) + } +} + +// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig_constrained +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @notallconcrete_where_gensig_constrained : $@convention(method) (NotAllConcrete) -> NotAllConcrete { +// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete) -> (NotAllConcrete, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for .TangentVector, NotAllConcrete.TangentVector>) +// CHECK-NEXT: } + +extension NotAllConcrete where T == Float { + // Original generic signature: `` + // Derivative generic signature: `` + // Witness generic signature: `` (not all concrete) + @_silgen_name("notallconcrete_original_gensig") + @differentiable + func originalGenericSignature() -> NotAllConcrete { + return self + } + + @derivative(of: originalGenericSignature) + func jvpOriginalGenericSignature() -> ( + value: NotAllConcrete, differential: (TangentVector) -> TangentVector + ) { + (originalGenericSignature(), { $0 }) + } + +// CHECK-LABEL: // differentiability witness for notallconcrete_original_gensig +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @notallconcrete_original_gensig : $@convention(method) (NotAllConcrete) -> NotAllConcrete { +// CHECK-NEXT: jvp: @AD__notallconcrete_original_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete) -> (NotAllConcrete, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for .TangentVector, NotAllConcrete.TangentVector>) +// CHECK-NEXT: } + + // Original generic signature: `` + // Derivative generic signature: `` (explicit `where` clause) + // Witness generic signature: `` (not all concrete) + @_silgen_name("notallconcrete_where_gensig") + @differentiable(where T == Float) + func whereClauseGenericSignature() -> NotAllConcrete { + return self + } + + @derivative(of: whereClauseGenericSignature) + func jvpWhereClauseGenericSignature() -> ( + value: NotAllConcrete, differential: (TangentVector) -> TangentVector + ) { + (whereClauseGenericSignature(), { $0 }) + } + +// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @notallconcrete_where_gensig : $@convention(method) (NotAllConcrete) -> NotAllConcrete { +// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete) -> (NotAllConcrete, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for .TangentVector, NotAllConcrete.TangentVector>) +// CHECK-NEXT: } +}