diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9657f583c375b..4af03126fa1ed 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1096,6 +1096,26 @@ class VectorExtractOpConversion SmallVector positionVec = getMixedValues( adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter); + for (unsigned idx = 0; idx < positionVec.size(); ++idx) { + if (auto position = llvm::dyn_cast(positionVec[idx])) { + auto defOp = position.getDefiningOp(); + while (defOp) { + if (llvm::isa(defOp)) { + Attribute value = + defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]); + positionVec[idx] = OpFoldResult{ + rewriter.getI64IntegerAttr(cast(value).getInt())}; + break; + } else if (auto unrealizedCastOp = + llvm::dyn_cast(defOp)) { + defOp = unrealizedCastOp.getOperand(0).getDefiningOp(); + } else { + break; + } + } + } + } + // The Vector -> LLVM lowering models N-D vectors as nested aggregates of // 1-d vectors. This nesting is modeled using arrays. We do this conversion // from a N-d vector extract to a nested aggregate vector extract in two @@ -1231,6 +1251,25 @@ class VectorInsertOpConversion SmallVector positionVec = getMixedValues( adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter); + for (unsigned idx = 0; idx < positionVec.size(); ++idx) { + if (auto position = llvm::dyn_cast(positionVec[idx])) { + auto defOp = position.getDefiningOp(); + while (defOp) { + if (llvm::isa(defOp)) { + Attribute value = + defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]); + positionVec[idx] = OpFoldResult{ + rewriter.getI64IntegerAttr(cast(value).getInt())}; + break; + } else if (auto unrealizedCastOp = + llvm::dyn_cast(defOp)) { + defOp = unrealizedCastOp.getOperand(0).getDefiningOp(); + } else { + break; + } + } + } + } // Overwrite entire vector with value. Should be handled by folder, but // just to be safe. @@ -1242,8 +1281,9 @@ class VectorInsertOpConversion // One-shot insertion of a vector into an array (only requires insertvalue). if (isa(sourceType)) { - if (insertOp.hasDynamicPosition()) + if (!llvm::all_of(position, llvm::IsaPred)) { return failure(); + } Value inserted = rewriter.create( loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position)); @@ -1255,8 +1295,9 @@ class VectorInsertOpConversion Value extracted = adaptor.getDest(); auto oneDVectorType = destVectorType; if (position.size() > 1) { - if (insertOp.hasDynamicPosition()) + if (!llvm::all_of(position, llvm::IsaPred)) { return failure(); + } oneDVectorType = reducedVectorTypeBack(destVectorType); extracted = rewriter.create( @@ -1270,8 +1311,9 @@ class VectorInsertOpConversion // Potential insertion of resulting 1-D vector into array. if (position.size() > 1) { - if (insertOp.hasDynamicPosition()) + if (!llvm::all_of(position, llvm::IsaPred)) { return failure(); + } inserted = rewriter.create( loc, adaptor.getDest(), inserted, diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index f95e943250bd4..d16d78556da10 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -4094,3 +4094,93 @@ func.func @step_scalable() -> vector<[4]xindex> { %0 = vector.step : vector<[4]xindex> return %0 : vector<[4]xindex> } + +// ----- + +// CHECK-LABEL: @extract_arith_constnt +func.func @extract_arith_constnt() -> i32 { + %v = arith.constant dense<0> : vector<32x1xi32> + %c_0 = arith.constant 0 : index + %elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32> + return %elem : i32 +} + +// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32> +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>> +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>> +// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VAL_5:.*]] = llvm.extractelement %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32> +// CHECK: return %[[VAL_5]] : i32 + +// ----- + +// CHECK-LABEL: @extract_llvm_constnt() + +module { + func.func @extract_llvm_constnt() -> i32 { + %0 = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>> + %1 = builtin.unrealized_conversion_cast %0 : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32> + %2 = llvm.mlir.constant(0 : index) : i64 + %3 = builtin.unrealized_conversion_cast %2 : i64 to index + %4 = vector.extract %1[%3, %3] : i32 from vector<32x1xi32> + return %4 : i32 + } +} + +// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>> +// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>> +// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VAL_4:.*]] = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32> +// CHECK: return %[[VAL_4]] : i32 + +// ----- + +// CHECK-LABEL: @insert_arith_constnt() + +func.func @insert_arith_constnt() -> vector<32x1xi32> { + %v = arith.constant dense<0> : vector<32x1xi32> + %c_0 = arith.constant 0 : index + %c_1 = arith.constant 1 : i32 + %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32> + return %v_1 : vector<32x1xi32> +} + +// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32> +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>> +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>> +// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VAL_6:.*]] = llvm.insertelement %[[VAL_3]], %[[VAL_4]]{{\[}}%[[VAL_5]] : i64] : vector<1xi32> +// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>> +// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32> +// CHECK: return %[[VAL_8]] : vector<32x1xi32> + +// ----- + +// CHECK-LABEL: @insert_llvm_constnt() + +module { + func.func @insert_llvm_constnt() -> vector<32x1xi32> { + %0 = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>> + %1 = builtin.unrealized_conversion_cast %0 : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32> + %2 = llvm.mlir.constant(0 : index) : i64 + %3 = builtin.unrealized_conversion_cast %2 : i64 to index + %4 = llvm.mlir.constant(1 : i32) : i32 + %5 = vector.insert %4, %1 [%3, %3] : i32 into vector<32x1xi32> + return %5 : vector<32x1xi32> + } +} + +// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>> +// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.array<32 x vector<1xi32>> +// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_0]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32> +// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_2]][0] : !llvm.array<32 x vector<1xi32>> +// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32> +// CHECK: return %[[VAL_7]] : vector<32x1xi32> +// CHECK: }