@@ -1114,38 +1114,41 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst(
11141114 // %vjp' = convert_escape_to_noescape %vjp
11151115 // %y = differentiable_function(%orig', %jvp', %vjp')
11161116 if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getConverted ())) {
1117- auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) {
1118- if (!DFI->hasExtractee (extractee))
1119- return SILValue ();
1117+ if (DFI->hasOneUse ()) {
1118+ auto createConvertEscapeToNoEscape =
1119+ [&](NormalDifferentiableFunctionTypeComponent extractee) {
1120+ if (!DFI->hasExtractee (extractee))
1121+ return SILValue ();
11201122
1121- auto operand = DFI->getExtractee (extractee);
1122- auto fnType = operand->getType ().castTo <SILFunctionType>();
1123- auto noEscapeFnType =
1124- fnType->getWithExtInfo (fnType->getExtInfo ().withNoEscape ());
1125- auto noEscapeType = SILType::getPrimitiveObjectType (noEscapeFnType);
1126- return Builder.createConvertEscapeToNoEscape (
1127- operand.getLoc (), operand, noEscapeType, Cvt->isLifetimeGuaranteed ())->getResult (0 );
1128- };
1123+ auto operand = DFI->getExtractee (extractee);
1124+ auto fnType = operand->getType ().castTo <SILFunctionType>();
1125+ auto noEscapeFnType =
1126+ fnType->getWithExtInfo (fnType->getExtInfo ().withNoEscape ());
1127+ auto noEscapeType = SILType::getPrimitiveObjectType (noEscapeFnType);
1128+ return Builder.createConvertEscapeToNoEscape (
1129+ operand.getLoc (), operand, noEscapeType, Cvt->isLifetimeGuaranteed ())->getResult (0 );
1130+ };
11291131
1130- SILValue originalNoEscape =
1131- createConvertEscapeToNoEscape (NormalDifferentiableFunctionTypeComponent::Original);
1132- SILValue convertedJVP = createConvertEscapeToNoEscape (
1133- NormalDifferentiableFunctionTypeComponent::JVP);
1134- SILValue convertedVJP = createConvertEscapeToNoEscape (
1135- NormalDifferentiableFunctionTypeComponent::VJP);
1136-
1137- Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1138- if (convertedJVP && convertedVJP)
1139- derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1140-
1141- auto *newDFI = Builder.createDifferentiableFunction (
1142- DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1143- originalNoEscape, derivativeFunctions);
1144- assert (newDFI->getType () == Cvt->getType () &&
1145- " New `@differentiable` function instruction should have same type "
1146- " as the old `convert_escape_to_no_escape` instruction" );
1147- return newDFI;
1148- }
1132+ SILValue originalNoEscape =
1133+ createConvertEscapeToNoEscape (NormalDifferentiableFunctionTypeComponent::Original);
1134+ SILValue convertedJVP = createConvertEscapeToNoEscape (
1135+ NormalDifferentiableFunctionTypeComponent::JVP);
1136+ SILValue convertedVJP = createConvertEscapeToNoEscape (
1137+ NormalDifferentiableFunctionTypeComponent::VJP);
1138+
1139+ Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1140+ if (convertedJVP && convertedVJP)
1141+ derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1142+
1143+ auto *newDFI = Builder.createDifferentiableFunction (
1144+ DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1145+ originalNoEscape, derivativeFunctions);
1146+ assert (newDFI->getType () == Cvt->getType () &&
1147+ " New `@differentiable` function instruction should have same type "
1148+ " as the old `convert_escape_to_no_escape` instruction" );
1149+ return newDFI;
1150+ }
1151+ }
11491152
11501153 return nullptr ;
11511154}
0 commit comments