@@ -1331,6 +1331,7 @@ void PullbackEmitter::visitSILInstruction(SILInstruction *inst) {
13311331AllocStackInst *
13321332PullbackEmitter::getArrayAdjointElementBuffer (SILValue arrayAdjoint,
13331333 int eltIndex, SILLocation loc) {
1334+ auto &ctx = builder.getASTContext ();
13341335 auto arrayTanType = cast<StructType>(arrayAdjoint->getType ().getASTType ());
13351336 auto arrayType = arrayTanType->getParent ()->castTo <BoundGenericStructType>();
13361337 auto eltTanType = arrayType->getGenericArgs ().front ()->getCanonicalType ();
@@ -1340,7 +1341,19 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
13401341 auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct ();
13411342 auto subscriptLookup =
13421343 arrayTanStructDecl->lookupDirect (DeclBaseName::createSubscript ());
1343- auto *subscriptDecl = cast<SubscriptDecl>(subscriptLookup.front ());
1344+ SubscriptDecl *subscriptDecl = nullptr ;
1345+ for (auto *candidate : subscriptLookup) {
1346+ auto candidateModule = candidate->getModuleContext ();
1347+ if (candidateModule->getName () == ctx.Id_Differentiation ||
1348+ candidateModule->isStdlibModule ()) {
1349+ assert (!subscriptDecl && " Multiple `Array.TangentVector.subscript`s" );
1350+ subscriptDecl = cast<SubscriptDecl>(candidate);
1351+ #ifdef NDEBUG
1352+ break ;
1353+ #endif
1354+ }
1355+ }
1356+ assert (subscriptDecl && " No `Array.TangentVector.subscript`" );
13441357 auto *subscriptGetterDecl = subscriptDecl->getAccessor (AccessorKind::Get);
13451358 assert (subscriptGetterDecl && " No `Array.TangentVector.subscript` getter" );
13461359 SILOptFunctionBuilder fb (getContext ().getTransform ());
@@ -1352,7 +1365,6 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
13521365 subscriptGetterFn->getLoweredFunctionType ()->getSubstGenericSignature ();
13531366 // Apply `Array.TangentVector.subscript.getter` to get array element adjoint
13541367 // buffer.
1355- auto &ctx = builder.getASTContext ();
13561368 // %index_literal = integer_literal $Builtin.IntXX, <index>
13571369 auto builtinIntType =
13581370 SILType::getPrimitiveObjectType (ctx.getIntDecl ()
0 commit comments