diff --git a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td index bd9cdf823822..4ccfa08d8fe9 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td +++ b/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td @@ -81,6 +81,35 @@ def UniformQuantizedPerAxisType: DialectType<(type }]; } +def QuantileQuantizedType: DialectType<(type + VarInt:$flags, + Type:$storageType, + Type:$expressedType, + Array:$quantiles, + DoubleAPFloat:$scale, + SignedVarInt:$zeroPoint, + SignedVarInt:$storageTypeMin, + SignedVarInt:$storageTypeMax +)>; + +def QuantileQuantizedPerAxisType: DialectType<(type + VarInt:$flags, + Type:$storageType, + Type:$expressedType, + VarInt:$quantizedDimension, + SignedVarInt:$storageTypeMin, + SignedVarInt:$storageTypeMax, + Array:$quantiles, + Array:$scales, + Array:$zeroPoints +)> { + // Note: builder order differs from bytecode. + let cBuilder = [{ + get<$_resultType>(context, flags, storageType, expressedType, quantiles, scales, + zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax) + }]; +} + /// This enum contains marker codes used to indicate which attribute is /// currently being decoded, and how it should be decoded. The order of these /// codes should generally be unchanged, as any changes will inevitably break @@ -93,7 +122,9 @@ def QuantDialectTypes : DialectTypes<"Quant"> { AnyQuantizedTypeWithExpressedType, CalibratedQuantizedType, UniformQuantizedType, - UniformQuantizedPerAxisType + UniformQuantizedPerAxisType, + QuantileQuantizedType, + QuantileQuantizedPerAxisType ]; } diff --git a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td index da822d0a61de..4b6bc7c910ab 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td +++ b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td @@ -67,8 +67,18 @@ def quant_UniformQuantizedType : CPred<"::llvm::isa($_self)">, "UniformQuantizedType">; +// An implementation of QuantileQuantizedType. +def quant_QuantileQuantizedType : + DialectType($_self)">, + "QuantileQuantizedType">; + // Predicate for detecting a container or primitive of UniformQuantizedType. def quant_UniformQuantizedValueType : quant_TypedPrimitiveOrContainer; +// Predicate for detecting a container or primitive of QuantileQuantizedType. +def quant_QuantileQuantizedValueType : + quant_TypedPrimitiveOrContainer; + #endif // DIALECT_QUANT_QUANT_OPS_BASE_ diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h index 8d21955caaaf..02eed19ddbe4 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -25,6 +25,8 @@ struct QuantizedTypeStorage; struct AnyQuantizedTypeStorage; struct UniformQuantizedTypeStorage; struct UniformQuantizedPerAxisTypeStorage; +struct QuantileQuantizedTypeStorage; +struct QuantileQuantizedPerAxisTypeStorage; struct CalibratedQuantizedTypeStorage; } // namespace detail @@ -390,6 +392,128 @@ class UniformQuantizedPerAxisType } }; +/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a +/// look up table array of quantile values. +/// +/// Syntax synopsis: +/// Per-layer, all parameters expressed: +/// !quant +/// Per-layer, optional parameters omitted: +/// !quant +/// +/// StorageType: 'i'|'u' NumBits +/// ExpressedType: 'f16', 'f32', 'bf16', 'f64' +/// Quantiles: Quantile+ +/// Quantile: A legal double value +/// Scale: A legal double value +/// ZeroPoint: An integer value +class QuantileQuantizedType + : public Type::TypeBase { +public: + using Base::Base; + using Base::getChecked; + + static constexpr StringLiteral name = "quant.quantile"; + + /// Gets an instance of the type with all parameters specified but not + /// checked. + static QuantileQuantizedType get(unsigned flags, Type storageType, + Type expressedType, + ArrayRef quantiles, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax); + + static QuantileQuantizedType + getChecked(function_ref emitError, unsigned flags, + Type storageType, Type expressedType, ArrayRef quantiles, + double scale, int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax); + + /// Verifies construction invariants and issues errors/warnings. + static LogicalResult verify(function_ref emitError, + unsigned flags, Type storageType, + Type expressedType, ArrayRef quantiles, + double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Gets the quantile values + ArrayRef getQuantiles() const; + + // Fixed point values are real numbers divided by a scale. + // Currently, only signed storage types are treated as fixed point. + // A fixed point value can be obtained from an affine value by subtracting + // the zeroPoint. + // In the future, this may be explicit versus implied by type and zeroPoint. + bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; } +}; + +/// Represents per-axis QuantileQuantizedType (also known as per-channel +/// quantization). +/// +/// Syntax synopsis: +/// Per-axis, all parameters expressed: +/// !quant +/// Per-axis, optional parameters omitted: +/// !quant +/// +/// StorageType: 'i'|'u' NumBits +/// ExpressedType: 'f16', 'f32', 'bf16', 'f64' +/// QuantizedDim: An integer value +/// Quantiles: Quantile+ +/// Quantile: A legal double value +/// QuantParams: (Scale ':' ZeroPoint)+ +/// Scale: A legal double value +/// ZeroPoint: An integer value +class QuantileQuantizedPerAxisType + : public Type::TypeBase { +public: + using Base::Base; + using Base::getChecked; + + static constexpr StringLiteral name = "quant.quantile_per_axis"; + + /// Gets an instance of the type with all parameters specified but not + /// checked. + static QuantileQuantizedPerAxisType + get(unsigned flags, Type storageType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Gets an instance of the type with all specified parameters checked. + /// Returns a nullptr convertible type on failure. + static QuantileQuantizedPerAxisType + getChecked(function_ref emitError, unsigned flags, + Type storageType, Type expressedType, ArrayRef quantiles, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax); + + /// Verifies construction invariants and issues errors/warnings. + static LogicalResult verify(function_ref emitError, + unsigned flags, Type storageType, + Type expressedType, ArrayRef quantiles, + ArrayRef scales, + ArrayRef zeroPoints, + int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Gets the quantile values + ArrayRef getQuantiles() const; + + /// Fixed point values are real numbers divided by a scale. + /// Currently, only signed storage types are treated as fixed point. + /// A fixed point value can be obtained from an affine value by subtracting + /// the zeroPoint. + /// In the future, this may be explicit versus implied by type and zeroPoint. + bool isFixedPoint() const { + return isSigned() && !llvm::is_contained(getZeroPoints(), 0); + } +}; + /// A quantized type that infers its range from given min/max values. /// /// Typical syntax: diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index c9a6bbc9ceee..124733286ce8 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -28,7 +28,8 @@ using namespace mlir::quant::detail; void QuantizationDialect::initialize() { addTypes(); + UniformQuantizedPerAxisType, QuantileQuantizedType, + QuantileQuantizedPerAxisType>(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/Quant/QuantOps.cpp.inc" diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 164b3bf61ec7..b7884a9631e5 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -378,6 +378,138 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const { return getImpl()->quantizedDimension; } +QuantileQuantizedType +QuantileQuantizedType::get(unsigned flags, Type storageType, Type expressedType, + ArrayRef quantiles, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax) { + return Base::get(storageType.getContext(), flags, storageType, expressedType, + quantiles, scale, zeroPoint, storageTypeMin, storageTypeMax); +} + +QuantileQuantizedType QuantileQuantizedType::getChecked( + function_ref emitError, unsigned flags, + Type storageType, Type expressedType, ArrayRef quantiles, + double scale, int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax) { + return Base::getChecked(emitError, storageType.getContext(), flags, + storageType, expressedType, quantiles, scale, + zeroPoint, storageTypeMin, storageTypeMax); +} +LogicalResult +QuantileQuantizedType::verify(function_ref emitError, + unsigned flags, Type storageType, + Type expressedType, ArrayRef quantiles, + double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { + if (failed(UniformQuantizedType::verify(emitError, flags, storageType, + expressedType, scale, zeroPoint, + storageTypeMin, storageTypeMax))) { + return failure(); + } + + const auto quantileArraySize = quantiles.size(); + unsigned typeWidth{}; + if (storageType.isa()) { + typeWidth = llvm::dyn_cast(storageType).getWidth(); + } else if (storageType.isa() || + storageType.isa()) { + // Both Float8E5M2Type and Float8E4M3FNType derive from FloatType. + typeWidth = llvm::dyn_cast(storageType).getWidth(); + } else { + return emitError() << "illegal storage type, supported types are: integral " + "types, Float8E4M3FNType and Float8E5M2Type "; + } + + const size_t expectedSize = 1 << typeWidth; + if (quantileArraySize != expectedSize) { + return emitError() << "quantiles array size needs to be equal to " + "2^(bit_size(storageType)), expected: " + << expectedSize << ", found: " << quantileArraySize; + } + + // Verify quantiles + for (double quantile : quantiles) { + if (std::isinf(quantile) || std::isnan(quantile)) { + return emitError() << "illegal quantile value: " << quantile; + } + } + + return success(); +} + +ArrayRef QuantileQuantizedType::getQuantiles() const { + return getImpl()->getQuantiles(); +} + +QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get( + unsigned flags, Type storageType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) { + return Base::get(storageType.getContext(), flags, storageType, expressedType, + quantiles, scales, zeroPoints, quantizedDimension, + storageTypeMin, storageTypeMax); +} + +QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked( + function_ref emitError, unsigned flags, + Type storageType, Type expressedType, ArrayRef quantiles, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax) { + return Base::getChecked(emitError, storageType.getContext(), flags, + storageType, expressedType, quantiles, scales, + zeroPoints, quantizedDimension, storageTypeMin, + storageTypeMax); +} + +LogicalResult QuantileQuantizedPerAxisType::verify( + function_ref emitError, unsigned flags, + Type storageType, Type expressedType, ArrayRef quantiles, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax) { + if (failed(UniformQuantizedPerAxisType::verify( + emitError, flags, storageType, expressedType, scales, zeroPoints, + quantizedDimension, storageTypeMin, storageTypeMax))) { + return failure(); + } + + const auto quantileArraySize = quantiles.size(); + unsigned typeWidth{}; + if (storageType.isa()) { + typeWidth = llvm::dyn_cast(storageType).getWidth(); + } else if (storageType.isa() || + storageType.isa()) { + // Both Float8E5M2Type and Float8E4M3FNType derive from FloatType. + typeWidth = llvm::dyn_cast(storageType).getWidth(); + } else { + return emitError() << "illegal storage type, supported types are: integral " + "types, Float8E4M3FNType and Float8E5M2Type "; + } + + const size_t expectedSize = 1 << typeWidth; + if (quantileArraySize != expectedSize) { + return emitError() << "quantiles array size needs to be equal to " + "2^(bit_size(storageType)), expected: " + << expectedSize << ", found: " << quantileArraySize; + } + + // Verify quantiles + for (double quantile : quantiles) { + if (std::isinf(quantile) || std::isnan(quantile)) { + return emitError() << "illegal quantile value: " << quantile; + } + } + + return success(); +} + +ArrayRef QuantileQuantizedPerAxisType::getQuantiles() const { + return getImpl()->getQuantiles(); +} + CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType, double min, double max) { return Base::get(expressedType.getContext(), expressedType, min, max); diff --git a/mlir/lib/Dialect/Quant/IR/TypeDetail.h b/mlir/lib/Dialect/Quant/IR/TypeDetail.h index ef098811927c..fbb71448655a 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeDetail.h +++ b/mlir/lib/Dialect/Quant/IR/TypeDetail.h @@ -253,6 +253,145 @@ struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage { int32_t quantizedDimension; }; +struct QuantileQuantizedTypeStorage : public UniformQuantizedTypeStorage { + struct KeyTy : public UniformQuantizedTypeStorage::KeyTy { + KeyTy(unsigned flags, Type storageType, Type expressedType, + ArrayRef quantiles, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) + : UniformQuantizedTypeStorage::KeyTy(flags, storageType, expressedType, + scale, zeroPoint, storageTypeMin, + storageTypeMax), + quantiles(quantiles) {} + + ArrayRef quantiles; + ArrayRef getQuantiles() const { return quantiles; } + + // Check for equality of two structures that share KeyTy data members + // (by name). + template + static bool genericIsEqual(const T &lhs, const U &rhs) { + return UniformQuantizedTypeStorage::KeyTy::genericIsEqual(lhs, rhs) && + lhs.getQuantiles() == rhs.getQuantiles(); + } + + bool operator==(const KeyTy &other) const { + return genericIsEqual(*this, other); + } + + unsigned getHashValue() const { + int64_t scaleBits = llvm::bit_cast(scale); + int64_t *quantilesCast = llvm::bit_cast(quantiles.data()); + ArrayRef quantilesBits(quantilesCast, quantiles.size()); + return llvm::hash_combine( + flags, storageType, expressedType, + llvm::hash_combine_range(quantilesBits.begin(), quantilesBits.end()), + scaleBits, zeroPoint, storageTypeMin, storageTypeMax); + } + }; + + QuantileQuantizedTypeStorage(const KeyTy &key, ArrayRef quantiles) + : UniformQuantizedTypeStorage(key), quantilesElements(quantiles.data()), + quantilesParamsSize(quantiles.size()) {} + + bool operator==(const KeyTy &key) const { + return KeyTy::genericIsEqual(*this, key); + } + + /// Construction. + static QuantileQuantizedTypeStorage * + construct(TypeStorageAllocator &allocator, const KeyTy &key) { + ArrayRef quantiles = allocator.copyInto(key.quantiles); + return new (allocator.allocate()) + QuantileQuantizedTypeStorage(key, quantiles); + } + + static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } + + ArrayRef getQuantiles() const { + return ArrayRef(quantilesElements, quantilesParamsSize); + } + + const double *quantilesElements; + unsigned quantilesParamsSize; +}; + +struct QuantileQuantizedPerAxisTypeStorage + : public UniformQuantizedPerAxisTypeStorage { + struct KeyTy : public UniformQuantizedPerAxisTypeStorage::KeyTy { + KeyTy(unsigned flags, Type storageType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) + : UniformQuantizedPerAxisTypeStorage::KeyTy( + flags, storageType, expressedType, scales, zeroPoints, + quantizedDimension, storageTypeMin, storageTypeMax), + quantiles(quantiles) {} + + ArrayRef quantiles; + ArrayRef getQuantiles() const { return quantiles; } + + // Check for equality of two structures that share KeyTy data members + // (by name). + template + static bool genericIsEqual(const T &lhs, const U &rhs) { + return UniformQuantizedPerAxisTypeStorage::KeyTy::genericIsEqual(lhs, + rhs) && + lhs.getQuantiles() == rhs.getQuantiles(); + } + + bool operator==(const KeyTy &other) const { + return genericIsEqual(*this, other); + } + + unsigned getHashValue() const { + int64_t *scalesCast = llvm::bit_cast(scales.data()); + ArrayRef scalesBits(scalesCast, scales.size()); + int64_t *quantilesCast = llvm::bit_cast(quantiles.data()); + ArrayRef quantilesBits(quantilesCast, quantiles.size()); + return llvm::hash_combine( + flags, storageType, expressedType, + llvm::hash_combine_range(quantilesBits.begin(), quantilesBits.end()), + llvm::hash_combine_range(scalesBits.begin(), scalesBits.end()), + llvm::hash_combine_range(zeroPoints.begin(), zeroPoints.end()), + storageTypeMin, storageTypeMax); + } + }; + + // We pass quantiles, scales and zeroPoints in directly rather than relying on + // KeyTy because we have to create new reallocated versions in `construct` + // below. + QuantileQuantizedPerAxisTypeStorage(const KeyTy &key, + ArrayRef quantiles, + ArrayRef scales, + ArrayRef zeroPoints) + : UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints), + quantilesElements(quantiles.data()), + quantilesParamsSize(quantiles.size()) {} + + bool operator==(const KeyTy &key) const { + return KeyTy::genericIsEqual(*this, key); + } + + /// Construction. + static QuantileQuantizedPerAxisTypeStorage * + construct(TypeStorageAllocator &allocator, const KeyTy &key) { + ArrayRef quantiles = allocator.copyInto(key.quantiles); + ArrayRef scales = allocator.copyInto(key.scales); + ArrayRef zeroPoints = allocator.copyInto(key.zeroPoints); + return new (allocator.allocate()) + QuantileQuantizedPerAxisTypeStorage(key, quantiles, scales, zeroPoints); + } + + static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } + + ArrayRef getQuantiles() const { + return ArrayRef(quantilesElements, quantilesParamsSize); + } + + const double *quantilesElements; + unsigned quantilesParamsSize; +}; // namespace detail + struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage { struct KeyTy { KeyTy(Type expressedType, double min, double max) diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index fc4c6d909d01..019d84918699 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -196,8 +196,9 @@ static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, int64_t &zeroPoint) { // scale[:zeroPoint]? // scale. - if (parser.parseFloat(scale)) + if (parser.parseFloat(scale)) { return failure(); + } // zero point. zeroPoint = 0; @@ -209,7 +210,7 @@ static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, return parser.parseInteger(zeroPoint); } -/// Parses a UniformQuantizedType. +/// Parses a UniformQuantizedType or a QuantileQuantizedType. /// /// uniform_type ::= uniform_per_layer /// | uniform_per_axis @@ -224,7 +225,22 @@ static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, /// axis-spec ::= `:` integer-literal /// scale-zero ::= float-literal `:` integer-literal /// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}` -static Type parseUniformType(DialectAsmParser &parser) { +/// +/// quantile_type ::= quantile_per_layer +/// | quantile_per_axis +/// quantile_per_layer ::= `quantile<` storage-spec expressed-type-spec +/// `,` quantiles-list `,` scale-zero `>` +/// quantile_per_axis ::= `quantile<` storage-spec expressed-type-spec +/// axis-spec `,` quantiles-list scale-zero-list `>` +/// storage-spec ::= storage-type (`<` storage-range `>`)? +/// storage-range ::= integer-literal `:` integer-literal +/// storage-type ::= (`i` | `u`) integer-literal +/// expressed-type-spec ::= `:` `f` integer-literal +/// axis-spec ::= `:` integer-literal +/// quantiles-list ::= `{` quantile (`,` quantile)* `}` +/// scale-zero ::= `:` float-literal `:` integer-literal +/// scale-zero-list ::= `:` `{` scale-zero (`,` scale-zero)* `}` +static Type parseUniformType(DialectAsmParser &parser, bool isQuantile) { Type storageType; FloatType expressedType; unsigned typeFlags = 0; @@ -232,6 +248,7 @@ static Type parseUniformType(DialectAsmParser &parser) { int64_t storageTypeMax; bool isPerAxis = false; int32_t quantizedDimension; + SmallVector quantiles; SmallVector scales; SmallVector zeroPoints; @@ -273,6 +290,28 @@ static Type parseUniformType(DialectAsmParser &parser) { return nullptr; } + // Quantile list + if (isQuantile) { + if (parser.parseLBrace()) { + return nullptr; + } + + do { + quantiles.emplace_back(); + if (parser.parseFloat(quantiles.back())) { + return nullptr; + } + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRBrace()) { + return nullptr; + } + + if (parser.parseColon()) { + return nullptr; + } + } + // Parameter specification. // For per-axis, ranges are in a {} delimitted list. if (isPerAxis) { @@ -308,6 +347,22 @@ static Type parseUniformType(DialectAsmParser &parser) { nullptr); } + if (isQuantile) { + ArrayRef quantilesRef(quantiles.begin(), quantiles.end()); + + if (isPerAxis) { + ArrayRef scalesRef(scales.begin(), scales.end()); + ArrayRef zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); + return parser.getChecked( + typeFlags, storageType, expressedType, quantilesRef, scalesRef, + zeroPointsRef, quantizedDimension, storageTypeMin, storageTypeMax); + } + + return parser.getChecked( + typeFlags, storageType, expressedType, quantilesRef, scales.front(), + zeroPoints.front(), storageTypeMin, storageTypeMax); + } + if (isPerAxis) { ArrayRef scalesRef(scales.begin(), scales.end()); ArrayRef zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); @@ -357,7 +412,9 @@ Type QuantizationDialect::parseType(DialectAsmParser &parser) const { return nullptr; if (typeNameSpelling == "uniform") - return parseUniformType(parser); + return parseUniformType(parser, false); + if (typeNameSpelling == "quantile") + return parseUniformType(parser, true); if (typeNameSpelling == "any") return parseAnyType(parser); if (typeNameSpelling == "calibrated") @@ -386,20 +443,20 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { int64_t defaultMin = type.getStorageType().isa() ? QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth) - : type.getStorageType().isa() - ? QuantizedType::getDefaultMinimumForF8E5M2() - : type.getStorageType().isa() - ? QuantizedType::getDefaultMinimumForF8E4M3FN() - : std::numeric_limits::max(); + : type.getStorageType().isa() + ? QuantizedType::getDefaultMinimumForF8E5M2() + : type.getStorageType().isa() + ? QuantizedType::getDefaultMinimumForF8E4M3FN() + : std::numeric_limits::max(); int64_t defaultMax = type.getStorageType().isa() ? QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth) - : type.getStorageType().isa() - ? QuantizedType::getDefaultMaximumForF8E5M2() - : type.getStorageType().isa() - ? QuantizedType::getDefaultMaximumForF8E4M3FN() - : std::numeric_limits::min(); + : type.getStorageType().isa() + ? QuantizedType::getDefaultMaximumForF8E5M2() + : type.getStorageType().isa() + ? QuantizedType::getDefaultMaximumForF8E4M3FN() + : std::numeric_limits::min(); if (defaultMin != type.getStorageTypeMin() || defaultMax != type.getStorageTypeMax()) { @@ -461,6 +518,54 @@ static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, out << "}>"; } +/// Helper that prints a QuantileQuantizedType. +static void printQuantileQuantizedType(QuantileQuantizedType type, + DialectAsmPrinter &out) { + out << "quantile<"; + printStorageType(type, out); + out << ":" << type.getExpressedType() << ", "; + + // scheme specific parameters + ArrayRef quantiles = type.getQuantiles(); + out << "{"; + llvm::interleave( + llvm::seq(0, quantiles.size()), out, + [&](size_t index) { out << quantiles[index]; }, ","); + out << "}:"; + + printQuantParams(type.getScale(), type.getZeroPoint(), out); + out << ">"; +} + +/// Helper that prints a QuantileQuantizedPerAxisType. +static void printQuantileQuantizedPerAxisType(QuantileQuantizedPerAxisType type, + DialectAsmPrinter &out) { + out << "quantile<"; + printStorageType(type, out); + out << ":" << type.getExpressedType() << ":"; + out << type.getQuantizedDimension(); + out << ", "; + + // scheme specific parameters + ArrayRef quantiles = type.getQuantiles(); + out << "{"; + llvm::interleave( + llvm::seq(0, quantiles.size()), out, + [&](size_t index) { out << quantiles[index]; }, ","); + out << "}:"; + + ArrayRef scales = type.getScales(); + ArrayRef zeroPoints = type.getZeroPoints(); + out << "{"; + llvm::interleave( + llvm::seq(0, scales.size()), out, + [&](size_t index) { + printQuantParams(scales[index], zeroPoints[index], out); + }, + ","); + out << "}>"; +} + /// Helper that prints a CalibratedQuantizedType. static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out) { @@ -477,6 +582,11 @@ void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { printUniformQuantizedType(uniformType, os); else if (auto perAxisType = llvm::dyn_cast(type)) printUniformQuantizedPerAxisType(perAxisType, os); + else if (auto uniformType = llvm::dyn_cast(type)) + printQuantileQuantizedType(uniformType, os); + else if (auto perAxisType = + llvm::dyn_cast(type)) + printQuantileQuantizedPerAxisType(perAxisType, os); else if (auto calibratedType = llvm::dyn_cast(type)) printCalibratedQuantizedType(calibratedType, os); else diff --git a/mlir/test/Dialect/Quant/Bytecode/types.mlir b/mlir/test/Dialect/Quant/Bytecode/types.mlir index 359a58557087..4d85b7d758d1 100644 --- a/mlir/test/Dialect/Quant/Bytecode/types.mlir +++ b/mlir/test/Dialect/Quant/Bytecode/types.mlir @@ -64,3 +64,34 @@ module @parseUniformPerAxisMixed attributes { bytecode.test = !quant.uniform } {} +//===----------------------------------------------------------------------===// +// QuantileQuantized +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: parseQuantilePerLayer +module @parseQuantilePerLayer attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile +} {} + +//===----------------------------------------------------------------------===// +// QuantileQuantizedPerAxis +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: parseQuantilePerAxisScaleZero +module @parseQuantilePerAxisScaleZero attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile:f32:1, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:{2.000000e+02:-120,9.987200e-01:127}> +} {} + +// CHECK-LABEL: parseQuantilePerAxisScaleNoZero +module @parseQuantilePerAxisScaleNoZero attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile +} {} + +// CHECK-LABEL: parseQuantilePerAxisMixed +module @parseQuantilePerAxisMixed attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile +} {} diff --git a/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir b/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir new file mode 100644 index 000000000000..8acfa2a587c1 --- /dev/null +++ b/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir @@ -0,0 +1,166 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// ----- +// Illegal quantile array size +// expected-error@+1 {{quantiles array size needs to be equal to 2^(bit_size(storageType)), expected: 256, found: 2}} +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Unrecognized token: trailing +// expected-error@+1 {{expected '>'}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127 23> + +// ----- +// Unrecognized token: missing storage type maximum +// expected-error@+1 {{expected ':'}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Unrecognized token: missing closing angle bracket +// expected-error@+1 {{unbalanced '<' character in pretty dialect name}} +!qalias = !quant> + +// ----- +// Unrecognized token: missing type colon +// expected-error@+1 {{expected ':'}} +!qalias = !quant.quantilef32, {-1.0,1.0}:0.99872:127> + +// ----- +// Unrecognized token: missing comma +// expected-error@+1 {{expected ','}} +!qalias = !quant.quantile:f32 {-1.0,1.0}:0.99872:127> + +// ----- +// Unrecognized storage type: illegal prefix +// expected-error@+1 {{illegal quantized storage type alias}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Unrecognized storage type: no width +// expected-error@+1 {{illegal quantized storage type alias}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Unrecognized storage type: storage size > 32 +// expected-error@+1 {{illegal storage type size: 33}} +!qalias = !quant.quantile + +// ----- +// Unrecognized storage type: storage size < 0 +// expected-error@+1 {{illegal quantized storage type alias}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Unrecognized storage type: storage size +// expected-error@+1 {{invalid integer width}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: max - min < 0 +// expected-error@+1 {{illegal storage min and storage max: (2:1)}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: max - min == 0 +// expected-error@+1 {{illegal storage min and storage max: (1:1)}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 9}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -9}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 60000}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -60000}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 500}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -500}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal uniform params: invalid scale +// expected-error@+1 {{expected floating point literal}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:abc:127> + +// ----- +// Illegal uniform params: invalid zero point separator +// expected-error@+1 {{expected '>'}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.1abc> + +// ----- +// Illegal uniform params: missing zero point +// expected-error@+1 {{expected integer value}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.1:> + +// ----- +// Illegal uniform params: invalid zero point +// expected-error@+1 {{expected integer value}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:0.1:abc> + +// ----- +// Illegal expressed type: f33 +// expected-error@+1 {{expected non-function type}} +!qalias = !quant.quantile:f33, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal scale: negative +// expected-error@+1 {{illegal scale: -1.000000}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:-1.0:127> + +// ----- +// Illegal uniform params: missing quantized dimension +// expected-error@+1 {{expected integer value}} +!qalias = !quant.quantile:f32:, {-1.0,1.0}:{2.000000e+02:-19.987200e-01:1}> + +// ----- +// Illegal uniform params: unspecified quantized dimension, when multiple scales +// provided. +// expected-error@+1 {{expected floating point literal}} +!qalias = !quant.quantile:f32, {-1.0,1.0}:{2.000000e+02,-19.987200e-01:1}> + +// ----- +// Illegal quantile params: unspecified quantile values +// expected-error@+1 {{expected floating point literal}} +!qalias = !quant.quantile:f32, {}:0.99872:127> + +// ----- +// Illegal quantile params: missing quantile values +// expected-error@+1 {{expected floating point literal}} +!qalias = !quant.quantile:f32, {-1.0,}:0.99872:127> + +// ----- +// Illegal quantile params: missing colon separator +// expected-error@+1 {{expected ':'}} +!qalias = !quant.quantile:f32, {-1.0,1.0}0.99872:127> + +// ----- +// Illegal quantile params: unbalanced } +// expected-error@+1 {{unbalanced '{' character in pretty dialect name}} +!qalias = !quant.quantile:f32, {-1.0,1.0:0.99872:127> + +// ----- +// Illegal quantile params: missing { +// expected-error@+1 {{unbalanced '<' character in pretty dialect name}} +!qalias = !quant.quantile:f32, -1.0,1.0}:0.99872:127> diff --git a/mlir/test/Dialect/Quant/parse-quantile.mlir b/mlir/test/Dialect/Quant/parse-quantile.mlir new file mode 100644 index 000000000000..0c5847c6b681 --- /dev/null +++ b/mlir/test/Dialect/Quant/parse-quantile.mlir @@ -0,0 +1,165 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file | FileCheck %s + +// ----- +// All per-layer params specified: +// [signed] storageType, storageTypeMin, storageTypeMax, expressedType, scale, zeroPoint +// CHECK: !quant.quantile:f32, {-1.000000e+00,-0.99219999999999997,-0.98429999999999995,-9.765000e-01,-9.686000e-01,-0.96079999999999998,-9.529000e-01,-9.451000e-01,-9.373000e-01,-9.294000e-01,-0.92159999999999997,-0.91369999999999995,-9.059000e-01,-8.980000e-01,-8.902000e-01,-8.824000e-01,-8.745000e-01,-8.667000e-01,-8.588000e-01,-8.510000e-01,-0.84309999999999996,-8.353000e-01,-8.275000e-01,-8.196000e-01,-8.118000e-01,-8.039000e-01,-7.961000e-01,-7.882000e-01,-0.78039999999999998,-7.725000e-01,-7.647000e-01,-7.569000e-01,-7.490000e-01,-7.412000e-01,-7.333000e-01,-7.255000e-01,-7.176000e-01,-7.098000e-01,-0.70199999999999996,-6.941000e-01,-6.863000e-01,-6.784000e-01,-6.706000e-01,-6.627000e-01,-6.549000e-01,-6.471000e-01,-0.63919999999999999,-0.63139999999999996,-6.235000e-01,-6.157000e-01,-6.078000e-01,-6.000000e-01,-5.922000e-01,-5.843000e-01,-5.765000e-01,-5.686000e-01,-5.608000e-01,-5.529000e-01,-5.451000e-01,-5.373000e-01,-5.294000e-01,-5.216000e-01,-5.137000e-01,-5.059000e-01,-4.980000e-01,-4.902000e-01,-4.824000e-01,-4.745000e-01,-4.667000e-01,-4.588000e-01,-4.510000e-01,-4.431000e-01,-4.353000e-01,-4.275000e-01,-4.196000e-01,-4.118000e-01,-4.039000e-01,-3.961000e-01,-3.882000e-01,-3.804000e-01,-3.725000e-01,-3.647000e-01,-3.569000e-01,-3.490000e-01,-3.412000e-01,-3.333000e-01,-3.255000e-01,-3.176000e-01,-3.098000e-01,-3.020000e-01,-2.941000e-01,-2.863000e-01,-2.784000e-01,-2.706000e-01,-2.627000e-01,-2.549000e-01,-2.471000e-01,-2.392000e-01,-2.314000e-01,-2.235000e-01,-2.157000e-01,-2.078000e-01,-2.000000e-01,-1.922000e-01,-1.843000e-01,-1.765000e-01,-1.686000e-01,-1.608000e-01,-1.529000e-01,-1.451000e-01,-1.373000e-01,-1.294000e-01,-1.216000e-01,-1.137000e-01,-1.059000e-01,-9.800000e-02,-9.020000e-02,-8.240000e-02,-0.074499999999999997,-0.066699999999999995,-5.880000e-02,-5.100000e-02,-4.310000e-02,-3.530000e-02,-2.750000e-02,-1.960000e-02,-1.180000e-02,-3.900000e-03,3.900000e-03,1.180000e-02,1.960000e-02,2.750000e-02,3.530000e-02,4.310000e-02,5.100000e-02,5.880000e-02,0.066699999999999995,0.074499999999999997,8.240000e-02,9.020000e-02,9.800000e-02,1.059000e-01,1.137000e-01,1.216000e-01,1.294000e-01,1.373000e-01,1.451000e-01,1.529000e-01,1.608000e-01,1.686000e-01,1.765000e-01,1.843000e-01,1.922000e-01,2.000000e-01,2.078000e-01,2.157000e-01,2.235000e-01,2.314000e-01,2.392000e-01,2.471000e-01,2.549000e-01,2.627000e-01,2.706000e-01,2.784000e-01,2.863000e-01,2.941000e-01,3.020000e-01,3.098000e-01,3.176000e-01,3.255000e-01,3.333000e-01,3.412000e-01,3.490000e-01,3.569000e-01,3.647000e-01,3.725000e-01,3.804000e-01,3.882000e-01,3.961000e-01,4.039000e-01,4.118000e-01,4.196000e-01,4.275000e-01,4.353000e-01,4.431000e-01,4.510000e-01,4.588000e-01,4.667000e-01,4.745000e-01,4.824000e-01,4.902000e-01,4.980000e-01,5.059000e-01,5.137000e-01,5.216000e-01,5.294000e-01,5.373000e-01,5.451000e-01,5.529000e-01,5.608000e-01,5.686000e-01,5.765000e-01,5.843000e-01,5.922000e-01,6.000000e-01,6.078000e-01,6.157000e-01,6.235000e-01,0.63139999999999996,0.63919999999999999,6.471000e-01,6.549000e-01,6.627000e-01,6.706000e-01,6.784000e-01,6.863000e-01,6.941000e-01,0.70199999999999996,7.098000e-01,7.176000e-01,7.255000e-01,7.333000e-01,7.412000e-01,7.490000e-01,7.569000e-01,7.647000e-01,7.725000e-01,0.78039999999999998,7.882000e-01,7.961000e-01,8.039000e-01,8.118000e-01,8.196000e-01,8.275000e-01,8.353000e-01,0.84309999999999996,8.510000e-01,8.588000e-01,8.667000e-01,8.745000e-01,8.824000e-01,8.902000e-01,8.980000e-01,9.059000e-01,0.91369999999999995,0.92159999999999997,9.294000e-01,9.373000e-01,9.451000e-01,9.529000e-01,0.96079999999999998,9.686000e-01,9.765000e-01,0.98429999999999995,0.99219999999999997,1.000000e+00}:9.987200e-01:127> +!qalias = !quant.quantile:f32, {-1.0000,-0.9922,-0.9843,-0.9765,-0.9686,-0.9608,-0.9529,-0.9451,-0.9373,-0.9294,-0.9216,-0.9137,-0.9059,-0.8980,-0.8902,-0.8824,-0.8745,-0.8667,-0.8588,-0.8510,-0.8431,-0.8353,-0.8275,-0.8196,-0.8118,-0.8039,-0.7961,-0.7882,-0.7804,-0.7725,-0.7647,-0.7569,-0.7490,-0.7412,-0.7333,-0.7255,-0.7176,-0.7098,-0.7020,-0.6941,-0.6863,-0.6784,-0.6706,-0.6627,-0.6549,-0.6471,-0.6392,-0.6314,-0.6235,-0.6157,-0.6078,-0.6000,-0.5922,-0.5843,-0.5765,-0.5686,-0.5608,-0.5529,-0.5451,-0.5373,-0.5294,-0.5216,-0.5137,-0.5059,-0.4980,-0.4902,-0.4824,-0.4745,-0.4667,-0.4588,-0.4510,-0.4431,-0.4353,-0.4275,-0.4196,-0.4118,-0.4039,-0.3961,-0.3882,-0.3804,-0.3725,-0.3647,-0.3569,-0.3490,-0.3412,-0.3333,-0.3255,-0.3176,-0.3098,-0.3020,-0.2941,-0.2863,-0.2784,-0.2706,-0.2627,-0.2549,-0.2471,-0.2392,-0.2314,-0.2235,-0.2157,-0.2078,-0.2000,-0.1922,-0.1843,-0.1765,-0.1686,-0.1608,-0.1529,-0.1451,-0.1373,-0.1294,-0.1216,-0.1137,-0.1059,-0.0980,-0.0902,-0.0824,-0.0745,-0.0667,-0.0588,-0.0510,-0.0431,-0.0353,-0.0275,-0.0196,-0.0118,-0.0039,0.0039,0.0118,0.0196,0.0275,0.0353,0.0431,0.0510,0.0588,0.0667,0.0745,0.0824,0.0902,0.0980,0.1059,0.1137,0.1216,0.1294,0.1373,0.1451,0.1529,0.1608,0.1686,0.1765,0.1843,0.1922,0.2000,0.2078,0.2157,0.2235,0.2314,0.2392,0.2471,0.2549,0.2627,0.2706,0.2784,0.2863,0.2941,0.3020,0.3098,0.3176,0.3255,0.3333,0.3412,0.3490,0.3569,0.3647,0.3725,0.3804,0.3882,0.3961,0.4039,0.4118,0.4196,0.4275,0.4353,0.4431,0.4510,0.4588,0.4667,0.4745,0.4824,0.4902,0.4980,0.5059,0.5137,0.5216,0.5294,0.5373,0.5451,0.5529,0.5608,0.5686,0.5765,0.5843,0.5922,0.6000,0.6078,0.6157,0.6235,0.6314,0.6392,0.6471,0.6549,0.6627,0.6706,0.6784,0.6863,0.6941,0.7020,0.7098,0.7176,0.7255,0.7333,0.7412,0.7490,0.7569,0.7647,0.7725,0.7804,0.7882,0.7961,0.8039,0.8118,0.8196,0.8275,0.8353,0.8431,0.8510,0.8588,0.8667,0.8745,0.8824,0.8902,0.8980,0.9059,0.9137,0.9216,0.9294,0.9373,0.9451,0.9529,0.9608,0.9686,0.9765,0.9843,0.9922,1.0000}:0.99872:127> +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Trailing whitespace. +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Default min/max value optimization for integers. +// CHECK: !quant.quantile +!qalias = !quant.quantile:f32, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Default min/max value optimization for f8E5M2. +// CHECK: !quant.quantile +!qalias = !quant.quantile:f32, {-1.0000,-0.9922,-0.9843,-0.9765,-0.9686,-0.9608,-0.9529,-0.9451,-0.9373,-0.9294,-0.9216,-0.9137,-0.9059,-0.8980,-0.8902,-0.8824,-0.8745,-0.8667,-0.8588,-0.8510,-0.8431,-0.8353,-0.8275,-0.8196,-0.8118,-0.8039,-0.7961,-0.7882,-0.7804,-0.7725,-0.7647,-0.7569,-0.7490,-0.7412,-0.7333,-0.7255,-0.7176,-0.7098,-0.7020,-0.6941,-0.6863,-0.6784,-0.6706,-0.6627,-0.6549,-0.6471,-0.6392,-0.6314,-0.6235,-0.6157,-0.6078,-0.6000,-0.5922,-0.5843,-0.5765,-0.5686,-0.5608,-0.5529,-0.5451,-0.5373,-0.5294,-0.5216,-0.5137,-0.5059,-0.4980,-0.4902,-0.4824,-0.4745,-0.4667,-0.4588,-0.4510,-0.4431,-0.4353,-0.4275,-0.4196,-0.4118,-0.4039,-0.3961,-0.3882,-0.3804,-0.3725,-0.3647,-0.3569,-0.3490,-0.3412,-0.3333,-0.3255,-0.3176,-0.3098,-0.3020,-0.2941,-0.2863,-0.2784,-0.2706,-0.2627,-0.2549,-0.2471,-0.2392,-0.2314,-0.2235,-0.2157,-0.2078,-0.2000,-0.1922,-0.1843,-0.1765,-0.1686,-0.1608,-0.1529,-0.1451,-0.1373,-0.1294,-0.1216,-0.1137,-0.1059,-0.0980,-0.0902,-0.0824,-0.0745,-0.0667,-0.0588,-0.0510,-0.0431,-0.0353,-0.0275,-0.0196,-0.0118,-0.0039,0.0039,0.0118,0.0196,0.0275,0.0353,0.0431,0.0510,0.0588,0.0667,0.0745,0.0824,0.0902,0.0980,0.1059,0.1137,0.1216,0.1294,0.1373,0.1451,0.1529,0.1608,0.1686,0.1765,0.1843,0.1922,0.2000,0.2078,0.2157,0.2235,0.2314,0.2392,0.2471,0.2549,0.2627,0.2706,0.2784,0.2863,0.2941,0.3020,0.3098,0.3176,0.3255,0.3333,0.3412,0.3490,0.3569,0.3647,0.3725,0.3804,0.3882,0.3961,0.4039,0.4118,0.4196,0.4275,0.4353,0.4431,0.4510,0.4588,0.4667,0.4745,0.4824,0.4902,0.4980,0.5059,0.5137,0.5216,0.5294,0.5373,0.5451,0.5529,0.5608,0.5686,0.5765,0.5843,0.5922,0.6000,0.6078,0.6157,0.6235,0.6314,0.6392,0.6471,0.6549,0.6627,0.6706,0.6784,0.6863,0.6941,0.7020,0.7098,0.7176,0.7255,0.7333,0.7412,0.7490,0.7569,0.7647,0.7725,0.7804,0.7882,0.7961,0.8039,0.8118,0.8196,0.8275,0.8353,0.8431,0.8510,0.8588,0.8667,0.8745,0.8824,0.8902,0.8980,0.9059,0.9137,0.9216,0.9294,0.9373,0.9451,0.9529,0.9608,0.9686,0.9765,0.9843,0.9922,1.0000}:0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Default min/max value optimization for f8E4M3FN. +// CHECK: !quant.quantile +!qalias = !quant.quantile:f32, {-1.0000,-0.9922,-0.9843,-0.9765,-0.9686,-0.9608,-0.9529,-0.9451,-0.9373,-0.9294,-0.9216,-0.9137,-0.9059,-0.8980,-0.8902,-0.8824,-0.8745,-0.8667,-0.8588,-0.8510,-0.8431,-0.8353,-0.8275,-0.8196,-0.8118,-0.8039,-0.7961,-0.7882,-0.7804,-0.7725,-0.7647,-0.7569,-0.7490,-0.7412,-0.7333,-0.7255,-0.7176,-0.7098,-0.7020,-0.6941,-0.6863,-0.6784,-0.6706,-0.6627,-0.6549,-0.6471,-0.6392,-0.6314,-0.6235,-0.6157,-0.6078,-0.6000,-0.5922,-0.5843,-0.5765,-0.5686,-0.5608,-0.5529,-0.5451,-0.5373,-0.5294,-0.5216,-0.5137,-0.5059,-0.4980,-0.4902,-0.4824,-0.4745,-0.4667,-0.4588,-0.4510,-0.4431,-0.4353,-0.4275,-0.4196,-0.4118,-0.4039,-0.3961,-0.3882,-0.3804,-0.3725,-0.3647,-0.3569,-0.3490,-0.3412,-0.3333,-0.3255,-0.3176,-0.3098,-0.3020,-0.2941,-0.2863,-0.2784,-0.2706,-0.2627,-0.2549,-0.2471,-0.2392,-0.2314,-0.2235,-0.2157,-0.2078,-0.2000,-0.1922,-0.1843,-0.1765,-0.1686,-0.1608,-0.1529,-0.1451,-0.1373,-0.1294,-0.1216,-0.1137,-0.1059,-0.0980,-0.0902,-0.0824,-0.0745,-0.0667,-0.0588,-0.0510,-0.0431,-0.0353,-0.0275,-0.0196,-0.0118,-0.0039,0.0039,0.0118,0.0196,0.0275,0.0353,0.0431,0.0510,0.0588,0.0667,0.0745,0.0824,0.0902,0.0980,0.1059,0.1137,0.1216,0.1294,0.1373,0.1451,0.1529,0.1608,0.1686,0.1765,0.1843,0.1922,0.2000,0.2078,0.2157,0.2235,0.2314,0.2392,0.2471,0.2549,0.2627,0.2706,0.2784,0.2863,0.2941,0.3020,0.3098,0.3176,0.3255,0.3333,0.3412,0.3490,0.3569,0.3647,0.3725,0.3804,0.3882,0.3961,0.4039,0.4118,0.4196,0.4275,0.4353,0.4431,0.4510,0.4588,0.4667,0.4745,0.4824,0.4902,0.4980,0.5059,0.5137,0.5216,0.5294,0.5373,0.5451,0.5529,0.5608,0.5686,0.5765,0.5843,0.5922,0.6000,0.6078,0.6157,0.6235,0.6314,0.6392,0.6471,0.6549,0.6627,0.6706,0.6784,0.6863,0.6941,0.7020,0.7098,0.7176,0.7255,0.7333,0.7412,0.7490,0.7569,0.7647,0.7725,0.7804,0.7882,0.7961,0.8039,0.8118,0.8196,0.8275,0.8353,0.8431,0.8510,0.8588,0.8667,0.8745,0.8824,0.8902,0.8980,0.9059,0.9137,0.9216,0.9294,0.9373,0.9451,0.9529,0.9608,0.9686,0.9765,0.9843,0.9922,1.0000}:0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Required per-layer params specified: +// [unsigned] storageType, expressedType, scale +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Exponential scale (-) +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Exponential scale (+) +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Storage type: f8E5M2 +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Storage type: f8E4M3FN +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: f32 +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: f32 +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: f16 +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: f64 +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: bf16 +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Per-axis scales and zero points (affine) +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Per-axis scales and no zero points (fixedpoint) +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Per-axis scales and zero points (mixed affine and fixedpoint) +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +}