Skip to content

Commit 2f24fd9

Browse files
sartilsramasit
authored andcommitted
Extend Quant dialect with Quantile Quantization type (#53)
* Expanding Quant dialect with Quantile Quantized type * Adding quantile mlir tests * Adding check on quantiles array size and updated mlir tests
1 parent 6728aa1 commit 2f24fd9

File tree

10 files changed

+923
-16
lines changed

10 files changed

+923
-16
lines changed

mlir/include/mlir/Dialect/Quant/IR/QuantBase.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,14 @@ class quant_ScalarOrTensorOf<Type etype> :
226226

227227
def quant_QuantizedType :
228228
Type<CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">, "quantized type">;
229+
230+
def quant_QuantileQuantizedType :
231+
DialectType<Quant_Dialect,
232+
CPred<"::llvm::isa<mlir::quant::QuantileQuantizedType>($_self)">,
233+
"QuantileQuantizedType">;
234+
235+
def quant_QuantileQuantizedValueType :
236+
quant_ScalarOrTensorOf<quant_QuantileQuantizedType>;
229237

230238
def quant_ScalarType :
231239
Type<Or<[

mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,35 @@ def UniformQuantizedPerAxisType: DialectType<(type
8181
}];
8282
}
8383

84+
def QuantileQuantizedType: DialectType<(type
85+
VarInt:$flags,
86+
Type:$storageType,
87+
Type:$expressedType,
88+
Array<DoubleAPFloatList>:$quantiles,
89+
DoubleAPFloat:$scale,
90+
SignedVarInt:$zeroPoint,
91+
SignedVarInt:$storageTypeMin,
92+
SignedVarInt:$storageTypeMax
93+
)>;
94+
95+
def QuantileQuantizedPerAxisType: DialectType<(type
96+
VarInt:$flags,
97+
Type:$storageType,
98+
Type:$expressedType,
99+
VarInt:$quantizedDimension,
100+
SignedVarInt:$storageTypeMin,
101+
SignedVarInt:$storageTypeMax,
102+
Array<DoubleAPFloatList>:$quantiles,
103+
Array<DoubleAPFloatList>:$scales,
104+
Array<SignedVarIntList>:$zeroPoints
105+
)> {
106+
// Note: builder order differs from bytecode.
107+
let cBuilder = [{
108+
get<$_resultType>(context, flags, storageType, expressedType, quantiles, scales,
109+
zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax)
110+
}];
111+
}
112+
84113
/// This enum contains marker codes used to indicate which attribute is
85114
/// currently being decoded, and how it should be decoded. The order of these
86115
/// codes should generally be unchanged, as any changes will inevitably break
@@ -93,7 +122,9 @@ def QuantDialectTypes : DialectTypes<"Quant"> {
93122
AnyQuantizedTypeWithExpressedType,
94123
CalibratedQuantizedType,
95124
UniformQuantizedType,
96-
UniformQuantizedPerAxisType
125+
UniformQuantizedPerAxisType,
126+
QuantileQuantizedType,
127+
QuantileQuantizedPerAxisType
97128
];
98129
}
99130

mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ struct QuantizedTypeStorage;
2525
struct AnyQuantizedTypeStorage;
2626
struct UniformQuantizedTypeStorage;
2727
struct UniformQuantizedPerAxisTypeStorage;
28+
struct QuantileQuantizedTypeStorage;
29+
struct QuantileQuantizedPerAxisTypeStorage;
2830
struct CalibratedQuantizedTypeStorage;
2931

3032
} // namespace detail
@@ -394,6 +396,128 @@ class UniformQuantizedPerAxisType
394396
}
395397
};
396398

