@@ -1110,39 +1110,42 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst(
11101110 // %vjp' = convert_escape_to_noescape %vjp
11111111 // %y = differentiable_function(%orig', %jvp', %vjp')
11121112 if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getOperand ())) {
1113- auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) {
1114- if (!DFI->hasExtractee (extractee))
1115- return SILValue ();
1113+ if (DFI->hasOneUse ()) {
1114+ auto createConvertEscapeToNoEscape =
1115+ [&](NormalDifferentiableFunctionTypeComponent extractee) {
1116+ if (!DFI->hasExtractee (extractee))
1117+ return SILValue ();
11161118
1117- auto operand = DFI->getExtractee (extractee);
1118- auto fnType = operand->getType ().castTo <SILFunctionType>();
1119- auto noEscapeFnType =
1120- fnType->getWithExtInfo (fnType->getExtInfo ().withNoEscape ());
1121- auto noEscapeType = SILType::getPrimitiveObjectType (noEscapeFnType);
1122- return Builder.createConvertEscapeToNoEscape (
1123- operand.getLoc (), operand, noEscapeType, Cvt->isLifetimeGuaranteed ())->getResult (0 );
1124- };
1119+ auto operand = DFI->getExtractee (extractee);
1120+ auto fnType = operand->getType ().castTo <SILFunctionType>();
1121+ auto noEscapeFnType =
1122+ fnType->getWithExtInfo (fnType->getExtInfo ().withNoEscape ());
1123+ auto noEscapeType = SILType::getPrimitiveObjectType (noEscapeFnType);
1124+ return Builder.createConvertEscapeToNoEscape (
1125+ operand.getLoc (), operand, noEscapeType, Cvt->isLifetimeGuaranteed ())->getResult (0 );
1126+ };
11251127
1126- SILValue originalNoEscape =
1127- createConvertEscapeToNoEscape (NormalDifferentiableFunctionTypeComponent::Original);
1128- SILValue convertedJVP = createConvertEscapeToNoEscape (
1129- NormalDifferentiableFunctionTypeComponent::JVP);
1130- SILValue convertedVJP = createConvertEscapeToNoEscape (
1131- NormalDifferentiableFunctionTypeComponent::VJP);
1132-
1133- llvm::Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1134- if (convertedJVP && convertedVJP)
1135- derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1136-
1137- auto *newDFI = Builder.createDifferentiableFunction (
1138- DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1139- originalNoEscape, derivativeFunctions);
1140- assert (newDFI->getType () == Cvt->getType () &&
1141- " New `@differentiable` function instruction should have same type "
1142- " as the old `convert_escape_to_no_escape` instruction" );
1143- return newDFI;
1144- }
1128+ SILValue originalNoEscape =
1129+ createConvertEscapeToNoEscape (NormalDifferentiableFunctionTypeComponent::Original);
1130+ SILValue convertedJVP = createConvertEscapeToNoEscape (
1131+ NormalDifferentiableFunctionTypeComponent::JVP);
1132+ SILValue convertedVJP = createConvertEscapeToNoEscape (
1133+ NormalDifferentiableFunctionTypeComponent::VJP);
1134+
1135+ llvm::Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1136+ if (convertedJVP && convertedVJP)
1137+ derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
11451138
1139+ auto *newDFI = Builder.createDifferentiableFunction (
1140+ DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1141+ originalNoEscape, derivativeFunctions);
1142+ assert (newDFI->getType () == Cvt->getType () &&
1143+ " New `@differentiable` function instruction should have same type "
1144+ " as the old `convert_escape_to_no_escape` instruction" );
1145+ return newDFI;
1146+ }
1147+ }
1148+
11461149 return nullptr ;
11471150}
11481151
0 commit comments