Skip to content

Commit 224da97

Browse files
committed
[MLIR] Revamp RegionBranchOpInterface
This is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition: - A `RegionBranchPoint` is either the parent (`RegionBranchOpInterface`) op or a `RegionBranchTerminatorOpInterface` operation in a nested region. - A `RegionSuccessor` is either one of the nested region or the parent `RegionBranchOpInterface` Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface. It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately.
1 parent eabfed8 commit 224da97

File tree

37 files changed

+915
-392
lines changed

37 files changed

+915
-392
lines changed

mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
397397
/// itself.
398398
virtual void visitRegionBranchControlFlowTransfer(
399399
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
400-
RegionBranchPoint regionTo, const AbstractDenseLattice &after,
400+
RegionSuccessor regionTo, const AbstractDenseLattice &after,
401401
AbstractDenseLattice *before) {
402402
meet(before, after);
403403
}
@@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis
526526
/// and "to" regions.
527527
virtual void visitRegionBranchControlFlowTransfer(
528528
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
529-
RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
529+
RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
530530
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
531531
branch, regionFrom, regionTo, after, before);
532532
}
@@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis
571571
}
572572
void visitRegionBranchControlFlowTransfer(
573573
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
574-
RegionBranchPoint regionTo, const AbstractDenseLattice &after,
574+
RegionSuccessor regionTo, const AbstractDenseLattice &after,
575575
AbstractDenseLattice *before) final {
576576
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
577577
static_cast<const LatticeT &>(after),

mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
286286
/// and propagating therefrom.
287287
virtual void
288288
visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
289-
RegionBranchPoint successor,
289+
RegionSuccessor successor,
290290
ArrayRef<AbstractSparseLattice *> lattices);
291291
};
292292

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [
644644

645645
/// Returns true if the mapping specified for this forall op is linear.
646646
bool usesLinearMapping();
647+
648+
/// RegionBranchOpInterface
649+
650+
OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
651+
return getInits();
652+
}
653+
647654
}];
648655
}
649656

mlir/include/mlir/IR/Diagnostics.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class MLIRContext;
2929
class Operation;
3030
class OperationName;
3131
class OpPrintingFlags;
32+
class OpWithFlags;
3233
class Type;
3334
class Value;
3435

@@ -199,6 +200,7 @@ class Diagnostic {
199200

200201
/// Stream in an Operation.
201202
Diagnostic &operator<<(Operation &op);
203+
Diagnostic &operator<<(OpWithFlags op);
202204
Diagnostic &operator<<(Operation *op) { return *this << *op; }
203205
/// Append an operation with the given printing flags.
204206
Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);

mlir/include/mlir/IR/Operation.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,7 @@ class OpWithFlags {
11141114
: op(op), theFlags(flags) {}
11151115
OpPrintingFlags &flags() { return theFlags; }
11161116
const OpPrintingFlags &flags() const { return theFlags; }
1117+
Operation *getOperation() const { return op; }
11171118

11181119
private:
11191120
Operation *op;

mlir/include/mlir/IR/Region.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ class RegionRange
379379
friend RangeBaseT;
380380
};
381381

382+
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region &region);
383+
382384
} // namespace mlir
383385

384386
#endif // MLIR_IR_REGION_H

mlir/include/mlir/Interfaces/ControlFlowInterfaces.h

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,16 @@
1515
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
1616

1717
#include "mlir/IR/OpDefinition.h"
18+
#include "mlir/IR/Operation.h"
19+
#include "llvm/ADT/PointerUnion.h"
20+
#include "llvm/ADT/STLExtras.h"
21+
#include "llvm/Support/DebugLog.h"
22+
#include "llvm/Support/raw_ostream.h"
1823

1924
namespace mlir {
2025
class BranchOpInterface;
2126
class RegionBranchOpInterface;
27+
class RegionBranchTerminatorOpInterface;
2228

2329
/// This class models how operands are forwarded to block arguments in control
2430
/// flow. It consists of a number, denoting how many of the successors block
@@ -186,35 +192,49 @@ class RegionSuccessor {
186192
public:
187193
/// Initialize a successor that branches to another region of the parent
188194
/// operation.
195+
/// TODO: the default value for the regionInputs is somehow broken.
196+
/// A region successor should have its input correctly set.
189197
RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
190-
: region(region), inputs(regionInputs) {}
198+
: successor(region), inputs(regionInputs) {
199+
assert(region && "Region must not be null");
200+
}
191201
/// Initialize a successor that branches back to/out of the parent operation.
192-
RegionSuccessor(Operation::result_range results)
193-
: inputs(ValueRange(results)) {}
194-
/// Constructor with no arguments.
195-
RegionSuccessor() : inputs(ValueRange()) {}
202+
/// The target must be one of the recursive parent operations.
203+
RegionSuccessor(Operation *successorOp, Operation::result_range results)
204+
: successor(successorOp), inputs(ValueRange(results)) {
205+
assert(successorOp && "Successor op must not be null");
206+
}
196207

197208
/// Return the given region successor. Returns nullptr if the successor is the
198209
/// parent operation.
199-
Region *getSuccessor() const { return region; }
210+
Region *getSuccessor() const { return dyn_cast<Region *>(successor); }
200211

201212
/// Return true if the successor is the parent operation.
202-
bool isParent() const { return region == nullptr; }
213+
bool isParent() const { return isa<Operation *>(successor); }
203214

204215
/// Return the inputs to the successor that are remapped by the exit values of
205216
/// the current region.
206217
ValueRange getSuccessorInputs() const { return inputs; }
207218

219+
bool operator==(RegionSuccessor rhs) const {
220+
return successor == rhs.successor && inputs == rhs.inputs;
221+
}
222+
223+
friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
224+
return !(lhs == rhs);
225+
}
226+
208227
private:
209-
Region *region{nullptr};
228+
llvm::PointerUnion<Region *, Operation *> successor{nullptr};
210229
ValueRange inputs;
211230
};
212231

