1717include "mlir/IR/AttrTypeBase.td"
1818include "mlir/IR/BuiltinDialect.td"
1919include "mlir/IR/BuiltinTypeInterfaces.td"
20+ include "mlir/IR/QuantizationInterface.td"
2021
2122// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
2223// This is to differentiate the types here with the ones in OpBase.td. We should
@@ -78,8 +79,8 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
7879//===----------------------------------------------------------------------===//
7980
8081// Base class for Builtin dialect float types.
81- class Builtin_FloatType<string name, string mnemonic>
82- : Builtin_Type<name, mnemonic, /* traits=*/[] , "::mlir::FloatType"> {
82+ class Builtin_FloatType<string name, string mnemonic, list<Trait> traits = [] >
83+ : Builtin_Type<name, mnemonic, traits, "::mlir::FloatType"> {
8384 let extraClassDeclaration = [{
8485 static }] # name # [{Type get(MLIRContext *context);
8586 }];
@@ -88,7 +89,8 @@ class Builtin_FloatType<string name, string mnemonic>
8889//===----------------------------------------------------------------------===//
8990// Float8E5M2Type
9091
91- def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
92+ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
93+ [QuantizationInterface]> {
9294 let summary = "8-bit floating point with 2 bit mantissa";
9395 let description = [{
9496 An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
@@ -104,6 +106,23 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
104106
105107 Described in: https://arxiv.org/abs/2209.05433
106108 }];
109+
110+ let extraClassDeclaration = [{
111+ static Float8E5M2Type get(MLIRContext *context);
112+
113+ /// QuantizationInterface method implementations
114+ bool isStorageSigned() const { return true; }
115+ unsigned getStorageWidth() const { return 8; }
116+ int64_t getDefaultMaximum([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned integralWidth) const {
117+ return 448;
118+ }
119+ int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const {
120+ return -getDefaultMaximum(isSigned, integralWidth);
121+ }
122+ std::string printStorageType([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned storageWidth) const {
123+ return "f8E5M2";
124+ }
125+ }];
107126}
108127
109128//===----------------------------------------------------------------------===//
@@ -128,7 +147,8 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
128147//===----------------------------------------------------------------------===//
129148// Float8E4M3FNType
130149
131- def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
150+ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
151+ [QuantizationInterface]> {
132152 let summary = "8-bit floating point with 3 bit mantissa";
133153 let description = [{
134154 An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
@@ -145,6 +165,23 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
145165
146166 Described in: https://arxiv.org/abs/2209.05433
147167 }];
168+
169+ let extraClassDeclaration = [{
170+ static Float8E4M3FNType get(MLIRContext *context);
171+
172+ /// QuantizationInterface method implementations
173+ bool isStorageSigned() const { return true; }
174+ unsigned getStorageWidth() const { return 8; }
175+ int64_t getDefaultMaximum([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned integralWidth) const {
176+ return 57344;
177+ }
178+ int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const{
179+ return -getDefaultMaximum(isSigned, integralWidth);
180+ }
181+ std::string printStorageType([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned storageWidth) const {
182+ return "f8E4M3FN";
183+ }
184+ }];
148185}
149186
150187//===----------------------------------------------------------------------===//
@@ -358,7 +395,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
358395// IntegerType
359396//===----------------------------------------------------------------------===//
360397
361- def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
398+ def Builtin_Integer : Builtin_Type<"Integer", "integer",
399+ [QuantizationInterface]> {
362400 let summary = "Integer type with arbitrary precision up to a fixed limit";
363401 let description = [{
364402 Syntax:
@@ -415,6 +453,25 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
415453 /// Integer representation maximal bitwidth.
416454 /// Note: This is aligned with the maximum width of llvm::IntegerType.
417455 static constexpr unsigned kMaxWidth = (1 << 24) - 1;
456+
457+ /// QuantizationInterface method implementations
458+ bool isStorageSigned() const { return !isUnsigned(); }
459+ unsigned getStorageWidth() const { return getWidth(); }
460+ int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const {
461+ if (isSigned) {
462+ return llvm::minIntN(integralWidth);
463+ }
464+ return 0;
465+ }
466+ int64_t getDefaultMaximum(bool isSigned, unsigned integralWidth) const {
467+ if (isSigned) {
468+ return llvm::maxIntN(integralWidth);
469+ }
470+ return llvm::maxUIntN(integralWidth);
471+ }
472+ std::string printStorageType(bool isSigned, unsigned storageWidth) const {
473+ return (isSigned ? "i" : "u") + std::to_string(storageWidth);
474+ }
418475 }];
419476}
420477
0 commit comments