From 90a7378cb60fa8150028ea581f75f43f9c334213 Mon Sep 17 00:00:00 2001 From: Luca Sarti Date: Fri, 20 Sep 2024 17:14:18 +0000 Subject: [PATCH 1/5] Extending QuantileQuantizedType and QuantileQuantizedPerAxisType with quantileType mlir::Type member --- .../Dialect/Quant/QuantDialectBytecode.td | 4 +- .../mlir/Dialect/Quant/QuantOpsBase.td | 20 +++ mlir/include/mlir/Dialect/Quant/QuantTypes.h | 61 +++++---- mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 92 ++++++++------ mlir/lib/Dialect/Quant/IR/TypeDetail.h | 39 ++++-- mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 116 +++++++++++++++--- mlir/test/Dialect/Quant/Bytecode/types.mlir | 16 ++- .../Dialect/Quant/parse-quantile-invalid.mlir | 84 ++++++++----- mlir/test/Dialect/Quant/parse-quantile.mlir | 72 +++++------ 9 files changed, 344 insertions(+), 160 deletions(-) diff --git a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td index 4ccfa08d8fe9..6c1e2b01f4ca 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td +++ b/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td @@ -84,6 +84,7 @@ def UniformQuantizedPerAxisType: DialectType<(type def QuantileQuantizedType: DialectType<(type VarInt:$flags, Type:$storageType, + Type:$quantileType, Type:$expressedType, Array:$quantiles, DoubleAPFloat:$scale, @@ -95,6 +96,7 @@ def QuantileQuantizedType: DialectType<(type def QuantileQuantizedPerAxisType: DialectType<(type VarInt:$flags, Type:$storageType, + Type:$quantileType, Type:$expressedType, VarInt:$quantizedDimension, SignedVarInt:$storageTypeMin, @@ -105,7 +107,7 @@ def QuantileQuantizedPerAxisType: DialectType<(type )> { // Note: builder order differs from bytecode. let cBuilder = [{ - get<$_resultType>(context, flags, storageType, expressedType, quantiles, scales, + get<$_resultType>(context, flags, storageType, quantileType, expressedType, quantiles, scales, zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax) }]; } diff --git a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td index 4b6bc7c910ab..820219c1ed17 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td +++ b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td @@ -67,18 +67,38 @@ def quant_UniformQuantizedType : CPred<"::llvm::isa($_self)">, "UniformQuantizedType">; +// An implementation of UniformQuantizedPerAxisType. +def quant_UniformQuantizedPerAxisType : + DialectType($_self)">, + "UniformQuantizedPerAxisType">; + // An implementation of QuantileQuantizedType. def quant_QuantileQuantizedType : DialectType($_self)">, "QuantileQuantizedType">; +// An implementation of QuantileQuantizedPerAxisType. +def quant_QuantileQuantizedPerAxisType : + DialectType($_self)">, + "QuantileQuantizedPerAxisType">; + // Predicate for detecting a container or primitive of UniformQuantizedType. def quant_UniformQuantizedValueType : quant_TypedPrimitiveOrContainer; +// Predicate for detecting a container or primitive of UniformQuantizedPerAxisType. +def quant_UniformQuantizedPerAxisValueType : + quant_TypedPrimitiveOrContainer; + // Predicate for detecting a container or primitive of QuantileQuantizedType. def quant_QuantileQuantizedValueType : quant_TypedPrimitiveOrContainer; +// Predicate for detecting a container or primitive of QuantileQuantizedPerAxisType. +def quant_QuantileQuantizedPerAxisValueType : + quant_TypedPrimitiveOrContainer; + #endif // DIALECT_QUANT_QUANT_OPS_BASE_ diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h index 02eed19ddbe4..a53a342fe52a 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -16,6 +16,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Types.h" #include "llvm/Support/MathExtras.h" +#include namespace mlir { namespace quant { @@ -397,11 +398,12 @@ class UniformQuantizedPerAxisType /// /// Syntax synopsis: /// Per-layer, all parameters expressed: -/// !quant +/// !quant /// Per-layer, optional parameters omitted: -/// !quant +/// !quant /// /// StorageType: 'i'|'u' NumBits +/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64' /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' /// Quantiles: Quantile+ /// Quantile: A legal double value @@ -419,23 +421,30 @@ class QuantileQuantizedType /// Gets an instance of the type with all parameters specified but not /// checked. static QuantileQuantizedType get(unsigned flags, Type storageType, - Type expressedType, + Type quantileType, Type expressedType, ArrayRef quantiles, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax); static QuantileQuantizedType getChecked(function_ref emitError, unsigned flags, - Type storageType, Type expressedType, ArrayRef quantiles, - double scale, int64_t zeroPoint, int64_t storageTypeMin, - int64_t storageTypeMax); + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax); /// Verifies construction invariants and issues errors/warnings. static LogicalResult verify(function_ref emitError, unsigned flags, Type storageType, - Type expressedType, ArrayRef quantiles, - double scale, int64_t zeroPoint, - int64_t storageTypeMin, int64_t storageTypeMax); + Type quantileType, Type expressedType, + ArrayRef quantiles, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax); + + /// Gets the quantileType + Type getQuantileType() const; + + /// Gets the quantileType bit width + unsigned getQuantileTypeIntegralWidth() const; /// Gets the quantile values ArrayRef getQuantiles() const; @@ -453,11 +462,12 @@ class QuantileQuantizedType /// /// Syntax synopsis: /// Per-axis, all parameters expressed: -/// !quant +/// !quant /// Per-axis, optional parameters omitted: -/// !quant +/// !quant /// /// StorageType: 'i'|'u' NumBits +/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64' /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' /// QuantizedDim: An integer value /// Quantiles: Quantile+ @@ -478,7 +488,7 @@ class QuantileQuantizedPerAxisType /// Gets an instance of the type with all parameters specified but not /// checked. static QuantileQuantizedPerAxisType - get(unsigned flags, Type storageType, Type expressedType, + get(unsigned flags, Type storageType, Type quantileType, Type expressedType, ArrayRef quantiles, ArrayRef scales, ArrayRef zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax); @@ -487,19 +497,24 @@ class QuantileQuantizedPerAxisType /// Returns a nullptr convertible type on failure. static QuantileQuantizedPerAxisType getChecked(function_ref emitError, unsigned flags, - Type storageType, Type expressedType, ArrayRef quantiles, - ArrayRef scales, ArrayRef zeroPoints, - int32_t quantizedDimension, int64_t storageTypeMin, - int64_t storageTypeMax); + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax); /// Verifies construction invariants and issues errors/warnings. - static LogicalResult verify(function_ref emitError, - unsigned flags, Type storageType, - Type expressedType, ArrayRef quantiles, - ArrayRef scales, - ArrayRef zeroPoints, - int32_t quantizedDimension, - int64_t storageTypeMin, int64_t storageTypeMax); + static LogicalResult + verify(function_ref emitError, unsigned flags, + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Gets the quantileType + Type getQuantileType() const; + + /// Gets the quantileType bit width + unsigned getQuantileTypeIntegralWidth() const; /// Gets the quantile values ArrayRef getQuantiles() const; diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index b7884a9631e5..23f1a10e49a4 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -379,36 +379,35 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const { } QuantileQuantizedType -QuantileQuantizedType::get(unsigned flags, Type storageType, Type expressedType, - ArrayRef quantiles, double scale, - int64_t zeroPoint, int64_t storageTypeMin, - int64_t storageTypeMax) { - return Base::get(storageType.getContext(), flags, storageType, expressedType, - quantiles, scale, zeroPoint, storageTypeMin, storageTypeMax); +QuantileQuantizedType::get(unsigned flags, Type storageType, Type quantileType, + Type expressedType, ArrayRef quantiles, + double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { + return Base::get(storageType.getContext(), flags, storageType, quantileType, + expressedType, quantiles, scale, zeroPoint, storageTypeMin, + storageTypeMax); } QuantileQuantizedType QuantileQuantizedType::getChecked( function_ref emitError, unsigned flags, - Type storageType, Type expressedType, ArrayRef quantiles, - double scale, int64_t zeroPoint, int64_t storageTypeMin, - int64_t storageTypeMax) { + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { return Base::getChecked(emitError, storageType.getContext(), flags, - storageType, expressedType, quantiles, scale, - zeroPoint, storageTypeMin, storageTypeMax); + storageType, quantileType, expressedType, quantiles, + scale, zeroPoint, storageTypeMin, storageTypeMax); } -LogicalResult -QuantileQuantizedType::verify(function_ref emitError, - unsigned flags, Type storageType, - Type expressedType, ArrayRef quantiles, - double scale, int64_t zeroPoint, - int64_t storageTypeMin, int64_t storageTypeMax) { +LogicalResult QuantileQuantizedType::verify( + function_ref emitError, unsigned flags, + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { if (failed(UniformQuantizedType::verify(emitError, flags, storageType, expressedType, scale, zeroPoint, storageTypeMin, storageTypeMax))) { return failure(); } - const auto quantileArraySize = quantiles.size(); unsigned typeWidth{}; if (storageType.isa()) { typeWidth = llvm::dyn_cast(storageType).getWidth(); @@ -421,10 +420,17 @@ QuantileQuantizedType::verify(function_ref emitError, "types, Float8E4M3FNType and Float8E5M2Type "; } - const size_t expectedSize = 1 << typeWidth; + const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1; + const size_t typeWidthSize = 1 << typeWidth; + const size_t expectedSize = + (storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize; + + const auto quantileArraySize = quantiles.size(); if (quantileArraySize != expectedSize) { return emitError() << "quantiles array size needs to be equal to " - "2^(bit_size(storageType)), expected: " + "2^(bit_size(storageType)), or (storageTypeMax - " + "storageTypeMin + 1) when max and min differ from " + "the type limits; expected: " << expectedSize << ", found: " << quantileArraySize; } @@ -438,38 +444,46 @@ QuantileQuantizedType::verify(function_ref emitError, return success(); } +Type QuantileQuantizedType::getQuantileType() const { + return getImpl()->quantileType; +} + +unsigned QuantileQuantizedType::getQuantileTypeIntegralWidth() const { + return getImpl()->getQuantileType().getIntOrFloatBitWidth(); +} + ArrayRef QuantileQuantizedType::getQuantiles() const { return getImpl()->getQuantiles(); } QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get( - unsigned flags, Type storageType, Type expressedType, + unsigned flags, Type storageType, Type quantileType, Type expressedType, ArrayRef quantiles, ArrayRef scales, ArrayRef zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax) { - return Base::get(storageType.getContext(), flags, storageType, expressedType, - quantiles, scales, zeroPoints, quantizedDimension, - storageTypeMin, storageTypeMax); + return Base::get(storageType.getContext(), flags, storageType, quantileType, + expressedType, quantiles, scales, zeroPoints, + quantizedDimension, storageTypeMin, storageTypeMax); } QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked( function_ref emitError, unsigned flags, - Type storageType, Type expressedType, ArrayRef quantiles, - ArrayRef scales, ArrayRef zeroPoints, - int32_t quantizedDimension, int64_t storageTypeMin, - int64_t storageTypeMax) { + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) { return Base::getChecked(emitError, storageType.getContext(), flags, - storageType, expressedType, quantiles, scales, - zeroPoints, quantizedDimension, storageTypeMin, - storageTypeMax); + storageType, quantileType, expressedType, quantiles, + scales, zeroPoints, quantizedDimension, + storageTypeMin, storageTypeMax); } LogicalResult QuantileQuantizedPerAxisType::verify( function_ref emitError, unsigned flags, - Type storageType, Type expressedType, ArrayRef quantiles, - ArrayRef scales, ArrayRef zeroPoints, - int32_t quantizedDimension, int64_t storageTypeMin, - int64_t storageTypeMax) { + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) { if (failed(UniformQuantizedPerAxisType::verify( emitError, flags, storageType, expressedType, scales, zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax))) { @@ -506,6 +520,14 @@ LogicalResult QuantileQuantizedPerAxisType::verify( return success(); } +Type QuantileQuantizedPerAxisType::getQuantileType() const { + return getImpl()->quantileType; +} + +unsigned QuantileQuantizedPerAxisType::getQuantileTypeIntegralWidth() const { + return getImpl()->getQuantileType().getIntOrFloatBitWidth(); +} + ArrayRef QuantileQuantizedPerAxisType::getQuantiles() const { return getImpl()->getQuantiles(); } diff --git a/mlir/lib/Dialect/Quant/IR/TypeDetail.h b/mlir/lib/Dialect/Quant/IR/TypeDetail.h index fbb71448655a..dc578429734b 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeDetail.h +++ b/mlir/lib/Dialect/Quant/IR/TypeDetail.h @@ -255,15 +255,17 @@ struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage { struct QuantileQuantizedTypeStorage : public UniformQuantizedTypeStorage { struct KeyTy : public UniformQuantizedTypeStorage::KeyTy { - KeyTy(unsigned flags, Type storageType, Type expressedType, - ArrayRef quantiles, double scale, int64_t zeroPoint, - int64_t storageTypeMin, int64_t storageTypeMax) + KeyTy(unsigned flags, Type storageType, Type quantileType, + Type expressedType, ArrayRef quantiles, double scale, + int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) : UniformQuantizedTypeStorage::KeyTy(flags, storageType, expressedType, scale, zeroPoint, storageTypeMin, storageTypeMax), - quantiles(quantiles) {} + quantileType(quantileType), quantiles(quantiles) {} + Type quantileType; ArrayRef quantiles; + Type getQuantileType() const { return quantileType; } ArrayRef getQuantiles() const { return quantiles; } // Check for equality of two structures that share KeyTy data members @@ -271,6 +273,7 @@ struct QuantileQuantizedTypeStorage : public UniformQuantizedTypeStorage { template static bool genericIsEqual(const T &lhs, const U &rhs) { return UniformQuantizedTypeStorage::KeyTy::genericIsEqual(lhs, rhs) && + lhs.getQuantileType() == rhs.getQuantileType() && lhs.getQuantiles() == rhs.getQuantiles(); } @@ -283,14 +286,15 @@ struct QuantileQuantizedTypeStorage : public UniformQuantizedTypeStorage { int64_t *quantilesCast = llvm::bit_cast(quantiles.data()); ArrayRef quantilesBits(quantilesCast, quantiles.size()); return llvm::hash_combine( - flags, storageType, expressedType, + flags, storageType, quantileType, expressedType, llvm::hash_combine_range(quantilesBits.begin(), quantilesBits.end()), scaleBits, zeroPoint, storageTypeMin, storageTypeMax); } }; QuantileQuantizedTypeStorage(const KeyTy &key, ArrayRef quantiles) - : UniformQuantizedTypeStorage(key), quantilesElements(quantiles.data()), + : UniformQuantizedTypeStorage(key), quantileType(key.getQuantileType()), + quantilesElements(quantiles.data()), quantilesParamsSize(quantiles.size()) {} bool operator==(const KeyTy &key) const { @@ -307,10 +311,13 @@ struct QuantileQuantizedTypeStorage : public UniformQuantizedTypeStorage { static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } + Type getQuantileType() const { return quantileType; } + ArrayRef getQuantiles() const { return ArrayRef(quantilesElements, quantilesParamsSize); } + Type quantileType; const double *quantilesElements; unsigned quantilesParamsSize; }; @@ -318,16 +325,19 @@ struct QuantileQuantizedTypeStorage : public UniformQuantizedTypeStorage { struct QuantileQuantizedPerAxisTypeStorage : public UniformQuantizedPerAxisTypeStorage { struct KeyTy : public UniformQuantizedPerAxisTypeStorage::KeyTy { - KeyTy(unsigned flags, Type storageType, Type expressedType, - ArrayRef quantiles, ArrayRef scales, - ArrayRef zeroPoints, int32_t quantizedDimension, - int64_t storageTypeMin, int64_t storageTypeMax) + KeyTy(unsigned flags, Type storageType, Type quantileType, + Type expressedType, ArrayRef quantiles, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax) : UniformQuantizedPerAxisTypeStorage::KeyTy( flags, storageType, expressedType, scales, zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax), - quantiles(quantiles) {} + quantileType(quantileType), quantiles(quantiles) {} + Type quantileType; ArrayRef quantiles; + Type getQuantileType() const { return quantileType; } ArrayRef getQuantiles() const { return quantiles; } // Check for equality of two structures that share KeyTy data members @@ -336,6 +346,7 @@ struct QuantileQuantizedPerAxisTypeStorage static bool genericIsEqual(const T &lhs, const U &rhs) { return UniformQuantizedPerAxisTypeStorage::KeyTy::genericIsEqual(lhs, rhs) && + lhs.getQuantileType() == rhs.getQuantileType() && lhs.getQuantiles() == rhs.getQuantiles(); } @@ -349,7 +360,7 @@ struct QuantileQuantizedPerAxisTypeStorage int64_t *quantilesCast = llvm::bit_cast(quantiles.data()); ArrayRef quantilesBits(quantilesCast, quantiles.size()); return llvm::hash_combine( - flags, storageType, expressedType, + flags, storageType, quantileType, expressedType, llvm::hash_combine_range(quantilesBits.begin(), quantilesBits.end()), llvm::hash_combine_range(scalesBits.begin(), scalesBits.end()), llvm::hash_combine_range(zeroPoints.begin(), zeroPoints.end()), @@ -365,6 +376,7 @@ struct QuantileQuantizedPerAxisTypeStorage ArrayRef scales, ArrayRef zeroPoints) : UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints), + quantileType(key.getQuantileType()), quantilesElements(quantiles.data()), quantilesParamsSize(quantiles.size()) {} @@ -384,10 +396,13 @@ struct QuantileQuantizedPerAxisTypeStorage static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } + Type getQuantileType() const { return quantileType; } + ArrayRef getQuantiles() const { return ArrayRef(quantilesElements, quantilesParamsSize); } + Type quantileType; const double *quantilesElements; unsigned quantilesParamsSize; }; // namespace detail diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index 019d84918699..24525628f3c0 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -72,6 +72,43 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) { return type; } +static Type parseQuantileType(DialectAsmParser &parser) { + auto typeLoc = parser.getCurrentLocation(); + Type type; + + // Parse storage type (alpha_ident, integer_literal). + StringRef identifier; + unsigned storageTypeWidth = 0; + OptionalParseResult result = parser.parseOptionalType(type); + if (result.has_value()) { + if (!succeeded(*result)) + return nullptr; + + if (!type.isa() && !type.isa()) { + parser.emitError(typeLoc, "illegal quantile type alias"); + return nullptr; + } + } else if (succeeded(parser.parseKeyword(&identifier))) { + // Otherwise, this must be an unsigned integer (`u` integer-literal) + if (identifier.consume_front("u")) { + if (identifier.getAsInteger(10, storageTypeWidth)) { + parser.emitError(typeLoc, "expected quantile type width"); + return nullptr; + } + constexpr bool isSigned = false; + type = parser.getBuilder().getIntegerType(storageTypeWidth, isSigned); + + } else { + parser.emitError(typeLoc, "illegal quantile type alias"); + return nullptr; + } + } else { + return nullptr; + } + + return type; +} + static ParseResult checkStorageRange(DialectAsmParser &parser, int64_t storageTypeMin, int64_t storageTypeMax, int64_t defaultStorageTypeMin, @@ -228,20 +265,24 @@ static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, /// /// quantile_type ::= quantile_per_layer /// | quantile_per_axis -/// quantile_per_layer ::= `quantile<` storage-spec expressed-type-spec -/// `,` quantiles-list `,` scale-zero `>` -/// quantile_per_axis ::= `quantile<` storage-spec expressed-type-spec -/// axis-spec `,` quantiles-list scale-zero-list `>` +/// quantile_per_layer ::= `quantile<` storage-spec quantile-type-spec +/// expressed-type-spec `,` quantiles-list `,` +/// scale-zero `>` +/// quantile_per_axis ::= `quantile<` storage-spec quantile-type-spec +/// expressed-type-spec axis-spec `,` quantiles-list +/// scale-zero-list `>` /// storage-spec ::= storage-type (`<` storage-range `>`)? /// storage-range ::= integer-literal `:` integer-literal /// storage-type ::= (`i` | `u`) integer-literal -/// expressed-type-spec ::= `:` `f` integer-literal -/// axis-spec ::= `:` integer-literal -/// quantiles-list ::= `{` quantile (`,` quantile)* `}` +/// quantile-type-spec ::= `:` ((`i` | `u` | `f`) integer-literal | `f8E5M2` | +/// `f8E4M3FN`) +/// expressed-type-spec ::= `:` `f` integer-literal axis-spec ::= +/// `:` integer-literal quantiles-list ::= `{` quantile (`,` quantile)* `}` /// scale-zero ::= `:` float-literal `:` integer-literal /// scale-zero-list ::= `:` `{` scale-zero (`,` scale-zero)* `}` static Type parseUniformType(DialectAsmParser &parser, bool isQuantile) { Type storageType; + Type quantileType; FloatType expressedType; unsigned typeFlags = 0; int64_t storageTypeMin; @@ -273,6 +314,19 @@ static Type parseUniformType(DialectAsmParser &parser, bool isQuantile) { return nullptr; } + // quantile type. + if (isQuantile) { + if (parser.parseColon()) { + return nullptr; + } + quantileType = parseQuantileType(parser); + if (!quantileType) { + return nullptr; + } + // mlir::emitRemark(parser.getEncodedSourceLoc(parser.getCurrentLocation())) + // << "Here Here " << quantileType; + } + // Expressed type. if (parser.parseColon() || parser.parseType(expressedType)) { return nullptr; @@ -353,14 +407,25 @@ static Type parseUniformType(DialectAsmParser &parser, bool isQuantile) { if (isPerAxis) { ArrayRef scalesRef(scales.begin(), scales.end()); ArrayRef zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); + + // mlir::emitRemark(parser.getEncodedSourceLoc(parser.getCurrentLocation())) + // << "storageType: " << storageType + // << ", quantileType: " << quantileType + // << ", expressedType: " << expressedType + // << ", quantilesRef: " << quantilesRef << ", scalesRef: " << + // scalesRef + // << ", zeroPointsRef: " << zeroPointsRef + // << ", quantizedDimension: " << quantizedDimension; + return parser.getChecked( - typeFlags, storageType, expressedType, quantilesRef, scalesRef, - zeroPointsRef, quantizedDimension, storageTypeMin, storageTypeMax); + typeFlags, storageType, quantileType, expressedType, quantilesRef, + scalesRef, zeroPointsRef, quantizedDimension, storageTypeMin, + storageTypeMax); } return parser.getChecked( - typeFlags, storageType, expressedType, quantilesRef, scales.front(), - zeroPoints.front(), storageTypeMin, storageTypeMax); + typeFlags, storageType, quantileType, expressedType, quantilesRef, + scales.front(), zeroPoints.front(), storageTypeMin, storageTypeMax); } if (isPerAxis) { @@ -406,6 +471,7 @@ static Type parseCalibratedType(DialectAsmParser &parser) { /// Parse a type registered to this dialect. Type QuantizationDialect::parseType(DialectAsmParser &parser) const { + // All types start with an identifier that we switch on. StringRef typeNameSpelling; if (failed(parser.parseKeyword(&typeNameSpelling))) @@ -465,6 +531,24 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { } } +static void printQuantileType(Type quantileType, DialectAsmPrinter &out) { + if (auto intType = llvm::dyn_cast(quantileType)) { + const unsigned storageTypeWidth = intType.getWidth(); + if (intType.isSigned()) { + out << ":i" << storageTypeWidth; + } else { + out << ":u" << storageTypeWidth; + } + } else if (quantileType.isa()) { + out << ":f8E5M2"; + } else if (quantileType.isa()) { + out << ":f8E4M3FN"; + } else { + // Float types + out << ":" << quantileType; + } +} + static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out) { out << scale; @@ -523,6 +607,7 @@ static void printQuantileQuantizedType(QuantileQuantizedType type, DialectAsmPrinter &out) { out << "quantile<"; printStorageType(type, out); + printQuantileType(type.getQuantileType(), out); out << ":" << type.getExpressedType() << ", "; // scheme specific parameters @@ -542,6 +627,7 @@ static void printQuantileQuantizedPerAxisType(QuantileQuantizedPerAxisType type, DialectAsmPrinter &out) { out << "quantile<"; printStorageType(type, out); + printQuantileType(type.getQuantileType(), out); out << ":" << type.getExpressedType() << ":"; out << type.getQuantizedDimension(); out << ", "; @@ -578,15 +664,15 @@ static void printCalibratedQuantizedType(CalibratedQuantizedType type, void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { if (auto anyType = llvm::dyn_cast(type)) printAnyQuantizedType(anyType, os); - else if (auto uniformType = llvm::dyn_cast(type)) - printUniformQuantizedType(uniformType, os); - else if (auto perAxisType = llvm::dyn_cast(type)) - printUniformQuantizedPerAxisType(perAxisType, os); else if (auto uniformType = llvm::dyn_cast(type)) printQuantileQuantizedType(uniformType, os); else if (auto perAxisType = llvm::dyn_cast(type)) printQuantileQuantizedPerAxisType(perAxisType, os); + else if (auto uniformType = llvm::dyn_cast(type)) + printUniformQuantizedType(uniformType, os); + else if (auto perAxisType = llvm::dyn_cast(type)) + printUniformQuantizedPerAxisType(perAxisType, os); else if (auto calibratedType = llvm::dyn_cast(type)) printCalibratedQuantizedType(calibratedType, os); else diff --git a/mlir/test/Dialect/Quant/Bytecode/types.mlir b/mlir/test/Dialect/Quant/Bytecode/types.mlir index 4d85b7d758d1..c784a3610e22 100644 --- a/mlir/test/Dialect/Quant/Bytecode/types.mlir +++ b/mlir/test/Dialect/Quant/Bytecode/types.mlir @@ -70,8 +70,8 @@ module @parseUniformPerAxisMixed attributes { // CHECK-LABEL: parseQuantilePerLayer module @parseQuantilePerLayer attributes { - // CHECK: !quant.quantile - bytecode.test = !quant.quantile + // CHECK: !quant.quantile + bytecode.test = !quant.quantile } {} //===----------------------------------------------------------------------===// @@ -81,17 +81,23 @@ module @parseQuantilePerLayer attributes { // CHECK-LABEL: parseQuantilePerAxisScaleZero module @parseQuantilePerAxisScaleZero attributes { // CHECK: !quant.quantile - bytecode.test = !quant.quantile:f32:1, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:{2.000000e+02:-120,9.987200e-01:127}> + bytecode.test = !quant.quantile +} {} + +// CHECK-LABEL: parseQuantilePerAxisScaleZeroU4 +module @parseQuantilePerAxisScaleZeroU4 attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile:f16:f32:1, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:{2.000000e+02:-120,9.987200e-01:127}> } {} // CHECK-LABEL: parseQuantilePerAxisScaleNoZero module @parseQuantilePerAxisScaleNoZero attributes { // CHECK: !quant.quantile - bytecode.test = !quant.quantile + bytecode.test = !quant.quantile } {} // CHECK-LABEL: parseQuantilePerAxisMixed module @parseQuantilePerAxisMixed attributes { // CHECK: !quant.quantile - bytecode.test = !quant.quantile + bytecode.test = !quant.quantile } {} diff --git a/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir b/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir index 8acfa2a587c1..8a1e9927d168 100644 --- a/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir +++ b/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir @@ -1,9 +1,27 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics +// ----- +// Illegal missing quantileType +// expected-error@+1 {{expected ':'}} +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Illegal quantileType value +// expected-error@+1 {{illegal quantile type alias}} +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + // ----- // Illegal quantile array size -// expected-error@+1 {{quantiles array size needs to be equal to 2^(bit_size(storageType)), expected: 256, found: 2}} -!qalias = !quant.quantile +// expected-error@+1 {{quantiles array size needs to be equal to 2^(bit_size(storageType)), or (storageTypeMax - storageTypeMin + 1) when max and min differ from the type limits; expected: 256, found: 2}} +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -12,155 +30,155 @@ func.func @parse() -> !qalias { // ----- // Unrecognized token: trailing // expected-error@+1 {{expected '>'}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127 23> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127 23> // ----- // Unrecognized token: missing storage type maximum // expected-error@+1 {{expected ':'}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Unrecognized token: missing closing angle bracket // expected-error@+1 {{unbalanced '<' character in pretty dialect name}} -!qalias = !quant> +!qalias = !quant> // ----- // Unrecognized token: missing type colon // expected-error@+1 {{expected ':'}} -!qalias = !quant.quantilef32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantilef16:f32, {-1.0,1.0}:0.99872:127> // ----- // Unrecognized token: missing comma // expected-error@+1 {{expected ','}} -!qalias = !quant.quantile:f32 {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile // ----- // Unrecognized storage type: illegal prefix // expected-error@+1 {{illegal quantized storage type alias}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Unrecognized storage type: no width // expected-error@+1 {{illegal quantized storage type alias}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Unrecognized storage type: storage size > 32 // expected-error@+1 {{illegal storage type size: 33}} -!qalias = !quant.quantile +!qalias = !quant.quantile // ----- // Unrecognized storage type: storage size < 0 // expected-error@+1 {{illegal quantized storage type alias}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Unrecognized storage type: storage size // expected-error@+1 {{invalid integer width}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Illegal storage min/max: max - min < 0 // expected-error@+1 {{illegal storage min and storage max: (2:1)}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Illegal storage min/max: max - min == 0 // expected-error@+1 {{illegal storage min and storage max: (1:1)}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Illegal storage min/max: max > defaultMax // expected-error@+1 {{illegal storage type maximum: 9}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Illegal storage min/max: min < defaultMin // expected-error@+1 {{illegal storage type minimum: -9}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Illegal storage min/max: max > defaultMax // expected-error@+1 {{illegal storage type maximum: 60000}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Illegal storage min/max: min < defaultMin // expected-error@+1 {{illegal storage type minimum: -60000}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Illegal storage min/max: max > defaultMax // expected-error@+1 {{illegal storage type maximum: 500}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Illegal storage min/max: min < defaultMin // expected-error@+1 {{illegal storage type minimum: -500}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> // ----- // Illegal uniform params: invalid scale // expected-error@+1 {{expected floating point literal}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:abc:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:abc:127> // ----- // Illegal uniform params: invalid zero point separator // expected-error@+1 {{expected '>'}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.1abc> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.1abc> // ----- // Illegal uniform params: missing zero point // expected-error@+1 {{expected integer value}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.1:> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.1:> // ----- // Illegal uniform params: invalid zero point // expected-error@+1 {{expected integer value}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:0.1:abc> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.1:abc> // ----- // Illegal expressed type: f33 // expected-error@+1 {{expected non-function type}} -!qalias = !quant.quantile:f33, {-1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f33, {-1.0,1.0}:0.99872:127> // ----- // Illegal scale: negative // expected-error@+1 {{illegal scale: -1.000000}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:-1.0:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:-1.0:127> // ----- // Illegal uniform params: missing quantized dimension // expected-error@+1 {{expected integer value}} -!qalias = !quant.quantile:f32:, {-1.0,1.0}:{2.000000e+02:-19.987200e-01:1}> +!qalias = !quant.quantile:f16:f32:, {-1.0,1.0}:{2.000000e+02:-19.987200e-01:1}> // ----- // Illegal uniform params: unspecified quantized dimension, when multiple scales // provided. // expected-error@+1 {{expected floating point literal}} -!qalias = !quant.quantile:f32, {-1.0,1.0}:{2.000000e+02,-19.987200e-01:1}> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:{2.000000e+02,-19.987200e-01:1}> // ----- // Illegal quantile params: unspecified quantile values // expected-error@+1 {{expected floating point literal}} -!qalias = !quant.quantile:f32, {}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {}:0.99872:127> // ----- // Illegal quantile params: missing quantile values // expected-error@+1 {{expected floating point literal}} -!qalias = !quant.quantile:f32, {-1.0,}:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,}:0.99872:127> // ----- // Illegal quantile params: missing colon separator // expected-error@+1 {{expected ':'}} -!qalias = !quant.quantile:f32, {-1.0,1.0}0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}0.99872:127> // ----- // Illegal quantile params: unbalanced } // expected-error@+1 {{unbalanced '{' character in pretty dialect name}} -!qalias = !quant.quantile:f32, {-1.0,1.0:0.99872:127> +!qalias = !quant.quantile:f16:f32, {-1.0,1.0:0.99872:127> // ----- // Illegal quantile params: missing { // expected-error@+1 {{unbalanced '<' character in pretty dialect name}} -!qalias = !quant.quantile:f32, -1.0,1.0}:0.99872:127> +!qalias = !quant.quantile:f16:f32, -1.0,1.0}:0.99872:127> diff --git a/mlir/test/Dialect/Quant/parse-quantile.mlir b/mlir/test/Dialect/Quant/parse-quantile.mlir index 0c5847c6b681..a239e0ca4fab 100644 --- a/mlir/test/Dialect/Quant/parse-quantile.mlir +++ b/mlir/test/Dialect/Quant/parse-quantile.mlir @@ -3,8 +3,8 @@ // ----- // All per-layer params specified: // [signed] storageType, storageTypeMin, storageTypeMax, expressedType, scale, zeroPoint -// CHECK: !quant.quantile:f32, {-1.000000e+00,-0.99219999999999997,-0.98429999999999995,-9.765000e-01,-9.686000e-01,-0.96079999999999998,-9.529000e-01,-9.451000e-01,-9.373000e-01,-9.294000e-01,-0.92159999999999997,-0.91369999999999995,-9.059000e-01,-8.980000e-01,-8.902000e-01,-8.824000e-01,-8.745000e-01,-8.667000e-01,-8.588000e-01,-8.510000e-01,-0.84309999999999996,-8.353000e-01,-8.275000e-01,-8.196000e-01,-8.118000e-01,-8.039000e-01,-7.961000e-01,-7.882000e-01,-0.78039999999999998,-7.725000e-01,-7.647000e-01,-7.569000e-01,-7.490000e-01,-7.412000e-01,-7.333000e-01,-7.255000e-01,-7.176000e-01,-7.098000e-01,-0.70199999999999996,-6.941000e-01,-6.863000e-01,-6.784000e-01,-6.706000e-01,-6.627000e-01,-6.549000e-01,-6.471000e-01,-0.63919999999999999,-0.63139999999999996,-6.235000e-01,-6.157000e-01,-6.078000e-01,-6.000000e-01,-5.922000e-01,-5.843000e-01,-5.765000e-01,-5.686000e-01,-5.608000e-01,-5.529000e-01,-5.451000e-01,-5.373000e-01,-5.294000e-01,-5.216000e-01,-5.137000e-01,-5.059000e-01,-4.980000e-01,-4.902000e-01,-4.824000e-01,-4.745000e-01,-4.667000e-01,-4.588000e-01,-4.510000e-01,-4.431000e-01,-4.353000e-01,-4.275000e-01,-4.196000e-01,-4.118000e-01,-4.039000e-01,-3.961000e-01,-3.882000e-01,-3.804000e-01,-3.725000e-01,-3.647000e-01,-3.569000e-01,-3.490000e-01,-3.412000e-01,-3.333000e-01,-3.255000e-01,-3.176000e-01,-3.098000e-01,-3.020000e-01,-2.941000e-01,-2.863000e-01,-2.784000e-01,-2.706000e-01,-2.627000e-01,-2.549000e-01,-2.471000e-01,-2.392000e-01,-2.314000e-01,-2.235000e-01,-2.157000e-01,-2.078000e-01,-2.000000e-01,-1.922000e-01,-1.843000e-01,-1.765000e-01,-1.686000e-01,-1.608000e-01,-1.529000e-01,-1.451000e-01,-1.373000e-01,-1.294000e-01,-1.216000e-01,-1.137000e-01,-1.059000e-01,-9.800000e-02,-9.020000e-02,-8.240000e-02,-0.074499999999999997,-0.066699999999999995,-5.880000e-02,-5.100000e-02,-4.310000e-02,-3.530000e-02,-2.750000e-02,-1.960000e-02,-1.180000e-02,-3.900000e-03,3.900000e-03,1.180000e-02,1.960000e-02,2.750000e-02,3.530000e-02,4.310000e-02,5.100000e-02,5.880000e-02,0.066699999999999995,0.074499999999999997,8.240000e-02,9.020000e-02,9.800000e-02,1.059000e-01,1.137000e-01,1.216000e-01,1.294000e-01,1.373000e-01,1.451000e-01,1.529000e-01,1.608000e-01,1.686000e-01,1.765000e-01,1.843000e-01,1.922000e-01,2.000000e-01,2.078000e-01,2.157000e-01,2.235000e-01,2.314000e-01,2.392000e-01,2.471000e-01,2.549000e-01,2.627000e-01,2.706000e-01,2.784000e-01,2.863000e-01,2.941000e-01,3.020000e-01,3.098000e-01,3.176000e-01,3.255000e-01,3.333000e-01,3.412000e-01,3.490000e-01,3.569000e-01,3.647000e-01,3.725000e-01,3.804000e-01,3.882000e-01,3.961000e-01,4.039000e-01,4.118000e-01,4.196000e-01,4.275000e-01,4.353000e-01,4.431000e-01,4.510000e-01,4.588000e-01,4.667000e-01,4.745000e-01,4.824000e-01,4.902000e-01,4.980000e-01,5.059000e-01,5.137000e-01,5.216000e-01,5.294000e-01,5.373000e-01,5.451000e-01,5.529000e-01,5.608000e-01,5.686000e-01,5.765000e-01,5.843000e-01,5.922000e-01,6.000000e-01,6.078000e-01,6.157000e-01,6.235000e-01,0.63139999999999996,0.63919999999999999,6.471000e-01,6.549000e-01,6.627000e-01,6.706000e-01,6.784000e-01,6.863000e-01,6.941000e-01,0.70199999999999996,7.098000e-01,7.176000e-01,7.255000e-01,7.333000e-01,7.412000e-01,7.490000e-01,7.569000e-01,7.647000e-01,7.725000e-01,0.78039999999999998,7.882000e-01,7.961000e-01,8.039000e-01,8.118000e-01,8.196000e-01,8.275000e-01,8.353000e-01,0.84309999999999996,8.510000e-01,8.588000e-01,8.667000e-01,8.745000e-01,8.824000e-01,8.902000e-01,8.980000e-01,9.059000e-01,0.91369999999999995,0.92159999999999997,9.294000e-01,9.373000e-01,9.451000e-01,9.529000e-01,0.96079999999999998,9.686000e-01,9.765000e-01,0.98429999999999995,0.99219999999999997,1.000000e+00}:9.987200e-01:127> -!qalias = !quant.quantile:f32, {-1.0000,-0.9922,-0.9843,-0.9765,-0.9686,-0.9608,-0.9529,-0.9451,-0.9373,-0.9294,-0.9216,-0.9137,-0.9059,-0.8980,-0.8902,-0.8824,-0.8745,-0.8667,-0.8588,-0.8510,-0.8431,-0.8353,-0.8275,-0.8196,-0.8118,-0.8039,-0.7961,-0.7882,-0.7804,-0.7725,-0.7647,-0.7569,-0.7490,-0.7412,-0.7333,-0.7255,-0.7176,-0.7098,-0.7020,-0.6941,-0.6863,-0.6784,-0.6706,-0.6627,-0.6549,-0.6471,-0.6392,-0.6314,-0.6235,-0.6157,-0.6078,-0.6000,-0.5922,-0.5843,-0.5765,-0.5686,-0.5608,-0.5529,-0.5451,-0.5373,-0.5294,-0.5216,-0.5137,-0.5059,-0.4980,-0.4902,-0.4824,-0.4745,-0.4667,-0.4588,-0.4510,-0.4431,-0.4353,-0.4275,-0.4196,-0.4118,-0.4039,-0.3961,-0.3882,-0.3804,-0.3725,-0.3647,-0.3569,-0.3490,-0.3412,-0.3333,-0.3255,-0.3176,-0.3098,-0.3020,-0.2941,-0.2863,-0.2784,-0.2706,-0.2627,-0.2549,-0.2471,-0.2392,-0.2314,-0.2235,-0.2157,-0.2078,-0.2000,-0.1922,-0.1843,-0.1765,-0.1686,-0.1608,-0.1529,-0.1451,-0.1373,-0.1294,-0.1216,-0.1137,-0.1059,-0.0980,-0.0902,-0.0824,-0.0745,-0.0667,-0.0588,-0.0510,-0.0431,-0.0353,-0.0275,-0.0196,-0.0118,-0.0039,0.0039,0.0118,0.0196,0.0275,0.0353,0.0431,0.0510,0.0588,0.0667,0.0745,0.0824,0.0902,0.0980,0.1059,0.1137,0.1216,0.1294,0.1373,0.1451,0.1529,0.1608,0.1686,0.1765,0.1843,0.1922,0.2000,0.2078,0.2157,0.2235,0.2314,0.2392,0.2471,0.2549,0.2627,0.2706,0.2784,0.2863,0.2941,0.3020,0.3098,0.3176,0.3255,0.3333,0.3412,0.3490,0.3569,0.3647,0.3725,0.3804,0.3882,0.3961,0.4039,0.4118,0.4196,0.4275,0.4353,0.4431,0.4510,0.4588,0.4667,0.4745,0.4824,0.4902,0.4980,0.5059,0.5137,0.5216,0.5294,0.5373,0.5451,0.5529,0.5608,0.5686,0.5765,0.5843,0.5922,0.6000,0.6078,0.6157,0.6235,0.6314,0.6392,0.6471,0.6549,0.6627,0.6706,0.6784,0.6863,0.6941,0.7020,0.7098,0.7176,0.7255,0.7333,0.7412,0.7490,0.7569,0.7647,0.7725,0.7804,0.7882,0.7961,0.8039,0.8118,0.8196,0.8275,0.8353,0.8431,0.8510,0.8588,0.8667,0.8745,0.8824,0.8902,0.8980,0.9059,0.9137,0.9216,0.9294,0.9373,0.9451,0.9529,0.9608,0.9686,0.9765,0.9843,0.9922,1.0000}:0.99872:127> +// CHECK: !quant.quantile:f16:f32, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-0.066699999999999995,0.066699999999999995,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}:9.987200e-01:127> +!qalias = !quant.quantile:f16:f32, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:0.99872:127> func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -12,8 +12,8 @@ func.func @parse() -> !qalias { // ----- // Trailing whitespace. -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -21,8 +21,8 @@ func.func @parse() -> !qalias { // ----- // Default min/max value optimization for integers. -// CHECK: !quant.quantile -!qalias = !quant.quantile:f32, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:0.99872:127 > +// CHECK: !quant.quantile +!qalias = !quant.quantile:f16:f32, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:0.99872:127 > func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -30,8 +30,8 @@ func.func @parse() -> !qalias { // ----- // Default min/max value optimization for f8E5M2. -// CHECK: !quant.quantile -!qalias = !quant.quantile:f32, {-1.0000,-0.9922,-0.9843,-0.9765,-0.9686,-0.9608,-0.9529,-0.9451,-0.9373,-0.9294,-0.9216,-0.9137,-0.9059,-0.8980,-0.8902,-0.8824,-0.8745,-0.8667,-0.8588,-0.8510,-0.8431,-0.8353,-0.8275,-0.8196,-0.8118,-0.8039,-0.7961,-0.7882,-0.7804,-0.7725,-0.7647,-0.7569,-0.7490,-0.7412,-0.7333,-0.7255,-0.7176,-0.7098,-0.7020,-0.6941,-0.6863,-0.6784,-0.6706,-0.6627,-0.6549,-0.6471,-0.6392,-0.6314,-0.6235,-0.6157,-0.6078,-0.6000,-0.5922,-0.5843,-0.5765,-0.5686,-0.5608,-0.5529,-0.5451,-0.5373,-0.5294,-0.5216,-0.5137,-0.5059,-0.4980,-0.4902,-0.4824,-0.4745,-0.4667,-0.4588,-0.4510,-0.4431,-0.4353,-0.4275,-0.4196,-0.4118,-0.4039,-0.3961,-0.3882,-0.3804,-0.3725,-0.3647,-0.3569,-0.3490,-0.3412,-0.3333,-0.3255,-0.3176,-0.3098,-0.3020,-0.2941,-0.2863,-0.2784,-0.2706,-0.2627,-0.2549,-0.2471,-0.2392,-0.2314,-0.2235,-0.2157,-0.2078,-0.2000,-0.1922,-0.1843,-0.1765,-0.1686,-0.1608,-0.1529,-0.1451,-0.1373,-0.1294,-0.1216,-0.1137,-0.1059,-0.0980,-0.0902,-0.0824,-0.0745,-0.0667,-0.0588,-0.0510,-0.0431,-0.0353,-0.0275,-0.0196,-0.0118,-0.0039,0.0039,0.0118,0.0196,0.0275,0.0353,0.0431,0.0510,0.0588,0.0667,0.0745,0.0824,0.0902,0.0980,0.1059,0.1137,0.1216,0.1294,0.1373,0.1451,0.1529,0.1608,0.1686,0.1765,0.1843,0.1922,0.2000,0.2078,0.2157,0.2235,0.2314,0.2392,0.2471,0.2549,0.2627,0.2706,0.2784,0.2863,0.2941,0.3020,0.3098,0.3176,0.3255,0.3333,0.3412,0.3490,0.3569,0.3647,0.3725,0.3804,0.3882,0.3961,0.4039,0.4118,0.4196,0.4275,0.4353,0.4431,0.4510,0.4588,0.4667,0.4745,0.4824,0.4902,0.4980,0.5059,0.5137,0.5216,0.5294,0.5373,0.5451,0.5529,0.5608,0.5686,0.5765,0.5843,0.5922,0.6000,0.6078,0.6157,0.6235,0.6314,0.6392,0.6471,0.6549,0.6627,0.6706,0.6784,0.6863,0.6941,0.7020,0.7098,0.7176,0.7255,0.7333,0.7412,0.7490,0.7569,0.7647,0.7725,0.7804,0.7882,0.7961,0.8039,0.8118,0.8196,0.8275,0.8353,0.8431,0.8510,0.8588,0.8667,0.8745,0.8824,0.8902,0.8980,0.9059,0.9137,0.9216,0.9294,0.9373,0.9451,0.9529,0.9608,0.9686,0.9765,0.9843,0.9922,1.0000}:0.99872:127 > +// CHECK: !quant.quantile +!qalias = !quant.quantile:f16:f32, {-1.0000,-0.9922,-0.9843,-0.9765,-0.9686,-0.9608,-0.9529,-0.9451,-0.9373,-0.9294,-0.9216,-0.9137,-0.9059,-0.8980,-0.8902,-0.8824,-0.8745,-0.8667,-0.8588,-0.8510,-0.8431,-0.8353,-0.8275,-0.8196,-0.8118,-0.8039,-0.7961,-0.7882,-0.7804,-0.7725,-0.7647,-0.7569,-0.7490,-0.7412,-0.7333,-0.7255,-0.7176,-0.7098,-0.7020,-0.6941,-0.6863,-0.6784,-0.6706,-0.6627,-0.6549,-0.6471,-0.6392,-0.6314,-0.6235,-0.6157,-0.6078,-0.6000,-0.5922,-0.5843,-0.5765,-0.5686,-0.5608,-0.5529,-0.5451,-0.5373,-0.5294,-0.5216,-0.5137,-0.5059,-0.4980,-0.4902,-0.4824,-0.4745,-0.4667,-0.4588,-0.4510,-0.4431,-0.4353,-0.4275,-0.4196,-0.4118,-0.4039,-0.3961,-0.3882,-0.3804,-0.3725,-0.3647,-0.3569,-0.3490,-0.3412,-0.3333,-0.3255,-0.3176,-0.3098,-0.3020,-0.2941,-0.2863,-0.2784,-0.2706,-0.2627,-0.2549,-0.2471,-0.2392,-0.2314,-0.2235,-0.2157,-0.2078,-0.2000,-0.1922,-0.1843,-0.1765,-0.1686,-0.1608,-0.1529,-0.1451,-0.1373,-0.1294,-0.1216,-0.1137,-0.1059,-0.0980,-0.0902,-0.0824,-0.0745,-0.0667,-0.0588,-0.0510,-0.0431,-0.0353,-0.0275,-0.0196,-0.0118,-0.0039,0.0039,0.0118,0.0196,0.0275,0.0353,0.0431,0.0510,0.0588,0.0667,0.0745,0.0824,0.0902,0.0980,0.1059,0.1137,0.1216,0.1294,0.1373,0.1451,0.1529,0.1608,0.1686,0.1765,0.1843,0.1922,0.2000,0.2078,0.2157,0.2235,0.2314,0.2392,0.2471,0.2549,0.2627,0.2706,0.2784,0.2863,0.2941,0.3020,0.3098,0.3176,0.3255,0.3333,0.3412,0.3490,0.3569,0.3647,0.3725,0.3804,0.3882,0.3961,0.4039,0.4118,0.4196,0.4275,0.4353,0.4431,0.4510,0.4588,0.4667,0.4745,0.4824,0.4902,0.4980,0.5059,0.5137,0.5216,0.5294,0.5373,0.5451,0.5529,0.5608,0.5686,0.5765,0.5843,0.5922,0.6000,0.6078,0.6157,0.6235,0.6314,0.6392,0.6471,0.6549,0.6627,0.6706,0.6784,0.6863,0.6941,0.7020,0.7098,0.7176,0.7255,0.7333,0.7412,0.7490,0.7569,0.7647,0.7725,0.7804,0.7882,0.7961,0.8039,0.8118,0.8196,0.8275,0.8353,0.8431,0.8510,0.8588,0.8667,0.8745,0.8824,0.8902,0.8980,0.9059,0.9137,0.9216,0.9294,0.9373,0.9451,0.9529,0.9608,0.9686,0.9765,0.9843,0.9922,1.0000}:0.99872:127 > func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -39,8 +39,8 @@ func.func @parse() -> !qalias { // ----- // Default min/max value optimization for f8E4M3FN. -// CHECK: !quant.quantile -!qalias = !quant.quantile:f32, {-1.0000,-0.9922,-0.9843,-0.9765,-0.9686,-0.9608,-0.9529,-0.9451,-0.9373,-0.9294,-0.9216,-0.9137,-0.9059,-0.8980,-0.8902,-0.8824,-0.8745,-0.8667,-0.8588,-0.8510,-0.8431,-0.8353,-0.8275,-0.8196,-0.8118,-0.8039,-0.7961,-0.7882,-0.7804,-0.7725,-0.7647,-0.7569,-0.7490,-0.7412,-0.7333,-0.7255,-0.7176,-0.7098,-0.7020,-0.6941,-0.6863,-0.6784,-0.6706,-0.6627,-0.6549,-0.6471,-0.6392,-0.6314,-0.6235,-0.6157,-0.6078,-0.6000,-0.5922,-0.5843,-0.5765,-0.5686,-0.5608,-0.5529,-0.5451,-0.5373,-0.5294,-0.5216,-0.5137,-0.5059,-0.4980,-0.4902,-0.4824,-0.4745,-0.4667,-0.4588,-0.4510,-0.4431,-0.4353,-0.4275,-0.4196,-0.4118,-0.4039,-0.3961,-0.3882,-0.3804,-0.3725,-0.3647,-0.3569,-0.3490,-0.3412,-0.3333,-0.3255,-0.3176,-0.3098,-0.3020,-0.2941,-0.2863,-0.2784,-0.2706,-0.2627,-0.2549,-0.2471,-0.2392,-0.2314,-0.2235,-0.2157,-0.2078,-0.2000,-0.1922,-0.1843,-0.1765,-0.1686,-0.1608,-0.1529,-0.1451,-0.1373,-0.1294,-0.1216,-0.1137,-0.1059,-0.0980,-0.0902,-0.0824,-0.0745,-0.0667,-0.0588,-0.0510,-0.0431,-0.0353,-0.0275,-0.0196,-0.0118,-0.0039,0.0039,0.0118,0.0196,0.0275,0.0353,0.0431,0.0510,0.0588,0.0667,0.0745,0.0824,0.0902,0.0980,0.1059,0.1137,0.1216,0.1294,0.1373,0.1451,0.1529,0.1608,0.1686,0.1765,0.1843,0.1922,0.2000,0.2078,0.2157,0.2235,0.2314,0.2392,0.2471,0.2549,0.2627,0.2706,0.2784,0.2863,0.2941,0.3020,0.3098,0.3176,0.3255,0.3333,0.3412,0.3490,0.3569,0.3647,0.3725,0.3804,0.3882,0.3961,0.4039,0.4118,0.4196,0.4275,0.4353,0.4431,0.4510,0.4588,0.4667,0.4745,0.4824,0.4902,0.4980,0.5059,0.5137,0.5216,0.5294,0.5373,0.5451,0.5529,0.5608,0.5686,0.5765,0.5843,0.5922,0.6000,0.6078,0.6157,0.6235,0.6314,0.6392,0.6471,0.6549,0.6627,0.6706,0.6784,0.6863,0.6941,0.7020,0.7098,0.7176,0.7255,0.7333,0.7412,0.7490,0.7569,0.7647,0.7725,0.7804,0.7882,0.7961,0.8039,0.8118,0.8196,0.8275,0.8353,0.8431,0.8510,0.8588,0.8667,0.8745,0.8824,0.8902,0.8980,0.9059,0.9137,0.9216,0.9294,0.9373,0.9451,0.9529,0.9608,0.9686,0.9765,0.9843,0.9922,1.0000}:0.99872:127 > +// CHECK: !quant.quantile +!qalias = !quant.quantile:f16:f32, {-1.0000,-0.9922,-0.9843,-0.9765,-0.9686,-0.9608,-0.9529,-0.9451,-0.9373,-0.9294,-0.9216,-0.9137,-0.9059,-0.8980,-0.8902,-0.8824,-0.8745,-0.8667,-0.8588,-0.8510,-0.8431,-0.8353,-0.8275,-0.8196,-0.8118,-0.8039,-0.7961,-0.7882,-0.7804,-0.7725,-0.7647,-0.7569,-0.7490,-0.7412,-0.7333,-0.7255,-0.7176,-0.7098,-0.7020,-0.6941,-0.6863,-0.6784,-0.6706,-0.6627,-0.6549,-0.6471,-0.6392,-0.6314,-0.6235,-0.6157,-0.6078,-0.6000,-0.5922,-0.5843,-0.5765,-0.5686,-0.5608,-0.5529,-0.5451,-0.5373,-0.5294,-0.5216,-0.5137,-0.5059,-0.4980,-0.4902,-0.4824,-0.4745,-0.4667,-0.4588,-0.4510,-0.4431,-0.4353,-0.4275,-0.4196,-0.4118,-0.4039,-0.3961,-0.3882,-0.3804,-0.3725,-0.3647,-0.3569,-0.3490,-0.3412,-0.3333,-0.3255,-0.3176,-0.3098,-0.3020,-0.2941,-0.2863,-0.2784,-0.2706,-0.2627,-0.2549,-0.2471,-0.2392,-0.2314,-0.2235,-0.2157,-0.2078,-0.2000,-0.1922,-0.1843,-0.1765,-0.1686,-0.1608,-0.1529,-0.1451,-0.1373,-0.1294,-0.1216,-0.1137,-0.1059,-0.0980,-0.0902,-0.0824,-0.0745,-0.0667,-0.0588,-0.0510,-0.0431,-0.0353,-0.0275,-0.0196,-0.0118,-0.0039,0.0039,0.0118,0.0196,0.0275,0.0353,0.0431,0.0510,0.0588,0.0667,0.0745,0.0824,0.0902,0.0980,0.1059,0.1137,0.1216,0.1294,0.1373,0.1451,0.1529,0.1608,0.1686,0.1765,0.1843,0.1922,0.2000,0.2078,0.2157,0.2235,0.2314,0.2392,0.2471,0.2549,0.2627,0.2706,0.2784,0.2863,0.2941,0.3020,0.3098,0.3176,0.3255,0.3333,0.3412,0.3490,0.3569,0.3647,0.3725,0.3804,0.3882,0.3961,0.4039,0.4118,0.4196,0.4275,0.4353,0.4431,0.4510,0.4588,0.4667,0.4745,0.4824,0.4902,0.4980,0.5059,0.5137,0.5216,0.5294,0.5373,0.5451,0.5529,0.5608,0.5686,0.5765,0.5843,0.5922,0.6000,0.6078,0.6157,0.6235,0.6314,0.6392,0.6471,0.6549,0.6627,0.6706,0.6784,0.6863,0.6941,0.7020,0.7098,0.7176,0.7255,0.7333,0.7412,0.7490,0.7569,0.7647,0.7725,0.7804,0.7882,0.7961,0.8039,0.8118,0.8196,0.8275,0.8353,0.8431,0.8510,0.8588,0.8667,0.8745,0.8824,0.8902,0.8980,0.9059,0.9137,0.9216,0.9294,0.9373,0.9451,0.9529,0.9608,0.9686,0.9765,0.9843,0.9922,1.0000}:0.99872:127 > func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -49,8 +49,8 @@ func.func @parse() -> !qalias { // ----- // Required per-layer params specified: // [unsigned] storageType, expressedType, scale -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -58,8 +58,8 @@ func.func @parse() -> !qalias { // ----- // Exponential scale (-) -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -67,8 +67,8 @@ func.func @parse() -> !qalias { // ----- // Exponential scale (+) -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -76,8 +76,8 @@ func.func @parse() -> !qalias { // ----- // Storage type: f8E5M2 -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -85,8 +85,8 @@ func.func @parse() -> !qalias { // ----- // Storage type: f8E4M3FN -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -94,8 +94,8 @@ func.func @parse() -> !qalias { // ----- // Expressed type: f32 -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -103,8 +103,8 @@ func.func @parse() -> !qalias { // ----- // Expressed type: f32 -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -112,8 +112,8 @@ func.func @parse() -> !qalias { // ----- // Expressed type: f16 -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -121,8 +121,8 @@ func.func @parse() -> !qalias { // ----- // Expressed type: f64 -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -130,8 +130,8 @@ func.func @parse() -> !qalias { // ----- // Expressed type: bf16 -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -139,8 +139,8 @@ func.func @parse() -> !qalias { // ----- // Per-axis scales and zero points (affine) -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -148,8 +148,8 @@ func.func @parse() -> !qalias { // ----- // Per-axis scales and no zero points (fixedpoint) -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias @@ -157,8 +157,8 @@ func.func @parse() -> !qalias { // ----- // Per-axis scales and zero points (mixed affine and fixedpoint) -// CHECK: !quant.quantile -!qalias = !quant.quantile +// CHECK: !quant.quantile +!qalias = !quant.quantile func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias From d11755cc9467df857b73a5a46d964f1307e82a8a Mon Sep 17 00:00:00 2001 From: Luca Sarti Date: Mon, 30 Sep 2024 16:00:30 +0000 Subject: [PATCH 2/5] Fixing casting problem with Quantile types --- mlir/include/mlir/Dialect/Quant/QuantTypes.h | 8 ++++++++ mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h index a53a342fe52a..116611894b9b 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -297,6 +297,8 @@ class UniformQuantizedType int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax); + static bool classof(mlir::Type type); + /// Gets the scale term. The scale designates the difference between the real /// values corresponding to consecutive quantized values differing by 1. double getScale() const; @@ -360,6 +362,8 @@ class UniformQuantizedPerAxisType int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax); + static bool classof(mlir::Type type); + /// Gets the quantization scales. The scales designate the difference between /// the real values corresponding to consecutive quantized values differing /// by 1. The ith scale corresponds to the ith slice in the @@ -440,6 +444,8 @@ class QuantileQuantizedType int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax); + static bool classof(mlir::Type type); + /// Gets the quantileType Type getQuantileType() const; @@ -510,6 +516,8 @@ class QuantileQuantizedPerAxisType ArrayRef zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax); + static bool classof(mlir::Type type); + /// Gets the quantileType Type getQuantileType() const; diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 23f1a10e49a4..4967adec41f1 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/MathExtras.h" @@ -305,6 +306,11 @@ LogicalResult UniformQuantizedType::verify( return success(); } +bool UniformQuantizedType::classof(mlir::Type type) { + return type.getTypeID() == mlir::TypeID::get() || + type.getTypeID() == mlir::TypeID::get(); +} + double UniformQuantizedType::getScale() const { return getImpl()->scale; } int64_t UniformQuantizedType::getZeroPoint() const { @@ -366,6 +372,11 @@ LogicalResult UniformQuantizedPerAxisType::verify( return success(); } +bool UniformQuantizedPerAxisType::classof(mlir::Type type) { + return type.getTypeID() == mlir::TypeID::get() || + type.getTypeID() == mlir::TypeID::get(); +} + ArrayRef UniformQuantizedPerAxisType::getScales() const { return getImpl()->getScales(); } @@ -444,6 +455,10 @@ LogicalResult QuantileQuantizedType::verify( return success(); } +bool QuantileQuantizedType::classof(mlir::Type type) { + return type.getTypeID() == mlir::TypeID::get(); +} + Type QuantileQuantizedType::getQuantileType() const { return getImpl()->quantileType; } @@ -520,6 +535,10 @@ LogicalResult QuantileQuantizedPerAxisType::verify( return success(); } +bool QuantileQuantizedPerAxisType::classof(mlir::Type type) { + return type.getTypeID() == mlir::TypeID::get(); +} + Type QuantileQuantizedPerAxisType::getQuantileType() const { return getImpl()->quantileType; } From aa669b20c5d77a9d087cf3360fa22155fc06750e Mon Sep 17 00:00:00 2001 From: Luca Sarti Date: Mon, 30 Sep 2024 16:19:52 +0000 Subject: [PATCH 3/5] Minor fixes to types.mlir --- mlir/test/Dialect/Quant/Bytecode/types.mlir | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/test/Dialect/Quant/Bytecode/types.mlir b/mlir/test/Dialect/Quant/Bytecode/types.mlir index c784a3610e22..c797572056f7 100644 --- a/mlir/test/Dialect/Quant/Bytecode/types.mlir +++ b/mlir/test/Dialect/Quant/Bytecode/types.mlir @@ -70,7 +70,7 @@ module @parseUniformPerAxisMixed attributes { // CHECK-LABEL: parseQuantilePerLayer module @parseQuantilePerLayer attributes { - // CHECK: !quant.quantile + // CHECK: !quant.quantile bytecode.test = !quant.quantile } {} @@ -80,24 +80,24 @@ module @parseQuantilePerLayer attributes { // CHECK-LABEL: parseQuantilePerAxisScaleZero module @parseQuantilePerAxisScaleZero attributes { - // CHECK: !quant.quantile + // CHECK: !quant.quantile bytecode.test = !quant.quantile } {} // CHECK-LABEL: parseQuantilePerAxisScaleZeroU4 module @parseQuantilePerAxisScaleZeroU4 attributes { - // CHECK: !quant.quantile + // CHECK: !quant.quantile bytecode.test = !quant.quantile:f16:f32:1, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:{2.000000e+02:-120,9.987200e-01:127}> } {} // CHECK-LABEL: parseQuantilePerAxisScaleNoZero module @parseQuantilePerAxisScaleNoZero attributes { - // CHECK: !quant.quantile + // CHECK: !quant.quantile bytecode.test = !quant.quantile } {} // CHECK-LABEL: parseQuantilePerAxisMixed module @parseQuantilePerAxisMixed attributes { - // CHECK: !quant.quantile + // CHECK: !quant.quantile bytecode.test = !quant.quantile } {} From d3d31ecbc4537e505719a722cdd2ae89804d4b74 Mon Sep 17 00:00:00 2001 From: Luca Sarti Date: Tue, 1 Oct 2024 10:38:16 +0000 Subject: [PATCH 4/5] Minor cleanup --- mlir/include/mlir/Dialect/Quant/QuantTypes.h | 1 - mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 1 - mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 11 ----------- 3 files changed, 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h index 116611894b9b..403e3187d7c5 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -16,7 +16,6 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Types.h" #include "llvm/Support/MathExtras.h" -#include namespace mlir { namespace quant { diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 4967adec41f1..8b6d7a58f2be 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -12,7 +12,6 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/Support/TypeID.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/MathExtras.h" diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index 24525628f3c0..9ce1accc4ea5 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -323,8 +323,6 @@ static Type parseUniformType(DialectAsmParser &parser, bool isQuantile) { if (!quantileType) { return nullptr; } - // mlir::emitRemark(parser.getEncodedSourceLoc(parser.getCurrentLocation())) - // << "Here Here " << quantileType; } // Expressed type. @@ -408,15 +406,6 @@ static Type parseUniformType(DialectAsmParser &parser, bool isQuantile) { ArrayRef scalesRef(scales.begin(), scales.end()); ArrayRef zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); - // mlir::emitRemark(parser.getEncodedSourceLoc(parser.getCurrentLocation())) - // << "storageType: " << storageType - // << ", quantileType: " << quantileType - // << ", expressedType: " << expressedType - // << ", quantilesRef: " << quantilesRef << ", scalesRef: " << - // scalesRef - // << ", zeroPointsRef: " << zeroPointsRef - // << ", quantizedDimension: " << quantizedDimension; - return parser.getChecked( typeFlags, storageType, quantileType, expressedType, quantilesRef, scalesRef, zeroPointsRef, quantizedDimension, storageTypeMin, From 571eed38f481e1c27affd2e48448c134cb096737 Mon Sep 17 00:00:00 2001 From: Luca Sarti Date: Thu, 3 Oct 2024 09:59:29 +0000 Subject: [PATCH 5/5] Expanding comment about supported quantileType member types --- mlir/include/mlir/Dialect/Quant/QuantTypes.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h index 403e3187d7c5..6a6d3a54891c 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -397,7 +397,8 @@ class UniformQuantizedPerAxisType }; /// QuantileQuantizedType derives from UniformQuantizedType and adds to it a -/// look up table array of quantile values. +/// look up table array of quantile values. The type of the data in the look up table is determined by +/// the quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64. /// /// Syntax synopsis: /// Per-layer, all parameters expressed: @@ -463,7 +464,8 @@ class QuantileQuantizedType }; /// Represents per-axis QuantileQuantizedType (also known as per-channel -/// quantization). +/// quantization). The type of the data in the look up table is determined by the +/// quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64. /// /// Syntax synopsis: /// Per-axis, all parameters expressed: