@@ -874,7 +874,7 @@ def LLVM_MatrixColumnMajorLoadOp : LLVM_OneResultIntrOp<"matrix.column.major.loa
874874 const llvm::DataLayout &dl =
875875 builder.GetInsertBlock()->getModule()->getDataLayout();
876876 llvm::Type *ElemTy = moduleTranslation.convertType(
877- getVectorElementType( op.getType()));
877+ op.getType().getElementType( ));
878878 llvm::Align align = dl.getABITypeAlign(ElemTy);
879879 $res = mb.CreateColumnMajorLoad(
880880 ElemTy, $data, align, $stride, $isVolatile, $rows,
@@ -907,7 +907,7 @@ def LLVM_MatrixColumnMajorStoreOp : LLVM_ZeroResultIntrOp<"matrix.column.major.s
907907 llvm::MatrixBuilder mb(builder);
908908 const llvm::DataLayout &dl =
909909 builder.GetInsertBlock()->getModule()->getDataLayout();
910- Type elementType = getVectorElementType( op.getMatrix().getType());
910+ Type elementType = op.getMatrix().getType().getElementType( );
911911 llvm::Align align = dl.getABITypeAlign(
912912 moduleTranslation.convertType(elementType));
913913 mb.CreateColumnMajorStore(
@@ -1164,7 +1164,8 @@ def LLVM_vector_insert
11641164 let extraClassDeclaration = [{
11651165 uint64_t getVectorBitWidth(Type vector) {
11661166 return getVectorNumElements(vector).getKnownMinValue() *
1167- getVectorElementType(vector).getIntOrFloatBitWidth();
1167+ ::llvm::cast<VectorType>(vector).getElementType()
1168+ .getIntOrFloatBitWidth();
11681169 }
11691170 uint64_t getSrcVectorBitWidth() {
11701171 return getVectorBitWidth(getSrcvec().getType());
@@ -1196,7 +1197,8 @@ def LLVM_vector_extract
11961197 let extraClassDeclaration = [{
11971198 uint64_t getVectorBitWidth(Type vector) {
11981199 return getVectorNumElements(vector).getKnownMinValue() *
1199- getVectorElementType(vector).getIntOrFloatBitWidth();
1200+ ::llvm::cast<VectorType>(vector).getElementType()
1201+ .getIntOrFloatBitWidth();
12001202 }
12011203 uint64_t getSrcVectorBitWidth() {
12021204 return getVectorBitWidth(getSrcvec().getType());
@@ -1216,8 +1218,8 @@ def LLVM_vector_interleave2
12161218 "result has twice as many elements as 'vec1'",
12171219 And<[CPred<"getVectorNumElements($res.getType()) == "
12181220 "getVectorNumElements($vec1.getType()) * 2">,
1219- CPred<"getVectorElementType ($vec1.getType()) == "
1220- "getVectorElementType ($res.getType())">]>>,
1221+ CPred<"::llvm::cast<VectorType> ($vec1.getType()).getElementType( ) == "
1222+ "::llvm::cast<VectorType> ($res.getType()).getElementType( )">]>>,
12211223 ]>,
12221224 Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
12231225
0 commit comments