diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 12e6475fa66e3..7c019e7d25bf2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -2032,11 +2032,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern { } // Newly created `WarpOp` will yield values in following order: - // 1. All init args of the `ForOp`. - // 2. All escaping values. - // 3. All non-`ForOp` yielded values. + // 1. Loop bounds. + // 2. All init args of the `ForOp`. + // 3. All escaping values. + // 4. All non-`ForOp` yielded values. SmallVector newWarpOpYieldValues; SmallVector newWarpOpDistTypes; + newWarpOpYieldValues.insert( + newWarpOpYieldValues.end(), + {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()}); + newWarpOpDistTypes.insert(newWarpOpDistTypes.end(), + {forOp.getLowerBound().getType(), + forOp.getUpperBound().getType(), + forOp.getStep().getType()}); for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) { newWarpOpYieldValues.push_back(initArg); // Compute the distributed type for this init arg. @@ -2072,20 +2080,24 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // Next, we create a new `ForOp` with the init args yielded by the new // `WarpOp`. + const unsigned initArgsStartIdx = 3; // After loop bounds. const unsigned escapingValuesStartIdx = + initArgsStartIdx + forOp.getInitArgs().size(); // `ForOp` init args are positioned before // escaping values in the new `WarpOp`. SmallVector newForOpOperands; - for (size_t i = 0; i < escapingValuesStartIdx; ++i) + for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i) newForOpOperands.push_back(newWarpOp.getResult(newIndices[i])); // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); auto newForOp = scf::ForOp::create( - rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr, - forOp.getUnsignedCmp()); + rewriter, forOp.getLoc(), + /**LowerBound=**/ newWarpOp.getResult(newIndices[0]), + /**UpperBound=**/ newWarpOp.getResult(newIndices[1]), + /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands, + /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the // newly created `ForOp`. This `WarpOp` will contain all ops that were // contained within the original `ForOp` body. diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 401cdd29b281c..0cf6dd151e16c 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -473,6 +473,41 @@ func.func @warp_scf_for_use_from_above(%arg0: index) { return } +// ----- +// CHECK-PROP-LABEL: func.func @warp_scf_for_local_loop_bounds +// CHECK-PROP: (%{{.*}}: index, %[[ARG1:[a-zA-Z0-9]+]]: index) { +// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG1]] : index) -> (vector<4xf32>) { +// CHECK-PROP: ^bb0(%{{.*}}: index): +// CHECK-PROP: %[[T2:.*]] = "some_def"() : () -> vector<128xf32> +// CHECK-PROP: gpu.yield %[[T2]] : vector<128xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[FOR:.*]] = scf.for %{{.*}} to %[[ARG1]] step %{{.*}} iter_args(%{{.*}}) -> (vector<4xf32>) { +// CHECK-PROP: %[[W2:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] +// CHECK-PROP-SAME: args(%{{.*}} : vector<4xf32>) -> (vector<4xf32>) { +// CHECK-PROP: ^bb0(%{{.*}}: vector<128xf32>): +// CHECK-PROP: gpu.yield %{{.*}} : vector<128xf32> +// CHECK-PROP: } +// CHECK-PROP: scf.yield %[[W2]] : vector<4xf32> +// CHECK-PROP: } +// CHECK-PROP: "some_use"(%[[FOR]]) : (vector<4xf32>) -> () +// CHECK-PROP: return +func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) { + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = gpu.warp_execute_on_lane_0(%arg0)[32] + args(%bound : index) -> (vector<4xf32>) { + ^bb0(%arg1: index): + %ini = "some_def"() : () -> (vector<128xf32>) + %3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) { + %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>) + scf.yield %acc : vector<128xf32> + } + gpu.yield %3 : vector<128xf32> + } + "some_use"(%0) : (vector<4xf32>) -> () + return +} + // ----- // CHECK-PROP-LABEL: func @warp_scf_for_swap(