diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index b4a5461f4405d..94f9ead9e1665 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1989,6 +1989,45 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) { return fromElementsOp.getElements()[flatIndex]; } +/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants, +/// then fold it. +template +static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, + SmallVectorImpl &operands) { + std::vector staticPosition = op.getStaticPosition().vec(); + OperandRange dynamicPosition = op.getDynamicPosition(); + ArrayRef dynamicPositionAttr = adaptor.getDynamicPosition(); + + // If the dynamic operands is empty, it is returned directly. + if (!dynamicPosition.size()) + return {}; + + // `index` is used to iterate over the `dynamicPosition`. + unsigned index = 0; + + // `opChange` is a flag. If it is true, it means to update `op` in place. + bool opChange = false; + for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) { + if (!ShapedType::isDynamic(staticPosition[i])) + continue; + Attribute positionAttr = dynamicPositionAttr[index]; + Value position = dynamicPosition[index++]; + if (auto attr = mlir::dyn_cast_if_present(positionAttr)) { + staticPosition[i] = attr.getInt(); + opChange = true; + continue; + } + operands.push_back(position); + } + + if (opChange) { + op.setStaticPosition(staticPosition); + op.getOperation()->setOperands(operands); + return op.getResult(); + } + return {}; +} + /// Fold an insert or extract operation into an poison value when a poison index /// is found at any dimension of the static position. static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, @@ -2035,6 +2074,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { return val; if (auto val = foldScalarExtractFromFromElements(*this)) return val; + SmallVector operands = {getVector()}; + if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands)) + return val; return OpFoldResult(); } @@ -3094,6 +3136,9 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { // (type mismatch). if (getNumIndices() == 0 && getSourceType() == getType()) return getSource(); + SmallVector operands = {getSource(), getDest()}; + if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands)) + return val; if (auto res = foldPoisonIndexInsertExtractOp( getContext(), adaptor.getStaticPosition(), kPoisonIndex)) return res; diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index d319b9043b4b8..d261327ec005f 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -530,6 +530,25 @@ func.func @extract_scalar_from_vec_0d_index(%arg0: vector) -> index { // ----- +func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<32x1xi32>) -> i32 { + %0 = arith.constant 0 : index + %1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32> + return %1 : i32 +} + +// At compile time, since the indices of extractOp are constants, +// they will be collapsed and folded away; therefore, the lowering works. + +// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const +// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 { +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>> +// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<32 x vector<1xi32>> +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[RES:.*]] = llvm.extractelement %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32> +// CHECK: return %[[RES]] : i32 + +// ----- + //===----------------------------------------------------------------------===// // vector.insertelement //===----------------------------------------------------------------------===// @@ -781,6 +800,29 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1 // ----- +func.func @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<4x1xi32>) -> vector<4x1xi32> { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : i32 + %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32> + return %res : vector<4x1xi32> +} + +// At compile time, since the indices of insertOp are constants, +// they will be collapsed and folded away; therefore, the lowering works. + +// CHECK-LABEL: @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const +// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> { +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x1xi32> to !llvm.array<4 x vector<1xi32>> +// CHECK: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<4 x vector<1xi32>> +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VEC_1:.*]] = llvm.insertelement %[[C1]], %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32> +// CHECK: %[[VEC_2:.*]] = llvm.insertvalue %[[VEC_1]], %[[CAST]][0] : !llvm.array<4 x vector<1xi32>> +// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[VEC_2]] : !llvm.array<4 x vector<1xi32>> to vector<4x1xi32> +// CHECK: return %[[RES]] : vector<4x1xi32> + +// ----- + //===----------------------------------------------------------------------===// // vector.type_cast // diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index a74e562ad2f68..93581cbfbe5e4 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -3171,3 +3171,29 @@ func.func @contiguous_scatter_step(%base: memref, memref, vector<16xindex>, vector<16xi1>, vector<16xf32> return } + +// ----- + +// CHECK-LABEL: @fold_extract_constant_indices +// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 { +// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32> +// CHECK: return %[[RES]] : i32 +func.func @fold_extract_constant_indices(%arg : vector<32x1xi32>) -> i32 { + %0 = arith.constant 0 : index + %1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32> + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: @fold_insert_constant_indices +// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> { +// CHECK: %[[VAL:.*]] = arith.constant 1 : i32 +// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, 0] : i32 into vector<4x1xi32> +// CHECK: return %[[RES]] : vector<4x1xi32> +func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi32> { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : i32 + %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32> + return %res : vector<4x1xi32> +} diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index dbe0b39422369..38771f2593449 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) { // CHECK-PROP-LABEL: func.func @vector_extract_1d( // CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32 -// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) { // CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32> // CHECK-PROP: gpu.yield %[[V]] : vector<64xf32> // CHECK-PROP: } -// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32> +// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32> // CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]] // CHECK-PROP: return %[[SHUFFLED]] : f32 func.func @vector_extract_1d(%laneid: index) -> (f32) {