Skip to content

Commit 7961ba8

Browse files
AndreiFilipIntelsramasit
authored andcommitted
Added f8E4M3FN and f8E5M2 support to existing QuantizedTypes (17.x) (#37)
* added integral type check * changed to f8e_m_ format, removed isF8 helper method, formatting * added negative tests, removed duplicate test * Added endline in parse-uniform-invalid.mlir * changed to f8E5M2/f8E4M3FN format, enabled default type parsing * updated parse-any-invalid error message checks
1 parent 78577c6 commit 7961ba8

File tree

6 files changed

+202
-59
lines changed

6 files changed

+202
-59
lines changed

mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ class QuantizedType : public Type {
8282
return llvm::maxUIntN(integralWidth);
8383
}
8484

85+
static constexpr int64_t getDefaultMaximumForF8E4M3FN() { return 448; }
86+
87+
static constexpr int64_t getDefaultMinimumForF8E4M3FN() {
88+
return -getDefaultMaximumForF8E4M3FN();
89+
}
90+
91+
static constexpr int64_t getDefaultMaximumForF8E5M2() { return 57344; }
92+
93+
static constexpr int64_t getDefaultMinimumForF8E5M2() {
94+
return -getDefaultMaximumForF8E5M2();
95+
}
96+
8597
/// Gets the original expressed type that this quantized type approximates.
8698
/// Note that this presumes that the quantized type was always derived from
8799
/// a floating point type, which in the broadest definition, is not true (i.e.

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,28 +49,38 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
4949
unsigned flags, Type storageType,
5050
Type expressedType, int64_t storageTypeMin,
5151
int64_t storageTypeMax) {
52-
// Verify that the storage type is integral.
53-
// This restriction may be lifted at some point in favor of using bf16
54-
// or f16 as exact representations on hardware where that is advantageous.
55-
auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
56-
if (!intStorageType)
57-
return emitError() << "storage type must be integral";
58-
unsigned integralWidth = intStorageType.getWidth();
59-
60-
// Verify storage width.
61-
if (integralWidth == 0 || integralWidth > MaxStorageBits)
62-
return emitError() << "illegal storage type size: " << integralWidth;
63-
64-
// Verify storageTypeMin and storageTypeMax.
52+
6553
bool isSigned =
6654
(flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
67-
int64_t defaultIntegerMin =
68-
getDefaultMinimumForInteger(isSigned, integralWidth);
69-
int64_t defaultIntegerMax =
70-
getDefaultMaximumForInteger(isSigned, integralWidth);
71-
if (storageTypeMax - storageTypeMin <= 0 ||
72-
storageTypeMin < defaultIntegerMin ||
73-
storageTypeMax > defaultIntegerMax) {
55+
56+
// Integral storage type width checks
57+
if (storageType.isa<IntegerType>()) {
58+
unsigned integralWidth =
59+
llvm::dyn_cast<IntegerType>(storageType).getWidth();
60+
61+
if (integralWidth == 0 || integralWidth > MaxStorageBits)
62+
return emitError() << "illegal storage type size: " << integralWidth;
63+
}
64+
65+
int64_t defaultMin, defaultMax;
66+
if (storageType.isa<IntegerType>()) {
67+
const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth();
68+
defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width);
69+
defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width);
70+
} else if (storageType.isa<Float8E5M2Type>()) {
71+
defaultMin = QuantizedType::getDefaultMinimumForF8E5M2();
72+
defaultMax = QuantizedType::getDefaultMaximumForF8E5M2();
73+
} else if (storageType.isa<Float8E4M3FNType>()) {
74+
defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN();
75+
defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN();
76+
} else {
77+
return emitError() << "illegal storage type, supported types are: integral "
78+
"types, Float8E4M3FNType and Float8E5M2Type ";
79+
}
80+
81+
// Verify storageTypeMin and storageTypeMax.
82+
if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
83+
storageTypeMax > defaultMax) {
7484
return emitError() << "illegal storage min and storage max: ("
7585
<< storageTypeMin << ":" << storageTypeMax << ")";
7686
}

mlir/lib/Dialect/Quant/IR/TypeParser.cpp

Lines changed: 90 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
using namespace mlir;
2222
using namespace quant;
2323

24-
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
24+
static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
2525
auto typeLoc = parser.getCurrentLocation();
26-
IntegerType type;
26+
Type type;
2727

