-
Notifications
You must be signed in to change notification settings - Fork 34
Extending UniformQuantizedType with interface-based support for new storage types in Quant dialect #149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: npu/release/19.x
Are you sure you want to change the base?
Extending UniformQuantizedType with interface-based support for new storage types in Quant dialect #149
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| include "mlir/IR/AttrTypeBase.td" | ||
| include "mlir/IR/BuiltinDialect.td" | ||
| include "mlir/IR/BuiltinTypeInterfaces.td" | ||
| include "mlir/IR/QuantizationInterface.td" | ||
|
|
||
| // TODO: Currently the types defined in this file are prefixed with `Builtin_`. | ||
| // This is to differentiate the types here with the ones in OpBase.td. We should | ||
|
|
@@ -78,8 +79,8 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> { | |
| //===----------------------------------------------------------------------===// | ||
|
|
||
| // Base class for Builtin dialect float types. | ||
| class Builtin_FloatType<string name, string mnemonic> | ||
| : Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> { | ||
| class Builtin_FloatType<string name, string mnemonic, list<Trait> traits = []> | ||
| : Builtin_Type<name, mnemonic, traits, "::mlir::FloatType"> { | ||
| let extraClassDeclaration = [{ | ||
| static }] # name # [{Type get(MLIRContext *context); | ||
| }]; | ||
|
|
@@ -88,7 +89,8 @@ class Builtin_FloatType<string name, string mnemonic> | |
| //===----------------------------------------------------------------------===// | ||
| // Float8E5M2Type | ||
|
|
||
| def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> { | ||
| def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2", | ||
| [QuantizationInterface]> { | ||
| let summary = "8-bit floating point with 2 bit mantissa"; | ||
| let description = [{ | ||
| An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits | ||
|
|
@@ -104,6 +106,23 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> { | |
|
|
||
| Described in: https://arxiv.org/abs/2209.05433 | ||
| }]; | ||
|
|
||
| let extraClassDeclaration = [{ | ||
| static Float8E5M2Type get(MLIRContext *context); | ||
|
|
||
| /// QuantizationInterface method implementations | ||
| bool isStorageSigned() const { return true; } | ||
| unsigned getStorageWidth() const { return 8; } | ||
| int64_t getDefaultMaximum([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned integralWidth) const { | ||
| return 448; | ||
| } | ||
| int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const { | ||
| return -getDefaultMaximum(isSigned, integralWidth); | ||
| } | ||
| std::string printStorageType([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned storageWidth) const { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure we should have print method here. Isn't there more canonical way to stringify type name? |
||
| return "f8E5M2"; | ||
| } | ||
| }]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -128,7 +147,8 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> { | |
| //===----------------------------------------------------------------------===// | ||
| // Float8E4M3FNType | ||
|
|
||
| def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> { | ||
| def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN", | ||
| [QuantizationInterface]> { | ||
| let summary = "8-bit floating point with 3 bit mantissa"; | ||
| let description = [{ | ||
| An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits | ||
|
|
@@ -145,6 +165,23 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> { | |
|
|
||
| Described in: https://arxiv.org/abs/2209.05433 | ||
| }]; | ||
|
|
||
| let extraClassDeclaration = [{ | ||
| static Float8E4M3FNType get(MLIRContext *context); | ||
|
|
||
| /// QuantizationInterface method implementations | ||
| bool isStorageSigned() const { return true; } | ||
| unsigned getStorageWidth() const { return 8; } | ||
| int64_t getDefaultMaximum([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned integralWidth) const { | ||
| return 57344; | ||
| } | ||
| int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const{ | ||
| return -getDefaultMaximum(isSigned, integralWidth); | ||
| } | ||
| std::string printStorageType([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned storageWidth) const { | ||
| return "f8E4M3FN"; | ||
| } | ||
| }]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -358,7 +395,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> { | |
| // IntegerType | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| def Builtin_Integer : Builtin_Type<"Integer", "integer"> { | ||
| def Builtin_Integer : Builtin_Type<"Integer", "integer", | ||
| [QuantizationInterface]> { | ||
| let summary = "Integer type with arbitrary precision up to a fixed limit"; | ||
| let description = [{ | ||
| Syntax: | ||
|
|
@@ -415,6 +453,25 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> { | |
| /// Integer representation maximal bitwidth. | ||
| /// Note: This is aligned with the maximum width of llvm::IntegerType. | ||
| static constexpr unsigned kMaxWidth = (1 << 24) - 1; | ||
|
|
||
| /// QuantizationInterface method implementations | ||
| bool isStorageSigned() const { return !isUnsigned(); } | ||
| unsigned getStorageWidth() const { return getWidth(); } | ||
| int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const { | ||
| if (isSigned) { | ||
| return llvm::minIntN(integralWidth); | ||
| } | ||
| return 0; | ||
| } | ||
| int64_t getDefaultMaximum(bool isSigned, unsigned integralWidth) const { | ||
| if (isSigned) { | ||
| return llvm::maxIntN(integralWidth); | ||
| } | ||
| return llvm::maxUIntN(integralWidth); | ||
| } | ||
| std::string printStorageType(bool isSigned, unsigned storageWidth) const { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to pass these argument from outside? |
||
| return (isSigned ? "i" : "u") + std::to_string(storageWidth); | ||
| } | ||
| }]; | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| //===- QuantizationInterface.h - Quantile Float Interfaces --------*- C++ | ||
| //-*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_IR_QuantizationInterface_H | ||
| #define MLIR_IR_QuantizationInterface_H | ||
|
|
||
| #include "mlir/IR/Types.h" | ||
|
|
||
| // Forward declarations for the types we need in the implementation | ||
| namespace mlir { | ||
| class IntegerType; | ||
| class FloatType; | ||
| } // namespace mlir | ||
|
|
||
| #include "mlir/IR/QuantizationInterface.h.inc" | ||
|
|
||
| #endif // MLIR_IR_QuantizationInterface_H |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| #ifndef MLIR_IR_QUANTIZATIONINTERFACE | ||
| #define MLIR_IR_QUANTIZATIONINTERFACE | ||
|
|
||
| include "mlir/IR/OpBase.td" | ||
|
|
||
| def QuantizationInterface : TypeInterface<"QuantizationInterface"> { | ||
| let description = [{ | ||
| Interface for types that can be used as quantile storage types. | ||
| This interface provides methods to determine storage characteristics | ||
| like width and signedness for quantization purposes. | ||
| }]; | ||
| let cppNamespace = "::mlir"; | ||
|
|
||
| let methods = [ | ||
| InterfaceMethod<[{ | ||
| Get the storage type width in bits. | ||
| Returns the number of bits used to store values of this type. | ||
| }], | ||
| "unsigned", "getStorageWidth", (ins)>, | ||
|
|
||
| InterfaceMethod<[{ | ||
| Check if the storage type is signed. | ||
| Returns true if the type represents signed values, false for unsigned. | ||
| }], | ||
| "bool", "isStorageSigned", (ins)>, | ||
|
|
||
| InterfaceMethod<[{ | ||
| Get the default minimum value for the storage type. | ||
| }], | ||
| "int64_t", "getDefaultMinimum", (ins "bool":$isSigned, "unsigned":$integralWidth)>, | ||
|
|
||
| InterfaceMethod<[{ | ||
| Get the default maximum value for the storage type. | ||
| }], | ||
| "int64_t", "getDefaultMaximum", (ins "bool":$isSigned, "unsigned":$integralWidth)>, | ||
|
|
||
| InterfaceMethod<[{ | ||
| Get the name of the storage type. | ||
| }], | ||
| "std::string", "printStorageType", (ins "bool":$isSigned, "unsigned":$storageWidth)> | ||
| ]; | ||
|
|
||
| } | ||
|
|
||
| #endif // MLIR_IR_QUANTIZATIONINTERFACE |
Uh oh!
There was an error while loading. Please reload this page.