-
Notifications
You must be signed in to change notification settings - Fork 15k
[SCF] Fixed epilogue predicates in loop pipelining #108964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: SJW (sjw36) ChangesThe computed loop iteration is zero based, so only check it is less than zero. Full diff: https://github.com/llvm/llvm-project/pull/108964.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 7cecd4942b640f..ad6f790a5ba02c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -655,6 +655,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
+ Value zero =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+
SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 0; i < maxStage; i++) {
// iterI = total_iters - 1 - i
@@ -671,9 +674,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
if (dynamicLoop) {
- // pred = iterI >= lb
+ // pred = iterI < 0
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, iterI, lb);
+ loc, arith::CmpIPredicate::slt, iterI, zero);
}
}
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 4a1406faabce1b..048786bad5d447 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -781,12 +781,12 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
+// CHECK: %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_14]], %{{.*}}
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
+// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ADDI_19]], %{{.*}}
// CHECK: scf.if %[[CMPI_17]] {
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
// CHECK: } else {
@@ -845,7 +845,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// CHECK: %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
// CHECK: %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
-// CHECK: %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
+// CHECK: %[[CMPI_9:.*]] = arith.cmpi slt, %[[ADDI_8]], %{{.*}}
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_9]]
// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
// CHECK: scf.yield %[[ADDF_13]]
|
// pred = iterI < 0 | ||
predicates[i + 1] = rewriter.create<arith::CmpIOp>( | ||
loc, arith::CmpIPredicate::sge, iterI, lb); | ||
loc, arith::CmpIPredicate::slt, iterI, zero); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be pred = iterI >= 0
? So the predicate is true only when iterI is negative seems odd considering iterI = total_iters - 1 - i
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that's right.
So for:
scf.for (lb=3..ub=4, step=2) max_stage = 2
This computes:
total_iterations = (4 - 3) + 2 - 1 / 2 = 1
iterI(0) = total_iterations - 1 - 0 = 0
newlastIter(0) = iterI(0) * 2 + 3 = 3
pred(0) = iterI(0) >= 0 = true
iterI(1) = total_iterations - 1 - 1 = -1
newlastIter(1) = iterI(1) * 2 + 3 = 1
pred(1) = iterI(1) >= 0 = false
I'll update now.
can we add a test or modify one to not start from 0 so that this can be tested? |
@ThomasRaoux updated tests. |
d6c60bd
to
e7c5558
Compare
This mirrors upstream patch llvm/llvm-project#108964
The computed loop iteration is zero based, so only check it is less than zero. This fixes the case when lower bound is not zero.
ca879f8
to
93b0115
Compare
@ThomasRaoux does this look adequate? |
This mirrors upstream patch llvm/llvm-project#108964
The computed loop iteration is zero based, so only check it is less than zero.
This fixes the case when lower bound is not zero.