Skip to content

Commit a9d0f5e

Browse files
authored
[mlir] Allow loop-like operations in AbstractDenseForwardDataFlowAnalysis (#66179)
Remove assertion violated by loop-like operations. Signed-off-by: Victor Perez <[email protected]>
1 parent aee8f87 commit a9d0f5e

File tree

5 files changed

+156
-11
lines changed

5 files changed

+156
-11
lines changed

mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,6 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
199199
op == branch ? std::optional<unsigned>()
200200
: op->getBlock()->getParent()->getRegionNumber();
201201
if (auto *toBlock = point.dyn_cast<Block *>()) {
202-
assert(op == branch ||
203-
toBlock->getParent() != op->getBlock()->getParent());
204202
unsigned regionTo = toBlock->getParent()->getRegionNumber();
205203
visitRegionBranchControlFlowTransfer(branch, regionFrom, regionTo,
206204
*before, after);

mlir/test/Analysis/DataFlow/test-last-modified.mlir

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,122 @@ func.func @store_with_a_region_after_containing_a_store(%arg0: memref<f32>) -> m
229229
memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
230230
return {tag = "return"} %arg0 : memref<f32>
231231
}
232+
233+
// CHECK-LABEL: test_tag: store_with_a_loop_region_before::before:
234+
// CHECK: operand #0
235+
// CHECK: - pre
236+
// CHECK: test_tag: inside_region:
237+
// CHECK: operand #0
238+
// CHECK: - region
239+
// CHECK: test_tag: after:
240+
// CHECK: operand #0
241+
// CHECK: - region
242+
// CHECK: test_tag: return:
243+
// CHECK: operand #0
244+
// CHECK: - post
245+
func.func @store_with_a_loop_region_before(%arg0: memref<f32>) -> memref<f32> {
246+
%0 = arith.constant 0.0 : f32
247+
%1 = arith.constant 1.0 : f32
248+
memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
249+
memref.load %arg0[] {tag = "store_with_a_loop_region_before::before"} : memref<f32>
250+
test.store_with_a_loop_region %arg0 attributes { tag_name = "region", store_before_region = true } {
251+
memref.load %arg0[] {tag = "inside_region"} : memref<f32>
252+
test.store_with_a_region_terminator
253+
} : memref<f32>
254+
memref.load %arg0[] {tag = "after"} : memref<f32>
255+
memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
256+
return {tag = "return"} %arg0 : memref<f32>
257+
}
258+
259+
// CHECK-LABEL: test_tag: store_with_a_loop_region_after::before:
260+
// CHECK: operand #0
261+
// CHECK: - pre
262+
// CHECK: test_tag: inside_region:
263+
// CHECK: operand #0
264+
// CHECK: - pre
265+
// CHECK: test_tag: after:
266+
// CHECK: operand #0
267+
// CHECK: - region
268+
// CHECK: test_tag: return:
269+
// CHECK: operand #0
270+
// CHECK: - post
271+
func.func @store_with_a_loop_region_after(%arg0: memref<f32>) -> memref<f32> {
272+
%0 = arith.constant 0.0 : f32
273+
%1 = arith.constant 1.0 : f32
274+
memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
275+
memref.load %arg0[] {tag = "store_with_a_loop_region_after::before"} : memref<f32>
276+
test.store_with_a_loop_region %arg0 attributes { tag_name = "region", store_before_region = false } {
277+
memref.load %arg0[] {tag = "inside_region"} : memref<f32>
278+
test.store_with_a_region_terminator
279+
} : memref<f32>
280+
memref.load %arg0[] {tag = "after"} : memref<f32>
281+
memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
282+
return {tag = "return"} %arg0 : memref<f32>
283+
}
284+
285+
// CHECK-LABEL: test_tag: store_with_a_loop_region_before_containing_a_store::before:
286+
// CHECK: operand #0
287+
// CHECK: - pre
288+
// CHECK: test_tag: enter_region:
289+
// CHECK: operand #0
290+
// CHECK-DAG: - region
291+
// CHECK-DAG: - inner
292+
// CHECK: test_tag: exit_region:
293+
// CHECK: operand #0
294+
// CHECK: - inner
295+
// CHECK: test_tag: after:
296+
// CHECK: operand #0
297+
// CHECK-DAG: - region
298+
// CHECK-DAG: - inner
299+
// CHECK: test_tag: return:
300+
// CHECK: operand #0
301+
// CHECK: - post
302+
func.func @store_with_a_loop_region_before_containing_a_store(%arg0: memref<f32>) -> memref<f32> {
303+
%0 = arith.constant 0.0 : f32
304+
%1 = arith.constant 1.0 : f32
305+
memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
306+
memref.load %arg0[] {tag = "store_with_a_loop_region_before_containing_a_store::before"} : memref<f32>
307+
test.store_with_a_loop_region %arg0 attributes { tag_name = "region", store_before_region = true } {
308+
memref.load %arg0[] {tag = "enter_region"} : memref<f32>
309+
%2 = arith.constant 2.0 : f32
310+
memref.store %2, %arg0[] {tag_name = "inner"} : memref<f32>
311+
memref.load %arg0[] {tag = "exit_region"} : memref<f32>
312+
test.store_with_a_region_terminator
313+
} : memref<f32>
314+
memref.load %arg0[] {tag = "after"} : memref<f32>
315+
memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
316+
return {tag = "return"} %arg0 : memref<f32>
317+
}
318+
319+
// CHECK-LABEL: test_tag: store_with_a_loop_region_after_containing_a_store::before:
320+
// CHECK: operand #0
321+
// CHECK: - pre
322+
// CHECK: test_tag: enter_region:
323+
// CHECK: operand #0
324+
// CHECK-DAG: - pre
325+
// CHECK-DAG: - inner
326+
// CHECK: test_tag: exit_region:
327+
// CHECK: operand #0
328+
// CHECK: - inner
329+
// CHECK: test_tag: after:
330+
// CHECK: operand #0
331+
// CHECK: - region
332+
// CHECK: test_tag: return:
333+
// CHECK: operand #0
334+
// CHECK: - post
335+
func.func @store_with_a_loop_region_after_containing_a_store(%arg0: memref<f32>) -> memref<f32> {
336+
%0 = arith.constant 0.0 : f32
337+
%1 = arith.constant 1.0 : f32
338+
memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
339+
memref.load %arg0[] {tag = "store_with_a_loop_region_after_containing_a_store::before"} : memref<f32>
340+
test.store_with_a_loop_region %arg0 attributes { tag_name = "region", store_before_region = false } {
341+
memref.load %arg0[] {tag = "enter_region"} : memref<f32>
342+
%2 = arith.constant 2.0 : f32
343+
memref.store %2, %arg0[] {tag_name = "inner"} : memref<f32>
344+
memref.load %arg0[] {tag = "exit_region"} : memref<f32>
345+
test.store_with_a_region_terminator
346+
} : memref<f32>
347+
memref.load %arg0[] {tag = "after"} : memref<f32>
348+
memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
349+
return {tag = "return"} %arg0 : memref<f32>
350+
}

mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
1818
#include "mlir/Interfaces/SideEffectInterfaces.h"
1919
#include "mlir/Pass/Pass.h"
20+
#include "mlir/Support/LLVM.h"
21+
#include "llvm/ADT/TypeSwitch.h"
2022
#include <optional>
2123

2224
using namespace mlir;
@@ -133,15 +135,19 @@ void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer(
133135
RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
134136
std::optional<unsigned> regionTo, const LastModification &before,
135137
LastModification *after) {
136-
auto testStoreWithARegion =
137-
dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
138-
if (testStoreWithARegion &&
139-
((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) ||
140-
(!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) {
141-
return visitOperation(branch, before, after);
142-
}
143-
AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
144-
branch, regionFrom, regionTo, before, after);
138+
auto defaultHandling = [&]() {
139+
AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
140+
branch, regionFrom, regionTo, before, after);
141+
};
142+
TypeSwitch<Operation *>(branch.getOperation())
143+
.Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>(
144+
[=](auto storeWithRegion) {
145+
if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) ||
146+
(!regionFrom && storeWithRegion.getStoreBeforeRegion()))
147+
visitOperation(branch, before, after);
148+
defaultHandling();
149+
})
150+
.Default([=](auto) { defaultHandling(); });
145151
}
146152

