diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake index a3324705c525..9301ce4a4170 100644 --- a/mlir/cmake/modules/AddMLIR.cmake +++ b/mlir/cmake/modules/AddMLIR.cmake @@ -196,6 +196,14 @@ function(add_mlir_interface interface) add_dependencies(mlir-generic-headers MLIR${interface}IncGen) endfunction() +# Declare a dialect in the include directory +function(add_mlir_type_interface interface) + set(LLVM_TARGET_DEFINITIONS ${interface}.td) + mlir_tablegen(${interface}.h.inc -gen-type-interface-decls) + mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs) + add_public_tablegen_target(MLIR${interface}IncGen) + add_dependencies(mlir-generic-headers MLIR${interface}IncGen) +endfunction() # Generate Documentation function(add_mlir_doc doc_filename output_file output_directory command) diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 93de0919460c..e92ccb17ac11 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -91,6 +91,9 @@ class FloatType : public Type { // Tablegen Type Declarations //===----------------------------------------------------------------------===// +// Include QuantizationInterface before BuiltinTypes to resolve dependencies +#include "mlir/IR/QuantizationInterface.h" + #define GET_TYPEDEF_CLASSES #include "mlir/IR/BuiltinTypes.h.inc" diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 56ec6d97433f..689a854feffb 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -17,6 +17,7 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinDialect.td" include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/IR/QuantizationInterface.td" // TODO: Currently the types defined in this file are prefixed with `Builtin_`. // This is to differentiate the types here with the ones in OpBase.td. We should @@ -78,8 +79,8 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> { //===----------------------------------------------------------------------===// // Base class for Builtin dialect float types. -class Builtin_FloatType - : Builtin_Type { +class Builtin_FloatType traits = []> + : Builtin_Type { let extraClassDeclaration = [{ static }] # name # [{Type get(MLIRContext *context); }]; @@ -88,7 +89,8 @@ class Builtin_FloatType //===----------------------------------------------------------------------===// // Float8E5M2Type -def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> { +def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2", + [QuantizationInterface]> { let summary = "8-bit floating point with 2 bit mantissa"; let description = [{ An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits @@ -104,6 +106,23 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> { Described in: https://arxiv.org/abs/2209.05433 }]; + + let extraClassDeclaration = [{ + static Float8E5M2Type get(MLIRContext *context); + + /// QuantizationInterface method implementations + bool isStorageSigned() const { return true; } + unsigned getStorageWidth() const { return 8; } + int64_t getDefaultMaximum([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned integralWidth) const { + return 448; + } + int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const { + return -getDefaultMaximum(isSigned, integralWidth); + } + std::string printStorageType([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned storageWidth) const { + return "f8E5M2"; + } + }]; } //===----------------------------------------------------------------------===// @@ -128,7 +147,8 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> { //===----------------------------------------------------------------------===// // Float8E4M3FNType -def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> { +def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN", + [QuantizationInterface]> { let summary = "8-bit floating point with 3 bit mantissa"; let description = [{ An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits @@ -145,6 +165,23 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> { Described in: https://arxiv.org/abs/2209.05433 }]; + + let extraClassDeclaration = [{ + static Float8E4M3FNType get(MLIRContext *context); + + /// QuantizationInterface method implementations + bool isStorageSigned() const { return true; } + unsigned getStorageWidth() const { return 8; } + int64_t getDefaultMaximum([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned integralWidth) const { + return 57344; + } + int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const{ + return -getDefaultMaximum(isSigned, integralWidth); + } + std::string printStorageType([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned storageWidth) const { + return "f8E4M3FN"; + } + }]; } //===----------------------------------------------------------------------===// @@ -358,7 +395,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> { // IntegerType //===----------------------------------------------------------------------===// -def Builtin_Integer : Builtin_Type<"Integer", "integer"> { +def Builtin_Integer : Builtin_Type<"Integer", "integer", + [QuantizationInterface]> { let summary = "Integer type with arbitrary precision up to a fixed limit"; let description = [{ Syntax: @@ -415,6 +453,25 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> { /// Integer representation maximal bitwidth. /// Note: This is aligned with the maximum width of llvm::IntegerType. static constexpr unsigned kMaxWidth = (1 << 24) - 1; + + /// QuantizationInterface method implementations + bool isStorageSigned() const { return !isUnsigned(); } + unsigned getStorageWidth() const { return getWidth(); } + int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const { + if (isSigned) { + return llvm::minIntN(integralWidth); + } + return 0; + } + int64_t getDefaultMaximum(bool isSigned, unsigned integralWidth) const { + if (isSigned) { + return llvm::maxIntN(integralWidth); + } + return llvm::maxUIntN(integralWidth); + } + std::string printStorageType(bool isSigned, unsigned storageWidth) const { + return (isSigned ? "i" : "u") + std::to_string(storageWidth); + } }]; } diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt index 04a57d26a068..1bc163c25bfe 100644 --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -2,6 +2,8 @@ add_mlir_interface(OpAsmInterface) add_mlir_interface(SymbolInterfaces) add_mlir_interface(RegionKindInterface) +add_mlir_type_interface(QuantizationInterface) + set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td) mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls) mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs) diff --git a/mlir/include/mlir/IR/QuantizationInterface.h b/mlir/include/mlir/IR/QuantizationInterface.h new file mode 100644 index 000000000000..00a3a9484808 --- /dev/null +++ b/mlir/include/mlir/IR/QuantizationInterface.h @@ -0,0 +1,23 @@ +//===- QuantizationInterface.h - Quantile Float Interfaces --------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_QuantizationInterface_H +#define MLIR_IR_QuantizationInterface_H + +#include "mlir/IR/Types.h" + +// Forward declarations for the types we need in the implementation +namespace mlir { +class IntegerType; +class FloatType; +} // namespace mlir + +#include "mlir/IR/QuantizationInterface.h.inc" + +#endif // MLIR_IR_QuantizationInterface_H diff --git a/mlir/include/mlir/IR/QuantizationInterface.td b/mlir/include/mlir/IR/QuantizationInterface.td new file mode 100644 index 000000000000..9be6753cff2b --- /dev/null +++ b/mlir/include/mlir/IR/QuantizationInterface.td @@ -0,0 +1,45 @@ +#ifndef MLIR_IR_QUANTIZATIONINTERFACE +#define MLIR_IR_QUANTIZATIONINTERFACE + +include "mlir/IR/OpBase.td" + +def QuantizationInterface : TypeInterface<"QuantizationInterface"> { + let description = [{ + Interface for types that can be used as quantile storage types. + This interface provides methods to determine storage characteristics + like width and signedness for quantization purposes. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Get the storage type width in bits. + Returns the number of bits used to store values of this type. + }], + "unsigned", "getStorageWidth", (ins)>, + + InterfaceMethod<[{ + Check if the storage type is signed. + Returns true if the type represents signed values, false for unsigned. + }], + "bool", "isStorageSigned", (ins)>, + + InterfaceMethod<[{ + Get the default minimum value for the storage type. + }], + "int64_t", "getDefaultMinimum", (ins "bool":$isSigned, "unsigned":$integralWidth)>, + + InterfaceMethod<[{ + Get the default maximum value for the storage type. + }], + "int64_t", "getDefaultMaximum", (ins "bool":$isSigned, "unsigned":$integralWidth)>, + + InterfaceMethod<[{ + Get the name of the storage type. + }], + "std::string", "printStorageType", (ins "bool":$isSigned, "unsigned":$storageWidth)> + ]; + +} + +#endif // MLIR_IR_QUANTIZATIONINTERFACE diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 5a3500ec4278..7bdbe54f7f0d 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/QuantizationInterface.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/MathExtras.h" @@ -32,7 +33,6 @@ LogicalResult QuantizedType::verify(function_ref emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax) { - bool isSigned = (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed; @@ -46,16 +46,11 @@ QuantizedType::verify(function_ref emitError, } int64_t defaultMin, defaultMax; - if (storageType.isa()) { - const auto width = llvm::dyn_cast(storageType).getWidth(); - defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width); - defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width); - } else if (storageType.isa()) { - defaultMin = QuantizedType::getDefaultMinimumForF8E5M2(); - defaultMax = QuantizedType::getDefaultMaximumForF8E5M2(); - } else if (storageType.isa()) { - defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN(); - defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN(); + if (auto quantizationInterface = + llvm::dyn_cast(storageType)) { + const auto width = quantizationInterface.getStorageWidth(); + defaultMin = quantizationInterface.getDefaultMinimum(isSigned, width); + defaultMax = quantizationInterface.getDefaultMaximum(isSigned, width); } else { return emitError() << "illegal storage type, supported types are: integral " "types, Float8E4M3FNType and Float8E5M2Type "; @@ -75,17 +70,42 @@ Type QuantizedType::getStorageType() const { } int64_t QuantizedType::getStorageTypeMin() const { + Type storageType = static_cast(impl)->storageType; + + if (auto quantizationInterface = + llvm::dyn_cast(storageType)) { + unsigned storageWidth = quantizationInterface.getStorageWidth(); + bool isSigned = quantizationInterface.isStorageSigned(); + return quantizationInterface.getDefaultMinimum(isSigned, storageWidth); + } + return static_cast(impl)->storageTypeMin; } int64_t QuantizedType::getStorageTypeMax() const { + Type storageType = static_cast(impl)->storageType; + + if (auto quantizationInterface = + llvm::dyn_cast(storageType)) { + unsigned storageWidth = quantizationInterface.getStorageWidth(); + bool isSigned = quantizationInterface.isStorageSigned(); + return quantizationInterface.getDefaultMaximum(isSigned, storageWidth); + } + return static_cast(impl)->storageTypeMax; } unsigned QuantizedType::getStorageTypeIntegralWidth() const { // NOTE: If ever supporting non-integral storage types, some other scheme // for determining the width will be needed. - return static_cast(impl)->storageType.getIntOrFloatBitWidth(); + Type storageType = static_cast(impl)->storageType; + + if (auto quantizationInterface = + llvm::dyn_cast(storageType)) { + return quantizationInterface.getStorageWidth(); + } + + return storageType.getIntOrFloatBitWidth(); } Type QuantizedType::getExpressedType() const { @@ -282,6 +302,7 @@ LogicalResult UniformQuantizedType::verify( function_ref emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) { + if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType, storageTypeMin, storageTypeMax))) { return failure(); diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index 1fd148dd4736..a2823c1d7533 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Location.h" +#include "mlir/IR/QuantizationInterface.h" #include "mlir/IR/Types.h" #include "llvm/ADT/APFloat.h" #include "llvm/Support/Format.h" @@ -28,17 +29,17 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) { // Parse storage type (alpha_ident, integer_literal). StringRef identifier; unsigned storageTypeWidth = 0; + OptionalParseResult result = parser.parseOptionalType(type); if (result.has_value()) { if (!succeeded(*result)) return nullptr; - if (auto intType = llvm::dyn_cast(type)) { - isSigned = !intType.isUnsigned(); - storageTypeWidth = intType.getWidth(); - } else if (llvm::dyn_cast(type) || - llvm::dyn_cast(type)) { - storageTypeWidth = 8; - isSigned = true; + + if (auto quantizationInterface = + llvm::dyn_cast(type)) { + isSigned = + quantizationInterface.isStorageSigned(); // Change name or logic + storageTypeWidth = quantizationInterface.getStorageWidth(); } else { parser.emitError(typeLoc, "illegal quantized storage type alias"); return nullptr; @@ -128,16 +129,12 @@ static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax) { int64_t defaultMin, defaultMax; - if (storageType.isa()) { - const auto width = llvm::dyn_cast(storageType).getWidth(); - defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width); - defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width); - } else if (storageType.isa()) { - defaultMin = QuantizedType::getDefaultMinimumForF8E5M2(); - defaultMax = QuantizedType::getDefaultMaximumForF8E5M2(); - } else if (storageType.isa()) { - defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN(); - defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN(); + if (auto quantizationInterface = + llvm::dyn_cast(storageType)) { + const auto width = quantizationInterface.getStorageWidth(); + + defaultMin = quantizationInterface.getDefaultMinimum(isSigned, width); + defaultMax = quantizationInterface.getDefaultMaximum(isSigned, width); } else { defaultMin = std::numeric_limits::max(); defaultMax = std::numeric_limits::min(); @@ -484,34 +481,21 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { // storage type unsigned storageWidth = type.getStorageTypeIntegralWidth(); bool isSigned = type.isSigned(); - if (type.getStorageType().isa()) { - out << "f8E5M2"; - } else if (type.getStorageType().isa()) { - out << "f8E4M3FN"; - } else if (isSigned) { - out << "i" << storageWidth; + int64_t defaultMin, defaultMax; + + if (auto quantizationInterface = + llvm::dyn_cast(type.getStorageType())) { + out << quantizationInterface.printStorageType(isSigned, storageWidth); + + defaultMin = + quantizationInterface.getDefaultMinimum(isSigned, storageWidth); + defaultMax = + quantizationInterface.getDefaultMaximum(isSigned, storageWidth); + } else { - out << "u" << storageWidth; - } - - // storageTypeMin and storageTypeMax if not default. - int64_t defaultMin = - type.getStorageType().isa() - ? QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth) - : type.getStorageType().isa() - ? QuantizedType::getDefaultMinimumForF8E5M2() - : type.getStorageType().isa() - ? QuantizedType::getDefaultMinimumForF8E4M3FN() - : std::numeric_limits::max(); - - int64_t defaultMax = - type.getStorageType().isa() - ? QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth) - : type.getStorageType().isa() - ? QuantizedType::getDefaultMaximumForF8E5M2() - : type.getStorageType().isa() - ? QuantizedType::getDefaultMaximumForF8E4M3FN() - : std::numeric_limits::min(); + defaultMin = std::numeric_limits::max(); + defaultMax = std::numeric_limits::min(); + } if (defaultMin != type.getStorageTypeMin() || defaultMax != type.getStorageTypeMax()) { diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt index c38ce6c058a0..e8aa2b83b380 100644 --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -30,6 +30,7 @@ add_mlir_library(MLIRIR Operation.cpp OperationSupport.cpp PatternMatch.cpp + QuantizationInterface.cpp Region.cpp RegionKindInterface.cpp SymbolTable.cpp @@ -64,6 +65,7 @@ add_mlir_library(MLIRIR MLIRSideEffectInterfacesIncGen MLIRSymbolInterfacesIncGen MLIRTensorEncodingIncGen + MLIRQuantizationInterfaceIncGen LINK_LIBS PUBLIC MLIRSupport diff --git a/mlir/lib/IR/QuantizationInterface.cpp b/mlir/lib/IR/QuantizationInterface.cpp new file mode 100644 index 000000000000..fba89183da59 --- /dev/null +++ b/mlir/lib/IR/QuantizationInterface.cpp @@ -0,0 +1,22 @@ +//===- QuantizationInterface.cpp +//------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "llvm/ADT/Sequence.h" + +using namespace mlir; +using namespace mlir::detail; + +//===----------------------------------------------------------------------===// +/// Tablegen Interface Definitions +//===----------------------------------------------------------------------===// + +#include "mlir/IR/QuantizationInterface.cpp.inc"