From d9ad36dfc4c47e64e2ca0170f13569c3e904dd79 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 14 Oct 2025 20:01:35 +0000 Subject: [PATCH 1/2] add fix --- .../Vector/Transforms/VectorDistribute.cpp | 25 +++++++++---- .../Vector/vector-warp-distribute.mlir | 35 +++++++++++++++++++ 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index e95338f7d18be..2ee65dc0f902a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -2038,11 +2038,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern { } // Newly created `WarpOp` will yield values in following order: - // 1. All init args of the `ForOp`. - // 2. All escaping values. - // 3. All non-`ForOp` yielded values. + // 1. Loop bounds. + // 2. All init args of the `ForOp`. + // 3. All escaping values. + // 4. All non-`ForOp` yielded values. SmallVector newWarpOpYieldValues; SmallVector newWarpOpDistTypes; + newWarpOpYieldValues.insert( + newWarpOpYieldValues.end(), + {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()}); + newWarpOpDistTypes.insert(newWarpOpDistTypes.end(), + {forOp.getLowerBound().getType(), + forOp.getUpperBound().getType(), + forOp.getStep().getType()}); for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) { newWarpOpYieldValues.push_back(initArg); // Compute the distributed type for this init arg. @@ -2081,20 +2089,23 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // Next, we create a new `ForOp` with the init args yielded by the new // `WarpOp`. + const unsigned initArgsStartIdx = 3; // After loop bounds. const unsigned escapingValuesStartIdx = + initArgsStartIdx + forOp.getInitArgs().size(); // `ForOp` init args are positioned before // escaping values in the new `WarpOp`. SmallVector newForOpOperands; - for (size_t i = 0; i < escapingValuesStartIdx; ++i) + for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i) newForOpOperands.push_back(newWarpOp.getResult(i)); // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); auto newForOp = scf::ForOp::create( - rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr, - forOp.getUnsignedCmp()); + rewriter, forOp.getLoc(), /**LowerBound=**/ newWarpOp.getResult(0), + /**UpperBound=**/ newWarpOp.getResult(1), + /**Step=**/ newWarpOp.getResult(2), newForOpOperands, + /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the // newly created `ForOp`. This `WarpOp` will contain all ops that were // contained within the original `ForOp` body. diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index bb7639204022f..ab87684dbb01a 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -473,6 +473,41 @@ func.func @warp_scf_for_use_from_above(%arg0: index) { return } +// ----- +// CHECK-PROP-LABEL: func.func @warp_scf_for_local_loop_bounds +// CHECK-PROP: (%{{.*}}: index, %[[ARG1:[a-zA-Z0-9]+]]: index) { +// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG1]] : index) -> (vector<4xf32>) { +// CHECK-PROP: ^bb0(%{{.*}}: index): +// CHECK-PROP: %[[T2:.*]] = "some_def"() : () -> vector<128xf32> +// CHECK-PROP: gpu.yield %[[T2]] : vector<128xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[FOR:.*]] = scf.for %{{.*}} to %[[ARG1]] step %{{.*}} iter_args(%{{.*}}) -> (vector<4xf32>) { +// CHECK-PROP: %[[W2:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] +// CHECK-PROP-SAME: args(%{{.*}} : vector<4xf32>) -> (vector<4xf32>) { +// CHECK-PROP: ^bb0(%{{.*}}: vector<128xf32>): +// CHECK-PROP: gpu.yield %{{.*}} : vector<128xf32> +// CHECK-PROP: } +// CHECK-PROP: scf.yield %[[W2]] : vector<4xf32> +// CHECK-PROP: } +// CHECK-PROP: "some_use"(%[[FOR]]) : (vector<4xf32>) -> () +// CHECK-PROP: return +func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) { + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = gpu.warp_execute_on_lane_0(%arg0)[32] + args(%bound : index) -> (vector<4xf32>) { + ^bb0(%arg1: index): + %ini = "some_def"() : () -> (vector<128xf32>) + %3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) { + %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>) + scf.yield %acc : vector<128xf32> + } + gpu.yield %3 : vector<128xf32> + } + "some_use"(%0) : (vector<4xf32>) -> () + return +} + // ----- // CHECK-PROP-LABEL: func @warp_scf_for_swap( From 7224394205570ee7381e202f1e5239480f25fc65 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 15 Oct 2025 22:39:07 +0000 Subject: [PATCH 2/2] fix --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 7752f9e0a263b..7c019e7d25bf2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -2087,15 +2087,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // escaping values in the new `WarpOp`. SmallVector newForOpOperands; for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i) - newForOpOperands.push_back(newWarpOp.getResult(i)); + newForOpOperands.push_back(newWarpOp.getResult(newIndices[i])); // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); auto newForOp = scf::ForOp::create( - rewriter, forOp.getLoc(), /**LowerBound=**/ newWarpOp.getResult(0), - /**UpperBound=**/ newWarpOp.getResult(1), - /**Step=**/ newWarpOp.getResult(2), newForOpOperands, + rewriter, forOp.getLoc(), + /**LowerBound=**/ newWarpOp.getResult(newIndices[0]), + /**UpperBound=**/ newWarpOp.getResult(newIndices[1]), + /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands, /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the // newly created `ForOp`. This `WarpOp` will contain all ops that were