213232
/// This class represents a point being branched from in the methods of the
214233
/// `RegionBranchOpInterface`.
215234
/// One can branch from one of two kinds of places:
216235
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
217-
/// * A region within the parent operation.
236+
/// * A RegionBranchTerminatorOpInterface inside a region within the parent
237+
// operation.
218238
class RegionBranchPoint {
219239
public:
220240
/// Returns an instance of `RegionBranchPoint` representing the parent
@@ -223,55 +243,57 @@ class RegionBranchPoint {
223243

224244
/// Creates a `RegionBranchPoint` that branches from the given region.
225245
/// The pointer must not be null.
226-
RegionBranchPoint(Region *region) : maybeRegion(region) {
227-
assert(region && "Region must not be null");
228-
}
229-
230-
RegionBranchPoint(Region &region) : RegionBranchPoint(&region) {}
246+
inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);
231247

232248
/// Explicitly stops users from constructing with `nullptr`.
233249
RegionBranchPoint(std::nullptr_t) = delete;
234250

235-
/// Constructs a `RegionBranchPoint` from the the target of a
236-
/// `RegionSuccessor` instance.
237-
RegionBranchPoint(RegionSuccessor successor) {
238-
if (successor.isParent())
239-
maybeRegion = nullptr;
240-
else
241-
maybeRegion = successor.getSuccessor();
242-
}
243-
244-
/// Assigns a region being branched from.
245-
RegionBranchPoint &operator=(Region &region) {
246-
maybeRegion = &region;
247-
return *this;
248-
}
249-
250251
/// Returns true if branching from the parent op.
251-
bool isParent() const { return maybeRegion == nullptr; }
252+
bool isParent() const { return predecessor == nullptr; }
252253

253254
/// Returns the region if branching from a region.
254255
/// A null pointer otherwise.
255-
Region *getRegionOrNull() const { return maybeRegion; }
256+
Operation *getPredecessorOrNull() const { return predecessor; }
256257

257258
/// Returns true if the two branch points are equal.
258259
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
259-
return lhs.maybeRegion == rhs.maybeRegion;
260+
return lhs.predecessor == rhs.predecessor;
260261
}
261262

262263
private:
263264
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
264-
constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
265+
constexpr RegionBranchPoint() = default;
265266

266267
/// Internal encoding. Uses nullptr for representing branching from the parent
267-
/// op and the region being branched from otherwise.
268-
Region *maybeRegion;
268+
/// op and the region terminator being branched from otherwise.
269+
Operation *predecessor = nullptr;
269270
};
270271

271272
inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
272273
return !(lhs == rhs);
273274
}
274275

276+
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
277+
RegionBranchPoint point) {
278+
if (point.isParent())
279+
return os << "<from parent>";
280+
return os
281+
<< "<region #"
282+
<< point.getPredecessorOrNull()->getParentRegion()->getRegionNumber()
283+
<< ", terminator "
284+
<< OpWithFlags(point.getPredecessorOrNull(),
285+
OpPrintingFlags().skipRegions())
286+
<< ">";
287+
}
288+
289+
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
290+
RegionSuccessor successor) {
291+
if (successor.isParent())
292+
return os << "<to parent>";
293+
return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
294+
<< " with " << successor.getSuccessorInputs().size() << " inputs>";
295+
}
296+
275297
/// This class represents upper and lower bounds on the number of times a region
276298
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
277299
/// zero, but the upper bound may not be known.
@@ -348,4 +370,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
348370
/// Include the generated interface declarations.
349371
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
350372

373+
namespace mlir {
374+
inline RegionBranchPoint::RegionBranchPoint(
375+
RegionBranchTerminatorOpInterface predecessor)
376+
: predecessor(predecessor.getOperation()) {}
377+
} // namespace mlir
378+
351379
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H

0 commit comments

Comments
 (0)