|
15 | 15 | #define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H |
16 | 16 |
|
17 | 17 | #include "mlir/IR/OpDefinition.h" |
| 18 | +#include "mlir/IR/Operation.h" |
| 19 | +#include "llvm/ADT/PointerUnion.h" |
| 20 | +#include "llvm/ADT/STLExtras.h" |
| 21 | +#include "llvm/Support/DebugLog.h" |
| 22 | +#include "llvm/Support/raw_ostream.h" |
18 | 23 |
|
19 | 24 | namespace mlir { |
20 | 25 | class BranchOpInterface; |
21 | 26 | class RegionBranchOpInterface; |
| 27 | +class RegionBranchTerminatorOpInterface; |
22 | 28 |
|
23 | 29 | /// This class models how operands are forwarded to block arguments in control |
24 | 30 | /// flow. It consists of a number, denoting how many of the successors block |
@@ -186,92 +192,107 @@ class RegionSuccessor { |
186 | 192 | public: |
187 | 193 | /// Initialize a successor that branches to another region of the parent |
188 | 194 | /// operation. |
| 195 | + /// TODO: the default value for the regionInputs is somehow broken. |
| 196 | + /// A region successor should have its input correctly set. |
189 | 197 | RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {}) |
190 | | - : region(region), inputs(regionInputs) {} |
| 198 | + : successor(region), inputs(regionInputs) { |
| 199 | + assert(region && "Region must not be null"); |
| 200 | + } |
191 | 201 | /// Initialize a successor that branches back to/out of the parent operation. |
192 | | - RegionSuccessor(Operation::result_range results) |
193 | | - : inputs(ValueRange(results)) {} |
194 | | - /// Constructor with no arguments. |
195 | | - RegionSuccessor() : inputs(ValueRange()) {} |
| 202 | + /// The target must be one of the recursive parent operations. |
| 203 | + RegionSuccessor(Operation *successorOp, Operation::result_range results) |
| 204 | + : successor(successorOp), inputs(ValueRange(results)) { |
| 205 | + assert(successorOp && "Successor op must not be null"); |
| 206 | + } |
196 | 207 |
|
197 | 208 | /// Return the given region successor. Returns nullptr if the successor is the |
198 | 209 | /// parent operation. |
199 | | - Region *getSuccessor() const { return region; } |
| 210 | + Region *getSuccessor() const { return dyn_cast<Region *>(successor); } |
200 | 211 |
|
201 | 212 | /// Return true if the successor is the parent operation. |
202 | | - bool isParent() const { return region == nullptr; } |
| 213 | + bool isParent() const { return isa<Operation *>(successor); } |
203 | 214 |
|
204 | 215 | /// Return the inputs to the successor that are remapped by the exit values of |
205 | 216 | /// the current region. |
206 | 217 | ValueRange getSuccessorInputs() const { return inputs; } |
207 | 218 |
|
| 219 | + bool operator==(RegionSuccessor rhs) const { |
| 220 | + return successor == rhs.successor && inputs == rhs.inputs; |
| 221 | + } |
| 222 | + |
| 223 | + friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) { |
| 224 | + return !(lhs == rhs); |
| 225 | + } |
| 226 | + |
208 | 227 | private: |
209 | | - Region *region{nullptr}; |
| 228 | + llvm::PointerUnion<Region *, Operation *> successor{nullptr}; |
210 | 229 | ValueRange inputs; |
211 | 230 | }; |
212 | 231 |
|
213 | 232 | /// This class represents a point being branched from in the methods of the |
214 | 233 | /// `RegionBranchOpInterface`. |
215 | 234 | /// One can branch from one of two kinds of places: |
216 | 235 | /// * The parent operation (aka the `RegionBranchOpInterface` implementation) |
217 | | -/// * A region within the parent operation. |
| 236 | +/// * A RegionBranchTerminatorOpInterface inside a region within the parent |
| 237 | +// operation. |
218 | 238 | class RegionBranchPoint { |
219 | 239 | public: |
220 | 240 | /// Returns an instance of `RegionBranchPoint` representing the parent |
221 | 241 | /// operation. |
222 | 242 | static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); } |
223 | 243 |
|
224 | | - /// Creates a `RegionBranchPoint` that branches from the given region. |
225 | | - /// The pointer must not be null. |
226 | | - RegionBranchPoint(Region *region) : maybeRegion(region) { |
227 | | - assert(region && "Region must not be null"); |
228 | | - } |
229 | | - |
230 | | - RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {} |
| 244 | + /// Creates a `RegionBranchPoint` that branches from the given terminator. |
| 245 | + inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor); |
231 | 246 |
|
232 | 247 | /// Explicitly stops users from constructing with `nullptr`. |
233 | 248 | RegionBranchPoint(std::nullptr_t) = delete; |
234 | 249 |
|
235 | | - /// Constructs a `RegionBranchPoint` from the the target of a |
236 | | - /// `RegionSuccessor` instance. |
237 | | - RegionBranchPoint(RegionSuccessor successor) { |
238 | | - if (successor.isParent()) |
239 | | - maybeRegion = nullptr; |
240 | | - else |
241 | | - maybeRegion = successor.getSuccessor(); |
242 | | - } |
243 | | - |
244 | | - /// Assigns a region being branched from. |
245 | | - RegionBranchPoint &operator=(Region ®ion) { |
246 | | - maybeRegion = ®ion; |
247 | | - return *this; |
248 | | - } |
249 | | - |
250 | 250 | /// Returns true if branching from the parent op. |
251 | | - bool isParent() const { return maybeRegion == nullptr; } |
| 251 | + bool isParent() const { return predecessor == nullptr; } |
252 | 252 |
|
253 | 253 | /// Returns the region if branching from a region. |
254 | 254 | /// A null pointer otherwise. |
255 | | - Region *getRegionOrNull() const { return maybeRegion; } |
| 255 | + Operation *getPredecessorOrNull() const { return predecessor; } |
256 | 256 |
|
257 | 257 | /// Returns true if the two branch points are equal. |
258 | 258 | friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) { |
259 | | - return lhs.maybeRegion == rhs.maybeRegion; |
| 259 | + return lhs.predecessor == rhs.predecessor; |
260 | 260 | } |
261 | 261 |
|
262 | 262 | private: |
263 | 263 | // Private constructor to encourage the use of `RegionBranchPoint::parent`. |
264 | | - constexpr RegionBranchPoint() : maybeRegion(nullptr) {} |
| 264 | + constexpr RegionBranchPoint() = default; |
265 | 265 |
|
266 | 266 | /// Internal encoding. Uses nullptr for representing branching from the parent |
267 | | - /// op and the region being branched from otherwise. |
268 | | - Region *maybeRegion; |
| 267 | + /// op and the region terminator being branched from otherwise. |
| 268 | + Operation *predecessor = nullptr; |
269 | 269 | }; |
270 | 270 |
|
271 | 271 | inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) { |
272 | 272 | return !(lhs == rhs); |
273 | 273 | } |
274 | 274 |
|
| 275 | +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, |
| 276 | + RegionBranchPoint point) { |
| 277 | + if (point.isParent()) |
| 278 | + return os << "<from parent>"; |
| 279 | + return os |
| 280 | + << "<region #" |
| 281 | + << point.getPredecessorOrNull()->getParentRegion()->getRegionNumber() |
| 282 | + << ", terminator " |
| 283 | + << OpWithFlags(point.getPredecessorOrNull(), |
| 284 | + OpPrintingFlags().skipRegions()) |
| 285 | + << ">"; |
| 286 | +} |
| 287 | + |
| 288 | +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, |
| 289 | + RegionSuccessor successor) { |
| 290 | + if (successor.isParent()) |
| 291 | + return os << "<to parent>"; |
| 292 | + return os << "<to region #" << successor.getSuccessor()->getRegionNumber() |
| 293 | + << " with " << successor.getSuccessorInputs().size() << " inputs>"; |
| 294 | +} |
| 295 | + |
275 | 296 | /// This class represents upper and lower bounds on the number of times a region |
276 | 297 | /// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least |
277 | 298 | /// zero, but the upper bound may not be known. |
@@ -348,4 +369,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> { |
348 | 369 | /// Include the generated interface declarations. |
349 | 370 | #include "mlir/Interfaces/ControlFlowInterfaces.h.inc" |
350 | 371 |
|
| 372 | +namespace mlir { |
| 373 | +inline RegionBranchPoint::RegionBranchPoint( |
| 374 | + RegionBranchTerminatorOpInterface predecessor) |
| 375 | + : predecessor(predecessor.getOperation()) {} |
| 376 | +} // namespace mlir |
| 377 | + |
351 | 378 | #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H |
0 commit comments