@@ -287,7 +287,7 @@ mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
287287
288288void LatticePoint::print (llvm::raw_ostream &os) const {
289289 for (const auto &[value, state] : stateMap) {
290- os << value << " : " ;
290+ os << " \n * " << value << " : " ;
291291 ::print (os, state);
292292 }
293293}
@@ -361,6 +361,13 @@ void AllocationAnalysis::visitOperation(mlir::Operation *op,
361361 } else if (mlir::isa<fir::FreeMemOp>(op)) {
362362 assert (op->getNumOperands () == 1 && " fir.freemem has one operand" );
363363 mlir::Value operand = op->getOperand (0 );
364+
365+ // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir
366+ // to fir. Therefore, we only need to handle `fir::DeclareOp`s.
367+ if (auto declareOp =
368+ llvm::dyn_cast_if_present<fir::DeclareOp>(operand.getDefiningOp ()))
369+ operand = declareOp.getMemref ();
370+
364371 std::optional<AllocationState> operandState = before.get (operand);
365372 if (operandState && *operandState == AllocationState::Allocated) {
366373 // don't tag things not allocated in this function as freed, so that we
@@ -452,6 +459,9 @@ StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
452459 };
453460 func->walk ([&](mlir::func::ReturnOp child) { joinOperationLattice (child); });
454461 func->walk ([&](fir::UnreachableOp child) { joinOperationLattice (child); });
462+ func->walk (
463+ [&](mlir::omp::TerminatorOp child) { joinOperationLattice (child); });
464+
455465 llvm::DenseSet<mlir::Value> freedValues;
456466 point.appendFreedValues (freedValues);
457467
@@ -518,9 +528,18 @@ AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
518528
519529 // remove freemem operations
520530 llvm::SmallVector<mlir::Operation *> erases;
521- for (mlir::Operation *user : allocmem.getOperation ()->getUsers ())
531+ for (mlir::Operation *user : allocmem.getOperation ()->getUsers ()) {
532+ if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
533+ for (mlir::Operation *user : declareOp->getUsers ()) {
534+ if (mlir::isa<fir::FreeMemOp>(user))
535+ erases.push_back (user);
536+ }
537+ }
538+
522539 if (mlir::isa<fir::FreeMemOp>(user))
523540 erases.push_back (user);
541+ }
542+
524543 // now we are done iterating the users, it is safe to mutate them
525544 for (mlir::Operation *erase : erases)
526545 rewriter.eraseOp (erase);
@@ -633,9 +652,19 @@ AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) {
633652
634653 // find freemem ops
635654 llvm::SmallVector<mlir::Operation *, 1 > freeOps;
636- for (mlir::Operation *user : oldAllocOp->getUsers ())
655+
656+ for (mlir::Operation *user : oldAllocOp->getUsers ()) {
657+ if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
658+ for (mlir::Operation *user : declareOp->getUsers ()) {
659+ if (mlir::isa<fir::FreeMemOp>(user))
660+ freeOps.push_back (user);
661+ }
662+ }
663+
637664 if (mlir::isa<fir::FreeMemOp>(user))
638665 freeOps.push_back (user);
666+ }
667+
639668 assert (freeOps.size () && " DFA should only return freed memory" );
640669
641670 // Don't attempt to reason about a stacksave/stackrestore between different
@@ -717,12 +746,23 @@ void AllocMemConversion::insertStackSaveRestore(
717746 mlir::SymbolRefAttr stackRestoreSym =
718747 builder.getSymbolRefAttr (stackRestoreFn.getName ());
719748
749+ auto createStackRestoreCall = [&](mlir::Operation *user) {
750+ builder.setInsertionPoint (user);
751+ builder.create <fir::CallOp>(user->getLoc (),
752+ stackRestoreFn.getFunctionType ().getResults (),
753+ stackRestoreSym, mlir::ValueRange{sp});
754+ };
755+
720756 for (mlir::Operation *user : oldAlloc->getUsers ()) {
757+ if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
758+ for (mlir::Operation *user : declareOp->getUsers ()) {
759+ if (mlir::isa<fir::FreeMemOp>(user))
760+ createStackRestoreCall (user);
761+ }
762+ }
763+
721764 if (mlir::isa<fir::FreeMemOp>(user)) {
722- builder.setInsertionPoint (user);
723- builder.create <fir::CallOp>(user->getLoc (),
724- stackRestoreFn.getFunctionType ().getResults (),
725- stackRestoreSym, mlir::ValueRange{sp});
765+ createStackRestoreCall (user);
726766 }
727767 }
728768
0 commit comments