@@ -403,14 +403,17 @@ SILParameterInfo LargeSILTypeMapper::getNewParameter(GenericEnvironment *env,
403403 } else if (isLargeLoadableType (env, storageType, IGM)) {
404404 if (param.getConvention () == ParameterConvention::Direct_Guaranteed)
405405 return SILParameterInfo (storageType.getASTType (),
406- ParameterConvention::Indirect_In_Guaranteed);
406+ ParameterConvention::Indirect_In_Guaranteed,
407+ param.getDifferentiability ());
407408 else
408409 return SILParameterInfo (storageType.getASTType (),
409- ParameterConvention::Indirect_In_Constant);
410+ ParameterConvention::Indirect_In_Constant,
411+ param.getDifferentiability ());
410412 } else {
411413 auto newType = getNewSILType (env, storageType, IGM);
412414 return SILParameterInfo (newType.getASTType (),
413- param.getConvention ());
415+ param.getConvention (),
416+ param.getDifferentiability ());
414417 }
415418}
416419
@@ -1704,6 +1707,9 @@ class LoadableByAddress : public SILModuleTransform {
17041707 bool fixStoreToBlockStorageInstr (SILInstruction &I,
17051708 SmallVectorImpl<SILInstruction *> &Delete);
17061709
1710+ bool recreateDifferentiabilityWitnessFunction (
1711+ SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete);
1712+
17071713private:
17081714 llvm::SetVector<SILFunction *> modFuncs;
17091715 llvm::SetVector<SingleValueInstruction *> conversionInstrs;
@@ -2708,6 +2714,33 @@ bool LoadableByAddress::fixStoreToBlockStorageInstr(
27082714 return true ;
27092715}
27102716
2717+ bool LoadableByAddress::recreateDifferentiabilityWitnessFunction (
2718+ SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete) {
2719+ auto *instr = dyn_cast<DifferentiabilityWitnessFunctionInst>(&I);
2720+ if (!instr)
2721+ return false ;
2722+
2723+ // Check if we need to recreate the instruction.
2724+ auto *currIRMod = getIRGenModule ()->IRGen .getGenModule (instr->getFunction ());
2725+ auto resultFnTy = instr->getType ().castTo <SILFunctionType>();
2726+ auto genSig = resultFnTy->getSubstGenericSignature ();
2727+ GenericEnvironment *genEnv = nullptr ;
2728+ if (genSig)
2729+ genEnv = genSig->getGenericEnvironment ();
2730+ auto newResultFnTy =
2731+ MapperCache.getNewSILFunctionType (genEnv, resultFnTy, *currIRMod);
2732+ if (resultFnTy == newResultFnTy)
2733+ return true ;
2734+
2735+ SILBuilderWithScope builder (instr);
2736+ auto *newInstr = builder.createDifferentiabilityWitnessFunction (
2737+ instr->getLoc (), instr->getWitnessKind (), instr->getWitness (),
2738+ SILType::getPrimitiveObjectType (newResultFnTy));
2739+ instr->replaceAllUsesWith (newInstr);
2740+ Delete.push_back (instr);
2741+ return true ;
2742+ }
2743+
27112744bool LoadableByAddress::recreateTupleInstr (
27122745 SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete) {
27132746 auto *tupleInstr = dyn_cast<TupleInst>(&I);
@@ -2750,6 +2783,19 @@ bool LoadableByAddress::recreateConvInstr(SILInstruction &I,
27502783 auto currSILFunctionType = currSILType.castTo <SILFunctionType>();
27512784 GenericEnvironment *genEnv =
27522785 getSubstGenericEnvironment (convInstr->getFunction ());
2786+ // Differentiable function conversion instructions can happen while the
2787+ // function is still generic. In that case, we must calculate the new type
2788+ // using the converted function's generic environment rather than the
2789+ // converting function's generic environment.
2790+ //
2791+ // This happens in witness thunks for default implementations of derivative
2792+ // requirements.
2793+ if (convInstr->getKind () == SILInstructionKind::DifferentiableFunctionInst ||
2794+ convInstr->getKind () == SILInstructionKind::DifferentiableFunctionExtractInst ||
2795+ convInstr->getKind () == SILInstructionKind::LinearFunctionInst ||
2796+ convInstr->getKind () == SILInstructionKind::LinearFunctionExtractInst)
2797+ if (auto genSig = currSILFunctionType->getSubstGenericSignature ())
2798+ genEnv = genSig->getGenericEnvironment ();
27532799 CanSILFunctionType newFnType = MapperCache.getNewSILFunctionType (
27542800 genEnv, currSILFunctionType, *currIRMod);
27552801 SILType newType = SILType::getPrimitiveObjectType (newFnType);
@@ -2790,6 +2836,34 @@ bool LoadableByAddress::recreateConvInstr(SILInstruction &I,
27902836 instr->getLoc (), instr->getValue (), instr->getBase ());
27912837 break ;
27922838 }
2839+ case SILInstructionKind::DifferentiableFunctionInst: {
2840+ auto instr = cast<DifferentiableFunctionInst>(convInstr);
2841+ newInstr = convBuilder.createDifferentiableFunction (
2842+ instr->getLoc (), instr->getParameterIndices (),
2843+ instr->getOriginalFunction (),
2844+ instr->getOptionalDerivativeFunctionPair ());
2845+ break ;
2846+ }
2847+ case SILInstructionKind::DifferentiableFunctionExtractInst: {
2848+ auto instr = cast<DifferentiableFunctionExtractInst>(convInstr);
2849+ // Rewrite `differentiable_function_extract` with explicit extractee type.
2850+ newInstr = convBuilder.createDifferentiableFunctionExtract (
2851+ instr->getLoc (), instr->getExtractee (), instr->getOperand (), newType);
2852+ break ;
2853+ }
2854+ case SILInstructionKind::LinearFunctionInst: {
2855+ auto instr = cast<LinearFunctionInst>(convInstr);
2856+ newInstr = convBuilder.createLinearFunction (
2857+ instr->getLoc (), instr->getParameterIndices (),
2858+ instr->getOriginalFunction (), instr->getOptionalTransposeFunction ());
2859+ break ;
2860+ }
2861+ case SILInstructionKind::LinearFunctionExtractInst: {
2862+ auto instr = cast<LinearFunctionExtractInst>(convInstr);
2863+ newInstr = convBuilder.createLinearFunctionExtract (
2864+ instr->getLoc (), instr->getExtractee (), instr->getFunctionOperand ());
2865+ break ;
2866+ }
27932867 default :
27942868 llvm_unreachable (" Unexpected conversion instruction" );
27952869 }
@@ -2878,7 +2952,11 @@ void LoadableByAddress::run() {
28782952 case SILInstructionKind::ConvertEscapeToNoEscapeInst:
28792953 case SILInstructionKind::MarkDependenceInst:
28802954 case SILInstructionKind::ThinFunctionToPointerInst:
2881- case SILInstructionKind::ThinToThickFunctionInst: {
2955+ case SILInstructionKind::ThinToThickFunctionInst:
2956+ case SILInstructionKind::DifferentiableFunctionInst:
2957+ case SILInstructionKind::LinearFunctionInst:
2958+ case SILInstructionKind::LinearFunctionExtractInst:
2959+ case SILInstructionKind::DifferentiableFunctionExtractInst: {
28822960 conversionInstrs.insert (
28832961 cast<SingleValueInstruction>(currInstr));
28842962 break ;
@@ -2945,6 +3023,11 @@ void LoadableByAddress::run() {
29453023 if (modApplies.count (PAI) == 0 ) {
29463024 modApplies.insert (PAI);
29473025 }
3026+ } else if (isa<DifferentiableFunctionInst>(&I) ||
3027+ isa<LinearFunctionInst>(&I) ||
3028+ isa<DifferentiableFunctionExtractInst>(&I) ||
3029+ isa<LinearFunctionExtractInst>(&I)) {
3030+ conversionInstrs.insert (cast<SingleValueInstruction>(&I));
29483031 }
29493032 }
29503033 }
@@ -2988,6 +3071,8 @@ void LoadableByAddress::run() {
29883071 continue ;
29893072 else if (recreateApply (I, Delete))
29903073 continue ;
3074+ else if (recreateDifferentiabilityWitnessFunction (I, Delete))
3075+ continue ;
29913076 else
29923077 fixStoreToBlockStorageInstr (I, Delete);
29933078 }
0 commit comments