399+
/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a
400+
/// look up table array of quantile values.
401+
///
402+
/// Syntax synopsis:
403+
/// Per-layer, all parameters expressed:
404+
/// !quant<quantile[StorageType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
405+
/// Per-layer, optional parameters omitted:
406+
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
407+
///
408+
/// StorageType: 'i'|'u' NumBits
409+
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
410+
/// Quantiles: Quantile+
411+
/// Quantile: A legal double value
412+
/// Scale: A legal double value
413+
/// ZeroPoint: An integer value
414+
class QuantileQuantizedType
415+
: public Type::TypeBase<QuantileQuantizedType, UniformQuantizedType,
416+
detail::QuantileQuantizedTypeStorage> {
417+
public:
418+
using Base::Base;
419+
using Base::getChecked;
420+
421+
static constexpr StringLiteral name = "quant.quantile";
422+
423+
/// Gets an instance of the type with all parameters specified but not
424+
/// checked.
425+
static QuantileQuantizedType get(unsigned flags, Type storageType,
426+
Type expressedType,
427+
ArrayRef<double> quantiles, double scale,
428+
int64_t zeroPoint, int64_t storageTypeMin,
429+
int64_t storageTypeMax);
430+
431+
static QuantileQuantizedType
432+
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
433+
Type storageType, Type expressedType, ArrayRef<double> quantiles,
434+
double scale, int64_t zeroPoint, int64_t storageTypeMin,
435+
int64_t storageTypeMax);
436+
437+
/// Verifies construction invariants and issues errors/warnings.
438+
static LogicalResult verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
439+
unsigned flags, Type storageType,
440+
Type expressedType, ArrayRef<double> quantiles,
441+
double scale, int64_t zeroPoint,
442+
int64_t storageTypeMin, int64_t storageTypeMax);
443+
444+
/// Gets the quantile values
445+
ArrayRef<double> getQuantiles() const;
446+
447+
// Fixed point values are real numbers divided by a scale.
448+
// Currently, only signed storage types are treated as fixed point.
449+
// A fixed point value can be obtained from an affine value by subtracting
450+
// the zeroPoint.
451+
// In the future, this may be explicit versus implied by type and zeroPoint.
452+
bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
453+
};
454+
455+
/// Represents per-axis QuantileQuantizedType (also known as per-channel
456+
/// quantization).
457+
///
458+
/// Syntax synopsis:
459+
/// Per-axis, all parameters expressed:
460+
/// !quant<quantile[StorageType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
461+
/// Per-axis, optional parameters omitted:
462+
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
463+
///
464+
/// StorageType: 'i'|'u' NumBits
465+
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
466+
/// QuantizedDim: An integer value
467+
/// Quantiles: Quantile+
468+
/// Quantile: A legal double value
469+
/// QuantParams: (Scale ':' ZeroPoint)+
470+
/// Scale: A legal double value
471+
/// ZeroPoint: An integer value
472+
class QuantileQuantizedPerAxisType
473+
: public Type::TypeBase<QuantileQuantizedPerAxisType,
474+
UniformQuantizedPerAxisType,
475+
detail::QuantileQuantizedPerAxisTypeStorage> {
476+
public:
477+
using Base::Base;
478+
using Base::getChecked;
479+
480+
static constexpr StringLiteral name = "quant.quantile_per_axis";
481+
482+
/// Gets an instance of the type with all parameters specified but not
483+
/// checked.
484+
static QuantileQuantizedPerAxisType
485+
get(unsigned flags, Type storageType, Type expressedType,
486+
ArrayRef<double> quantiles, ArrayRef<double> scales,
487+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
488+
int64_t storageTypeMin, int64_t storageTypeMax);
489+
490+
/// Gets an instance of the type with all specified parameters checked.
491+
/// Returns a nullptr convertible type on failure.
492+
static QuantileQuantizedPerAxisType
493+
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
494+
Type storageType, Type expressedType, ArrayRef<double> quantiles,
495+
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
496+
int32_t quantizedDimension, int64_t storageTypeMin,
497+
int64_t storageTypeMax);
498+
499+
/// Verifies construction invariants and issues errors/warnings.
500+
static LogicalResult verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
501+
unsigned flags, Type storageType,
502+
Type expressedType, ArrayRef<double> quantiles,
503+
ArrayRef<double> scales,
504+
ArrayRef<int64_t> zeroPoints,
505+
int32_t quantizedDimension,
506+
int64_t storageTypeMin, int64_t storageTypeMax);
507+
508+
/// Gets the quantile values
509+
ArrayRef<double> getQuantiles() const;
510+
511+
/// Fixed point values are real numbers divided by a scale.
512+
/// Currently, only signed storage types are treated as fixed point.
513+
/// A fixed point value can be obtained from an affine value by subtracting
514+
/// the zeroPoint.
515+
/// In the future, this may be explicit versus implied by type and zeroPoint.
516+
bool isFixedPoint() const {
517+
return isSigned() && !llvm::is_contained(getZeroPoints(), 0);
518+
}
519+
};
520+
397521
/// A quantized type that infers its range from given min/max values.
398522
///
399523
/// Typical syntax:

mlir/lib/Dialect/Quant/IR/QuantOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
9393

9494
void QuantDialect::initialize() {
9595
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
96-
UniformQuantizedPerAxisType>();
96+
UniformQuantizedPerAxisType, QuantileQuantizedType,
97+
QuantileQuantizedPerAxisType>();
9798
addOperations<
9899
#define GET_OP_LIST
99100
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,138 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
420420
return getImpl()->quantizedDimension;
421421
}
422422