2828
// Parse storage type (alpha_ident, integer_literal).
2929
StringRef identifier;
@@ -32,20 +32,32 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
3232
if (result.has_value()) {
3333
if (!succeeded(*result))
3434
return nullptr;
35-
isSigned = !type.isUnsigned();
36-
storageTypeWidth = type.getWidth();
37-
} else if (succeeded(parser.parseKeyword(&identifier))) {
38-
// Otherwise, this must be an unsigned integer (`u` integer-literal).
39-
if (!identifier.consume_front("u")) {
40-
parser.emitError(typeLoc, "illegal storage type prefix");
35+
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
36+
isSigned = !intType.isUnsigned();
37+
storageTypeWidth = intType.getWidth();
38+
} else if (llvm::dyn_cast<Float8E5M2Type>(type) ||
39+
llvm::dyn_cast<Float8E4M3FNType>(type)) {
40+
storageTypeWidth = 8;
41+
isSigned = true;
42+
} else {
43+
parser.emitError(typeLoc, "illegal quantized storage type alias");
4144
return nullptr;
4245
}
43-
if (identifier.getAsInteger(10, storageTypeWidth)) {
44-
parser.emitError(typeLoc, "expected storage type width");
46+
} else if (succeeded(parser.parseKeyword(&identifier))) {
47+
// Otherwise, this must be an unsigned integer (`u` integer-literal)
48+
if (identifier.consume_front("u")) {
49+
if (identifier.getAsInteger(10, storageTypeWidth)) {
50+
parser.emitError(typeLoc, "expected storage type width");
51+
return nullptr;
52+
}
53+
isSigned = false;
54+
type = parser.getBuilder().getIntegerType(storageTypeWidth);
55+
56+
} else {
57+
parser.emitError(typeLoc, "illegal quantized storage type alias");
4558
return nullptr;
4659
}
47-
isSigned = false;
48-
type = parser.getBuilder().getIntegerType(storageTypeWidth);
60+
4961
} else {
5062
return nullptr;
5163
}
@@ -60,35 +72,56 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
6072
return type;
6173
}
6274

63-
static ParseResult parseStorageRange(DialectAsmParser &parser,
64-
IntegerType storageType, bool isSigned,
65-
int64_t &storageTypeMin,
75+
static ParseResult
76+
checkStorageRange(DialectAsmParser &parser, int64_t storageTypeMin,
77+
int64_t storageTypeMax, int64_t defaultStorageTypeMin,
78+
int64_t defaultStorageTypeMax, SMLoc minLoc, SMLoc maxLoc) {
79+
if (storageTypeMin < defaultStorageTypeMin) {
80+
return parser.emitError(minLoc, "illegal storage type minimum: ")
81+
<< storageTypeMin;
82+
}
83+
if (storageTypeMax > defaultStorageTypeMax) {
84+
return parser.emitError(maxLoc, "illegal storage type maximum: ")
85+
<< storageTypeMax;
86+
}
87+
return success();
88+
}
89+
90+
static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
91+
bool isSigned, int64_t &storageTypeMin,
6692
int64_t &storageTypeMax) {
67-
int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
68-
isSigned, storageType.getWidth());
69-
int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
70-
isSigned, storageType.getWidth());
93+
int64_t defaultMin, defaultMax;
94+
if (storageType.isa<IntegerType>()) {
95+
const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth();
96+
defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width);
97+
defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width);
98+
} else if (storageType.isa<Float8E5M2Type>()) {
99+
defaultMin = QuantizedType::getDefaultMinimumForF8E5M2();
100+
defaultMax = QuantizedType::getDefaultMaximumForF8E5M2();
101+
} else if (storageType.isa<Float8E4M3FNType>()) {
102+
defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN();
103+
defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN();
104+
} else {
105+
defaultMin = std::numeric_limits<int64_t>::max();
106+
defaultMax = std::numeric_limits<int64_t>::min();
107+
}
108+
71109
if (failed(parser.parseOptionalLess())) {
72-
storageTypeMin = defaultIntegerMin;
73-
storageTypeMax = defaultIntegerMax;
110+
storageTypeMin = defaultMin;
111+
storageTypeMax = defaultMax;
74112
return success();
75113
}
76114

