diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md index d0509e036682f..468f69c419071 100644 --- a/mlir/docs/Dialects/LLVM.md +++ b/mlir/docs/Dialects/LLVM.md @@ -334,8 +334,6 @@ compatible with the LLVM dialect: - `bool LLVM::isCompatibleVectorType(Type)` - checks whether a type is a vector type compatible with the LLVM dialect; -- `Type LLVM::getVectorElementType(Type)` - returns the element type of any - vector type compatible with the LLVM dialect; - `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number of elements in any vector type compatible with the LLVM dialect; - `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 2debd09f78b34..ab928c9e2d0e7 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -874,7 +874,7 @@ def LLVM_MatrixColumnMajorLoadOp : LLVM_OneResultIntrOp<"matrix.column.major.loa const llvm::DataLayout &dl = builder.GetInsertBlock()->getModule()->getDataLayout(); llvm::Type *ElemTy = moduleTranslation.convertType( - getVectorElementType(op.getType())); + op.getType().getElementType()); llvm::Align align = dl.getABITypeAlign(ElemTy); $res = mb.CreateColumnMajorLoad( ElemTy, $data, align, $stride, $isVolatile, $rows, @@ -907,7 +907,7 @@ def LLVM_MatrixColumnMajorStoreOp : LLVM_ZeroResultIntrOp<"matrix.column.major.s llvm::MatrixBuilder mb(builder); const llvm::DataLayout &dl = builder.GetInsertBlock()->getModule()->getDataLayout(); - Type elementType = getVectorElementType(op.getMatrix().getType()); + Type elementType = op.getMatrix().getType().getElementType(); llvm::Align align = dl.getABITypeAlign( moduleTranslation.convertType(elementType)); mb.CreateColumnMajorStore( @@ -1164,7 +1164,8 @@ def LLVM_vector_insert let extraClassDeclaration = [{ uint64_t getVectorBitWidth(Type vector) { return getVectorNumElements(vector).getKnownMinValue() * - getVectorElementType(vector).getIntOrFloatBitWidth(); + ::llvm::cast(vector).getElementType() + .getIntOrFloatBitWidth(); } uint64_t getSrcVectorBitWidth() { return getVectorBitWidth(getSrcvec().getType()); @@ -1196,7 +1197,8 @@ def LLVM_vector_extract let extraClassDeclaration = [{ uint64_t getVectorBitWidth(Type vector) { return getVectorNumElements(vector).getKnownMinValue() * - getVectorElementType(vector).getIntOrFloatBitWidth(); + ::llvm::cast(vector).getElementType() + .getIntOrFloatBitWidth(); } uint64_t getSrcVectorBitWidth() { return getVectorBitWidth(getSrcvec().getType()); @@ -1216,8 +1218,8 @@ def LLVM_vector_interleave2 "result has twice as many elements as 'vec1'", And<[CPred<"getVectorNumElements($res.getType()) == " "getVectorNumElements($vec1.getType()) * 2">, - CPred<"getVectorElementType($vec1.getType()) == " - "getVectorElementType($res.getType())">]>>, + CPred<"::llvm::cast($vec1.getType()).getElementType() == " + "::llvm::cast($res.getType()).getElementType()">]>>, ]>, Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index 1fa1d3be557db..b97b5ac932c97 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -113,17 +113,20 @@ def LLVM_AnyNonAggregate : Type, - "LLVM dialect-compatible vector type">; + "LLVM dialect-compatible vector type", + "::mlir::VectorType">; // Type constraint accepting any LLVM fixed-length vector type. def LLVM_AnyFixedVector : Type, - "LLVM dialect-compatible fixed-length vector type">; + "LLVM dialect-compatible fixed-length vector type", + "::mlir::VectorType">; // Type constraint accepting any LLVM scalable vector type. def LLVM_AnyScalableVector : Type, - "LLVM dialect-compatible scalable vector type">; + "LLVM dialect-compatible scalable vector type", + "::mlir::VectorType">; // Type constraint accepting an LLVM vector type with an additional constraint // on the vector element type. @@ -131,9 +134,10 @@ class LLVM_VectorOf : Type< And<[LLVM_AnyVector.predicate, SubstLeaves< "$_self", - "::mlir::LLVM::getVectorElementType($_self)", + "::llvm::cast<::mlir::VectorType>($_self).getElementType()", element.predicate>]>, - "LLVM dialect-compatible vector of " # element.summary>; + "LLVM dialect-compatible vector of " # element.summary, + "::mlir::VectorType">; // Type constraint accepting a constrained type, or a vector of such types. class LLVM_ScalarOrVectorOf : diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index b107b64e55b46..6602318b07b85 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -820,8 +820,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", //===----------------------------------------------------------------------===// def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [Pure, - TypesMatchWith<"result type matches vector element type", "vector", "res", - "LLVM::getVectorElementType($_self)">]> { + TypesMatchWith< + "result type matches vector element type", "vector", "res", + "::llvm::cast<::mlir::VectorType>($_self).getElementType()">]> { let summary = "Extract an element from an LLVM vector."; let arguments = (ins LLVM_AnyVector:$vector, AnySignlessInteger:$position); @@ -881,7 +882,8 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [Pure]> { def LLVM_InsertElementOp : LLVM_Op<"insertelement", [Pure, TypesMatchWith<"argument type matches vector element type", "vector", - "value", "LLVM::getVectorElementType($_self)">, + "value", + "::llvm::cast<::mlir::VectorType>($_self).getElementType()">, AllTypesMatch<["res", "vector"]>]> { let summary = "Insert an element into an LLVM vector."; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index 03c246e589643..a2a76c49a2bda 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -111,10 +111,6 @@ bool isCompatibleFloatingPointType(Type type); /// dialect pointers and LLVM dialect scalable vector types. bool isCompatibleVectorType(Type type); -/// Returns the element type of any vector type compatible with the LLVM -/// dialect. -Type getVectorElementType(Type type); - /// Returns the element count of any LLVM-compatible vector type. llvm::ElementCount getVectorNumElements(Type type); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 6e0adfc1e0ff3..93979e0f73324 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -78,10 +78,9 @@ static unsigned getBitWidth(Type type) { /// Returns the bit width of LLVMType integer or vector. static unsigned getLLVMTypeBitWidth(Type type) { - return cast((LLVM::isCompatibleVectorType(type) - ? LLVM::getVectorElementType(type) - : type)) - .getWidth(); + if (auto vecTy = dyn_cast(type)) + type = vecTy.getElementType(); + return cast(type).getWidth(); } /// Creates `IntegerAttribute` with all bits set for given type diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 78eb4c9b3481f..33a1686541996 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2734,9 +2734,9 @@ void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1, Value v2, DenseI32ArrayAttr mask, ArrayRef attrs) { auto containerType = v1.getType(); - auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType), - mask.size(), - LLVM::isScalableVectorType(containerType)); + auto vType = LLVM::getVectorType( + cast(containerType).getElementType(), mask.size(), + LLVM::isScalableVectorType(containerType)); build(builder, state, vType, v1, v2, mask); state.addAttributes(attrs); } @@ -2752,8 +2752,9 @@ static ParseResult parseShuffleType(AsmParser &parser, Type v1Type, if (!LLVM::isCompatibleVectorType(v1Type)) return parser.emitError(parser.getCurrentLocation(), "expected an LLVM compatible vector type"); - resType = LLVM::getVectorType(LLVM::getVectorElementType(v1Type), mask.size(), - LLVM::isScalableVectorType(v1Type)); + resType = + LLVM::getVectorType(cast(v1Type).getElementType(), + mask.size(), LLVM::isScalableVectorType(v1Type)); return success(); } @@ -3318,7 +3319,7 @@ LogicalResult AtomicRMWOp::verify() { if (isCompatibleVectorType(valType)) { if (isScalableVectorType(valType)) return emitOpError("expected LLVM IR fixed vector type"); - Type elemType = getVectorElementType(valType); + Type elemType = llvm::cast(valType).getElementType(); if (!isCompatibleFloatingPointType(elemType)) return emitOpError( "expected LLVM IR floating point type for vector element"); @@ -3423,9 +3424,10 @@ static LogicalResult verifyExtOp(ExtOp op) { return op.emitError("input and output vectors are of incompatible shape"); // Because this is a CastOp, the element of vectors is guaranteed to be an // integer. - inputType = cast(getVectorElementType(op.getArg().getType())); - outputType = - cast(getVectorElementType(op.getResult().getType())); + inputType = cast( + cast(op.getArg().getType()).getElementType()); + outputType = cast( + cast(op.getResult().getType()).getElementType()); } else { // Because this is a CastOp and arg is not a vector, arg is guaranteed to be // an integer. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index 663adc3c34256..b3c2a29309528 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -821,12 +821,6 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) { return false; } -Type mlir::LLVM::getVectorElementType(Type type) { - auto vecTy = dyn_cast(type); - assert(vecTy && "incompatible with LLVM vector type"); - return vecTy.getElementType(); -} - llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) { auto vecTy = dyn_cast(type); assert(vecTy && "incompatible with LLVM vector type"); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 2859abdb41772..187f1bdf7af6e 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -826,7 +826,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef arrayShape = {}) { } // An LLVM dialect vector can only contain scalars. - Type elementType = LLVM::getVectorElementType(type); + Type elementType = cast(type).getElementType(); if (!elementType.isIntOrFloat()) return {}; diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index db55088d812e6..0cd6b1f20a1bf 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -515,21 +515,21 @@ func.func @extractvalue_wrong_nesting() { // ----- func.func @invalid_vector_type_1(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) { - // expected-error@+1 {{'vector' must be LLVM dialect-compatible vector}} + // expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}} %0 = llvm.extractelement %arg2[%arg1 : i32] : f32 } // ----- func.func @invalid_vector_type_2(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) { - // expected-error@+1 {{'vector' must be LLVM dialect-compatible vector}} + // expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}} %0 = llvm.insertelement %arg2, %arg2[%arg1 : i32] : f32 } // ----- func.func @invalid_vector_type_3(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) { - // expected-error@+2 {{expected an LLVM compatible vector type}} + // expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}} %0 = llvm.shufflevector %arg2, %arg2 [0, 0, 0, 0, 7] : f32 } diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir index 7bb64542accdf..90c0f5ac55cb1 100644 --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -211,7 +211,7 @@ llvm.func @vec_reduce_fmax_intr_wrong_type(%arg0 : vector<4xi32>) -> i32 { // ----- llvm.func @matrix_load_intr_wrong_type(%ptr : !llvm.ptr, %stride : i32) -> f32 { - // expected-error @below{{op result #0 must be LLVM dialect-compatible vector type, but got 'f32'}} + // expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}} %0 = llvm.intr.matrix.column.major.load %ptr, { isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : f32 from !llvm.ptr stride i32 llvm.return %0 : f32 @@ -229,7 +229,7 @@ llvm.func @matrix_store_intr_wrong_type(%matrix : vector<48xf32>, %ptr : i32, %s // ----- llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32) -> vector<12xf32> { - // expected-error @below{{op operand #1 must be LLVM dialect-compatible vector type, but got 'f32'}} + // expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}} %0 = llvm.intr.matrix.multiply %arg0, %arg1 { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} : (vector<64xf32>, f32) -> vector<12xf32> llvm.return %0 : vector<12xf32> @@ -238,7 +238,7 @@ llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32) // ----- llvm.func @matrix_transpose_intr_wrong_type(%matrix : f32) -> vector<48xf32> { - // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}} + // expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}} %0 = llvm.intr.matrix.transpose %matrix {rows = 3: i32, columns = 16: i32} : f32 into vector<48xf32> llvm.return %0 : vector<48xf32> } @@ -286,7 +286,7 @@ llvm.func @masked_gather_intr_wrong_type_scalable(%ptrs : vector<7x!llvm.ptr>, % // ----- llvm.func @masked_scatter_intr_wrong_type(%vec : f32, %ptrs : vector<7x!llvm.ptr>, %mask : vector<7xi1>) { - // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}} + // expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}} llvm.intr.masked.scatter %vec, %ptrs, %mask { alignment = 1: i32} : f32, vector<7xi1> into vector<7x!llvm.ptr> llvm.return }