@@ -2095,6 +2095,23 @@ PullbackCloner::PullbackCloner(VJPCloner &vjpCloner)
20952095
20962096PullbackCloner::~PullbackCloner () { delete &impl; }
20972097
2098+ static SILValue getArrayValue (ApplyInst *ai) {
2099+ SILValue arrayValue;
2100+ for (auto use : ai->getUses ()) {
2101+ auto *dti = dyn_cast<DestructureTupleInst>(use->getUser ());
2102+ if (!dti)
2103+ continue ;
2104+ DEBUG_ASSERT (!arrayValue && " Array value already found" );
2105+ // The first `destructure_tuple` result is the `Array` value.
2106+ arrayValue = dti->getResult (0 );
2107+ #ifndef DEBUG_ASSERT_enabled
2108+ break ;
2109+ #endif
2110+ }
2111+ ASSERT (arrayValue);
2112+ return arrayValue;
2113+ }
2114+
20982115// --------------------------------------------------------------------------//
20992116// Entry point
21002117// --------------------------------------------------------------------------//
@@ -2439,6 +2456,134 @@ bool PullbackCloner::Implementation::run() {
24392456 // Visit original blocks in post-order and perform differentiation
24402457 // in corresponding pullback blocks. If errors occurred, back out.
24412458 else {
2459+ LLVM_DEBUG (getADDebugStream ()
2460+ << " Begin search for adjoints of loop-local active values\n " );
2461+ llvm::DenseMap<const SILLoop *, llvm::DenseSet<SILValue>>
2462+ loopLocalActiveValues;
2463+ for (auto *bb : originalBlocks) {
2464+ const SILLoop *loop = vjpCloner.getLoopInfo ()->getLoopFor (bb);
2465+ if (loop == nullptr )
2466+ continue ;
2467+ SILBasicBlock *loopHeader = loop->getHeader ();
2468+ SILBasicBlock *pbLoopHeader = getPullbackBlock (loopHeader);
2469+ LLVM_DEBUG (getADDebugStream ()
2470+ << " Original bb" << bb->getDebugID ()
2471+ << " belongs to a loop, original header bb"
2472+ << loopHeader->getDebugID () << " , pullback header bb"
2473+ << pbLoopHeader->getDebugID () << ' \n ' );
2474+ builder.setInsertionPoint (pbLoopHeader);
2475+ auto bbActiveValuesIt = activeValues.find (bb);
2476+ if (bbActiveValuesIt == activeValues.end ())
2477+ continue ;
2478+ const auto &bbActiveValues = bbActiveValuesIt->second ;
2479+ for (SILValue bbActiveValue : bbActiveValues) {
2480+ if (vjpCloner.getLoopInfo ()->getLoopFor (
2481+ bbActiveValue->getParentBlock ()) != loop) {
2482+ LLVM_DEBUG (
2483+ getADDebugStream ()
2484+ << " The following active value is NOT loop-local, skipping: "
2485+ << bbActiveValue);
2486+ continue ;
2487+ }
2488+
2489+ auto [_, wasInserted] =
2490+ loopLocalActiveValues[loop].insert (bbActiveValue);
2491+ LLVM_DEBUG (getADDebugStream ()
2492+ << " The following active value is loop-local, " );
2493+ if (!wasInserted) {
2494+ LLVM_DEBUG (llvm::dbgs () << " but it was already processed, skipping: "
2495+ << bbActiveValue);
2496+ continue ;
2497+ }
2498+
2499+ if (getTangentValueCategory (bbActiveValue) ==
2500+ SILValueCategory::Object) {
2501+ LLVM_DEBUG (llvm::dbgs ()
2502+ << " zeroing its adjoint value in loop header: "
2503+ << bbActiveValue);
2504+ setAdjointValue (bb, bbActiveValue,
2505+ makeZeroAdjointValue (getRemappedTangentType (
2506+ bbActiveValue->getType ())));
2507+ continue ;
2508+ }
2509+
2510+ ASSERT (getTangentValueCategory (bbActiveValue) ==
2511+ SILValueCategory::Address);
2512+
2513+ // getAdjointProjection might call materializeAdjointDirect which
2514+ // writes to debug output, emit \n.
2515+ LLVM_DEBUG (llvm::dbgs ()
2516+ << " checking if it's adjoint is a projection\n " );
2517+
2518+ if (!getAdjointProjection (bb, bbActiveValue)) {
2519+ LLVM_DEBUG (getADDebugStream ()
2520+ << " Adjoint for the following value is NOT a projection, "
2521+ " zeroing its adjoint buffer in loop header: "
2522+ << bbActiveValue);
2523+
2524+ // All adjoint buffers are allocated in the pullback entry and
2525+ // deallocated in the pullback exit. So, use IsNotInitialization to
2526+ // emit destroy_addr before zeroing the buffer.
2527+ ASSERT (bufferMap.contains ({bb, bbActiveValue}));
2528+ builder.emitZeroIntoBuffer (pbLoc, getAdjointBuffer (bb, bbActiveValue),
2529+ IsNotInitialization);
2530+
2531+ continue ;
2532+ }
2533+
2534+ LLVM_DEBUG (getADDebugStream ()
2535+ << " Adjoint for the following value is a projection, " );
2536+
2537+ // If Projection::isAddressProjection(v) is true for a value v, it
2538+ // is not added to active values list (see recordValueIfActive).
2539+ //
2540+ // Ensure that only the following value types conforming to
2541+ // getAdjointProjection but not conforming to
2542+ // Projection::isAddressProjection can go here.
2543+ //
2544+ // Instructions conforming to Projection::isAddressProjection and
2545+ // thus never corresponding to an active value do not need any
2546+ // handling, because only active values can have adjoints from
2547+ // previous iterations propagated via BB arguments.
2548+ do {
2549+ // Consider '%X = begin_access [modify] [static] %Y'.
2550+ // 1. If %Y is loop-local, it's adjoint buffer will
2551+ // be zeroed, and we'll have zero adjoint projection to it.
2552+ // 2. Otherwise, we do not need to zero the projection buffer.
2553+ // Thus, we can just skip.
2554+ if (dyn_cast<BeginAccessInst>(bbActiveValue)) {
2555+ LLVM_DEBUG (llvm::dbgs () << " skipping: " << bbActiveValue);
2556+ break ;
2557+ }
2558+
2559+ // Consider the following sequence:
2560+ // %1 = function_ref @allocUninitArray
2561+ // %2 = apply %1<Float>(%0)
2562+ // (%3, %4) = destructure_tuple %2
2563+ // %5 = mark_dependence %4 on %3
2564+ // %6 = pointer_to_address %6 to [strict] $*Float
2565+ // Since %6 is active, %3 (which is an array) must also be active.
2566+ // Thus, adjoint for %3 will be zeroed if needed. Ensure that expected
2567+ // invariants hold and then skip.
2568+ if (auto *ai = getAllocateUninitializedArrayIntrinsicElementAddress (
2569+ bbActiveValue)) {
2570+ ASSERT (isa<PointerToAddressInst>(bbActiveValue));
2571+ SILValue arrayValue = getArrayValue (ai);
2572+ ASSERT (llvm::find (bbActiveValues, arrayValue) !=
2573+ bbActiveValues.end ());
2574+ ASSERT (vjpCloner.getLoopInfo ()->getLoopFor (
2575+ arrayValue->getParentBlock ()) == loop);
2576+ LLVM_DEBUG (llvm::dbgs () << " skipping: " << bbActiveValue);
2577+ break ;
2578+ }
2579+
2580+ ASSERT (false );
2581+ } while (false );
2582+ }
2583+ }
2584+ LLVM_DEBUG (getADDebugStream ()
2585+ << " End search for adjoints of loop-local active values\n " );
2586+
24422587 for (auto *bb : originalBlocks) {
24432588 visitSILBasicBlock (bb);
24442589 if (errorOccurred)
@@ -3339,19 +3484,9 @@ SILValue PullbackCloner::Implementation::getAdjointProjection(
33393484 eltIndex = ili->getValue ().getLimitedValue ();
33403485 }
33413486 // Get the array adjoint value.
3342- SILValue arrayAdjoint;
3343- assert (ai && " Expected `array.uninitialized_intrinsic` application" );
3344- for (auto use : ai->getUses ()) {
3345- auto *dti = dyn_cast<DestructureTupleInst>(use->getUser ());
3346- if (!dti)
3347- continue ;
3348- assert (!arrayAdjoint && " Array adjoint already found" );
3349- // The first `destructure_tuple` result is the `Array` value.
3350- auto arrayValue = dti->getResult (0 );
3351- arrayAdjoint = materializeAdjointDirect (
3352- getAdjointValue (origBB, arrayValue), definingInst->getLoc ());
3353- }
3354- assert (arrayAdjoint && " Array does not have adjoint value" );
3487+ SILValue arrayValue = getArrayValue (ai);
3488+ SILValue arrayAdjoint = materializeAdjointDirect (
3489+ getAdjointValue (origBB, arrayValue), definingInst->getLoc ());
33553490 // Apply `Array.TangentVector.subscript` to get array element adjoint value.
33563491 auto *eltAdjBuffer =
33573492 getArrayAdjointElementBuffer (arrayAdjoint, eltIndex, ai->getLoc ());
0 commit comments