@@ -142,6 +142,19 @@ TypeDescType parseTypeDesc(mlir::DialectAsmParser &parser, mlir::Location loc) {
142142 return parseTypeSingleton<TypeDescType>(parser, loc);
143143}
144144
145+ // `vector` `<` len `:` type `>`
146+ fir::VectorType parseVector (mlir::DialectAsmParser &parser,
147+ mlir::Location loc) {
148+ int64_t len = 0 ;
149+ mlir::Type eleTy;
150+ if (parser.parseLess () || parser.parseInteger (len) || parser.parseColon () ||
151+ parser.parseType (eleTy) || parser.parseGreater ()) {
152+ parser.emitError (parser.getNameLoc (), " invalid vector type" );
153+ return {};
154+ }
155+ return fir::VectorType::get (len, eleTy);
156+ }
157+
145158// `void`
146159mlir::Type parseVoid (mlir::DialectAsmParser &parser) {
147160 return parser.getBuilder ().getNoneType ();
@@ -346,6 +359,8 @@ mlir::Type fir::parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser) {
346359 return parseDerived (parser, loc);
347360 if (typeNameLit == " void" )
348361 return parseVoid (parser);
362+ if (typeNameLit == " vector" )
363+ return parseVector (parser, loc);
349364
350365 parser.emitError (parser.getNameLoc (), " unknown FIR type " + typeNameLit);
351366 return {};
@@ -790,6 +805,39 @@ struct TypeDescTypeStorage : public mlir::TypeStorage {
790805 explicit TypeDescTypeStorage (mlir::Type ofTy) : ofTy{ofTy} {}
791806};
792807
808+ // / Vector type storage
809+ struct VectorTypeStorage : public mlir ::TypeStorage {
810+ using KeyTy = std::tuple<uint64_t , mlir::Type>;
811+
812+ static unsigned hashKey (const KeyTy &key) {
813+ return llvm::hash_combine (std::get<uint64_t >(key),
814+ std::get<mlir::Type>(key));
815+ }
816+
817+ bool operator ==(const KeyTy &key) const {
818+ return key == KeyTy{getLen (), getEleTy ()};
819+ }
820+
821+ static VectorTypeStorage *construct (mlir::TypeStorageAllocator &allocator,
822+ const KeyTy &key) {
823+ auto *storage = allocator.allocate <VectorTypeStorage>();
824+ return new (storage)
825+ VectorTypeStorage{std::get<uint64_t >(key), std::get<mlir::Type>(key)};
826+ }
827+
828+ uint64_t getLen () const { return len; }
829+ mlir::Type getEleTy () const { return eleTy; }
830+
831+ protected:
832+ uint64_t len;
833+ mlir::Type eleTy;
834+
835+ private:
836+ VectorTypeStorage () = delete ;
837+ explicit VectorTypeStorage (uint64_t len, mlir::Type eleTy)
838+ : len{len}, eleTy{eleTy} {}
839+ };
840+
793841} // namespace detail
794842
795843template <typename A, typename B>
@@ -1069,12 +1117,34 @@ mlir::LogicalResult fir::SequenceType::verifyConstructionInvariants(
10691117 eleTy.isa <BoxProcType>() || eleTy.isa <FieldType>() ||
10701118 eleTy.isa <LenType>() || eleTy.isa <HeapType>() ||
10711119 eleTy.isa <PointerType>() || eleTy.isa <ReferenceType>() ||
1072- eleTy.isa <TypeDescType>() || eleTy.isa <SequenceType>())
1120+ eleTy.isa <TypeDescType>() || eleTy.isa <fir::VectorType>() ||
1121+ eleTy.isa <SequenceType>())
10731122 return mlir::emitError (loc, " cannot build an array of this element type: " )
10741123 << eleTy << ' \n ' ;
10751124 return mlir::success ();
10761125}
10771126
1127+ // ===----------------------------------------------------------------------===//
1128+ // Vector type
1129+ // ===----------------------------------------------------------------------===//
1130+
1131+ fir::VectorType fir::VectorType::get (uint64_t len, mlir::Type eleTy) {
1132+ return Base::get (eleTy.getContext (), len, eleTy);
1133+ }
1134+
1135+ mlir::Type fir::VectorType::getEleTy () const { return getImpl ()->getEleTy (); }
1136+
1137+ uint64_t fir::VectorType::getLen () const { return getImpl ()->getLen (); }
1138+
1139+ mlir::LogicalResult
1140+ fir::VectorType::verifyConstructionInvariants (mlir::Location loc, uint64_t len,
1141+ mlir::Type eleTy) {
1142+ if (!(fir::isa_real (eleTy) || fir::isa_integer (eleTy)))
1143+ return mlir::emitError (loc, " cannot build a vector of type " )
1144+ << eleTy << ' \n ' ;
1145+ return mlir::success ();
1146+ }
1147+
10781148// compare if two shapes are equivalent
10791149bool fir::operator ==(const SequenceType::Shape &sh_1,
10801150 const SequenceType::Shape &sh_2) {
@@ -1302,4 +1372,10 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
13021372 os << ' >' ;
13031373 return ;
13041374 }
1375+ if (auto type = ty.dyn_cast <fir::VectorType>()) {
1376+ os << " vector<" << type.getLen () << ' :' ;
1377+ p.printType (type.getEleTy ());
1378+ os << ' >' ;
1379+ return ;
1380+ }
13051381}
0 commit comments