From 4969ff63268d0da1c5cb6a85142369a8d95772f0 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 28 Oct 2025 01:05:35 -0700 Subject: [PATCH] Revert " [MLIR] Revamp RegionBranchOpInterface (#161575)" This reverts commit ab1fd21b541056ccd1e0584e438082f417ad3cb4. --- flang/lib/Optimizer/Dialect/FIROps.cpp | 7 +- .../mlir/Analysis/DataFlow/DenseAnalysis.h | 6 +- .../mlir/Analysis/DataFlow/SparseAnalysis.h | 2 +- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 7 - .../mlir/Dialect/Transform/IR/TransformOps.td | 6 +- .../TuneExtension/TuneExtensionOps.td | 2 +- mlir/include/mlir/IR/Diagnostics.h | 2 - mlir/include/mlir/IR/Operation.h | 1 - mlir/include/mlir/IR/Region.h | 2 - .../mlir/Interfaces/ControlFlowInterfaces.h | 104 ++---- .../mlir/Interfaces/ControlFlowInterfaces.td | 108 +----- .../AliasAnalysis/LocalAliasAnalysis.cpp | 325 ++++++------------ .../Analysis/DataFlow/DeadCodeAnalysis.cpp | 9 +- mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp | 4 +- mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 6 +- mlir/lib/Analysis/SliceWalk.cpp | 2 +- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 50 ++- mlir/lib/Dialect/Async/IR/Async.cpp | 11 +- .../OwnershipBasedBufferDeallocation.cpp | 11 +- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 8 +- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 2 +- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +- mlir/lib/Dialect/SCF/IR/SCF.cpp | 52 ++- .../lib/Dialect/SCF/Transforms/ForToWhile.cpp | 1 + .../Dialect/SCF/Transforms/ForallToFor.cpp | 1 + mlir/lib/Dialect/Shape/IR/Shape.cpp | 2 +- .../SparseTensor/IR/SparseTensorDialect.cpp | 4 +- .../lib/Dialect/Transform/IR/TransformOps.cpp | 37 +- .../TuneExtension/TuneExtensionOps.cpp | 5 +- mlir/lib/IR/Diagnostics.cpp | 4 - mlir/lib/IR/Region.cpp | 15 - mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 305 +++++----------- mlir/lib/Transforms/RemoveDeadValues.cpp | 25 +- mlir/test/Dialect/SCF/invalid.mlir | 8 +- .../TestDenseBackwardDataFlowAnalysis.cpp | 4 +- mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 26 +- mlir/test/lib/Dialect/Test/TestOps.td | 2 +- .../Interfaces/ControlFlowInterfacesTest.cpp | 38 +- 38 files changed, 378 insertions(+), 828 deletions(-) diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 4f97acaa88b7a..d0164f32d9b6a 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -4484,7 +4484,7 @@ void fir::IfOp::getSuccessorRegions( llvm::SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { - regions.push_back(mlir::RegionSuccessor(getOperation(), getResults())); + regions.push_back(mlir::RegionSuccessor(getResults())); return; } @@ -4494,8 +4494,7 @@ void fir::IfOp::getSuccessorRegions( // Don't consider the else region if it is empty. mlir::Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back( - mlir::RegionSuccessor(getOperation(), getOperation()->getResults())); + regions.push_back(mlir::RegionSuccessor()); else regions.push_back(mlir::RegionSuccessor(elseRegion)); } @@ -4514,7 +4513,7 @@ void fir::IfOp::getEntrySuccessorRegions( if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(getOperation(), getOperation()->getResults()); + regions.emplace_back(getResults()); } } diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h index 3c87c453a4cf0..8bcfe51ad7cd1 100644 --- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h @@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis { /// itself. virtual void visitRegionBranchControlFlowTransfer( RegionBranchOpInterface branch, RegionBranchPoint regionFrom, - RegionSuccessor regionTo, const AbstractDenseLattice &after, + RegionBranchPoint regionTo, const AbstractDenseLattice &after, AbstractDenseLattice *before) { meet(before, after); } @@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis /// and "to" regions. virtual void visitRegionBranchControlFlowTransfer( RegionBranchOpInterface branch, RegionBranchPoint regionFrom, - RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) { + RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) { AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer( branch, regionFrom, regionTo, after, before); } @@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis } void visitRegionBranchControlFlowTransfer( RegionBranchOpInterface branch, RegionBranchPoint regionForm, - RegionSuccessor regionTo, const AbstractDenseLattice &after, + RegionBranchPoint regionTo, const AbstractDenseLattice &after, AbstractDenseLattice *before) final { visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo, static_cast(after), diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h index 985573476ab78..1a33ecf8b5aa9 100644 --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis { /// and propagating therefrom. virtual void visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch, - RegionSuccessor successor, + RegionBranchPoint successor, ArrayRef lattices); }; diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 48690151caf01..fadd3fc10bfc4 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -644,13 +644,6 @@ def ForallOp : SCF_Op<"forall", [ /// Returns true if the mapping specified for this forall op is linear. bool usesLinearMapping(); - - /// RegionBranchOpInterface - - OperandRange getEntrySuccessorOperands(RegionSuccessor successor) { - return getInits(); - } - }]; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index ed69287410509..62e66b3dabee8 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -25,7 +25,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" def AlternativesOp : TransformDialectOp<"alternatives", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -624,7 +624,7 @@ def ForeachOp : TransformDialectOp<"foreach", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + "getSuccessorRegions", "getEntrySuccessorOperands"]>, SingleBlockImplicitTerminator<"::mlir::transform::YieldOp"> ]> { let summary = "Executes the body for each element of the payload"; @@ -1237,7 +1237,7 @@ def SelectOp : TransformDialectOp<"select", def SequenceOp : TransformDialectOp<"sequence", [DeclareOpInterfaceMethods, MatchOpInterface, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td index 4079848fd203a..d095659fc4838 100644 --- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td +++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td @@ -63,7 +63,7 @@ def KnobOp : Op, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h index a0a99f4953822..7ff718ad7f241 100644 --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -29,7 +29,6 @@ class MLIRContext; class Operation; class OperationName; class OpPrintingFlags; -class OpWithFlags; class Type; class Value; @@ -200,7 +199,6 @@ class Diagnostic { /// Stream in an Operation. Diagnostic &operator<<(Operation &op); - Diagnostic &operator<<(OpWithFlags op); Diagnostic &operator<<(Operation *op) { return *this << *op; } /// Append an operation with the given printing flags. Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags); diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index b2019574a820d..5569392cf0b41 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -1114,7 +1114,6 @@ class OpWithFlags { : op(op), theFlags(flags) {} OpPrintingFlags &flags() { return theFlags; } const OpPrintingFlags &flags() const { return theFlags; } - Operation *getOperation() const { return op; } private: Operation *op; diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h index 53d461df98710..1fcb316750230 100644 --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -379,8 +379,6 @@ class RegionRange friend RangeBaseT; }; -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region ®ion); - } // namespace mlir #endif // MLIR_IR_REGION_H diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index 47afd252c6d68..d63800c12d132 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -15,16 +15,10 @@ #define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H #include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "llvm/ADT/PointerUnion.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/DebugLog.h" -#include "llvm/Support/raw_ostream.h" namespace mlir { class BranchOpInterface; class RegionBranchOpInterface; -class RegionBranchTerminatorOpInterface; /// This class models how operands are forwarded to block arguments in control /// flow. It consists of a number, denoting how many of the successors block @@ -192,40 +186,27 @@ class RegionSuccessor { public: /// Initialize a successor that branches to another region of the parent /// operation. - /// TODO: the default value for the regionInputs is somehow broken. - /// A region successor should have its input correctly set. RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {}) - : successor(region), inputs(regionInputs) { - assert(region && "Region must not be null"); - } + : region(region), inputs(regionInputs) {} /// Initialize a successor that branches back to/out of the parent operation. - /// The target must be one of the recursive parent operations. - RegionSuccessor(Operation *successorOp, Operation::result_range results) - : successor(successorOp), inputs(ValueRange(results)) { - assert(successorOp && "Successor op must not be null"); - } + RegionSuccessor(Operation::result_range results) + : inputs(ValueRange(results)) {} + /// Constructor with no arguments. + RegionSuccessor() : inputs(ValueRange()) {} /// Return the given region successor. Returns nullptr if the successor is the /// parent operation. - Region *getSuccessor() const { return dyn_cast(successor); } + Region *getSuccessor() const { return region; } /// Return true if the successor is the parent operation. - bool isParent() const { return isa(successor); } + bool isParent() const { return region == nullptr; } /// Return the inputs to the successor that are remapped by the exit values of /// the current region. ValueRange getSuccessorInputs() const { return inputs; } - bool operator==(RegionSuccessor rhs) const { - return successor == rhs.successor && inputs == rhs.inputs; - } - - friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) { - return !(lhs == rhs); - } - private: - llvm::PointerUnion successor{nullptr}; + Region *region{nullptr}; ValueRange inputs; }; @@ -233,67 +214,64 @@ class RegionSuccessor { /// `RegionBranchOpInterface`. /// One can branch from one of two kinds of places: /// * The parent operation (aka the `RegionBranchOpInterface` implementation) -/// * A RegionBranchTerminatorOpInterface inside a region within the parent -// operation. +/// * A region within the parent operation. class RegionBranchPoint { public: /// Returns an instance of `RegionBranchPoint` representing the parent /// operation. static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); } - /// Creates a `RegionBranchPoint` that branches from the given terminator. - inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor); + /// Creates a `RegionBranchPoint` that branches from the given region. + /// The pointer must not be null. + RegionBranchPoint(Region *region) : maybeRegion(region) { + assert(region && "Region must not be null"); + } + + RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {} /// Explicitly stops users from constructing with `nullptr`. RegionBranchPoint(std::nullptr_t) = delete; + /// Constructs a `RegionBranchPoint` from the the target of a + /// `RegionSuccessor` instance. + RegionBranchPoint(RegionSuccessor successor) { + if (successor.isParent()) + maybeRegion = nullptr; + else + maybeRegion = successor.getSuccessor(); + } + + /// Assigns a region being branched from. + RegionBranchPoint &operator=(Region ®ion) { + maybeRegion = ®ion; + return *this; + } + /// Returns true if branching from the parent op. - bool isParent() const { return predecessor == nullptr; } + bool isParent() const { return maybeRegion == nullptr; } - /// Returns the terminator if branching from a region. + /// Returns the region if branching from a region. /// A null pointer otherwise. - Operation *getTerminatorPredecessorOrNull() const { return predecessor; } + Region *getRegionOrNull() const { return maybeRegion; } /// Returns true if the two branch points are equal. friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) { - return lhs.predecessor == rhs.predecessor; + return lhs.maybeRegion == rhs.maybeRegion; } private: // Private constructor to encourage the use of `RegionBranchPoint::parent`. - constexpr RegionBranchPoint() = default; + constexpr RegionBranchPoint() : maybeRegion(nullptr) {} /// Internal encoding. Uses nullptr for representing branching from the parent - /// op and the region terminator being branched from otherwise. - Operation *predecessor = nullptr; + /// op and the region being branched from otherwise. + Region *maybeRegion; }; inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) { return !(lhs == rhs); } -inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - RegionBranchPoint point) { - if (point.isParent()) - return os << ""; - return os << "getParentRegion() - ->getRegionNumber() - << ", terminator " - << OpWithFlags(point.getTerminatorPredecessorOrNull(), - OpPrintingFlags().skipRegions()) - << ">"; -} - -inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - RegionSuccessor successor) { - if (successor.isParent()) - return os << ""; - return os << "getRegionNumber() - << " with " << successor.getSuccessorInputs().size() << " inputs>"; -} - /// This class represents upper and lower bounds on the number of times a region /// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least /// zero, but the upper bound may not be known. @@ -370,10 +348,4 @@ struct ReturnLike : public TraitBase { /// Include the generated interface declarations. #include "mlir/Interfaces/ControlFlowInterfaces.h.inc" -namespace mlir { -inline RegionBranchPoint::RegionBranchPoint( - RegionBranchTerminatorOpInterface predecessor) - : predecessor(predecessor.getOperation()) {} -} // namespace mlir - #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 94242e3ba39ce..b8d08cc553caa 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -117,7 +117,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> { def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { let description = [{ - This interface provides information for region-holding operations that exhibit + This interface provides information for region operations that exhibit branching behavior between held regions. I.e., this interface allows for expressing control flow information for region holding operations. @@ -126,12 +126,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { be side-effect free. A "region branch point" indicates a point from which a branch originates. It - can indicate either a terminator in any of the immediately nested region of - this op or `RegionBranchPoint::parent()`. In the latter case, the branch - originates from outside of the op, i.e., when first executing this op. + can indicate either a region of this op or `RegionBranchPoint::parent()`. In + the latter case, the branch originates from outside of the op, i.e., when + first executing this op. A "region successor" indicates the target of a branch. It can indicate - either a region of this op or this op itself. In the former case, the region + either a region of this op or this op. In the former case, the region successor is a region pointer and a range of block arguments to which the "successor operands" are forwarded to. In the latter case, the control flow leaves this op and the region successor is a range of results of this op to @@ -151,10 +151,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { } ``` - `scf.for` has one region. The `scf.yield` has two region successors: the - region body itself and the `scf.for` op. `%b` is an entry successor - operand. `%c` is a successor operand. `%a` is a successor block argument. - `%r` is a successor result. + `scf.for` has one region. The region has two region successors: the region + itself and the `scf.for` op. %b is an entry successor operand. %c is a + successor operand. %a is a successor block argument. %r is a successor + result. }]; let cppNamespace = "::mlir"; @@ -162,16 +162,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { InterfaceMethod<[{ Returns the operands of this operation that are forwarded to the region successor's block arguments or this operation's results when branching - to `successor`. `successor` is guaranteed to be among the successors that are + to `point`. `point` is guaranteed to be among the successors that are returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`. Example: In the above example, this method returns the operand %b of the - `scf.for` op, regardless of the value of `successor`. I.e., this op always + `scf.for` op, regardless of the value of `point`. I.e., this op always forwards the same operands, regardless of whether the loop has 0 or more iterations. }], "::mlir::OperandRange", "getEntrySuccessorOperands", - (ins "::mlir::RegionSuccessor":$successor), [{}], + (ins "::mlir::RegionBranchPoint":$point), [{}], /*defaultImplementation=*/[{ auto operandEnd = this->getOperation()->operand_end(); return ::mlir::OperandRange(operandEnd, operandEnd); @@ -224,80 +224,6 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { (ins "::mlir::RegionBranchPoint":$point, "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions) >, - InterfaceMethod<[{ - Returns the potential region successors when branching from any - terminator in `region`. - These are the regions that may be selected during the flow of control. - }], - "void", "getSuccessorRegions", - (ins "::mlir::Region&":$region, - "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), - [{}], - /*defaultImplementation=*/[{ - for (::mlir::Block &block : region) { - if (block.empty()) - continue; - if (auto terminator = - dyn_cast(block.back())) - $_op.getSuccessorRegions(RegionBranchPoint(terminator), - regions); - } - }]>, - InterfaceMethod<[{ - Returns the potential branching point (predecessors) for a given successor. - }], - "void", "getPredecessors", - (ins "::mlir::RegionSuccessor":$successor, - "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors), - [{}], - /*defaultImplementation=*/[{ - ::llvm::SmallVector<::mlir::RegionSuccessor> successors; - $_op.getSuccessorRegions(RegionBranchPoint::parent(), - successors); - if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) { - return succ.getSuccessor() == successor.getSuccessor() || - (succ.isParent() && successor.isParent()); - })) - predecessors.push_back(RegionBranchPoint::parent()); - for (Region ®ion : $_op->getRegions()) { - for (::mlir::Block &block : region) { - if (block.empty()) - continue; - if (auto terminator = - dyn_cast(block.back())) { - ::llvm::SmallVector<::mlir::RegionSuccessor> successors; - $_op.getSuccessorRegions(RegionBranchPoint(terminator), - successors); - if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) { - return succ.getSuccessor() == successor.getSuccessor() || - (succ.isParent() && successor.isParent()); - })) - predecessors.push_back(terminator); - } - } - } - }]>, - InterfaceMethod<[{ - Returns the potential values across all (predecessors) for a given successor - input, modeled by its index (its position in the list of values). - }], - "void", "getPredecessorValues", - (ins "::mlir::RegionSuccessor":$successor, - "int":$index, - "::llvm::SmallVectorImpl<::mlir::Value> &":$predecessorValues), - [{}], - /*defaultImplementation=*/[{ - ::llvm::SmallVector<::mlir::RegionBranchPoint> predecessors; - $_op.getPredecessors(successor, predecessors); - for (auto predecessor : predecessors) { - if (predecessor.isParent()) { - predecessorValues.push_back($_op.getEntrySuccessorOperands(successor)[index]); - continue; - } - auto terminator = cast(predecessor.getTerminatorPredecessorOrNull()); - predecessorValues.push_back(terminator.getSuccessorOperands(successor)[index]); - } - }]>, InterfaceMethod<[{ Populates `invocationBounds` with the minimum and maximum number of times this operation will invoke the attached regions (assuming the @@ -372,7 +298,7 @@ def RegionBranchTerminatorOpInterface : passing them to the region successor indicated by `point`. }], "::mlir::MutableOperandRange", "getMutableSuccessorOperands", - (ins "::mlir::RegionSuccessor":$point) + (ins "::mlir::RegionBranchPoint":$point) >, InterfaceMethod<[{ Returns the potential region successors that are branched to after this @@ -391,7 +317,7 @@ def RegionBranchTerminatorOpInterface : /*defaultImplementation=*/[{ ::mlir::Operation *op = $_op; ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp()) - .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions); + .getSuccessorRegions(op->getParentRegion(), regions); }] >, ]; @@ -411,8 +337,8 @@ def RegionBranchTerminatorOpInterface : // them to the region successor given by `index`. If `index` is None, this // function returns the operands that are passed as a result to the parent // operation. - ::mlir::OperandRange getSuccessorOperands(::mlir::RegionSuccessor successor) { - return getMutableSuccessorOperands(successor); + ::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) { + return getMutableSuccessorOperands(point); } }]; } @@ -578,7 +504,7 @@ def ReturnLike : TraitList<[ /*extraOpDeclaration=*/"", /*extraOpDefinition=*/[{ ::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands( - ::mlir::RegionSuccessor successor) { + ::mlir::RegionBranchPoint point) { return ::mlir::MutableOperandRange(*this); } }] diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp index 24cb123e51877..a84d10d5d609d 100644 --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -16,21 +16,19 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Region.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/DebugLog.h" #include #include #include using namespace mlir; -#define DEBUG_TYPE "local-alias-analysis" - //===----------------------------------------------------------------------===// // Underlying Address Computation //===----------------------------------------------------------------------===// @@ -44,47 +42,81 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth, DenseSet &visited, SmallVectorImpl &output); -/// Given a RegionBranchOpInterface operation (`branch`), a Value`inputValue` -/// which is an input for the provided successor (`initialSuccessor`), try to -/// find the possible sources for the value along the control flow edges. -static void collectUnderlyingAddressValues2( - RegionBranchOpInterface branch, RegionSuccessor initialSuccessor, - Value inputValue, unsigned inputIndex, unsigned maxDepth, - DenseSet &visited, SmallVectorImpl &output) { - LDBG() << "collectUnderlyingAddressValues2: " - << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions()); - LDBG() << " with initialSuccessor " << initialSuccessor; - LDBG() << " inputValue: " << inputValue; - LDBG() << " inputIndex: " << inputIndex; - LDBG() << " maxDepth: " << maxDepth; - ValueRange inputs = initialSuccessor.getSuccessorInputs(); - if (inputs.empty()) { - LDBG() << " input is empty, enqueue value"; - output.push_back(inputValue); - return; - } - unsigned firstInputIndex, lastInputIndex; - if (isa(inputs[0])) { - firstInputIndex = cast(inputs[0]).getArgNumber(); - lastInputIndex = cast(inputs.back()).getArgNumber(); - } else { - firstInputIndex = cast(inputs[0]).getResultNumber(); - lastInputIndex = cast(inputs.back()).getResultNumber(); - } - if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) { - LDBG() << " !! Input index " << inputIndex << " out of range " - << firstInputIndex << " to " << lastInputIndex - << ", adding input value to output"; - output.push_back(inputValue); - return; +/// Given a successor (`region`) of a RegionBranchOpInterface, collect all of +/// the underlying values being addressed by one of the successor inputs. If the +/// provided `region` is null, as per `RegionBranchOpInterface` this represents +/// the parent operation. +static void collectUnderlyingAddressValues(RegionBranchOpInterface branch, + Region *region, Value inputValue, + unsigned inputIndex, + unsigned maxDepth, + DenseSet &visited, + SmallVectorImpl &output) { + // Given the index of a region of the branch (`predIndex`), or std::nullopt to + // represent the parent operation, try to return the index into the outputs of + // this region predecessor that correspond to the input values of `region`. If + // an index could not be found, std::nullopt is returned instead. + auto getOperandIndexIfPred = + [&](RegionBranchPoint pred) -> std::optional { + SmallVector successors; + branch.getSuccessorRegions(pred, successors); + for (RegionSuccessor &successor : successors) { + if (successor.getSuccessor() != region) + continue; + // Check that the successor inputs map to the given input value. + ValueRange inputs = successor.getSuccessorInputs(); + if (inputs.empty()) { + output.push_back(inputValue); + break; + } + unsigned firstInputIndex, lastInputIndex; + if (region) { + firstInputIndex = cast(inputs[0]).getArgNumber(); + lastInputIndex = cast(inputs.back()).getArgNumber(); + } else { + firstInputIndex = cast(inputs[0]).getResultNumber(); + lastInputIndex = cast(inputs.back()).getResultNumber(); + } + if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) { + output.push_back(inputValue); + break; + } + return inputIndex - firstInputIndex; + } + return std::nullopt; + }; + + // Check branches from the parent operation. + auto branchPoint = RegionBranchPoint::parent(); + if (region) + branchPoint = region; + + if (std::optional operandIndex = + getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) { + collectUnderlyingAddressValues( + branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth, + visited, output); } - SmallVector predecessorValues; - branch.getPredecessorValues(initialSuccessor, inputIndex - firstInputIndex, - predecessorValues); - LDBG() << " Found " << predecessorValues.size() << " predecessor values"; - for (Value predecessorValue : predecessorValues) { - LDBG() << " Processing predecessor value: " << predecessorValue; - collectUnderlyingAddressValues(predecessorValue, maxDepth, visited, output); + // Check branches from each child region. + Operation *op = branch.getOperation(); + for (Region ®ion : op->getRegions()) { + if (std::optional operandIndex = getOperandIndexIfPred(region)) { + for (Block &block : region) { + // Try to determine possible region-branch successor operands for the + // current region. + if (auto term = dyn_cast( + block.getTerminator())) { + collectUnderlyingAddressValues( + term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth, + visited, output); + } else if (block.getNumSuccessors()) { + // Otherwise, if this terminator may exit the region we can't make + // any assumptions about which values get passed. + output.push_back(inputValue); + return; + } + } + } } } @@ -92,28 +124,22 @@ static void collectUnderlyingAddressValues2( static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth, DenseSet &visited, SmallVectorImpl &output) { - LDBG() << "collectUnderlyingAddressValues (OpResult): " << result; - LDBG() << " maxDepth: " << maxDepth; - Operation *op = result.getOwner(); // If this is a view, unwrap to the source. if (ViewLikeOpInterface view = dyn_cast(op)) { if (result == view.getViewDest()) { - LDBG() << " Unwrapping view to source: " << view.getViewSource(); return collectUnderlyingAddressValues(view.getViewSource(), maxDepth, visited, output); } } // Check to see if we can reason about the control flow of this op. if (auto branch = dyn_cast(op)) { - LDBG() << " Processing region branch operation"; - return collectUnderlyingAddressValues2( - branch, RegionSuccessor(op, op->getResults()), result, - result.getResultNumber(), maxDepth, visited, output); + return collectUnderlyingAddressValues(branch, /*region=*/nullptr, result, + result.getResultNumber(), maxDepth, + visited, output); } - LDBG() << " Adding result to output: " << result; output.push_back(result); } @@ -122,23 +148,14 @@ static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth, static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, DenseSet &visited, SmallVectorImpl &output) { - LDBG() << "collectUnderlyingAddressValues (BlockArgument): " << arg; - LDBG() << " maxDepth: " << maxDepth; - LDBG() << " argNumber: " << arg.getArgNumber(); - LDBG() << " isEntryBlock: " << arg.getOwner()->isEntryBlock(); - Block *block = arg.getOwner(); unsigned argNumber = arg.getArgNumber(); // Handle the case of a non-entry block. if (!block->isEntryBlock()) { - LDBG() << " Processing non-entry block with " - << std::distance(block->pred_begin(), block->pred_end()) - << " predecessors"; for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { auto branch = dyn_cast((*it)->getTerminator()); if (!branch) { - LDBG() << " Cannot analyze control flow, adding argument to output"; // We can't analyze the control flow, so bail out early. output.push_back(arg); return; @@ -148,12 +165,10 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, unsigned index = it.getSuccessorIndex(); Value operand = branch.getSuccessorOperands(index)[argNumber]; if (!operand) { - LDBG() << " No operand found for argument, adding to output"; // We can't analyze the control flow, so bail out early. output.push_back(arg); return; } - LDBG() << " Processing operand from predecessor: " << operand; collectUnderlyingAddressValues(operand, maxDepth, visited, output); } return; @@ -163,35 +178,10 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, Region *region = block->getParent(); Operation *op = region->getParentOp(); if (auto branch = dyn_cast(op)) { - LDBG() << " Processing region branch operation for entry block"; - // We have to find the successor matching the region, so that the input - // arguments are correctly set. - // TODO: this isn't comprehensive: the successor may not be reachable from - // the entry block. - SmallVector successors; - branch.getSuccessorRegions(RegionBranchPoint::parent(), successors); - RegionSuccessor regionSuccessor(region); - bool found = false; - for (RegionSuccessor &successor : successors) { - if (successor.getSuccessor() == region) { - LDBG() << " Found matching region successor: " << successor; - found = true; - regionSuccessor = successor; - break; - } - } - if (!found) { - LDBG() - << " No matching region successor found, adding argument to output"; - output.push_back(arg); - return; - } - return collectUnderlyingAddressValues2( - branch, regionSuccessor, arg, argNumber, maxDepth, visited, output); + return collectUnderlyingAddressValues(branch, region, arg, argNumber, + maxDepth, visited, output); } - LDBG() - << " Cannot reason about underlying address, adding argument to output"; // We can't reason about the underlying address of this argument. output.push_back(arg); } @@ -200,26 +190,17 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, static void collectUnderlyingAddressValues(Value value, unsigned maxDepth, DenseSet &visited, SmallVectorImpl &output) { - LDBG() << "collectUnderlyingAddressValues: " << value; - LDBG() << " maxDepth: " << maxDepth; - // Check that we don't infinitely recurse. - if (!visited.insert(value).second) { - LDBG() << " Value already visited, skipping"; + if (!visited.insert(value).second) return; - } if (maxDepth == 0) { - LDBG() << " Max depth reached, adding value to output"; output.push_back(value); return; } --maxDepth; - if (BlockArgument arg = dyn_cast(value)) { - LDBG() << " Processing as BlockArgument"; + if (BlockArgument arg = dyn_cast(value)) return collectUnderlyingAddressValues(arg, maxDepth, visited, output); - } - LDBG() << " Processing as OpResult"; collectUnderlyingAddressValues(cast(value), maxDepth, visited, output); } @@ -227,11 +208,9 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth, /// Given a value, collect all of the underlying values being addressed. static void collectUnderlyingAddressValues(Value value, SmallVectorImpl &output) { - LDBG() << "collectUnderlyingAddressValues: " << value; DenseSet visited; collectUnderlyingAddressValues(value, maxUnderlyingValueSearchDepth, visited, output); - LDBG() << " Collected " << output.size() << " underlying values"; } //===----------------------------------------------------------------------===// @@ -248,33 +227,19 @@ static LogicalResult getAllocEffectFor(Value value, std::optional &effect, Operation *&allocScopeOp) { - LDBG() << "getAllocEffectFor: " << value; - // Try to get a memory effect interface for the parent operation. Operation *op; - if (BlockArgument arg = dyn_cast(value)) { + if (BlockArgument arg = dyn_cast(value)) op = arg.getOwner()->getParentOp(); - LDBG() << " BlockArgument, parent op: " - << OpWithFlags(op, OpPrintingFlags().skipRegions()); - } else { + else op = cast(value).getOwner(); - LDBG() << " OpResult, owner op: " - << OpWithFlags(op, OpPrintingFlags().skipRegions()); - } - MemoryEffectOpInterface interface = dyn_cast(op); - if (!interface) { - LDBG() << " No memory effect interface found"; + if (!interface) return failure(); - } // Try to find an allocation effect on the resource. - if (!(effect = interface.getEffectOnValue(value))) { - LDBG() << " No allocation effect found on value"; + if (!(effect = interface.getEffectOnValue(value))) return failure(); - } - - LDBG() << " Found allocation effect"; // If we found an allocation effect, try to find a scope for the allocation. // If the resource of this allocation is automatically scoped, find the parent @@ -282,12 +247,6 @@ getAllocEffectFor(Value value, if (llvm::isa( effect->getResource())) { allocScopeOp = op->getParentWithTrait(); - if (allocScopeOp) { - LDBG() << " Automatic allocation scope found: " - << OpWithFlags(allocScopeOp, OpPrintingFlags().skipRegions()); - } else { - LDBG() << " Automatic allocation scope found: null"; - } return success(); } @@ -296,12 +255,6 @@ getAllocEffectFor(Value value, // For now assume allocation scope to the function scope (we don't care if // pointer escape outside function). allocScopeOp = op->getParentOfType(); - if (allocScopeOp) { - LDBG() << " Function scope found: " - << OpWithFlags(allocScopeOp, OpPrintingFlags().skipRegions()); - } else { - LDBG() << " Function scope found: null"; - } return success(); } @@ -340,44 +293,33 @@ static std::optional checkDistinctObjects(Value lhs, Value rhs) { /// Given the two values, return their aliasing behavior. AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { - LDBG() << "aliasImpl: " << lhs << " vs " << rhs; - - if (lhs == rhs) { - LDBG() << " Same value, must alias"; + if (lhs == rhs) return AliasResult::MustAlias; - } - Operation *lhsAllocScope = nullptr, *rhsAllocScope = nullptr; std::optional lhsAlloc, rhsAlloc; // Handle the case where lhs is a constant. Attribute lhsAttr, rhsAttr; if (matchPattern(lhs, m_Constant(&lhsAttr))) { - LDBG() << " lhs is constant"; // TODO: This is overly conservative. Two matching constants don't // necessarily map to the same address. For example, if the two values // correspond to different symbols that both represent a definition. - if (matchPattern(rhs, m_Constant(&rhsAttr))) { - LDBG() << " rhs is also constant, may alias"; + if (matchPattern(rhs, m_Constant(&rhsAttr))) return AliasResult::MayAlias; - } // Try to find an alloc effect on rhs. If an effect was found we can't // alias, otherwise we might. - bool rhsHasAlloc = - succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope)); - LDBG() << " rhs has alloc effect: " << rhsHasAlloc; - return rhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias; + return succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope)) + ? AliasResult::NoAlias + : AliasResult::MayAlias; } // Handle the case where rhs is a constant. if (matchPattern(rhs, m_Constant(&rhsAttr))) { - LDBG() << " rhs is constant"; // Try to find an alloc effect on lhs. If an effect was found we can't // alias, otherwise we might. - bool lhsHasAlloc = - succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope)); - LDBG() << " lhs has alloc effect: " << lhsHasAlloc; - return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias; + return succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope)) + ? AliasResult::NoAlias + : AliasResult::MayAlias; } if (std::optional result = checkDistinctObjects(lhs, rhs)) @@ -387,14 +329,9 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { // an allocation effect. bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope)); bool rhsHasAlloc = succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope)); - LDBG() << " lhs has alloc effect: " << lhsHasAlloc; - LDBG() << " rhs has alloc effect: " << rhsHasAlloc; - if (lhsHasAlloc == rhsHasAlloc) { // If both values have an allocation effect we know they don't alias, and if // neither have an effect we can't make an assumptions. - LDBG() << " Both have same alloc status: " - << (lhsHasAlloc ? "NoAlias" : "MayAlias"); return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias; } @@ -402,7 +339,6 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { // and one without. Move the one with the effect to the lhs to make the next // checks simpler. if (rhsHasAlloc) { - LDBG() << " Swapping lhs and rhs to put alloc effect on lhs"; std::swap(lhs, rhs); lhsAlloc = rhsAlloc; lhsAllocScope = rhsAllocScope; @@ -411,74 +347,49 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { // If the effect has a scoped allocation region, check to see if the // non-effect value is defined above that scope. if (lhsAllocScope) { - LDBG() << " Checking allocation scope: " - << OpWithFlags(lhsAllocScope, OpPrintingFlags().skipRegions()); // If the parent operation of rhs is an ancestor of the allocation scope, or // if rhs is an entry block argument of the allocation scope we know the two // values can't alias. Operation *rhsParentOp = rhs.getParentRegion()->getParentOp(); - if (rhsParentOp->isProperAncestor(lhsAllocScope)) { - LDBG() << " rhs parent is ancestor of alloc scope, no alias"; + if (rhsParentOp->isProperAncestor(lhsAllocScope)) return AliasResult::NoAlias; - } if (rhsParentOp == lhsAllocScope) { BlockArgument rhsArg = dyn_cast(rhs); - if (rhsArg && rhs.getParentBlock()->isEntryBlock()) { - LDBG() << " rhs is entry block arg of alloc scope, no alias"; + if (rhsArg && rhs.getParentBlock()->isEntryBlock()) return AliasResult::NoAlias; - } } } // If we couldn't reason about the relationship between the two values, // conservatively assume they might alias. - LDBG() << " Cannot reason about relationship, may alias"; return AliasResult::MayAlias; } /// Given the two values, return their aliasing behavior. AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) { - LDBG() << "alias: " << lhs << " vs " << rhs; - - if (lhs == rhs) { - LDBG() << " Same value, must alias"; + if (lhs == rhs) return AliasResult::MustAlias; - } // Get the underlying values being addressed. SmallVector lhsValues, rhsValues; collectUnderlyingAddressValues(lhs, lhsValues); collectUnderlyingAddressValues(rhs, rhsValues); - LDBG() << " lhs underlying values: " << lhsValues.size(); - LDBG() << " rhs underlying values: " << rhsValues.size(); - // If we failed to collect for either of the values somehow, conservatively // assume they may alias. - if (lhsValues.empty() || rhsValues.empty()) { - LDBG() << " Failed to collect underlying values, may alias"; + if (lhsValues.empty() || rhsValues.empty()) return AliasResult::MayAlias; - } // Check the alias results against each of the underlying values. std::optional result; for (Value lhsVal : lhsValues) { for (Value rhsVal : rhsValues) { - LDBG() << " Checking underlying values: " << lhsVal << " vs " << rhsVal; AliasResult nextResult = aliasImpl(lhsVal, rhsVal); - LDBG() << " Result: " - << (nextResult == AliasResult::MustAlias ? "MustAlias" - : nextResult == AliasResult::NoAlias ? "NoAlias" - : "MayAlias"); result = result ? result->merge(nextResult) : nextResult; } } // We should always have a valid result here. - LDBG() << " Final result: " - << (result->isMust() ? "MustAlias" - : result->isNo() ? "NoAlias" - : "MayAlias"); return *result; } @@ -487,12 +398,8 @@ AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) { //===----------------------------------------------------------------------===// ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) { - LDBG() << "getModRef: " << OpWithFlags(op, OpPrintingFlags().skipRegions()) - << " on location " << location; - // Check to see if this operation relies on nested side effects. if (op->hasTrait()) { - LDBG() << " Operation has recursive memory effects, returning ModAndRef"; // TODO: To check recursive operations we need to check all of the nested // operations, which can result in a quadratic number of queries. We should // introduce some caching of some kind to help alleviate this, especially as @@ -503,64 +410,38 @@ ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) { // Otherwise, check to see if this operation has a memory effect interface. MemoryEffectOpInterface interface = dyn_cast(op); - if (!interface) { - LDBG() << " No memory effect interface, returning ModAndRef"; + if (!interface) return ModRefResult::getModAndRef(); - } // Build a ModRefResult by merging the behavior of the effects of this // operation. SmallVector effects; interface.getEffects(effects); - LDBG() << " Found " << effects.size() << " memory effects"; ModRefResult result = ModRefResult::getNoModRef(); for (const MemoryEffects::EffectInstance &effect : effects) { - if (isa(effect.getEffect())) { - LDBG() << " Skipping alloc/free effect"; + if (isa(effect.getEffect())) continue; - } // Check for an alias between the effect and our memory location. // TODO: Add support for checking an alias with a symbol reference. AliasResult aliasResult = AliasResult::MayAlias; - if (Value effectValue = effect.getValue()) { - LDBG() << " Checking alias between effect value " << effectValue - << " and location " << location; + if (Value effectValue = effect.getValue()) aliasResult = alias(effectValue, location); - LDBG() << " Alias result: " - << (aliasResult.isMust() ? "MustAlias" - : aliasResult.isNo() ? "NoAlias" - : "MayAlias"); - } else { - LDBG() << " No effect value, assuming MayAlias"; - } // If we don't alias, ignore this effect. - if (aliasResult.isNo()) { - LDBG() << " No alias, ignoring effect"; + if (aliasResult.isNo()) continue; - } // Merge in the corresponding mod or ref for this effect. if (isa(effect.getEffect())) { - LDBG() << " Adding Ref to result"; result = result.merge(ModRefResult::getRef()); } else { assert(isa(effect.getEffect())); - LDBG() << " Adding Mod to result"; result = result.merge(ModRefResult::getMod()); } - if (result.isModAndRef()) { - LDBG() << " Result is now ModAndRef, breaking"; + if (result.isModAndRef()) break; - } } - - LDBG() << " Final ModRef result: " - << (result.isModAndRef() ? "ModAndRef" - : result.isMod() ? "Mod" - : result.isRef() ? "Ref" - : "NoModRef"); return result; } diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index 0fc5b4482bf3e..377f7ebe06750 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -501,10 +501,11 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op, return; SmallVector successors; - auto terminator = dyn_cast(op); - if (!terminator) - return; - terminator.getSuccessorRegions(*operands, successors); + if (auto terminator = dyn_cast(op)) + terminator.getSuccessorRegions(*operands, successors); + else + branch.getSuccessorRegions(op->getParentRegion(), successors); + visitRegionBranchEdges(branch, op, successors); } diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index 0682e5f26785a..daa3db55b2852 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -588,9 +588,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { // flow, propagate the lattice back along the control flow edge. if (auto branch = dyn_cast(block->getParentOp())) { LDBG() << " Exit block of region branch operation"; - auto terminator = - cast(block->getTerminator()); - visitRegionBranchOperation(point, branch, terminator, before); + visitRegionBranchOperation(point, branch, block->getParent(), before); return; } diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 8e63ae86753b4..0d2e2ed85549d 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -130,7 +130,7 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { // The results of a region branch operation are determined by control-flow. if (auto branch = dyn_cast(op)) { visitRegionSuccessors(getProgramPointAfter(branch), branch, - /*successor=*/{branch, branch->getResults()}, + /*successor=*/RegionBranchPoint::parent(), resultLattices); return success(); } @@ -279,7 +279,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation( void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( ProgramPoint *point, RegionBranchOpInterface branch, - RegionSuccessor successor, ArrayRef lattices) { + RegionBranchPoint successor, ArrayRef lattices) { const auto *predecessors = getOrCreateFor(point, point); assert(predecessors->allPredecessorsKnown() && "unexpected unresolved region successors"); @@ -314,7 +314,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( visitNonControlFlowArgumentsImpl( branch, RegionSuccessor( - branch, branch->getResults().slice(firstIndex, inputs.size())), + branch->getResults().slice(firstIndex, inputs.size())), lattices, firstIndex); } else { if (!inputs.empty()) diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp index 863f260cd4b6a..817d71a3452ca 100644 --- a/mlir/lib/Analysis/SliceWalk.cpp +++ b/mlir/lib/Analysis/SliceWalk.cpp @@ -114,7 +114,7 @@ mlir::getControlFlowPredecessors(Value value) { if (!regionOp) return std::nullopt; // Add the control flow predecessor operands to the work list. - RegionSuccessor region(regionOp, regionOp->getResults()); + RegionSuccessor region(regionOp->getResults()); SmallVector predecessorOperands = getRegionPredecessorOperands( regionOp, region, opResult.getResultNumber()); return predecessorOperands; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 0c3592124cdec..e0a53cd52f143 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2716,9 +2716,8 @@ LogicalResult AffineForOp::fold(FoldAdaptor adaptor, return success(folded); } -OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) { - assert((successor.isParent() || successor.getSuccessor() == &getRegion()) && - "invalid region point"); +OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert((point.isParent() || point == getRegion()) && "invalid region point"); // The initial operands map to the loop arguments after the induction // variable or are forwarded to the results when the trip count is zero. @@ -2727,41 +2726,34 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) { void AffineForOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { - assert((point.isParent() || - point.getTerminatorPredecessorOrNull()->getParentRegion() == - &getRegion()) && - "expected loop region"); + assert((point.isParent() || point == getRegion()) && "expected loop region"); // The loop may typically branch back to its body or to the parent operation. // If the predecessor is the parent op and the trip count is known to be at // least one, branch into the body using the iterator arguments. And in cases // we know the trip count is zero, it can only branch back to its parent. std::optional tripCount = getTrivialConstantTripCount(*this); - if (tripCount.has_value()) { - if (!point.isParent()) { - // From the loop body, if the trip count is one, we can only branch back - // to the parent. - if (tripCount == 1) { - regions.push_back(RegionSuccessor(getOperation(), getResults())); - return; - } - if (tripCount == 0) - return; - } else { - if (tripCount.value() > 0) { - regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - return; - } - if (tripCount.value() == 0) { - regions.push_back(RegionSuccessor(getOperation(), getResults())); - return; - } + if (point.isParent() && tripCount.has_value()) { + if (tripCount.value() > 0) { + regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); + return; + } + if (tripCount.value() == 0) { + regions.push_back(RegionSuccessor(getResults())); + return; } } + // From the loop body, if the trip count is one, we can only branch back to + // the parent. + if (!point.isParent() && tripCount == 1) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + // In all other cases, the loop may branch back to itself or the parent // operation. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor(getOperation(), getResults())); + regions.push_back(RegionSuccessor(getResults())); } AffineBound AffineForOp::getLowerBound() { @@ -3150,7 +3142,7 @@ void AffineIfOp::getSuccessorRegions( RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); // If the "else" region is empty, branch bach into parent. if (getElseRegion().empty()) { - regions.push_back(RegionSuccessor(getOperation(), getResults())); + regions.push_back(getResults()); } else { regions.push_back( RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); @@ -3160,7 +3152,7 @@ void AffineIfOp::getSuccessorRegions( // If the predecessor is the `else`/`then` region, then branching into parent // op is valid. - regions.push_back(RegionSuccessor(getOperation(), getResults())); + regions.push_back(RegionSuccessor(getResults())); } LogicalResult AffineIfOp::verify() { diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index 8e4a49df76b52..dc7b07d911c17 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -36,9 +36,8 @@ void AsyncDialect::initialize() { constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes"; -OperandRange ExecuteOp::getEntrySuccessorOperands(RegionSuccessor successor) { - assert(successor.getSuccessor() == &getBodyRegion() && - "invalid region index"); +OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(point == getBodyRegion() && "invalid region index"); return getBodyOperands(); } @@ -54,10 +53,8 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) { void ExecuteOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `body` region branch back to the parent operation. - if (!point.isParent() && - point.getTerminatorPredecessorOrNull()->getParentRegion() == - &getBodyRegion()) { - regions.push_back(RegionSuccessor(getOperation(), getBodyResults())); + if (point == getBodyRegion()) { + regions.push_back(RegionSuccessor(getBodyResults())); return; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index 36a759c279eb7..b593ccab060c7 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -562,11 +562,8 @@ LogicalResult BufferDeallocation::updateFunctionSignature(FunctionOpInterface op) { SmallVector returnOperandTypes(llvm::map_range( op.getFunctionBody().getOps(), - [&](RegionBranchTerminatorOpInterface branchOp) { - return branchOp - .getSuccessorOperands(RegionSuccessor( - op.getOperation(), op.getOperation()->getResults())) - .getTypes(); + [](RegionBranchTerminatorOpInterface op) { + return op.getSuccessorOperands(RegionBranchPoint::parent()).getTypes(); })); if (!llvm::all_equal(returnOperandTypes)) return op->emitError( @@ -945,8 +942,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) { // about, but we would need to check how many successors there are and under // which condition they are taken, etc. - MutableOperandRange operands = op.getMutableSuccessorOperands( - RegionSuccessor(op.getOperation(), op.getOperation()->getResults())); + MutableOperandRange operands = + op.getMutableSuccessorOperands(RegionBranchPoint::parent()); SmallVector updatedOwnerships; auto result = deallocation_impl::insertDeallocOpForReturnLike( diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 0992ce14b4afb..4754f0bfe895e 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -845,8 +845,7 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { - regions.push_back( - RegionSuccessor(getOperation(), getOperation()->getResults())); + regions.push_back(RegionSuccessor()); return; } @@ -855,8 +854,7 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back( - RegionSuccessor(getOperation(), getOperation()->getResults())); + regions.push_back(RegionSuccessor()); else regions.push_back(RegionSuccessor(elseRegion)); } @@ -873,7 +871,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(getOperation(), getOperation()->getResults()); + regions.emplace_back(); } } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 6c6d8d2bad55d..b5f8ddaadacdf 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2399,7 +2399,7 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser, void WarpExecuteOnLane0Op::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor(getOperation(), getResults())); + regions.push_back(RegionSuccessor(getResults())); return; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1c21a2f270da6..c551fba93e367 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -405,7 +405,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { void AllocaScopeOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor(getOperation(), getResults())); + regions.push_back(RegionSuccessor(getResults())); return; } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 2946b53c8cb36..1ab01d86bcd10 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -397,7 +397,7 @@ void ExecuteRegionOp::getSuccessorRegions( } // Otherwise, the region branches back to the parent operation. - regions.push_back(RegionSuccessor(getOperation(), getResults())); + regions.push_back(RegionSuccessor(getResults())); } //===----------------------------------------------------------------------===// @@ -405,11 +405,10 @@ void ExecuteRegionOp::getSuccessorRegions( //===----------------------------------------------------------------------===// MutableOperandRange -ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) { - assert( - (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) && - "condition op can only exit the loop or branch to the after" - "region"); +ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { + assert((point.isParent() || point == getParentOp().getAfter()) && + "condition op can only exit the loop or branch to the after" + "region"); // Pass all operands except the condition to the successor region. return getArgsMutable(); } @@ -427,7 +426,7 @@ void ConditionOp::getSuccessorRegions( regions.emplace_back(&whileOp.getAfter(), whileOp.getAfter().getArguments()); if (!boolAttr || !boolAttr.getValue()) - regions.emplace_back(whileOp.getOperation(), whileOp.getResults()); + regions.emplace_back(whileOp.getResults()); } //===----------------------------------------------------------------------===// @@ -750,7 +749,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) { return dyn_cast_or_null(containingOp); } -OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) { +OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { return getInitArgs(); } @@ -760,7 +759,7 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor(getOperation(), getResults())); + regions.push_back(RegionSuccessor(getResults())); } SmallVector ForallOp::getLoopRegions() { return {&getRegion()}; } @@ -2054,10 +2053,9 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point, // parallel by multiple threads. We should not expect to branch back into // the forall body after the region's execution is complete. if (point.isParent()) - regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); + regions.push_back(RegionSuccessor(&getRegion())); else - regions.push_back( - RegionSuccessor(getOperation(), getOperation()->getResults())); + regions.push_back(RegionSuccessor()); } //===----------------------------------------------------------------------===// @@ -2335,10 +2333,9 @@ void IfOp::print(OpAsmPrinter &p) { void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { - // The `then` and the `else` region branch back to the parent operation or one - // of the recursive parent operations (early exit case). + // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { - regions.push_back(RegionSuccessor(getOperation(), getResults())); + regions.push_back(RegionSuccessor(getResults())); return; } @@ -2347,8 +2344,7 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back( - RegionSuccessor(getOperation(), getOperation()->getResults())); + regions.push_back(RegionSuccessor()); else regions.push_back(RegionSuccessor(elseRegion)); } @@ -2365,7 +2361,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(getOperation(), getResults()); + regions.emplace_back(getResults()); } } @@ -3389,8 +3385,7 @@ void ParallelOp::getSuccessorRegions( // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion())); - regions.push_back(RegionSuccessor( - getOperation(), ResultRange{getResults().end(), getResults().end()})); + regions.push_back(RegionSuccessor()); } //===----------------------------------------------------------------------===// @@ -3436,7 +3431,7 @@ LogicalResult ReduceOp::verifyRegions() { } MutableOperandRange -ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) { +ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { // No operands are forwarded to the next iteration. return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0); } @@ -3519,8 +3514,8 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() { return getBeforeArguments(); } -OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) { - assert(successor.getSuccessor() == &getBefore() && +OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(point == getBefore() && "WhileOp is expected to branch only to the first region"); return getInits(); } @@ -3533,18 +3528,15 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point, return; } - assert(llvm::is_contained( - {&getAfter(), &getBefore()}, - point.getTerminatorPredecessorOrNull()->getParentRegion()) && + assert(llvm::is_contained({&getAfter(), &getBefore()}, point) && "there are only two regions in a WhileOp"); // The body region always branches back to the condition region. - if (point.getTerminatorPredecessorOrNull()->getParentRegion() == - &getAfter()) { + if (point == getAfter()) { regions.emplace_back(&getBefore(), getBefore().getArguments()); return; } - regions.emplace_back(getOperation(), getResults()); + regions.emplace_back(getResults()); regions.emplace_back(&getAfter(), getAfter().getArguments()); } @@ -4453,7 +4445,7 @@ void IndexSwitchOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl &successors) { // All regions branch back to the parent op. if (!point.isParent()) { - successors.emplace_back(getOperation(), getResults()); + successors.emplace_back(getResults()); return; } diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index ddcbda86cf1f3..ae52af5009dc9 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -23,6 +23,7 @@ namespace mlir { #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir +using namespace llvm; using namespace mlir; using scf::ForOp; using scf::WhileOp; diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp index 00bef707fadd3..a2f03f1e1056e 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp @@ -21,6 +21,7 @@ namespace mlir { #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir +using namespace llvm; using namespace mlir; using scf::LoopNest; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index f0f22e5ef4a83..5ba828918c22a 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -346,7 +346,7 @@ void AssumingOp::getSuccessorRegions( // parent, so return the correct RegionSuccessor purely based on the index // being None or 0. if (!point.isParent()) { - regions.push_back(RegionSuccessor(getOperation(), getResults())); + regions.push_back(RegionSuccessor(getResults())); return; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 3962e3e84dd31..1a9d9e158ee75 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2597,7 +2597,7 @@ std::optional> IterateOp::getYieldedValuesMutable() { std::optional IterateOp::getLoopResults() { return getResults(); } -OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) { +OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { return getInitArgs(); } @@ -2607,7 +2607,7 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point, // or back into the operation itself. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); // It is possible for loop not to enter the body. - regions.push_back(RegionSuccessor(getOperation(), getResults())); + regions.push_back(RegionSuccessor(getResults())); } void CoIterateOp::build(OpBuilder &builder, OperationState &odsState, diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 062606e7e10b6..365afab3764c8 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -96,9 +96,9 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, // AlternativesOp //===----------------------------------------------------------------------===// -OperandRange transform::AlternativesOp::getEntrySuccessorOperands( - RegionSuccessor successor) { - if (!successor.isParent() && getOperation()->getNumOperands() == 1) +OperandRange +transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { + if (!point.isParent() && getOperation()->getNumOperands() == 1) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), getOperation()->operand_end()); @@ -107,18 +107,15 @@ OperandRange transform::AlternativesOp::getEntrySuccessorOperands( void transform::AlternativesOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { for (Region &alternative : llvm::drop_begin( - getAlternatives(), point.isParent() - ? 0 - : point.getTerminatorPredecessorOrNull() - ->getParentRegion() - ->getRegionNumber() + - 1)) { + getAlternatives(), + point.isParent() ? 0 + : point.getRegionOrNull()->getRegionNumber() + 1)) { regions.emplace_back(&alternative, !getOperands().empty() ? alternative.getArguments() : Block::BlockArgListType()); } if (!point.isParent()) - regions.emplace_back(getOperation(), getOperation()->getResults()); + regions.emplace_back(getOperation()->getResults()); } void transform::AlternativesOp::getRegionInvocationBounds( @@ -1743,18 +1740,16 @@ void transform::ForeachOp::getSuccessorRegions( } // Branch back to the region or the parent. - assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == - &getBody() && - "unexpected region index"); + assert(point == getBody() && "unexpected region index"); regions.emplace_back(bodyRegion, bodyRegion->getArguments()); - regions.emplace_back(getOperation(), getOperation()->getResults()); + regions.emplace_back(); } OperandRange -transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) { +transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { // Each block argument handle is mapped to a subset (one op to be precise) // of the payload of the corresponding `targets` operand of ForeachOp. - assert(successor.getSuccessor() == &getBody() && "unexpected region index"); + assert(point == getBody() && "unexpected region index"); return getOperation()->getOperands(); } @@ -2953,8 +2948,8 @@ void transform::SequenceOp::getEffects( } OperandRange -transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) { - assert(successor.getSuccessor() == &getBody() && "unexpected region index"); +transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(point == getBody() && "unexpected region index"); if (getOperation()->getNumOperands() > 0) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), @@ -2971,10 +2966,8 @@ void transform::SequenceOp::getSuccessorRegions( return; } - assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == - &getBody() && - "unexpected region index"); - regions.emplace_back(getOperation(), getOperation()->getResults()); + assert(point == getBody() && "unexpected region index"); + regions.emplace_back(getOperation()->getResults()); } void transform::SequenceOp::getRegionInvocationBounds( diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp index f727118f3f9a0..c627158e999ed 100644 --- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -9,7 +9,6 @@ #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" @@ -113,7 +112,7 @@ static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, } OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands( - RegionSuccessor successor) { + RegionBranchPoint point) { // No operands will be forwarded to the region(s). return getOperands().slice(0, 0); } @@ -129,7 +128,7 @@ void transform::tune::AlternativesOp::getSuccessorRegions( for (Region &alternative : getAlternatives()) regions.emplace_back(&alternative, Block::BlockArgListType()); else - regions.emplace_back(getOperation(), getOperation()->getResults()); + regions.emplace_back(getOperation()->getResults()); } void transform::tune::AlternativesOp::getRegionInvocationBounds( diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index f4c9242ed3479..776b5c6588c71 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -138,10 +138,6 @@ Diagnostic &Diagnostic::operator<<(Operation &op) { return appendOp(op, OpPrintingFlags()); } -Diagnostic &Diagnostic::operator<<(OpWithFlags op) { - return appendOp(*op.getOperation(), op.flags()); -} - Diagnostic &Diagnostic::appendOp(Operation &op, const OpPrintingFlags &flags) { std::string str; llvm::raw_string_ostream os(str); diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index 15a941f380225..46b6298076d48 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -253,21 +253,6 @@ void Region::OpIterator::skipOverBlocksWithNoOps() { operation = block->begin(); } -llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os, Region ®ion) { - if (!region.getParentOp()) { - os << "Region has no parent op"; - } else { - os << "Region #" << region.getRegionNumber() << " in operation " - << region.getParentOp()->getName(); - } - for (auto it : llvm::enumerate(region.getBlocks())) { - os << "\n Block #" << it.index() << ":"; - for (Operation &op : it.value().getOperations()) - os << "\n " << OpWithFlags(&op, OpPrintingFlags().skipRegions()); - } - return os; -} - //===----------------------------------------------------------------------===// // RegionRange //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index 1e56810ff7aaf..ca3f7666dba8a 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -9,9 +9,7 @@ #include #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Operation.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "llvm/Support/DebugLog.h" using namespace mlir; @@ -40,31 +38,20 @@ SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, std::optional detail::getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor) { - LDBG() << "Getting branch successor argument for operand index " - << operandIndex << " in successor block"; - OperandRange forwardedOperands = operands.getForwardedOperands(); // Check that the operands are valid. - if (forwardedOperands.empty()) { - LDBG() << "No forwarded operands, returning nullopt"; + if (forwardedOperands.empty()) return std::nullopt; - } // Check to ensure that this operand is within the range. unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); if (operandIndex < operandsStart || - operandIndex >= (operandsStart + forwardedOperands.size())) { - LDBG() << "Operand index " << operandIndex << " out of range [" - << operandsStart << ", " - << (operandsStart + forwardedOperands.size()) - << "), returning nullopt"; + operandIndex >= (operandsStart + forwardedOperands.size())) return std::nullopt; - } // Index the successor. unsigned argIndex = operands.getProducedOperandCount() + operandIndex - operandsStart; - LDBG() << "Computed argument index " << argIndex << " for successor block"; return successor->getArgument(argIndex); } @@ -72,15 +59,9 @@ detail::getBranchSuccessorArgument(const SuccessorOperands &operands, LogicalResult detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands) { - LDBG() << "Verifying branch successor operands for successor #" << succNo - << " in operation " << op->getName(); - // Check the count. unsigned operandCount = operands.size(); Block *destBB = op->getSuccessor(succNo); - LDBG() << "Branch has " << operandCount << " operands, target block has " - << destBB->getNumArguments() << " arguments"; - if (operandCount != destBB->getNumArguments()) return op->emitError() << "branch has " << operandCount << " operands for successor #" << succNo @@ -88,22 +69,13 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, << destBB->getNumArguments(); // Check the types. - LDBG() << "Checking type compatibility for " - << (operandCount - operands.getProducedOperandCount()) - << " forwarded operands"; for (unsigned i = operands.getProducedOperandCount(); i != operandCount; ++i) { - Type operandType = operands[i].getType(); - Type argType = destBB->getArgument(i).getType(); - LDBG() << "Checking type compatibility: operand type " << operandType - << " vs argument type " << argType; - - if (!cast(op).areTypesCompatible(operandType, argType)) + if (!cast(op).areTypesCompatible( + operands[i].getType(), destBB->getArgument(i).getType())) return op->emitError() << "type mismatch for bb argument #" << i << " of successor #" << succNo; } - - LDBG() << "Branch successor operand verification successful"; return success(); } @@ -154,15 +126,15 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) { static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag, RegionBranchPoint sourceNo, - RegionSuccessor succRegionNo) { + RegionBranchPoint succRegionNo) { diag << "from "; - if (Operation *op = sourceNo.getTerminatorPredecessorOrNull()) - diag << "Operation " << op->getName(); + if (Region *region = sourceNo.getRegionOrNull()) + diag << "Region #" << region->getRegionNumber(); else diag << "parent operands"; diag << " to "; - if (Region *region = succRegionNo.getSuccessor()) + if (Region *region = succRegionNo.getRegionOrNull()) diag << "Region #" << region->getRegionNumber(); else diag << "parent results"; @@ -173,12 +145,13 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag, /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the /// types of the inputs that flow to a successor region. static LogicalResult -verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp, - RegionBranchPoint sourcePoint, - function_ref(RegionSuccessor)> +verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, + function_ref(RegionBranchPoint)> getInputsTypesForRegion) { + auto regionInterface = cast(op); + SmallVector successors; - branchOp.getSuccessorRegions(sourcePoint, successors); + regionInterface.getSuccessorRegions(sourcePoint, successors); for (RegionSuccessor &succ : successors) { FailureOr sourceTypes = getInputsTypesForRegion(succ); @@ -187,14 +160,10 @@ verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp, TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); if (sourceTypes->size() != succInputsTypes.size()) { - InFlightDiagnostic diag = - branchOp->emitOpError("region control flow edge "); - std::string succStr; - llvm::raw_string_ostream os(succStr); - os << succ; + InFlightDiagnostic diag = op->emitOpError("region control flow edge "); return printRegionEdgeName(diag, sourcePoint, succ) << ": source has " << sourceTypes->size() - << " operands, but target successor " << os.str() << " needs " + << " operands, but target successor needs " << succInputsTypes.size(); } @@ -202,10 +171,8 @@ verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp, llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { Type sourceType = std::get<0>(typesIdx.value()); Type inputType = std::get<1>(typesIdx.value()); - - if (!branchOp.areTypesCompatible(sourceType, inputType)) { - InFlightDiagnostic diag = - branchOp->emitOpError("along control flow edge "); + if (!regionInterface.areTypesCompatible(sourceType, inputType)) { + InFlightDiagnostic diag = op->emitOpError("along control flow edge "); return printRegionEdgeName(diag, sourcePoint, succ) << ": source type #" << typesIdx.index() << " " << sourceType << " should match input type #" << typesIdx.index() << " " @@ -213,7 +180,6 @@ verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp, } } } - return success(); } @@ -221,18 +187,34 @@ verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp, LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { auto regionInterface = cast(op); - auto inputTypesFromParent = [&](RegionSuccessor successor) -> TypeRange { - return regionInterface.getEntrySuccessorOperands(successor).getTypes(); + auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange { + return regionInterface.getEntrySuccessorOperands(point).getTypes(); }; // Verify types along control flow edges originating from the parent. - if (failed(verifyTypesAlongAllEdges( - regionInterface, RegionBranchPoint::parent(), inputTypesFromParent))) + if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(), + inputTypesFromParent))) return failure(); + auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { + if (lhs.size() != rhs.size()) + return false; + for (auto types : llvm::zip(lhs, rhs)) { + if (!regionInterface.areTypesCompatible(std::get<0>(types), + std::get<1>(types))) { + return false; + } + } + return true; + }; + // Verify types along control flow edges originating from each region. for (Region ®ion : op->getRegions()) { - // Collect all return-like terminators in the region. + + // Since there can be multiple terminators implementing the + // `RegionBranchTerminatorOpInterface`, all should have the same operand + // types when passing them to the same region. + SmallVector regionReturnOps; for (Block &block : region) if (!block.empty()) @@ -245,20 +227,33 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { if (regionReturnOps.empty()) continue; - // Verify types along control flow edges originating from each return-like - // terminator. - for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { - - auto inputTypesForRegion = - [&](RegionSuccessor successor) -> FailureOr { - OperandRange terminatorOperands = - regionReturnOp.getSuccessorOperands(successor); - return TypeRange(terminatorOperands.getTypes()); - }; - if (failed(verifyTypesAlongAllEdges(regionInterface, regionReturnOp, - inputTypesForRegion))) - return failure(); - } + auto inputTypesForRegion = + [&](RegionBranchPoint point) -> FailureOr { + std::optional regionReturnOperands; + for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { + auto terminatorOperands = regionReturnOp.getSuccessorOperands(point); + + if (!regionReturnOperands) { + regionReturnOperands = terminatorOperands; + continue; + } + + // Found more than one ReturnLike terminator. Make sure the operand + // types match with the first one. + if (!areTypesCompatible(regionReturnOperands->getTypes(), + terminatorOperands.getTypes())) { + InFlightDiagnostic diag = op->emitOpError("along control flow edge"); + return printRegionEdgeName(diag, region, point) + << " operands mismatch between return-like terminators"; + } + } + + // All successors get the same set of operand types. + return TypeRange(regionReturnOperands->getTypes()); + }; + + if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion))) + return failure(); } return success(); @@ -277,74 +272,31 @@ using StopConditionFn = function_ref visited)>; static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn) { auto op = cast(begin->getParentOp()); - LDBG() << "Starting region graph traversal from region #" - << begin->getRegionNumber() << " in operation " << op->getName(); - SmallVector visited(op->getNumRegions(), false); visited[begin->getRegionNumber()] = true; - LDBG() << "Initialized visited array with " << op->getNumRegions() - << " regions"; // Retrieve all successors of the region and enqueue them in the worklist. SmallVector worklist; auto enqueueAllSuccessors = [&](Region *region) { - LDBG() << "Enqueuing successors for region #" << region->getRegionNumber(); - SmallVector operandAttributes(op->getNumOperands()); - for (Block &block : *region) { - if (block.empty()) - continue; - auto terminator = - dyn_cast(block.back()); - if (!terminator) - continue; - SmallVector successors; - operandAttributes.resize(terminator->getNumOperands()); - terminator.getSuccessorRegions(operandAttributes, successors); - LDBG() << "Found " << successors.size() - << " successors from terminator in block"; - for (RegionSuccessor successor : successors) { - if (!successor.isParent()) { - worklist.push_back(successor.getSuccessor()); - LDBG() << "Added region #" - << successor.getSuccessor()->getRegionNumber() - << " to worklist"; - } else { - LDBG() << "Skipping parent successor"; - } - } - } + SmallVector successors; + op.getSuccessorRegions(region, successors); + for (RegionSuccessor successor : successors) + if (!successor.isParent()) + worklist.push_back(successor.getSuccessor()); }; enqueueAllSuccessors(begin); - LDBG() << "Initial worklist size: " << worklist.size(); // Process all regions in the worklist via DFS. while (!worklist.empty()) { Region *nextRegion = worklist.pop_back_val(); - LDBG() << "Processing region #" << nextRegion->getRegionNumber() - << " from worklist (remaining: " << worklist.size() << ")"; - - if (stopConditionFn(nextRegion, visited)) { - LDBG() << "Stop condition met for region #" - << nextRegion->getRegionNumber() << ", returning true"; + if (stopConditionFn(nextRegion, visited)) return true; - } - llvm::dbgs() << "Region: " << nextRegion << "\n"; - if (!nextRegion->getParentOp()) { - llvm::errs() << "Region " << *nextRegion << " has no parent op\n"; - return false; - } - if (visited[nextRegion->getRegionNumber()]) { - LDBG() << "Region #" << nextRegion->getRegionNumber() - << " already visited, skipping"; + if (visited[nextRegion->getRegionNumber()]) continue; - } visited[nextRegion->getRegionNumber()] = true; - LDBG() << "Marking region #" << nextRegion->getRegionNumber() - << " as visited"; enqueueAllSuccessors(nextRegion); } - LDBG() << "Traversal completed, returning false"; return false; } @@ -370,26 +322,18 @@ static bool isRegionReachable(Region *begin, Region *r) { /// mutually exclusive if they are not reachable from each other as per /// RegionBranchOpInterface::getSuccessorRegions. bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { - LDBG() << "Checking if operations are in mutually exclusive regions: " - << a->getName() << " and " << b->getName(); - assert(a && "expected non-empty operation"); assert(b && "expected non-empty operation"); auto branchOp = a->getParentOfType(); while (branchOp) { - LDBG() << "Checking branch operation " << branchOp->getName(); - // Check if b is inside branchOp. (We already know that a is.) if (!branchOp->isProperAncestor(b)) { - LDBG() << "Operation b is not inside branchOp, checking next ancestor"; // Check next enclosing RegionBranchOpInterface. branchOp = branchOp->getParentOfType(); continue; } - LDBG() << "Both operations are inside branchOp, finding their regions"; - // b is contained in branchOp. Retrieve the regions in which `a` and `b` // are contained. Region *regionA = nullptr, *regionB = nullptr; @@ -397,136 +341,63 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { if (r.findAncestorOpInRegion(*a)) { assert(!regionA && "already found a region for a"); regionA = &r; - LDBG() << "Found region #" << r.getRegionNumber() << " for operation a"; } if (r.findAncestorOpInRegion(*b)) { assert(!regionB && "already found a region for b"); regionB = &r; - LDBG() << "Found region #" << r.getRegionNumber() << " for operation b"; } } assert(regionA && regionB && "could not find region of op"); - LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #" - << regionB->getRegionNumber(); - // `a` and `b` are in mutually exclusive regions if both regions are // distinct and neither region is reachable from the other region. - bool regionsAreDistinct = (regionA != regionB); - bool aNotReachableFromB = !isRegionReachable(regionA, regionB); - bool bNotReachableFromA = !isRegionReachable(regionB, regionA); - - LDBG() << "Regions distinct: " << regionsAreDistinct - << ", A not reachable from B: " << aNotReachableFromB - << ", B not reachable from A: " << bNotReachableFromA; - - bool mutuallyExclusive = - regionsAreDistinct && aNotReachableFromB && bNotReachableFromA; - LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive; - - return mutuallyExclusive; + return regionA != regionB && !isRegionReachable(regionA, regionB) && + !isRegionReachable(regionB, regionA); } // Could not find a common RegionBranchOpInterface among a's and b's // ancestors. - LDBG() << "No common RegionBranchOpInterface found, operations are not " - "mutually exclusive"; return false; } bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { - LDBG() << "Checking if region #" << index << " is repetitive in operation " - << getOperation()->getName(); - Region *region = &getOperation()->getRegion(index); - bool isRepetitive = isRegionReachable(region, region); - - LDBG() << "Region #" << index << " is repetitive: " << isRepetitive; - return isRepetitive; + return isRegionReachable(region, region); } bool RegionBranchOpInterface::hasLoop() { - LDBG() << "Checking if operation " << getOperation()->getName() - << " has loops"; - SmallVector entryRegions; getSuccessorRegions(RegionBranchPoint::parent(), entryRegions); - LDBG() << "Found " << entryRegions.size() << " entry regions"; - - for (RegionSuccessor successor : entryRegions) { - if (!successor.isParent()) { - LDBG() << "Checking entry region #" - << successor.getSuccessor()->getRegionNumber() << " for loops"; - - bool hasLoop = - traverseRegionGraph(successor.getSuccessor(), - [](Region *nextRegion, ArrayRef visited) { - // Interrupt traversal if the region was already - // visited. - return visited[nextRegion->getRegionNumber()]; - }); - - if (hasLoop) { - LDBG() << "Found loop in entry region #" - << successor.getSuccessor()->getRegionNumber(); - return true; - } - } else { - LDBG() << "Skipping parent successor"; - } - } - - LDBG() << "No loops found in operation"; + for (RegionSuccessor successor : entryRegions) + if (!successor.isParent() && + traverseRegionGraph(successor.getSuccessor(), + [](Region *nextRegion, ArrayRef visited) { + // Interrupt traversal if the region was already + // visited. + return visited[nextRegion->getRegionNumber()]; + })) + return true; return false; } Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { - LDBG() << "Finding enclosing repetitive region for operation " - << op->getName(); - while (Region *region = op->getParentRegion()) { - LDBG() << "Checking region #" << region->getRegionNumber() - << " in operation " << region->getParentOp()->getName(); - op = region->getParentOp(); - if (auto branchOp = dyn_cast(op)) { - LDBG() - << "Found RegionBranchOpInterface, checking if region is repetitive"; - if (branchOp.isRepetitiveRegion(region->getRegionNumber())) { - LDBG() << "Found repetitive region #" << region->getRegionNumber(); + if (auto branchOp = dyn_cast(op)) + if (branchOp.isRepetitiveRegion(region->getRegionNumber())) return region; - } - } else { - LDBG() << "Parent operation does not implement RegionBranchOpInterface"; - } } - - LDBG() << "No enclosing repetitive region found"; return nullptr; } Region *mlir::getEnclosingRepetitiveRegion(Value value) { - LDBG() << "Finding enclosing repetitive region for value"; - Region *region = value.getParentRegion(); while (region) { - LDBG() << "Checking region #" << region->getRegionNumber() - << " in operation " << region->getParentOp()->getName(); - Operation *op = region->getParentOp(); - if (auto branchOp = dyn_cast(op)) { - LDBG() - << "Found RegionBranchOpInterface, checking if region is repetitive"; - if (branchOp.isRepetitiveRegion(region->getRegionNumber())) { - LDBG() << "Found repetitive region #" << region->getRegionNumber(); + if (auto branchOp = dyn_cast(op)) + if (branchOp.isRepetitiveRegion(region->getRegionNumber())) return region; - } - } else { - LDBG() << "Parent operation does not implement RegionBranchOpInterface"; - } region = op->getParentRegion(); } - - LDBG() << "No enclosing repetitive region found for value"; return nullptr; } diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 41f3f9d76a3b1..e0c65b0e09774 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -432,7 +432,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // Return the successors of `region` if the latter is not null. Else return // the successors of `regionBranchOp`. - auto getSuccessors = [&](RegionBranchPoint point) { + auto getSuccessors = [&](Region *region = nullptr) { + auto point = region ? region : RegionBranchPoint::parent(); SmallVector successors; regionBranchOp.getSuccessorRegions(point, successors); return successors; @@ -455,8 +456,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // `nonForwardedOperands`. auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) { nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true); - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint::parent())) { + for (const RegionSuccessor &successor : getSuccessors()) { for (OpOperand *opOperand : getForwardedOpOperands(successor)) nonForwardedOperands.reset(opOperand->getOperandNumber()); } @@ -469,13 +469,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, for (Region ®ion : regionBranchOp->getRegions()) { if (region.empty()) continue; - // TODO: this isn't correct in face of multiple terminators. Operation *terminator = region.front().getTerminator(); nonForwardedRets[terminator] = BitVector(terminator->getNumOperands(), true); - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint( - cast(terminator)))) { + for (const RegionSuccessor &successor : getSuccessors(®ion)) { for (OpOperand *opOperand : getForwardedOpOperands(successor, terminator)) nonForwardedRets[terminator].reset(opOperand->getOperandNumber()); @@ -492,13 +489,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, DenseMap &argsToKeep, Region *region = nullptr) { Operation *terminator = region ? region->front().getTerminator() : nullptr; - RegionBranchPoint point = - terminator - ? RegionBranchPoint( - cast(terminator)) - : RegionBranchPoint::parent(); - for (const RegionSuccessor &successor : getSuccessors(point)) { + for (const RegionSuccessor &successor : getSuccessors(region)) { Region *successorRegion = successor.getSuccessor(); for (auto [opOperand, input] : llvm::zip(getForwardedOpOperands(successor, terminator), @@ -525,8 +517,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, resultsOrArgsToKeepChanged = false; // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`. - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint::parent())) { + for (const RegionSuccessor &successor : getSuccessors()) { Region *successorRegion = successor.getSuccessor(); for (auto [opOperand, input] : llvm::zip(getForwardedOpOperands(successor), @@ -560,9 +551,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, if (region.empty()) continue; Operation *terminator = region.front().getTerminator(); - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint( - cast(terminator)))) { + for (const RegionSuccessor &successor : getSuccessors(®ion)) { Region *successorRegion = successor.getSuccessor(); for (auto [opOperand, input] : llvm::zip(getForwardedOpOperands(successor, terminator), diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index 3f481ad5dbba7..37fc86b18e7f0 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -373,7 +373,7 @@ func.func @reduceReturn_not_inside_reduce(%arg0 : f32) { func.func @std_if_incorrect_yield(%arg0: i1, %arg1: f32) { - // expected-error@+1 {{region control flow edge from Operation scf.yield to parent results: source has 1 operands, but target successor needs 2}} + // expected-error@+1 {{region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 2}} %x, %y = scf.if %arg0 -> (f32, f32) { %0 = arith.addf %arg1, %arg1 : f32 scf.yield %0 : f32 @@ -544,7 +544,7 @@ func.func @while_invalid_terminator() { func.func @while_cross_region_type_mismatch() { %true = arith.constant true - // expected-error@+1 {{region control flow edge from Operation scf.condition to Region #1: source has 0 operands, but target successor needs 1}} + // expected-error@+1 {{'scf.while' op region control flow edge from Region #0 to Region #1: source has 0 operands, but target successor needs 1}} scf.while : () -> () { scf.condition(%true) } do { @@ -557,7 +557,7 @@ func.func @while_cross_region_type_mismatch() { func.func @while_cross_region_type_mismatch() { %true = arith.constant true - // expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: source type #0 'i1' should match input type #0 'i32'}} + // expected-error@+1 {{'scf.while' op along control flow edge from Region #0 to Region #1: source type #0 'i1' should match input type #0 'i32'}} %0 = scf.while : () -> (i1) { scf.condition(%true) %true : i1 } do { @@ -570,7 +570,7 @@ func.func @while_cross_region_type_mismatch() { func.func @while_result_type_mismatch() { %true = arith.constant true - // expected-error@+1 {{region control flow edge from Operation scf.condition to parent results: source has 1 operands, but target successor needs 0}} + // expected-error@+1 {{'scf.while' op region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 0}} scf.while : () -> () { scf.condition(%true) %true : i1 } do { diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp index 7a7a58384fbb8..eb0d9801e7d3f 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp @@ -66,7 +66,7 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis { void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, RegionBranchPoint regionFrom, - RegionSuccessor regionTo, + RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) override; @@ -240,7 +240,7 @@ void NextAccessAnalysis::visitCallControlFlowTransfer( void NextAccessAnalysis::visitRegionBranchControlFlowTransfer( RegionBranchOpInterface branch, RegionBranchPoint regionFrom, - RegionSuccessor regionTo, const NextAccess &after, NextAccess *before) { + RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) { LDBG() << "visitRegionBranchControlFlowTransfer: " << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions()); LDBG() << " regionFrom: " << (regionFrom.isParent() ? "parent" : "region"); diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 4d4ec02546bc7..b211e243f234c 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -633,9 +633,8 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { parser.getCurrentLocation(), result.operands); } -OperandRange RegionIfOp::getEntrySuccessorOperands(RegionSuccessor successor) { - assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, - successor.getSuccessor()) && +OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && "invalid region index"); return getOperands(); } @@ -644,11 +643,10 @@ void RegionIfOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { // We always branch to the join region. if (!point.isParent()) { - if (point.getTerminatorPredecessorOrNull()->getParentRegion() != - &getJoinRegion()) + if (point != getJoinRegion()) regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); else - regions.push_back(RegionSuccessor(getOperation(), getResults())); + regions.push_back(RegionSuccessor(getResults())); return; } @@ -675,7 +673,7 @@ void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, if (point.isParent()) regions.emplace_back(&getRegion()); else - regions.emplace_back(getOperation(), getResults()); + regions.emplace_back(getResults()); } void AnyCondOp::getRegionInvocationBounds( @@ -1109,11 +1107,11 @@ void LoopBlockOp::getSuccessorRegions( if (point.isParent()) return; - regions.emplace_back(getOperation(), getOperation()->getResults()); + regions.emplace_back((*this)->getResults()); } -OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) { - assert(successor.getSuccessor() == &getBody()); +OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(point == getBody()); return MutableOperandRange(getInitMutable()); } @@ -1122,8 +1120,8 @@ OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) { //===----------------------------------------------------------------------===// MutableOperandRange -LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionSuccessor successor) { - if (successor.isParent()) +LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { + if (point.isParent()) return getExitArgMutable(); return getNextIterArgMutable(); } @@ -1215,7 +1213,7 @@ void TestStoreWithARegion::getSuccessorRegions( if (point.isParent()) regions.emplace_back(&getBody(), getBody().front().getArguments()); else - regions.emplace_back(getOperation(), getOperation()->getResults()); + regions.emplace_back(); } //===----------------------------------------------------------------------===// @@ -1229,7 +1227,7 @@ void TestStoreWithALoopRegion::getSuccessorRegions( // enter the body. regions.emplace_back( RegionSuccessor(&getBody(), getBody().front().getArguments())); - regions.emplace_back(getOperation(), getOperation()->getResults()); + regions.emplace_back(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index a3430ba49a291..05a33cf1afd94 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2581,7 +2581,7 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term", def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [ NoTerminator, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods ]> { let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases); let regions = (region VariadicRegion>:$caseRegions); diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp index 2e6950fca6be2..f1aae15393fd3 100644 --- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp @@ -13,24 +13,17 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Parser/Parser.h" -#include "llvm/Support/DebugLog.h" #include using namespace mlir; /// A dummy op that is also a terminator. -struct DummyOp : public Op { +struct DummyOp : public Op { using Op::Op; static ArrayRef getAttributeNames() { return {}; } static StringRef getOperationName() { return "cftest.dummy_op"; } - - MutableOperandRange getMutableSuccessorOperands(RegionSuccessor point) { - return MutableOperandRange(getOperation(), 0, 0); - } }; /// All regions of this op are mutually exclusive. @@ -46,8 +39,6 @@ struct MutuallyExclusiveRegionsOp // Regions have no successors. void getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) {} - using RegionBranchOpInterface::Trait< - MutuallyExclusiveRegionsOp>::getSuccessorRegions; }; /// All regions of this op call each other in a large circle. @@ -62,18 +53,13 @@ struct LoopRegionsOp void getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { - if (point.getTerminatorPredecessorOrNull()) { - Region *region = - point.getTerminatorPredecessorOrNull()->getParentRegion(); - if (region == &(*this)->getRegion(1)) + if (Region *region = point.getRegionOrNull()) { + if (point == (*this)->getRegion(1)) // This region also branches back to the parent. - regions.push_back( - RegionSuccessor(getOperation()->getParentOp(), - getOperation()->getParentOp()->getResults())); + regions.push_back(RegionSuccessor()); regions.push_back(RegionSuccessor(region)); } } - using RegionBranchOpInterface::Trait::getSuccessorRegions; }; /// Each region branches back it itself or the parent. @@ -89,17 +75,11 @@ struct DoubleLoopRegionsOp void getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { - if (point.getTerminatorPredecessorOrNull()) { - Region *region = - point.getTerminatorPredecessorOrNull()->getParentRegion(); - regions.push_back( - RegionSuccessor(getOperation()->getParentOp(), - getOperation()->getParentOp()->getResults())); + if (Region *region = point.getRegionOrNull()) { + regions.push_back(RegionSuccessor()); regions.push_back(RegionSuccessor(region)); } } - using RegionBranchOpInterface::Trait< - DoubleLoopRegionsOp>::getSuccessorRegions; }; /// Regions are executed sequentially. @@ -113,15 +93,11 @@ struct SequentialRegionsOp // Region 0 has Region 1 as a successor. void getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { - if (point.getTerminatorPredecessorOrNull() && - point.getTerminatorPredecessorOrNull()->getParentRegion() == - &(*this)->getRegion(0)) { + if (point == (*this)->getRegion(0)) { Operation *thisOp = this->getOperation(); regions.push_back(RegionSuccessor(&thisOp->getRegion(1))); } } - using RegionBranchOpInterface::Trait< - SequentialRegionsOp>::getSuccessorRegions; }; /// A dialect putting all the above together.