1515#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
1616
1717#include " mlir/IR/OpDefinition.h"
18+ #include " mlir/IR/Operation.h"
19+ #include " llvm/ADT/PointerUnion.h"
20+ #include " llvm/ADT/STLExtras.h"
21+ #include " llvm/Support/DebugLog.h"
22+ #include " llvm/Support/raw_ostream.h"
1823
1924namespace mlir {
2025class BranchOpInterface ;
2126class RegionBranchOpInterface ;
27+ class RegionBranchTerminatorOpInterface ;
2228
2329// / This class models how operands are forwarded to block arguments in control
2430// / flow. It consists of a number, denoting how many of the successors block
@@ -186,35 +192,49 @@ class RegionSuccessor {
186192public:
187193 // / Initialize a successor that branches to another region of the parent
188194 // / operation.
195+ // / TODO: the default value for the regionInputs is somehow broken.
196+ // / A region successor should have its input correctly set.
189197 RegionSuccessor (Region *region, Block::BlockArgListType regionInputs = {})
190- : region(region), inputs(regionInputs) {}
198+ : successor(region), inputs(regionInputs) {
199+ assert (region && " Region must not be null" );
200+ }
191201 // / Initialize a successor that branches back to/out of the parent operation.
192- RegionSuccessor (Operation::result_range results)
193- : inputs(ValueRange(results)) {}
194- // / Constructor with no arguments.
195- RegionSuccessor () : inputs(ValueRange()) {}
202+ // / The target must be one of the recursive parent operations.
203+ RegionSuccessor (Operation *successorOp, Operation::result_range results)
204+ : successor(successorOp), inputs(ValueRange(results)) {
205+ assert (successorOp && " Successor op must not be null" );
206+ }
196207
197208 // / Return the given region successor. Returns nullptr if the successor is the
198209 // / parent operation.
199- Region *getSuccessor () const { return region ; }
210+ Region *getSuccessor () const { return dyn_cast<Region *>(successor) ; }
200211
201212 // / Return true if the successor is the parent operation.
202- bool isParent () const { return region == nullptr ; }
213+ bool isParent () const { return isa<Operation *>(successor) ; }
203214
204215 // / Return the inputs to the successor that are remapped by the exit values of
205216 // / the current region.
206217 ValueRange getSuccessorInputs () const { return inputs; }
207218
219+ bool operator ==(RegionSuccessor rhs) const {
220+ return successor == rhs.successor && inputs == rhs.inputs ;
221+ }
222+
223+ friend bool operator !=(RegionSuccessor lhs, RegionSuccessor rhs) {
224+ return !(lhs == rhs);
225+ }
226+
208227private:
209- Region *region {nullptr };
228+ llvm::PointerUnion< Region *, Operation *> successor {nullptr };
210229 ValueRange inputs;
211230};
212231
213232// / This class represents a point being branched from in the methods of the
214233// / `RegionBranchOpInterface`.
215234// / One can branch from one of two kinds of places:
216235// / * The parent operation (aka the `RegionBranchOpInterface` implementation)
217- // / * A region within the parent operation.
236+ // / * A RegionBranchTerminatorOpInterface inside a region within the parent
237+ // operation.
218238class RegionBranchPoint {
219239public:
220240 // / Returns an instance of `RegionBranchPoint` representing the parent
@@ -223,55 +243,57 @@ class RegionBranchPoint {
223243
224244 // / Creates a `RegionBranchPoint` that branches from the given region.
225245 // / The pointer must not be null.
226- RegionBranchPoint (Region *region) : maybeRegion(region) {
227- assert (region && " Region must not be null" );
228- }
229-
230- RegionBranchPoint (Region ®ion) : RegionBranchPoint(®ion) {}
246+ inline RegionBranchPoint (RegionBranchTerminatorOpInterface predecessor);
231247
232248 // / Explicitly stops users from constructing with `nullptr`.
233249 RegionBranchPoint (std::nullptr_t ) = delete ;
234250
235- // / Constructs a `RegionBranchPoint` from the the target of a
236- // / `RegionSuccessor` instance.
237- RegionBranchPoint (RegionSuccessor successor) {
238- if (successor.isParent ())
239- maybeRegion = nullptr ;
240- else
241- maybeRegion = successor.getSuccessor ();
242- }
243-
244- // / Assigns a region being branched from.
245- RegionBranchPoint &operator =(Region ®ion) {
246- maybeRegion = ®ion;
247- return *this ;
248- }
249-
250251 // / Returns true if branching from the parent op.
251- bool isParent () const { return maybeRegion == nullptr ; }
252+ bool isParent () const { return predecessor == nullptr ; }
252253
253254 // / Returns the region if branching from a region.
254255 // / A null pointer otherwise.
255- Region * getRegionOrNull () const { return maybeRegion ; }
256+ Operation * getPredecessorOrNull () const { return predecessor ; }
256257
257258 // / Returns true if the two branch points are equal.
258259 friend bool operator ==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
259- return lhs.maybeRegion == rhs.maybeRegion ;
260+ return lhs.predecessor == rhs.predecessor ;
260261 }
261262
262263private:
263264 // Private constructor to encourage the use of `RegionBranchPoint::parent`.
264- constexpr RegionBranchPoint () : maybeRegion( nullptr ) {}
265+ constexpr RegionBranchPoint () = default;
265266
266267 // / Internal encoding. Uses nullptr for representing branching from the parent
267- // / op and the region being branched from otherwise.
268- Region *maybeRegion ;
268+ // / op and the region terminator being branched from otherwise.
269+ Operation *predecessor = nullptr ;
269270};
270271
271272inline bool operator !=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
272273 return !(lhs == rhs);
273274}
274275
276+ inline llvm::raw_ostream &operator <<(llvm::raw_ostream &os,
277+ RegionBranchPoint point) {
278+ if (point.isParent ())
279+ return os << " <from parent>" ;
280+ return os
281+ << " <region #"
282+ << point.getPredecessorOrNull ()->getParentRegion ()->getRegionNumber ()
283+ << " , terminator "
284+ << OpWithFlags (point.getPredecessorOrNull (),
285+ OpPrintingFlags ().skipRegions ())
286+ << " >" ;
287+ }
288+
289+ inline llvm::raw_ostream &operator <<(llvm::raw_ostream &os,
290+ RegionSuccessor successor) {
291+ if (successor.isParent ())
292+ return os << " <to parent>" ;
293+ return os << " <to region #" << successor.getSuccessor ()->getRegionNumber ()
294+ << " with " << successor.getSuccessorInputs ().size () << " inputs>" ;
295+ }
296+
275297// / This class represents upper and lower bounds on the number of times a region
276298// / of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
277299// / zero, but the upper bound may not be known.
@@ -348,4 +370,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
348370// / Include the generated interface declarations.
349371#include " mlir/Interfaces/ControlFlowInterfaces.h.inc"
350372
373+ namespace mlir {
374+ inline RegionBranchPoint::RegionBranchPoint (
375+ RegionBranchTerminatorOpInterface predecessor)
376+ : predecessor(predecessor.getOperation()) {}
377+ } // namespace mlir
378+
351379#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
0 commit comments