diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h index 347beb9e4c64f..cad5173599453 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h @@ -128,6 +128,13 @@ struct PipeliningOption { /// lambda to generate the predicated version of operations. bool peelEpilogue = true; + /// Control whether the transformation checks that the number of iterations is + /// greater or equal to the number of stages and skip the transformation if + /// this is not the case. If the loop is dynamic and this is set to true and + /// the loop bounds are not static the pipeliner will have to predicate + /// operations in the the prologue/epilogue. + bool supportDynamicLoops = false; + // Callback to predicate operations when the prologue or epilogue are not // peeled. This takes the original operation, an i1 predicate value and the // pattern rewriter. It is expected to replace the given operation with diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index 20fa8089201aa..6c36600975a59 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -44,9 +44,10 @@ struct LoopPipelinerInternal { unsigned maxStage = 0; DenseMap stages; std::vector opOrder; - int64_t ub; - int64_t lb; - int64_t step; + Value ub; + Value lb; + Value step; + bool dynamicLoop; PipeliningOption::AnnotationlFnType annotateFn = nullptr; bool peelEpilogue; PipeliningOption::PredicateOpFn predicateFn = nullptr; @@ -96,25 +97,41 @@ bool LoopPipelinerInternal::initializeLoopInfo( ForOp op, const PipeliningOption &options) { LDBG("Start initializeLoopInfo"); forOp = op; - auto upperBoundCst = - forOp.getUpperBound().getDefiningOp(); - auto lowerBoundCst = - forOp.getLowerBound().getDefiningOp(); - auto stepCst = forOp.getStep().getDefiningOp(); + ub = forOp.getUpperBound(); + lb = forOp.getLowerBound(); + step = forOp.getStep(); + + dynamicLoop = true; + auto upperBoundCst = getConstantIntValue(ub); + auto lowerBoundCst = getConstantIntValue(lb); + auto stepCst = getConstantIntValue(step); if (!upperBoundCst || !lowerBoundCst || !stepCst) { - LDBG("--no constant bounds or step -> BAIL"); - return false; + if (!options.supportDynamicLoops) { + LDBG("--dynamic loop not supported -> BAIL"); + return false; + } + } else { + int64_t ubImm = upperBoundCst.value(); + int64_t lbImm = lowerBoundCst.value(); + int64_t stepImm = stepCst.value(); + int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm); + if (numIteration > maxStage) { + dynamicLoop = false; + } else if (!options.supportDynamicLoops) { + LDBG("--fewer loop iterations than pipeline stages -> BAIL"); + return false; + } } - ub = upperBoundCst.value(); - lb = lowerBoundCst.value(); - step = stepCst.value(); peelEpilogue = options.peelEpilogue; predicateFn = options.predicateFn; - if (!peelEpilogue && predicateFn == nullptr) { + if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { LDBG("--no epilogue or predicate set -> BAIL"); return false; } - int64_t numIteration = ceilDiv(ub - lb, step); + if (dynamicLoop && peelEpilogue) { + LDBG("--dynamic loop doesn't support epilogue yet -> BAIL"); + return false; + } std::vector> schedule; options.getScheduleFn(forOp, schedule); if (schedule.empty()) { @@ -128,10 +145,6 @@ bool LoopPipelinerInternal::initializeLoopInfo( stages[opSchedule.first] = opSchedule.second; opOrder.push_back(opSchedule.first); } - if (numIteration <= maxStage) { - LDBG("--fewer loop iterations than pipeline stages -> BAIL"); - return false; - } // All operations need to have a stage. for (Operation &op : forOp.getBody()->without_terminator()) { @@ -204,10 +217,31 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { setValueMapping(arg, operand.get(), 0); } auto yield = cast(forOp.getBody()->getTerminator()); + Location loc = forOp.getLoc(); + SmallVector predicates(maxStage); for (int64_t i = 0; i < maxStage; i++) { + if (dynamicLoop) { + Type t = ub.getType(); + // pred = ub > lb + (i * step) + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, i)))); + predicates[i] = rewriter.create( + loc, arith::CmpIPredicate::slt, iv, ub); + } + // special handling for induction variable as the increment is implicit. - Value iv = - rewriter.create(forOp.getLoc(), lb + i * step); + // iv = lb + i * step + Type t = lb.getType(); + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, + rewriter.getIntegerAttr(t, i)))); setValueMapping(forOp.getInductionVar(), iv, i); for (Operation *op : opOrder) { if (stages[op] > i) @@ -220,6 +254,12 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { newOperand->set(replacement); } }); + int predicateIdx = i - stages[op]; + if (predicates[predicateIdx]) { + newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); + assert(newOp && "failed to predicate op."); + } + rewriter.setInsertionPointAfter(newOp); if (annotateFn) annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i); for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { @@ -326,9 +366,16 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop( // `numStages - 1` iterations. Then we adjust the upper bound to remove those // iterations. Value newUb = forOp.getUpperBound(); - if (peelEpilogue) - newUb = rewriter.create(forOp.getLoc(), - ub - maxStage * step); + if (peelEpilogue) { + Type t = ub.getType(); + Location loc = forOp.getLoc(); + // newUb = ub - maxStage * step + Value maxStageValue = rewriter.create( + loc, rewriter.getIntegerAttr(t, maxStage)); + Value maxStageByStep = + rewriter.create(loc, step, maxStageValue); + newUb = rewriter.create(loc, ub, maxStageByStep); + } auto newForOp = rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, forOp.getStep(), newLoopArg); @@ -358,9 +405,17 @@ LogicalResult LoopPipelinerInternal::createKernel( SmallVector predicates(maxStage + 1, nullptr); if (!peelEpilogue) { // Create a predicate for each stage except the last stage. + Location loc = newForOp.getLoc(); + Type t = ub.getType(); for (unsigned i = 0; i < maxStage; i++) { - Value c = rewriter.create( - newForOp.getLoc(), ub - (maxStage - i) * step); + // c = ub - (maxStage - i) * step + Value c = rewriter.create( + loc, ub, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i))))); + Value pred = rewriter.create( newForOp.getLoc(), arith::CmpIPredicate::slt, newForOp.getInductionVar(), c); @@ -383,8 +438,14 @@ LogicalResult LoopPipelinerInternal::createKernel( // version incremented based on the stage where it is used. if (operand->get() == forOp.getInductionVar()) { rewriter.setInsertionPoint(newOp); - Value offset = rewriter.create( - forOp.getLoc(), (maxStage - stages[op]) * step); + + // offset = (maxStage - stages[op]) * step + Type t = step.getType(); + Value offset = rewriter.create( + forOp.getLoc(), step, + rewriter.create( + forOp.getLoc(), + rewriter.getIntegerAttr(t, maxStage - stages[op]))); Value iv = rewriter.create( forOp.getLoc(), newForOp.getInductionVar(), offset); nestedNewOp->setOperand(operand->getOperandNumber(), iv); @@ -508,8 +569,24 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) { // Emit different versions of the induction variable. They will be // removed by dead code if not used. for (int64_t i = 0; i < maxStage; i++) { - Value newlastIter = rewriter.create( - forOp.getLoc(), lb + step * ((((ub - 1) - lb) / step) - i)); + Location loc = forOp.getLoc(); + Type t = lb.getType(); + Value minusOne = + rewriter.create(loc, rewriter.getIntegerAttr(t, -1)); + // number of iterations = ((ub - 1) - lb) / step + Value totalNumIteration = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, ub, minusOne), lb), + step); + // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i) + Value minusI = + rewriter.create(loc, rewriter.getIntegerAttr(t, -i)); + Value newlastIter = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, totalNumIteration, minusI))); setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); } // Emit `maxStage - 1` epilogue part that includes operations from stages diff --git a/mlir/test/Dialect/NVGPU/transform-pipeline-shared.mlir b/mlir/test/Dialect/NVGPU/transform-pipeline-shared.mlir index 42b072374261e..e959949babd9e 100644 --- a/mlir/test/Dialect/NVGPU/transform-pipeline-shared.mlir +++ b/mlir/test/Dialect/NVGPU/transform-pipeline-shared.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s --transform-interpreter -canonicalize --split-input-file --verify-diagnostics | FileCheck %s func.func @simple_depth_2_unpeeled(%global: memref, %result: memref ) { %c0 = arith.constant 0 : index @@ -78,15 +78,19 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: @async_depth_2_predicated // CHECK-SAME: %[[GLOBAL:.+]]: memref -func.func @async_depth_2_predicated(%global: memref) { +func.func @async_depth_2_predicated(%global: memref, %alloc_size: index) { %c0 = arith.constant 0 : index %c98 = arith.constant 98 : index %c100 = arith.constant 100 : index - %c200 = arith.constant 200 : index - // CHECK: %[[C4:.+]] = arith.constant 4 + // CHECK-DAG: %[[C4:.+]] = arith.constant 4 + // CHECK-DAG: %[[C90:.+]] = arith.constant 90 + // CHECK-DAG: %[[C96:.+]] = arith.constant 96 + // CHECK-DAG: %[[C8:.+]] = arith.constant 8 + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 %c4 = arith.constant 4 : index // CHECK: %[[SHARED:.+]] = memref.alloc{{.*}} #gpu.address_space - %shared = memref.alloc(%c200) : memref> + %shared = memref.alloc(%alloc_size) : memref> %c0f = arith.constant 0.0 : f32 // CHECK: %[[TOKEN0:.+]] = nvgpu.device_async_copy // CHECK: %[[TOKEN1:.+]] = nvgpu.device_async_copy @@ -95,16 +99,11 @@ func.func @async_depth_2_predicated(%global: memref) { // CHECK-SAME: %[[ITER_ARG1:.+]] = %[[TOKEN1]] scf.for %i = %c0 to %c98 step %c4 { // Condition for the predication "select" below. - // CHECK: %[[C90:.+]] = arith.constant 90 // CHECK: %[[CMP0:.+]] = arith.cmpi slt, %[[I]], %[[C90]] // CHECK: nvgpu.device_async_wait %[[ITER_ARG0]] {numGroups = 1 - // Original "select" with updated induction variable. - // CHECK: %[[C96:.+]] = arith.constant 96 - // CHECK: %[[C8:.+]] = arith.constant 8 // CHECK: %[[I_PLUS_8:.+]] = arith.addi %[[I]], %[[C8]] // CHECK: %[[CMP1:.+]] = arith.cmpi slt, %[[I_PLUS_8]], %[[C96]] - // CHECK: %[[C2:.+]] = arith.constant 2 // CHECK: %[[SELECTED0:.+]] = arith.select %[[CMP1]], %[[C4]], %[[C2]] %c96 = arith.constant 96 : index %cond = arith.cmpi slt, %i, %c96 : index @@ -113,14 +112,11 @@ func.func @async_depth_2_predicated(%global: memref) { // Updated induction variables (two more) for the device_async_copy below. // These are generated repeatedly by the pipeliner. - // CHECK: %[[C8_2:.+]] = arith.constant 8 - // CHECK: %[[I_PLUS_8_2:.+]] = arith.addi %[[I]], %[[C8_2]] - // CHECK: %[[C8_3:.+]] = arith.constant 8 - // CHECK: %[[I_PLUS_8_3:.+]] = arith.addi %[[I]], %[[C8_3]] + // CHECK: %[[I_PLUS_8_2:.+]] = arith.addi %[[I]], %[[C8]] + // CHECK: %[[I_PLUS_8_3:.+]] = arith.addi %[[I]], %[[C8]] // The second "select" is generated by predication and selects 0 for // the two last iterations. - // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[SELECTED1:.+]] = arith.select %[[CMP0]], %[[SELECTED0]], %[[C0]] // CHECK: %[[ASYNC_TOKEN:.+]] = nvgpu.device_async_copy %[[GLOBAL]][%[[I_PLUS_8_3]]], %[[SHARED]][%[[I_PLUS_8_2]]], 4, %[[SELECTED1]] %token = nvgpu.device_async_copy %global[%i], %shared[%i], 4, %read_size diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir index 4cd686d2cdb86..8a57ddccfee66 100644 --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -723,3 +723,50 @@ func.func @stage_0_value_escape(%A: memref, %result: memref) { memref.store %r, %result[%c1] : memref return } + +// ----- + +// NOEPILOGUE-LABEL: dynamic_loop( +// NOEPILOGUE-SAME: %[[A:.*]]: memref, %[[R:.*]]: memref, %[[LB:.+]]: index, %[[UB:.+]]: index, %[[STEP:.+]]: index) { +// NOEPILOGUE-DAG: %[[C2:.+]] = arith.constant 2 : index +// NOEPILOGUE-DAG: %[[CSTF:.+]] = arith.constant 1.000000e+00 : f32 +// Prologue: +// NOEPILOGUE: %[[P_I0:.+]] = arith.cmpi slt, %[[LB]], %[[UB]] : index +// NOEPILOGUE: %[[L0:.+]] = scf.if %[[P_I0]] -> (f32) { +// NOEPILOGUE-NEXT: memref.load %[[A]][%[[LB]]] : memref +// NOEPILOGUE: %[[IV1:.+]] = arith.addi %[[LB]], %[[STEP]] : index +// NOEPILOGUE: %[[P_I1:.+]] = arith.cmpi slt, %[[IV1]], %[[UB]] : index +// NOEPILOGUE: %[[IV1_2:.+]] = arith.addi %[[LB]], %[[STEP]] : index +// NOEPILOGUE: %[[V0:.+]] = scf.if %[[P_I0]] -> (f32) { +// NOEPILOGUE-NEXT: arith.addf %[[L0]], %[[CSTF]] : f32 +// NOEPILOGUE: %[[L1:.+]] = scf.if %[[P_I1]] -> (f32) { +// NOEPILOGUE-NEXT: memref.load %[[A]][%[[IV1_2]]] : memref +// NOEPILOGUE: scf.for %[[IV2:.+]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[V1:.+]] = %[[V0]], %[[L2:.+]] = %[[L1]]) -> (f32, f32) { +// NOEPILOGUE-DAG: %[[S2:.+]] = arith.muli %[[STEP]], %[[C2]] : index +// NOEPILOGUE-DAG: %[[IT2:.+]] = arith.subi %[[UB]], %[[S2]] : index +// NOEPILOGUE-DAG: %[[P_I2:.+]] = arith.cmpi slt, %[[IV2]], %[[IT2]] : index +// NOEPILOGUE-DAG: %[[IT3:.+]] = arith.subi %[[UB]], %[[STEP]] : index +// NOEPILOGUE-DAG: %[[P_I3:.+]] = arith.cmpi slt, %[[IV2]], %[[IT3]] : index +// NOEPILOGUE: memref.store %[[V1]], %[[R]][%[[IV2]]] : memref +// NOEPILOGUE: %[[V2:.+]] = scf.if %[[P_I3]] -> (f32) { +// NOEPILOGUE: arith.addf %[[L2]], %[[CSTF]] : f32 +// NOEPILOGUE: %[[IT4:.+]] = arith.muli %[[STEP]], %[[C2]] : index +// NOEPILOGUE: %[[IV3:.+]] = arith.addi %[[IV2]], %[[IT4]] : index +// NOEPILOGUE: %[[L3:.+]] = scf.if %[[P_I2]] -> (f32) { +// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref +// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32 + +// In case dynamic loop pipelining is off check that the transformation didn't +// apply. +// CHECK-LABEL: dynamic_loop( +// CHECK-NOT: memref.load +// CHECK: scf.for +func.func @dynamic_loop(%A: memref, %result: memref, %lb: index, %ub: index, %step: index) { + %cf = arith.constant 1.0 : f32 + scf.for %i0 = %lb to %ub step %step { + %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref + %A1_elem = arith.addf %A_elem, %cf { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32 + memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : memref + } { __test_pipelining_loop__ } + return +} diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index 565d07669792f..a8a808424b690 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -217,6 +217,7 @@ struct TestSCFPipeliningPass if (annotatePipeline) options.annotateFn = annotate; if (noEpiloguePeeling) { + options.supportDynamicLoops = true; options.peelEpilogue = false; options.predicateFn = predicateOp; }