Skip to content

Conversation

@sjw36
Copy link
Contributor

@sjw36 sjw36 commented Oct 15, 2024

When pipelining an scf.for with dynamic loop bounds, the epilogue ramp-down must align with the prologue when num_stages > total_iterations.

For example:

scf.for (0..ub) {
  load(i)
  add(i)
  store(i)
}

When num_stages=3 the pipeline follows:

load(0)  -  add(0)      -  scf.for (0..ub-2)    -  store(ub-2)
            load(1)     -                       -  add(ub-1)     -  store(ub-1)

The trailing store(ub-2), i=ub-2, must align with the ramp-up for i=0 when ub < num_stages-1, so the index i should be max(0, ub-2) and each subsequent index is an increment. The predicate must also handle this scenario, so it becomes predicate[0] = total_iterations > epilogue_stage.

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: SJW (sjw36)

Changes

When pipelining an scf.for with dynamic loop bounds, the epilogue ramp-down must align with the prologue when num_stages > total_iterations.

For example:

scf.for (0..ub) {
  load(i)
  add(i)
  store(i)
}

When num_stages=3 the pipeline follows:

load(0)  -  add(0)     -  scf.for (0..ub-2)    -  store(ub-2)
                load(1)     -                              -  add(ub-1)     -  store(ub-1)

The trailing store(ub-2), i=ub-2, must align with the ramp-up for i=0 when ub &lt; num_stages-1, so the index i should be max(0, ub-2) and each subsequent index is an increment. The predicate must also handle this scenario, so it becomes predicate[0] = total_iterations &gt; epilogue_stage.


