Skip to content

Conversation

sjw36
Copy link
Contributor

@sjw36 sjw36 commented Sep 17, 2024

The computed loop iteration is zero based, so only check it is less than zero.
This fixes the case when lower bound is not zero.

@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: SJW (sjw36)

Changes

The computed loop iteration is zero based, so only check it is less than zero.
This fixes the case when lower bound is not zero.


Full diff: https://github.com/llvm/llvm-project/pull/108964.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp (+5-2)
  • (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+3-3)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 7cecd4942b640f..ad6f790a5ba02c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -655,6 +655,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
   Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
   Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
 
+  Value zero =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+
   SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
     // iterI = total_iters - 1 - i
@@ -671,9 +674,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
 
     if (dynamicLoop) {
-      // pred = iterI >= lb
+      // pred = iterI < 0
       predicates[i + 1] = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::sge, iterI, lb);
+          loc, arith::CmpIPredicate::slt, iterI, zero);
     }
   }
 
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 4a1406faabce1b..048786bad5d447 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -781,12 +781,12 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 //        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
 //        CHECK:   %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
 //        CHECK:   %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-//        CHECK:   %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
+//        CHECK:   %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_14]], %{{.*}}
 //        CHECK:   %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
 //        CHECK:   %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
 //        CHECK:   %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
 //        CHECK:   %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
+//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi slt, %[[ADDI_19]], %{{.*}}
 //        CHECK:   scf.if %[[CMPI_17]] {
 //        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
 //        CHECK:   } else {
@@ -845,7 +845,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
 //       CHECK:     %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
 //       CHECK:     %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
 //       CHECK:     %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
-//       CHECK:     %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
+//       CHECK:     %[[CMPI_9:.*]] = arith.cmpi slt, %[[ADDI_8]], %{{.*}}
 //       CHECK:     %[[IF_10:.*]] = scf.if %[[CMPI_9]]
 //       CHECK:       %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
 //       CHECK:       scf.yield %[[ADDF_13]]

Comment on lines 677 to 679
// pred = iterI < 0
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, iterI, lb);
loc, arith::CmpIPredicate::slt, iterI, zero);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be pred = iterI >= 0? So the predicate is true only when iterI is negative seems odd considering iterI = total_iters - 1 - i?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that's right.
So for:

scf.for (lb=3..ub=4, step=2) max_stage = 2                                                        

This computes:

  total_iterations = (4 - 3) + 2 - 1 / 2 = 1                                                    
  iterI(0) = total_iterations - 1 - 0 = 0                                                       
  newlastIter(0) = iterI(0) * 2 + 3 = 3                                                         
  pred(0) = iterI(0) >= 0 = true                                                                
  iterI(1) = total_iterations - 1 - 1 = -1                                                      
  newlastIter(1) = iterI(1) * 2 + 3 = 1                                                         
  pred(1) = iterI(1) >= 0 = false

I'll update now.

@ThomasRaoux
Copy link
Contributor

can we add a test or modify one to not start from 0 so that this can be tested?

@sjw36
Copy link
Contributor Author

sjw36 commented Sep 18, 2024

@ThomasRaoux updated tests.

@sjw36 sjw36 force-pushed the mlir-scf-pipeline-bugfix branch from d6c60bd to e7c5558 Compare September 18, 2024 18:03
antiagainst pushed a commit to triton-lang/triton that referenced this pull request Sep 18, 2024
The computed loop iteration is zero based, so only check it is less than zero.
This fixes the case when lower bound is not zero.
@sjw36 sjw36 force-pushed the mlir-scf-pipeline-bugfix branch from ca879f8 to 93b0115 Compare September 19, 2024 20:18
@sjw36
Copy link
Contributor Author

sjw36 commented Sep 20, 2024

@ThomasRaoux does this look adequate?

@antiagainst antiagainst merged commit fa089b0 into llvm:main Sep 24, 2024
8 checks passed
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants