2121using namespace mlir ;
2222using namespace quant ;
2323
24- static IntegerType parseStorageType (DialectAsmParser &parser, bool &isSigned) {
24+ static Type parseStorageType (DialectAsmParser &parser, bool &isSigned) {
2525 auto typeLoc = parser.getCurrentLocation ();
26- IntegerType type;
26+ Type type;
2727
2828 // Parse storage type (alpha_ident, integer_literal).
2929 StringRef identifier;
@@ -32,20 +32,32 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
3232 if (result.has_value ()) {
3333 if (!succeeded (*result))
3434 return nullptr ;
35- isSigned = !type.isUnsigned ();
36- storageTypeWidth = type.getWidth ();
37- } else if (succeeded (parser.parseKeyword (&identifier))) {
38- // Otherwise, this must be an unsigned integer (`u` integer-literal).
39- if (!identifier.consume_front (" u" )) {
40- parser.emitError (typeLoc, " illegal storage type prefix" );
35+ if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
36+ isSigned = !intType.isUnsigned ();
37+ storageTypeWidth = intType.getWidth ();
38+ } else if (llvm::dyn_cast<Float8E5M2Type>(type) ||
39+ llvm::dyn_cast<Float8E4M3FNType>(type)) {
40+ storageTypeWidth = 8 ;
41+ isSigned = true ;
42+ } else {
43+ parser.emitError (typeLoc, " illegal quantized storage type alias" );
4144 return nullptr ;
4245 }
43- if (identifier.getAsInteger (10 , storageTypeWidth)) {
44- parser.emitError (typeLoc, " expected storage type width" );
46+ } else if (succeeded (parser.parseKeyword (&identifier))) {
47+ // Otherwise, this must be an unsigned integer (`u` integer-literal)
48+ if (identifier.consume_front (" u" )) {
49+ if (identifier.getAsInteger (10 , storageTypeWidth)) {
50+ parser.emitError (typeLoc, " expected storage type width" );
51+ return nullptr ;
52+ }
53+ isSigned = false ;
54+ type = parser.getBuilder ().getIntegerType (storageTypeWidth);
55+
56+ } else {
57+ parser.emitError (typeLoc, " illegal quantized storage type alias" );
4558 return nullptr ;
4659 }
47- isSigned = false ;
48- type = parser.getBuilder ().getIntegerType (storageTypeWidth);
60+
4961 } else {
5062 return nullptr ;
5163 }
@@ -60,35 +72,56 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
6072 return type;
6173}
6274
63- static ParseResult parseStorageRange (DialectAsmParser &parser,
64- IntegerType storageType, bool isSigned,
65- int64_t &storageTypeMin,
75+ static ParseResult
76+ checkStorageRange (DialectAsmParser &parser, int64_t storageTypeMin,
77+ int64_t storageTypeMax, int64_t defaultStorageTypeMin,
78+ int64_t defaultStorageTypeMax, SMLoc minLoc, SMLoc maxLoc) {
79+ if (storageTypeMin < defaultStorageTypeMin) {
80+ return parser.emitError (minLoc, " illegal storage type minimum: " )
81+ << storageTypeMin;
82+ }
83+ if (storageTypeMax > defaultStorageTypeMax) {
84+ return parser.emitError (maxLoc, " illegal storage type maximum: " )
85+ << storageTypeMax;
86+ }
87+ return success ();
88+ }
89+
90+ static ParseResult parseStorageRange (DialectAsmParser &parser, Type storageType,
91+ bool isSigned, int64_t &storageTypeMin,
6692 int64_t &storageTypeMax) {
67- int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger (
68- isSigned, storageType.getWidth ());
69- int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger (
70- isSigned, storageType.getWidth ());
93+ int64_t defaultMin, defaultMax;
94+ if (storageType.isa <IntegerType>()) {
95+ const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth ();
96+ defaultMin = QuantizedType::getDefaultMinimumForInteger (isSigned, width);
97+ defaultMax = QuantizedType::getDefaultMaximumForInteger (isSigned, width);
98+ } else if (storageType.isa <Float8E5M2Type>()) {
99+ defaultMin = QuantizedType::getDefaultMinimumForF8E5M2 ();
100+ defaultMax = QuantizedType::getDefaultMaximumForF8E5M2 ();
101+ } else if (storageType.isa <Float8E4M3FNType>()) {
102+ defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN ();
103+ defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN ();
104+ } else {
105+ defaultMin = std::numeric_limits<int64_t >::max ();
106+ defaultMax = std::numeric_limits<int64_t >::min ();
107+ }
108+
71109 if (failed (parser.parseOptionalLess ())) {
72- storageTypeMin = defaultIntegerMin ;
73- storageTypeMax = defaultIntegerMax ;
110+ storageTypeMin = defaultMin ;
111+ storageTypeMax = defaultMax ;
74112 return success ();
75113 }
76114
77115 // Explicit storage min and storage max.
116+ // F8 min and max values are integers, so parseInteger() is used.
78117 SMLoc minLoc = parser.getCurrentLocation (), maxLoc;
79118 if (parser.parseInteger (storageTypeMin) || parser.parseColon () ||
80119 parser.getCurrentLocation (&maxLoc) ||
81120 parser.parseInteger (storageTypeMax) || parser.parseGreater ())
82121 return failure ();
83- if (storageTypeMin < defaultIntegerMin) {
84- return parser.emitError (minLoc, " illegal storage type minimum: " )
85- << storageTypeMin;
86- }
87- if (storageTypeMax > defaultIntegerMax) {
88- return parser.emitError (maxLoc, " illegal storage type maximum: " )
89- << storageTypeMax;
90- }
91- return success ();
122+
123+ return checkStorageRange (parser, storageTypeMin, storageTypeMax, defaultMin,
124+ defaultMax, minLoc, maxLoc);
92125}
93126
94127static FloatType parseExpressedTypeAndRange (DialectAsmParser &parser,
@@ -118,7 +151,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
118151// / storage-type ::= (`i` | `u`) integer-literal
119152// / expressed-type-spec ::= `:` `f` integer-literal
120153static Type parseAnyType (DialectAsmParser &parser) {
121- IntegerType storageType;
154+ Type storageType;
122155 FloatType expressedType;
123156 unsigned typeFlags = 0 ;
124157 int64_t storageTypeMin;
@@ -192,7 +225,7 @@ static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
192225// / scale-zero ::= float-literal `:` integer-literal
193226// / scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
194227static Type parseUniformType (DialectAsmParser &parser) {
195- IntegerType storageType;
228+ Type storageType;
196229 FloatType expressedType;
197230 unsigned typeFlags = 0 ;
198231 int64_t storageTypeMin;
@@ -339,14 +372,37 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
339372 // storage type
340373 unsigned storageWidth = type.getStorageTypeIntegralWidth ();
341374 bool isSigned = type.isSigned ();
342- if (isSigned) {
375+ if (type.getStorageType ().isa <Float8E5M2Type>()) {
376+ out << " f8E5M2" ;
377+ } else if (type.getStorageType ().isa <Float8E4M3FNType>()) {
378+ out << " f8E4M3FN" ;
379+ } else if (isSigned) {
343380 out << " i" << storageWidth;
344381 } else {
345382 out << " u" << storageWidth;
346383 }
347384
348385 // storageTypeMin and storageTypeMax if not default.
349- if (type.hasStorageTypeBounds ()) {
386+ int64_t defaultMin =
387+ type.getStorageType ().isa <IntegerType>()
388+ ? QuantizedType::getDefaultMinimumForInteger (isSigned, storageWidth)
389+ : type.getStorageType ().isa <Float8E5M2Type>()
390+ ? QuantizedType::getDefaultMinimumForF8E5M2 ()
391+ : type.getStorageType ().isa <Float8E4M3FNType>()
392+ ? QuantizedType::getDefaultMinimumForF8E4M3FN ()
393+ : std::numeric_limits<int64_t >::max ();
394+
395+ int64_t defaultMax =
396+ type.getStorageType ().isa <IntegerType>()
397+ ? QuantizedType::getDefaultMaximumForInteger (isSigned, storageWidth)
398+ : type.getStorageType ().isa <Float8E5M2Type>()
399+ ? QuantizedType::getDefaultMaximumForF8E5M2 ()
400+ : type.getStorageType ().isa <Float8E4M3FNType>()
401+ ? QuantizedType::getDefaultMaximumForF8E4M3FN ()
402+ : std::numeric_limits<int64_t >::min ();
403+
404+ if (defaultMin != type.getStorageTypeMin () ||
405+ defaultMax != type.getStorageTypeMax ()) {
350406 out << " <" << type.getStorageTypeMin () << " :" << type.getStorageTypeMax ()
351407 << " >" ;
352408 }
0 commit comments