Full diff: https://github.com/llvm/llvm-project/pull/112418.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp (+31-22)
  • (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+34-33)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 83c9cf69ba0364..be75640b44bd9a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -642,22 +642,25 @@ LogicalResult
 LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
                                     llvm::SmallVector<Value> &returnValues) {
   Location loc = forOp.getLoc();
+  Type t = lb.getType();
+
   // Emit different versions of the induction variable. They will be
   // removed by dead code if not used.
 
-  // bounds_range = ub - lb
-  // total_iterations = (bounds_range + step - 1) / step
-  Type t = lb.getType();
-  Value zero =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
-  Value one =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
-  Value minusOne =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+  auto getConst = [&](int v) {
+    return rewriter.create<arith::ConstantOp>(loc,
+                                              rewriter.getIntegerAttr(t, v));
+  };
+
+  // total_iterations = cdiv(range_diff, step);
+  // - range_diff = ub - lb
+  // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
+  Value zero = getConst(0);
+  Value one = getConst(1);
   Value stepLessZero = rewriter.create<arith::CmpIOp>(
       loc, arith::CmpIPredicate::slt, step, zero);
   Value stepDecr =
-      rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
+      rewriter.create<arith::SelectOp>(loc, stepLessZero, one, getConst(-1));
 
   Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
   Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
@@ -665,25 +668,31 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
       rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
   Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
 
+  // If total_iters < max_stage, start the epilogue at zero to match the
+  // ramp-up in the prologue.
+  // start_iter = max(0, total_iters - max_stage)
+  Value iterI =
+      rewriter.create<arith::SubIOp>(loc, totalIterations, getConst(maxStage));
+  iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
+
+  // Capture predicates for dynamic loops.
   SmallVector<Value> predicates(maxStage + 1);
-  for (int64_t i = 0; i < maxStage; i++) {
-    // iterI = total_iters - 1 - i
-    // May go negative...
-    Value minusI =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
-    Value iterI = rewriter.create<arith::AddIOp>(
-        loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
-        minusI);
+
+  for (int64_t i = 1; i <= maxStage; i++) {
     // newLastIter = lb + step * iterI
     Value newlastIter = rewriter.create<arith::AddIOp>(
         loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
 
-    setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+    setValueMapping(forOp.getInductionVar(), newlastIter, i);
+
+    // increment to next iterI
+    iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
 
     if (dynamicLoop) {
-      // pred = iterI >= 0
-      predicates[i + 1] = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::sge, iterI, zero);
+      // Disable stages when `i` is greater than total_iters.
+      // pred = total_iters >= i
+      predicates[i] = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::sge, totalIterations, getConst(i));
     }
   }
 
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index af49d2afc049ba..c879c83275bf86 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -767,6 +767,7 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 // Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL: dynamic_loop(
 //    CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//    CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
 //    CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //    CHECK-DAG:   %[[CM1:.*]] = arith.constant -1 : index
 //        CHECK:   %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
@@ -779,32 +780,32 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 //        CHECK:       scf.yield %[[ADDF_24]], %[[LOAD_27]]
 //        CHECK:   }
 //        CHECK:   %[[CMPI_10:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
-//        CHECK:   %[[SEL_10:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
-//        CHECK:   %[[SUBI_10:.*]] = arith.subi %[[UB]], %[[LB]]
-//        CHECK:   %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %[[STEP]]
-//        CHECK:   %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %[[SEL_10]]
-//        CHECK:   %[[DIVSI_13:.*]] = arith.divsi %[[ADDI_12]], %[[STEP]]
-//        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[DIVSI_13]], %[[CM1]]
-//        CHECK:   %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
-//        CHECK:   %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-//        CHECK:   %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
-//        CHECK:   %[[ADDI_18:.*]] = arith.addi %[[DIVSI_13]], %{{.*}}-1
-//        CHECK:   %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
-//        CHECK:   %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
-//        CHECK:   %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %[[C0]]
-//        CHECK:   scf.if %[[CMPI_17]] {
-//        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
+//        CHECK:   %[[SELECT_11:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
+//        CHECK:   %[[SUBI_12:.*]] = arith.subi %[[UB]], %[[LB]]
+//        CHECK:   %[[ADDI_13:.*]] = arith.addi %[[SUBI_12]], %[[STEP]]
+//        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[ADDI_13]], %[[SELECT_11]]
+//        CHECK:   %[[DIVSI_15:.*]] = arith.divsi %[[ADDI_14]], %[[STEP]]
+//        CHECK:   %[[SUBI_17:.*]] = arith.subi %[[DIVSI_15]], %[[C2]]
+//        CHECK:   %[[MAXSI_18:.*]] = arith.maxsi %[[SUBI_17]], %[[C0]]
+//        CHECK:   %[[MULI_19:.*]] = arith.muli %[[STEP]], %[[MAXSI_18]]
+//        CHECK:   %[[ADDI_20:.*]] = arith.addi %[[LB]], %[[MULI_19]]
+//        CHECK:   %[[ADDI_21:.*]] = arith.addi %[[MAXSI_18]], %[[C1]]
+//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C1]]
+//        CHECK:   %[[MULI_23:.*]] = arith.muli %[[STEP]], %[[ADDI_21]]
+//        CHECK:   %[[ADDI_24:.*]] = arith.addi %[[LB]], %[[MULI_23]]
+//        CHECK:   %[[CMPI_25:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C2]]
+//        CHECK:   scf.if %[[CMPI_22]] {
+//        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_20]]]
 //        CHECK:   } else {
 //        CHECK:   }
-//        CHECK:   %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) {
-//        CHECK:     %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}}
-//        CHECK:     scf.yield %[[ADDF_24]]
+//        CHECK:   %[[IF_26:.*]] = scf.if %[[CMPI_25]]
+//        CHECK:     %[[ADDF_27:.*]] = arith.addf %{{.*}}#1, %{{.*}}
+//        CHECK:     scf.yield %[[ADDF_27]]
 //        CHECK:   } else {
 //        CHECK:     scf.yield %{{.*}}
 //        CHECK:   }
-//        CHECK:   scf.if %[[CMPI_22]] {
-//        CHECK:     memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]]
+//        CHECK:   scf.if %[[CMPI_25]] {
+//        CHECK:     memref.store %[[IF_26]], %{{.*}}[%[[ADDI_24]]]
 //        CHECK:   } else {
 //        CHECK:   }
 //        CHECK:   return
@@ -842,6 +843,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
 //   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[CM1:.*]] = arith.constant -1 : index
+//   CHECK-DAG:   %[[CF0:.*]] = arith.constant 0.000000e+00
 //       CHECK:   %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
 //       CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
 //       CHECK:       %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
@@ -856,22 +858,21 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
 //       CHECK:     %[[ADDI_7:.*]] = arith.addi %[[SUBI_6]], %[[STEP]]
 //       CHECK:     %[[ADDI_8:.*]] = arith.addi %[[ADDI_7]], %[[SELECT_5]]
 //       CHECK:     %[[DIVSI_9:.*]] = arith.divsi %[[ADDI_8]], %[[STEP]]
-//       CHECK:     %[[ADDI_10:.*]] = arith.addi %[[DIVSI_9]], %[[CM1]]
-//       CHECK:     %[[CMPI_11:.*]] = arith.cmpi sge, %[[ADDI_10]], %[[C0]]
-//       CHECK:     %[[IF_10:.*]] = scf.if %[[CMPI_11]]
-//       CHECK:       %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
-//       CHECK:       scf.yield %[[ADDF_13]]
+//       CHECK:     %[[CMPI_10:.*]] = arith.cmpi sge, %[[DIVSI_9]], %[[C1]]
+//       CHECK:     %[[IF_11:.*]] = scf.if %[[CMPI_10]]
+//       CHECK:       %[[ADDF_14:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
+//       CHECK:       scf.yield %[[ADDF_14]]
 //       CHECK:     } else {
-//       CHECK:       scf.yield %{{.*}}
+//       CHECK:       scf.yield %[[CF0]]
 //       CHECK:     }
-//       CHECK:     %[[IF_11:.*]] = scf.if %[[CMPI_11]]
-//       CHECK:       %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
-//       CHECK:       scf.yield %[[MULF_13]]
+//       CHECK:     %[[IF_12:.*]] = scf.if %[[CMPI_10]]
+//       CHECK:       %[[MULF_14:.*]] = arith.mulf %[[IF_11]], %{{.*}}
+//       CHECK:       scf.yield %[[MULF_14]]
 //       CHECK:     } else {
-//       CHECK:       scf.yield %{{.*}}
+//       CHECK:       scf.yield %[[CF0]]
 //       CHECK:     }
-//       CHECK:     %[[SELECT_12:.*]] = arith.select %[[CMPI_11]], %[[IF_11]], %{{.*}}#0
-//       CHECK:     memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
+//       CHECK:     %[[SELECT_13:.*]] = arith.select %[[CMPI_10]], %[[IF_12]], %{{.*}}#0
+//       CHECK:     memref.store %[[SELECT_13]], %{{.*}}[%[[C0]]]
 func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
   %cf0 = arith.constant 1.0 : f32
   %cf1 = arith.constant 33.0 : f32

Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
Value minusOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
auto getConst = [&](int v) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe rename to createConst?

@antiagainst antiagainst merged commit 8da5aa1 into llvm:main Oct 15, 2024
6 of 7 checks passed
antiagainst pushed a commit to triton-lang/triton that referenced this pull request Oct 15, 2024
- The epilogue ramp-down indexing must start at zero or greater
  (total_iterations - max_stage) to ensure alignment with the prologue
  ramp-up stages.
- If total_iterations < max_stage, the trailing stages will be masked.

This commit mirrors upstream llvm/llvm-project#112418
and adds a functional test for correctness with num_stages=1,2,3,4.
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
…iters (llvm#112418)

When pipelining an `scf.for` with dynamic loop bounds, the epilogue
ramp-down must align with the prologue when num_stages >
total_iterations.

For example:
```
scf.for (0..ub) {
  load(i)
  add(i)
  store(i)
}
```
When num_stages=3 the pipeline follows:
```
load(0)  -  add(0)      -  scf.for (0..ub-2)    -  store(ub-2)
            load(1)     -                       -  add(ub-1)     -  store(ub-1)

```
The trailing `store(ub-2)`, `i=ub-2`, must align with the ramp-up for
`i=0` when `ub < num_stages-1`, so the index `i` should be `max(0,
ub-2)` and each subsequent index is an increment. The predicate must
also handle this scenario, so it becomes `predicate[0] =
total_iterations > epilogue_stage`.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
- The epilogue ramp-down indexing must start at zero or greater
  (total_iterations - max_stage) to ensure alignment with the prologue
  ramp-up stages.
- If total_iterations < max_stage, the trailing stages will be masked.

This commit mirrors upstream llvm/llvm-project#112418
and adds a functional test for correctness with num_stages=1,2,3,4.
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
- The epilogue ramp-down indexing must start at zero or greater
  (total_iterations - max_stage) to ensure alignment with the prologue
  ramp-up stages.
- If total_iterations < max_stage, the trailing stages will be masked.

This commit mirrors upstream llvm/llvm-project#112418
and adds a functional test for correctness with num_stages=1,2,3,4.
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
- The epilogue ramp-down indexing must start at zero or greater
  (total_iterations - max_stage) to ensure alignment with the prologue
  ramp-up stages.
- If total_iterations < max_stage, the trailing stages will be masked.

This commit mirrors upstream llvm/llvm-project#112418
and adds a functional test for correctness with num_stages=1,2,3,4.
liuyunqi20 pushed a commit to flagos-ai/flagtree that referenced this pull request Oct 21, 2025
- The epilogue ramp-down indexing must start at zero or greater
  (total_iterations - max_stage) to ensure alignment with the prologue
  ramp-up stages.
- If total_iterations < max_stage, the trailing stages will be masked.

This commit mirrors upstream llvm/llvm-project#112418
and adds a functional test for correctness with num_stages=1,2,3,4.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants