@@ -59,7 +59,7 @@ LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
5959 : kind(kind), original(original), derivative(derivative),
6060 activityInfo (activityInfo), indices(indices),
6161 typeConverter(context.getTypeConverter()) {
62- generateDifferentiationDataStructures (context, indices, derivative);
62+ generateDifferentiationDataStructures (context, derivative);
6363}
6464
6565SILType LinearMapInfo::remapTypeInDerivative (SILType ty) {
@@ -122,27 +122,24 @@ void LinearMapInfo::computeAccessLevel(NominalTypeDecl *nominal,
122122 }
123123}
124124
125- EnumDecl *LinearMapInfo::createBranchingTraceDecl (
126- SILBasicBlock *originalBB, SILAutoDiffIndices indices,
127- CanGenericSignature genericSig, SILLoopInfo *loopInfo) {
125+ EnumDecl *
126+ LinearMapInfo::createBranchingTraceDecl (SILBasicBlock *originalBB,
127+ CanGenericSignature genericSig,
128+ SILLoopInfo *loopInfo) {
128129 assert (originalBB->getParent () == original);
129130 auto &astCtx = original->getASTContext ();
130131 auto *moduleDecl = original->getModule ().getSwiftModule ();
131132 auto &file = getDeclarationFileUnit ();
132133 // Create a branching trace enum.
133- std::string enumName;
134- switch (kind) {
135- case AutoDiffLinearMapKind::Differential:
136- enumName = " _AD__" + original->getName ().str () + " _bb" +
137- std::to_string (originalBB->getDebugID ()) + " __Succ__" +
138- indices.mangle ();
139- break ;
140- case AutoDiffLinearMapKind::Pullback:
141- enumName = " _AD__" + original->getName ().str () + " _bb" +
142- std::to_string (originalBB->getDebugID ()) + " __Pred__" +
143- indices.mangle ();
144- break ;
145- }
134+ Mangle::ASTMangler mangler;
135+ auto *resultIndices = IndexSubset::get (
136+ original->getASTContext (),
137+ original->getLoweredFunctionType ()->getNumResults (), indices.source );
138+ auto *parameterIndices = indices.parameters ;
139+ AutoDiffConfig config (parameterIndices, resultIndices, genericSig);
140+ auto enumName = mangler.mangleAutoDiffGeneratedDeclaration (
141+ AutoDiffGeneratedDeclarationKind::BranchingTraceEnum,
142+ original->getName ().str (), originalBB->getDebugID (), kind, config);
146143 auto enumId = astCtx.getIdentifier (enumName);
147144 auto loc = original->getLocation ().getSourceLoc ();
148145 GenericParamList *genericParams = nullptr ;
@@ -199,25 +196,21 @@ EnumDecl *LinearMapInfo::createBranchingTraceDecl(
199196
200197StructDecl *
201198LinearMapInfo::createLinearMapStruct (SILBasicBlock *originalBB,
202- SILAutoDiffIndices indices,
203199 CanGenericSignature genericSig) {
204200 assert (originalBB->getParent () == original);
205201 auto *original = originalBB->getParent ();
206202 auto &astCtx = original->getASTContext ();
207203 auto &file = getDeclarationFileUnit ();
208- std::string structName;
209- switch (kind) {
210- case swift::AutoDiffLinearMapKind::Differential:
211- structName = " _AD__" + original->getName ().str () + " _bb" +
212- std::to_string (originalBB->getDebugID ()) + " __DF__" +
213- indices.mangle ();
214- break ;
215- case swift::AutoDiffLinearMapKind::Pullback:
216- structName = " _AD__" + original->getName ().str () + " _bb" +
217- std::to_string (originalBB->getDebugID ()) + " __PB__" +
218- indices.mangle ();
219- break ;
220- }
204+ // Create a linear map struct.
205+ Mangle::ASTMangler mangler;
206+ auto *resultIndices = IndexSubset::get (
207+ original->getASTContext (),
208+ original->getLoweredFunctionType ()->getNumResults (), indices.source );
209+ auto *parameterIndices = indices.parameters ;
210+ AutoDiffConfig config (parameterIndices, resultIndices, genericSig);
211+ auto structName = mangler.mangleAutoDiffGeneratedDeclaration (
212+ AutoDiffGeneratedDeclarationKind::LinearMapStruct,
213+ original->getName ().str (), originalBB->getDebugID (), kind, config);
221214 auto structId = astCtx.getIdentifier (structName);
222215 GenericParamList *genericParams = nullptr ;
223216 if (genericSig)
@@ -274,8 +267,7 @@ VarDecl *LinearMapInfo::addLinearMapDecl(ApplyInst *ai, SILType linearMapType) {
274267 return linearMapDecl;
275268}
276269
277- void LinearMapInfo::addLinearMapToStruct (ADContext &context, ApplyInst *ai,
278- SILAutoDiffIndices indices) {
270+ void LinearMapInfo::addLinearMapToStruct (ADContext &context, ApplyInst *ai) {
279271 SmallVector<SILValue, 4 > allResults;
280272 SmallVector<unsigned , 8 > activeParamIndices;
281273 SmallVector<unsigned , 8 > activeResultIndices;
@@ -379,7 +371,7 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
379371}
380372
381373void LinearMapInfo::generateDifferentiationDataStructures (
382- ADContext &context, SILAutoDiffIndices indices, SILFunction *derivativeFn) {
374+ ADContext &context, SILFunction *derivativeFn) {
383375 auto &astCtx = original->getASTContext ();
384376 auto *loopAnalysis = context.getPassManager ().getAnalysis <SILLoopAnalysis>();
385377 auto *loopInfo = loopAnalysis->get (original);
@@ -392,8 +384,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
392384
393385 // Create linear map struct for each original block.
394386 for (auto &origBB : *original) {
395- auto *linearMapStruct =
396- createLinearMapStruct (&origBB, indices, derivativeFnGenSig);
387+ auto *linearMapStruct = createLinearMapStruct (&origBB, derivativeFnGenSig);
397388 linearMapStructs.insert ({&origBB, linearMapStruct});
398389 }
399390
@@ -409,8 +400,8 @@ void LinearMapInfo::generateDifferentiationDataStructures(
409400 break ;
410401 }
411402 for (auto &origBB : *original) {
412- auto *traceEnum = createBranchingTraceDecl (&origBB, indices,
413- derivativeFnGenSig, loopInfo);
403+ auto *traceEnum =
404+ createBranchingTraceDecl (&origBB, derivativeFnGenSig, loopInfo);
414405 branchingTraceDecls.insert ({&origBB, traceEnum});
415406 if (origBB.isEntry ())
416407 continue ;
@@ -433,7 +424,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
433424 continue ;
434425 LLVM_DEBUG (getADDebugStream ()
435426 << " Adding linear map struct field for " << *ai);
436- addLinearMapToStruct (context, ai, indices );
427+ addLinearMapToStruct (context, ai);
437428 }
438429 }
439430 }
0 commit comments