423+
QuantileQuantizedType
424+
QuantileQuantizedType::get(unsigned flags, Type storageType, Type expressedType,
425+
ArrayRef<double> quantiles, double scale,
426+
int64_t zeroPoint, int64_t storageTypeMin,
427+
int64_t storageTypeMax) {
428+
return Base::get(storageType.getContext(), flags, storageType, expressedType,
429+
quantiles, scale, zeroPoint, storageTypeMin, storageTypeMax);
430+
}
431+
432+
QuantileQuantizedType QuantileQuantizedType::getChecked(
433+
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
434+
Type storageType, Type expressedType, ArrayRef<double> quantiles,
435+
double scale, int64_t zeroPoint, int64_t storageTypeMin,
436+
int64_t storageTypeMax) {
437+
return Base::getChecked(emitError, storageType.getContext(), flags,
438+
storageType, expressedType, quantiles, scale,
439+
zeroPoint, storageTypeMin, storageTypeMax);
440+
}
441+
LogicalResult
442+
QuantileQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
443+
unsigned flags, Type storageType,
444+
Type expressedType, ArrayRef<double> quantiles,
445+
double scale, int64_t zeroPoint,
446+
int64_t storageTypeMin, int64_t storageTypeMax) {
447+
if (failed(UniformQuantizedType::verifyInvariants(emitError, flags, storageType,
448+
expressedType, scale, zeroPoint,
449+
storageTypeMin, storageTypeMax))) {
450+
return failure();
451+
}
452+
453+
const auto quantileArraySize = quantiles.size();
454+
unsigned typeWidth{};
455+
if (storageType.isa<IntegerType>()) {
456+
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
457+
} else if (storageType.isa<Float8E5M2Type>() ||
458+
storageType.isa<Float8E4M3FNType>()) {
459+
// Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
460+
typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth();
461+
} else {
462+
return emitError() << "illegal storage type, supported types are: integral "
463+
"types, Float8E4M3FNType and Float8E5M2Type ";
464+
}
465+
466+
const size_t expectedSize = 1 << typeWidth;
467+
if (quantileArraySize != expectedSize) {
468+
return emitError() << "quantiles array size needs to be equal to "
469+
"2^(bit_size(storageType)), expected: "
470+
<< expectedSize << ", found: " << quantileArraySize;
471+
}
472+
473+
// Verify quantiles
474+
for (double quantile : quantiles) {
475+
if (std::isinf(quantile) || std::isnan(quantile)) {
476+
return emitError() << "illegal quantile value: " << quantile;
477+
}
478+
}
479+
480+
return success();
481+
}
482+
483+
ArrayRef<double> QuantileQuantizedType::getQuantiles() const {
484+
return getImpl()->getQuantiles();
485+
}
486+
487+
QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get(
488+
unsigned flags, Type storageType, Type expressedType,
489+
ArrayRef<double> quantiles, ArrayRef<double> scales,
490+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
491+
int64_t storageTypeMin, int64_t storageTypeMax) {
492+
return Base::get(storageType.getContext(), flags, storageType, expressedType,
493+
quantiles, scales, zeroPoints, quantizedDimension,
494+
storageTypeMin, storageTypeMax);
495+
}
496+
497+
QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked(
498+
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
499+
Type storageType, Type expressedType, ArrayRef<double> quantiles,
500+
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
501+
int32_t quantizedDimension, int64_t storageTypeMin,
502+
int64_t storageTypeMax) {
503+
return Base::getChecked(emitError, storageType.getContext(), flags,
504+
storageType, expressedType, quantiles, scales,
505+
zeroPoints, quantizedDimension, storageTypeMin,
506+
storageTypeMax);
507+
}
508+
509+
LogicalResult QuantileQuantizedPerAxisType::verifyInvariants(
510+
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
511+
Type storageType, Type expressedType, ArrayRef<double> quantiles,
512+
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
513+
int32_t quantizedDimension, int64_t storageTypeMin,
514+
int64_t storageTypeMax) {
515+
if (failed(UniformQuantizedPerAxisType::verifyInvariants(
516+
emitError, flags, storageType, expressedType, scales, zeroPoints,
517+
quantizedDimension, storageTypeMin, storageTypeMax))) {
518+
return failure();
519+
}
520+
521+
const auto quantileArraySize = quantiles.size();
522+
unsigned typeWidth{};
523+
if (storageType.isa<IntegerType>()) {
524+
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
525+
} else if (storageType.isa<Float8E5M2Type>() ||
526+
storageType.isa<Float8E4M3FNType>()) {
527+
// Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
528+
typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth();
529+
} else {
530+
return emitError() << "illegal storage type, supported types are: integral "
531+
"types, Float8E4M3FNType and Float8E5M2Type ";
532+
}
533+
534+
const size_t expectedSize = 1 << typeWidth;
535+
if (quantileArraySize != expectedSize) {
536+
return emitError() << "quantiles array size needs to be equal to "
537+
"2^(bit_size(storageType)), expected: "
538+
<< expectedSize << ", found: " << quantileArraySize;
539+
}
540+
541+
// Verify quantiles
542+
for (double quantile : quantiles) {
543+
if (std::isinf(quantile) || std::isnan(quantile)) {
544+
return emitError() << "illegal quantile value: " << quantile;
545+
}
546+
}
547+
548+
return success();
549+
}
550+
551+
ArrayRef<double> QuantileQuantizedPerAxisType::getQuantiles() const {
552+
return getImpl()->getQuantiles();
553+
}
554+
423555
CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
424556
double min, double max) {
425557
return Base::get(expressedType.getContext(), expressedType, min, max);

0 commit comments

Comments
 (0)