147153
namespace {

mlir/test/lib/Dialect/Test/TestDialect.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,16 @@ void TestStoreWithARegion::getSuccessorRegions(
13151315
regions.emplace_back();
13161316
}
13171317

1318+
void TestStoreWithALoopRegion::getSuccessorRegions(
1319+
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1320+
// Both the operation itself and the region may be branching into the body or
1321+
// back into the operation itself. It is possible for the operation not to
1322+
// enter the body.
1323+
regions.emplace_back(
1324+
RegionSuccessor(&getBody(), getBody().front().getArguments()));
1325+
regions.emplace_back();
1326+
}
1327+
13181328
LogicalResult
13191329
TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
13201330
::mlir::OperationState &state) {

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2953,6 +2953,18 @@ def TestStoreWithARegion : TEST_Op<"store_with_a_region",
29532953
"$address attr-dict-with-keyword regions `:` type($address)";
29542954
}
29552955

2956+
def TestStoreWithALoopRegion : TEST_Op<"store_with_a_loop_region",
2957+
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
2958+
SingleBlock]> {
2959+
let arguments = (ins
2960+
Arg<AnyMemRef, "", [MemWrite]>:$address,
2961+
BoolAttr:$store_before_region
2962+
);
2963+
let regions = (region AnyRegion:$body);
2964+
let assemblyFormat =
2965+
"$address attr-dict-with-keyword regions `:` type($address)";
2966+
}
2967+
29562968
def TestStoreWithARegionTerminator : TEST_Op<"store_with_a_region_terminator",
29572969
[ReturnLike, Terminator, NoMemoryEffect]> {
29582970
let assemblyFormat = "attr-dict";

0 commit comments

Comments
 (0)