77115
// Explicit storage min and storage max.
116+
// F8 min and max values are integers, so parseInteger() is used.
78117
SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
79118
if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
80119
parser.getCurrentLocation(&maxLoc) ||
81120
parser.parseInteger(storageTypeMax) || parser.parseGreater())
82121
return failure();
83-
if (storageTypeMin < defaultIntegerMin) {
84-
return parser.emitError(minLoc, "illegal storage type minimum: ")
85-
<< storageTypeMin;
86-
}
87-
if (storageTypeMax > defaultIntegerMax) {
88-
return parser.emitError(maxLoc, "illegal storage type maximum: ")
89-
<< storageTypeMax;
90-
}
91-
return success();
122+
123+
return checkStorageRange(parser, storageTypeMin, storageTypeMax, defaultMin,
124+
defaultMax, minLoc, maxLoc);
92125
}
93126

94127
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
@@ -118,7 +151,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
118151
/// storage-type ::= (`i` | `u`) integer-literal
119152
/// expressed-type-spec ::= `:` `f` integer-literal
120153
static Type parseAnyType(DialectAsmParser &parser) {
121-
IntegerType storageType;
154+
Type storageType;
122155
FloatType expressedType;
123156
unsigned typeFlags = 0;
124157
int64_t storageTypeMin;
@@ -192,7 +225,7 @@ static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
192225
/// scale-zero ::= float-literal `:` integer-literal
193226
/// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
194227
static Type parseUniformType(DialectAsmParser &parser) {
195-
IntegerType storageType;
228+
Type storageType;
196229
FloatType expressedType;
197230
unsigned typeFlags = 0;
198231
int64_t storageTypeMin;
@@ -339,14 +372,37 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
339372
// storage type
340373
unsigned storageWidth = type.getStorageTypeIntegralWidth();
341374
bool isSigned = type.isSigned();
342-
if (isSigned) {
375+
if (type.getStorageType().isa<Float8E5M2Type>()) {
376+
out << "f8E5M2";
377+
} else if (type.getStorageType().isa<Float8E4M3FNType>()) {
378+
out << "f8E4M3FN";
379+
} else if (isSigned) {
343380
out << "i" << storageWidth;
344381
} else {
345382
out << "u" << storageWidth;
346383
}
347384

348385
// storageTypeMin and storageTypeMax if not default.
349-
if (type.hasStorageTypeBounds()) {
386+
int64_t defaultMin =
387+
type.getStorageType().isa<IntegerType>()
388+
? QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth)
389+
: type.getStorageType().isa<Float8E5M2Type>()
390+
? QuantizedType::getDefaultMinimumForF8E5M2()
391+
: type.getStorageType().isa<Float8E4M3FNType>()
392+
? QuantizedType::getDefaultMinimumForF8E4M3FN()
393+
: std::numeric_limits<int64_t>::max();
394+
395+
int64_t defaultMax =
396+
type.getStorageType().isa<IntegerType>()
397+
? QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth)
398+
: type.getStorageType().isa<Float8E5M2Type>()
399+
? QuantizedType::getDefaultMaximumForF8E5M2()
400+
: type.getStorageType().isa<Float8E4M3FNType>()
401+
? QuantizedType::getDefaultMaximumForF8E4M3FN()
402+
: std::numeric_limits<int64_t>::min();
403+
404+
if (defaultMin != type.getStorageTypeMin() ||
405+
defaultMax != type.getStorageTypeMax()) {
350406
out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
351407
<< ">";
352408
}

mlir/test/Dialect/Quant/parse-any-invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
// -----
1919
// Unrecognized storage type: illegal prefix
20-
// expected-error@+1 {{illegal storage type prefix}}
20+
// expected-error@+1 {{illegal quantized storage type alias}}
2121
!qalias = !quant.any<int8<-4:3>:f32>
2222

2323
// -----
2424
// Unrecognized storage type: no width
25-
// expected-error@+1 {{illegal storage type prefix}}
25+
// expected-error@+1 {{illegal quantized storage type alias}}
2626
!qalias = !quant.any<i<-4:3>:f32>
2727

2828
// -----

mlir/test/Dialect/Quant/parse-uniform-invalid.mlir

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@
3737

3838
// -----
3939
// Unrecognized storage type: illegal prefix
40-
// expected-error@+1 {{illegal storage type prefix}}
40+
// expected-error@+1 {{illegal quantized storage type alias}}
4141
!qalias = !quant.uniform<int8<-4:3>:f32, 0.99872:127>
4242

4343
// -----
4444
// Unrecognized storage type: no width
45-
// expected-error@+1 {{illegal storage type prefix}}
45+
// expected-error@+1 {{illegal quantized storage type alias}}
4646
!qalias = !quant.uniform<i<-4:3>:f32, 0.99872:127>
4747

4848
// -----
@@ -52,7 +52,7 @@
5252

5353
// -----
5454
// Unrecognized storage type: storage size < 0
55-
// expected-error@+1 {{illegal storage type prefix}}
55+
// expected-error@+1 {{illegal quantized storage type alias}}
5656
!qalias = !quant.uniform<i-1<-4:3>:f32, 0.99872:127>
5757

5858
// -----
@@ -80,6 +80,26 @@
8080
// expected-error@+1 {{illegal storage type minimum: -9}}
8181
!qalias = !quant.uniform<i4<-9:1>:f32, 0.99872:127>
8282

83+
// -----
84+
// Illegal storage min/max: max > defaultMax
85+
// expected-error@+1 {{illegal storage type maximum: 60000}}
86+
!qalias = !quant.uniform<f8E5M2<-57344:60000>:f32, 0.99872:127>
87+
88+
// -----
89+
// Illegal storage min/max: min < defaultMin
90+
// expected-error@+1 {{illegal storage type minimum: -60000}}
91+
!qalias = !quant.uniform<f8E5M2<-60000:57344>:f32, 0.99872:127>
92+
93+
// -----
94+
// Illegal storage min/max: max > defaultMax
95+
// expected-error@+1 {{illegal storage type maximum: 500}}
96+
!qalias = !quant.uniform<f8E4M3FN<-448:500>:f32, 0.99872:127>
97+
98+
// -----
99+
// Illegal storage min/max: min < defaultMin
100+
// expected-error@+1 {{illegal storage type minimum: -500}}
101+
!qalias = !quant.uniform<f8E4M3FN<-500:448>:f32, 0.99872:127>
102+
83103
// -----
84104
// Illegal uniform params: invalid scale
85105
// expected-error@+1 {{expected floating point literal}}

mlir/test/Dialect/Quant/parse-uniform.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,33 @@ func.func @parse() -> !qalias {
1919
return %0 : !qalias
2020
}
2121

22+
// -----
23+
// Default min/max value optimization for integers.
24+
// CHECK: !quant.uniform<i8:f32, 9.987200e-01:127>
25+
!qalias = !quant.uniform<i8<-128:127>:f32, 0.99872:127 >
26+
func.func @parse() -> !qalias {
27+
%0 = "foo"() : () -> !qalias
28+
return %0 : !qalias
29+
}
30+
31+
// -----
32+
// Default min/max value optimization for f8E5M2.
33+
// CHECK: !quant.uniform<f8E5M2:f32, 9.987200e-01:127>
34+
!qalias = !quant.uniform<f8E5M2<-57344:57344>:f32, 0.99872:127 >
35+
func.func @parse() -> !qalias {
36+
%0 = "foo"() : () -> !qalias
37+
return %0 : !qalias
38+
}
39+
40+
// -----
41+
// Default min/max value optimization for f8E4M3FN.
42+
// CHECK: !quant.uniform<f8E4M3FN:f32, 9.987200e-01:127>
43+
!qalias = !quant.uniform<f8E4M3FN<-448:448>:f32, 0.99872:127 >
44+
func.func @parse() -> !qalias {
45+
%0 = "foo"() : () -> !qalias
46+
return %0 : !qalias
47+
}
48+
2249
// -----
2350
// Required per-layer params specified:
2451
// [unsigned] storageType, expressedType, scale
@@ -47,6 +74,24 @@ func.func @parse() -> !qalias {
4774
return %0 : !qalias
4875
}
4976

77+
// -----
78+
// Storage type: f8E5M2
79+
// CHECK: !quant.uniform<f8E5M2:f32, 2.000000e+02>
80+
!qalias = !quant.uniform<f8E5M2:f32, 2.0e+2>
81+
func.func @parse() -> !qalias {
82+
%0 = "foo"() : () -> !qalias
83+
return %0 : !qalias
84+
}
85+
86+
// -----
87+
// Storage type: f8E4M3FN
88+
// CHECK: !quant.uniform<f8E4M3FN:f32, 2.000000e+02>
89+
!qalias = !quant.uniform<f8E4M3FN:f32, 2.0e+2>
90+
func.func @parse() -> !qalias {
91+
%0 = "foo"() : () -> !qalias
92+
return %0 : !qalias
93+
}
94+
5095
// -----
5196
// Storage type: i16
5297
// CHECK: !quant.uniform<i16:f32, 2.000000e+02>

0 commit comments

Comments
 (0)