From 492439a533a552b2859951d6e331663f1c3244b2 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 21 Apr 2025 21:26:41 -0400 Subject: [PATCH 1/4] [mlir][linalg] Add folder for `linalg.index` We know that the index of unit dims is always 0. --- .../mlir/Dialect/Linalg/IR/LinalgOps.td | 1 + mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 29 +++++++ mlir/test/Dialect/Linalg/canonicalize.mlir | 80 +++++++++++++++++++ 3 files changed, 110 insertions(+) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index f8df828f74851..1b48bf5fcb237 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -88,6 +88,7 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>, let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; let hasVerifier = 1; + let hasFolder = 1; } def Linalg_SoftmaxOp : Linalg_Op<"softmax", diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 6c680498af2ad..a3787f101afa3 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2283,6 +2283,35 @@ LogicalResult IndexOp::verify() { return success(); } +OpFoldResult IndexOp::fold(FoldAdaptor adaptor) { + auto linalgOp = cast((*this)->getParentOp()); + int64_t flatDimPos = + cast(linalgOp.getShapesToLoopsMap().getResult(getDim())) + .getPosition(); + + // Find the flat dimension position among the operands. + int64_t flatPosOffset = 0; + for (Value operand : linalgOp->getOperands()) { + assert(flatDimPos >= flatPosOffset && "invalid position"); + auto shapedType = dyn_cast(operand.getType()); + if (!shapedType) + break; + + int64_t rank = shapedType.getRank(); + if (flatDimPos < flatPosOffset + rank) { + // Found the dimension within this shape. Now we can either fold if the + // dim size is 1, or bail out otherwise. + int64_t pos = flatDimPos - flatPosOffset; + if (shapedType.getDimSize(pos) != 1) + break; + + return IntegerAttr::get(IndexType::get(getContext()), 0); + } + flatPosOffset += rank; + } + return OpFoldResult{}; +} + /////// Operations corresponding to library calls defined with Tablegen //////// #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 86cb8f58abe02..3daf221f4402d 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -305,6 +305,86 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) { } // ----- + +// CHECK: func @fold_linalg_index_tensor_static +func.func @fold_linalg_index_tensor_static(%0: tensor<4x16xi32>, %1: tensor<1x16xi32>, + %2: tensor<4x1xi32>) -> tensor<4x1xi32> { +// CHECK-NEXT: linalg.generic +// CHECK: %[[IDX_0:.+]] = linalg.index 0 : index +// CHECK-NOT: linalg.index 1 +// CHECK: %[[IDX_2:.+]] = linalg.index 2 : index +// CHECK: %[[ADD:.+]] = arith.addi %[[IDX_0]], %[[IDX_2]] +// CHECK: %[[CAST:.+]] = arith.index_cast %[[ADD]] +// CHECK: linalg.yield %[[CAST]] + %res = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%0, %1 : tensor<4x16xi32>, tensor<1x16xi32>) + outs(%2 : tensor<4x1xi32>) { + ^bb0(%lhs: i32, %rhs: i32, %out: i32): + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %idx2 = linalg.index 2 : index + %add0 = arith.addi %idx0, %idx1 : index + %add1 = arith.addi %add0, %idx2 : index + %int = arith.index_cast %add1 : index to i32 + linalg.yield %int : i32 + } -> tensor<4x1xi32> + return %res : tensor<4x1xi32> +} + +// ----- + +// CHECK: func @fold_linalg_index_tensor_dynamic +func.func @fold_linalg_index_tensor_dynamic(%0: tensor, + %1: tensor) -> tensor { +// CHECK-NEXT: linalg.generic +// CHECK: %[[IDX_0:.+]] = linalg.index 0 : index +// CHECK-NOT: linalg.index 1 +// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_0]] +// CHECK: linalg.yield %[[CAST]] + %res = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0 : tensor) + outs(%1 : tensor) { + ^bb0(%lhs: i32, %out: i32): + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %add = arith.addi %idx0, %idx1 : index + %int = arith.index_cast %add : index to i32 + linalg.yield %int : i32 + } -> tensor + return %res : tensor +} + +// ----- + +// CHECK: func @fold_linalg_index_memref +func.func @fold_linalg_index_memref(%0: memref<1x?xi32>, %1: memref<1x?xi32>) { +// CHECK-NEXT: linalg.generic +// CHECK-NOT: linalg.index 0 +// CHECK: %[[IDX_1:.+]] = linalg.index 1 : index +// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_1]] +// CHECK: linalg.yield %[[CAST]] + linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0 : memref<1x?xi32>) + outs(%1 : memref<1x?xi32>) { + ^bb0(%lhs: i32, %out: i32): + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %add = arith.addi %idx0, %idx1 : index + %int = arith.index_cast %add : index to i32 + linalg.yield %int : i32 + } + return +} + +// ----- + // CHECK-LABEL: func @fold_fill_reshape() func.func @fold_fill_reshape() -> tensor<6x4xf32> { %zero = arith.constant 0.0 : f32 From b3e5afe9fbf42b095eaf232bc7f3ffa52b4864f1 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 21 Apr 2025 21:32:52 -0400 Subject: [PATCH 2/4] Update test --- mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir index 375fa37bd84b0..01eafafc8ea29 100644 --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -278,12 +278,11 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(% // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32> // CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex> // CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index -// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex> // CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex> // CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex> // CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> -// CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex> -// CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex> +// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex> +// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex> // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> // CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> From cbcdb3346e6296552fec9bbcbc3838764c59e6ec Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 21 Apr 2025 21:51:05 -0400 Subject: [PATCH 3/4] Simplify --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 28 +++++------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index a3787f101afa3..967b7685cd89c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2285,30 +2285,14 @@ LogicalResult IndexOp::verify() { OpFoldResult IndexOp::fold(FoldAdaptor adaptor) { auto linalgOp = cast((*this)->getParentOp()); - int64_t flatDimPos = - cast(linalgOp.getShapesToLoopsMap().getResult(getDim())) - .getPosition(); - - // Find the flat dimension position among the operands. - int64_t flatPosOffset = 0; - for (Value operand : linalgOp->getOperands()) { - assert(flatDimPos >= flatPosOffset && "invalid position"); - auto shapedType = dyn_cast(operand.getType()); - if (!shapedType) - break; - int64_t rank = shapedType.getRank(); - if (flatDimPos < flatPosOffset + rank) { - // Found the dimension within this shape. Now we can either fold if the - // dim size is 1, or bail out otherwise. - int64_t pos = flatDimPos - flatPosOffset; - if (shapedType.getDimSize(pos) != 1) - break; + // Index of unit dims is always 0. + SmallVector loopBounds = linalgOp.getStaticLoopRanges(); + uint64_t dim = getDim(); + assert(dim < loopBounds.size()); + if (loopBounds[dim] == 1) + return IntegerAttr::get(IndexType::get(getContext()), 0); - return IntegerAttr::get(IndexType::get(getContext()), 0); - } - flatPosOffset += rank; - } return OpFoldResult{}; } From 12ba5a3e8dbc9f80b67601df54da9482dc39ddf1 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Tue, 22 Apr 2025 11:29:06 -0400 Subject: [PATCH 4/4] Add assert message --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 967b7685cd89c..72fb3308a2549 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2289,7 +2289,7 @@ OpFoldResult IndexOp::fold(FoldAdaptor adaptor) { // Index of unit dims is always 0. SmallVector loopBounds = linalgOp.getStaticLoopRanges(); uint64_t dim = getDim(); - assert(dim < loopBounds.size()); + assert(dim < loopBounds.size() && "Dim is out of bounds"); if (loopBounds[dim] == 1) return IntegerAttr::get(IndexType::get(getContext()), 0);