Skip to content

Commit 20c3bab

Browse files
authored
Extend Quant dialect with Quantile Quantization type (#53)
* Expanding Quant dialect with Quantile Quantized type * Adding quantile mlir tests * Adding check on quantiles array size and updated mlir tests
1 parent 7892f00 commit 20c3bab

File tree

10 files changed

+925
-16
lines changed

10 files changed

+925
-16
lines changed

mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,35 @@ def UniformQuantizedPerAxisType: DialectType<(type
8181
}];
8282
}
8383

84+
def QuantileQuantizedType: DialectType<(type
85+
VarInt:$flags,
86+
Type:$storageType,
87+
Type:$expressedType,
88+
Array<DoubleAPFloatList>:$quantiles,
89+
DoubleAPFloat:$scale,
90+
SignedVarInt:$zeroPoint,
91+
SignedVarInt:$storageTypeMin,
92+
SignedVarInt:$storageTypeMax
93+
)>;
94+
95+
def QuantileQuantizedPerAxisType: DialectType<(type
96+
VarInt:$flags,
97+
Type:$storageType,
98+
Type:$expressedType,
99+
VarInt:$quantizedDimension,
100+
SignedVarInt:$storageTypeMin,
101+
SignedVarInt:$storageTypeMax,
102+
Array<DoubleAPFloatList>:$quantiles,
103+
Array<DoubleAPFloatList>:$scales,
104+
Array<SignedVarIntList>:$zeroPoints
105+
)> {
106+
// Note: builder order differs from bytecode.
107+
let cBuilder = [{
108+
get<$_resultType>(context, flags, storageType, expressedType, quantiles, scales,
109+
zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax)
110+
}];
111+
}
112+
84113
/// This enum contains marker codes used to indicate which attribute is
85114
/// currently being decoded, and how it should be decoded. The order of these
86115
/// codes should generally be unchanged, as any changes will inevitably break
@@ -93,7 +122,9 @@ def QuantDialectTypes : DialectTypes<"Quant"> {
93122
AnyQuantizedTypeWithExpressedType,
94123
CalibratedQuantizedType,
95124
UniformQuantizedType,
96-
UniformQuantizedPerAxisType
125+
UniformQuantizedPerAxisType,
126+
QuantileQuantizedType,
127+
QuantileQuantizedPerAxisType
97128
];
98129
}
99130

mlir/include/mlir/Dialect/Quant/QuantOpsBase.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,18 @@ def quant_UniformQuantizedType :
6767
CPred<"::llvm::isa<UniformQuantizedType>($_self)">,
6868
"UniformQuantizedType">;
6969

70+
// An implementation of QuantileQuantizedType.
71+
def quant_QuantileQuantizedType :
72+
DialectType<Quantization_Dialect,
73+
CPred<"::llvm::isa<QuantileQuantizedType>($_self)">,
74+
"QuantileQuantizedType">;
75+
7076
// Predicate for detecting a container or primitive of UniformQuantizedType.
7177
def quant_UniformQuantizedValueType :
7278
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
7379

80+
// Predicate for detecting a container or primitive of QuantileQuantizedType.
81+
def quant_QuantileQuantizedValueType :
82+
quant_TypedPrimitiveOrContainer<quant_QuantileQuantizedType>;
83+
7484
#endif // DIALECT_QUANT_QUANT_OPS_BASE_

mlir/include/mlir/Dialect/Quant/QuantTypes.h

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ struct QuantizedTypeStorage;
2525
struct AnyQuantizedTypeStorage;
2626
struct UniformQuantizedTypeStorage;
2727
struct UniformQuantizedPerAxisTypeStorage;
28+
struct QuantileQuantizedTypeStorage;
29+
struct QuantileQuantizedPerAxisTypeStorage;
2830
struct CalibratedQuantizedTypeStorage;
2931

3032
} // namespace detail
@@ -390,6 +392,128 @@ class UniformQuantizedPerAxisType
390392
}
391393
};
392394

395+
/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a
396+
/// look up table array of quantile values.
397+
///
398+
/// Syntax synopsis:
399+
/// Per-layer, all parameters expressed:
400+
/// !quant<quantile[StorageType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
401+
/// Per-layer, optional parameters omitted:
402+
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
403+
///
404+
/// StorageType: 'i'|'u' NumBits
405+
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
406+
/// Quantiles: Quantile+
407+
/// Quantile: A legal double value
408+
/// Scale: A legal double value
409+
/// ZeroPoint: An integer value
410+
class QuantileQuantizedType
411+
: public Type::TypeBase<QuantileQuantizedType, UniformQuantizedType,
412+
detail::QuantileQuantizedTypeStorage> {
413+
public:
414+
using Base::Base;
415+
using Base::getChecked;
416+
417+
static constexpr StringLiteral name = "quant.quantile";
418+
419+
/// Gets an instance of the type with all parameters specified but not
420+
/// checked.
421+
static QuantileQuantizedType get(unsigned flags, Type storageType,
422+
Type expressedType,
423+
ArrayRef<double> quantiles, double scale,
424+
int64_t zeroPoint, int64_t storageTypeMin,
425+
int64_t storageTypeMax);
426+
427+
static QuantileQuantizedType
428+
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
429+
Type storageType, Type expressedType, ArrayRef<double> quantiles,
430+
double scale, int64_t zeroPoint, int64_t storageTypeMin,
431+
int64_t storageTypeMax);
432+
433+
/// Verifies construction invariants and issues errors/warnings.
434+
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
435+
unsigned flags, Type storageType,
436+
Type expressedType, ArrayRef<double> quantiles,
437+
double scale, int64_t zeroPoint,
438+
int64_t storageTypeMin, int64_t storageTypeMax);
439+
440+
/// Gets the quantile values
441+
ArrayRef<double> getQuantiles() const;
442+
443+
// Fixed point values are real numbers divided by a scale.
444+
// Currently, only signed storage types are treated as fixed point.
445+
// A fixed point value can be obtained from an affine value by subtracting
446+
// the zeroPoint.
447+
// In the future, this may be explicit versus implied by type and zeroPoint.
448+
bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
449+
};
450+
451+
/// Represents per-axis QuantileQuantizedType (also known as per-channel
452+
/// quantization).
453+
///
454+
/// Syntax synopsis:
455+
/// Per-axis, all parameters expressed:
456+
/// !quant<quantile[StorageType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
457+
/// Per-axis, optional parameters omitted:
458+
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
459+
///
460+
/// StorageType: 'i'|'u' NumBits
461+
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
462+
/// QuantizedDim: An integer value
463+
/// Quantiles: Quantile+
464+
/// Quantile: A legal double value
465+
/// QuantParams: (Scale ':' ZeroPoint)+
466+
/// Scale: A legal double value
467+
/// ZeroPoint: An integer value
468+
class QuantileQuantizedPerAxisType
469+
: public Type::TypeBase<QuantileQuantizedPerAxisType,
470+
UniformQuantizedPerAxisType,
471+
detail::QuantileQuantizedPerAxisTypeStorage> {
472+
public:
473+
using Base::Base;
474+
using Base::getChecked;
475+
476+
static constexpr StringLiteral name = "quant.quantile_per_axis";
477+
478+
/// Gets an instance of the type with all parameters specified but not
479+
/// checked.
480+
static QuantileQuantizedPerAxisType
481+
get(unsigned flags, Type storageType, Type expressedType,
482+
ArrayRef<double> quantiles, ArrayRef<double> scales,
483+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
484+
int64_t storageTypeMin, int64_t storageTypeMax);
485+
486+
/// Gets an instance of the type with all specified parameters checked.
487+
/// Returns a nullptr convertible type on failure.
488+
static QuantileQuantizedPerAxisType
489+
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
490+
Type storageType, Type expressedType, ArrayRef<double> quantiles,
491+
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
492+
int32_t quantizedDimension, int64_t storageTypeMin,
493+
int64_t storageTypeMax);
494+
495+
/// Verifies construction invariants and issues errors/warnings.
496+
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
497+
unsigned flags, Type storageType,
498+
Type expressedType, ArrayRef<double> quantiles,
499+
ArrayRef<double> scales,
500+
ArrayRef<int64_t> zeroPoints,
501+
int32_t quantizedDimension,
502+
int64_t storageTypeMin, int64_t storageTypeMax);
503+
504+
/// Gets the quantile values
505+
ArrayRef<double> getQuantiles() const;
506+
507+
/// Fixed point values are real numbers divided by a scale.
508+
/// Currently, only signed storage types are treated as fixed point.
509+
/// A fixed point value can be obtained from an affine value by subtracting
510+
/// the zeroPoint.
511+
/// In the future, this may be explicit versus implied by type and zeroPoint.
512+
bool isFixedPoint() const {
513+
return isSigned() && !llvm::is_contained(getZeroPoints(), 0);
514+
}
515+
};
516+
393517
/// A quantized type that infers its range from given min/max values.
394518
///
395519
/// Typical syntax:

mlir/lib/Dialect/Quant/IR/QuantOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ using namespace mlir::quant::detail;
2828

2929
void QuantizationDialect::initialize() {
3030
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
31-
UniformQuantizedPerAxisType>();
31+
UniformQuantizedPerAxisType, QuantileQuantizedType,
32+
QuantileQuantizedPerAxisType>();
3233
addOperations<
3334
#define GET_OP_LIST
3435
#include "mlir/Dialect/Quant/QuantOps.cpp.inc"

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,138 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
378378
return getImpl()->quantizedDimension;
379379
}
380380

381+
QuantileQuantizedType
382+
QuantileQuantizedType::get(unsigned flags, Type storageType, Type expressedType,
383+
ArrayRef<double> quantiles, double scale,
384+
int64_t zeroPoint, int64_t storageTypeMin,
385+
int64_t storageTypeMax) {
386+
return Base::get(storageType.getContext(), flags, storageType, expressedType,
387+
quantiles, scale, zeroPoint, storageTypeMin, storageTypeMax);
388+
}
389+
390+
QuantileQuantizedType QuantileQuantizedType::getChecked(
391+
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
392+
Type storageType, Type expressedType, ArrayRef<double> quantiles,
393+
double scale, int64_t zeroPoint, int64_t storageTypeMin,
394+
int64_t storageTypeMax) {
395+
return Base::getChecked(emitError, storageType.getContext(), flags,
396+
storageType, expressedType, quantiles, scale,
397+
zeroPoint, storageTypeMin, storageTypeMax);
398+
}
399+
LogicalResult
400+
QuantileQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
401+
unsigned flags, Type storageType,
402+
Type expressedType, ArrayRef<double> quantiles,
403+
double scale, int64_t zeroPoint,
404+
int64_t storageTypeMin, int64_t storageTypeMax) {
405+
if (failed(UniformQuantizedType::verify(emitError, flags, storageType,
406+
expressedType, scale, zeroPoint,
407+
storageTypeMin, storageTypeMax))) {
408+
return failure();
409+
}
410+
411+
const auto quantileArraySize = quantiles.size();
412+
unsigned typeWidth{};
413+
if (storageType.isa<IntegerType>()) {
414+
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
415+
} else if (storageType.isa<Float8E5M2Type>() ||
416+
storageType.isa<Float8E4M3FNType>()) {
417+
// Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
418+
typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth();
419+
} else {
420+
return emitError() << "illegal storage type, supported types are: integral "
421+
"types, Float8E4M3FNType and Float8E5M2Type ";
422+
}
423+
424+
const size_t expectedSize = 1 << typeWidth;
425+
if (quantileArraySize != expectedSize) {
426+
return emitError() << "quantiles array size needs to be equal to "
427+
"2^(bit_size(storageType)), expected: "
428+
<< expectedSize << ", found: " << quantileArraySize;
429+
}
430+
431+
// Verify quantiles
432+
for (double quantile : quantiles) {
433+
if (std::isinf(quantile) || std::isnan(quantile)) {
434+
return emitError() << "illegal quantile value: " << quantile;
435+
}
436+
}
437+
438+
return success();
439+
}
440+
441+
ArrayRef<double> QuantileQuantizedType::getQuantiles() const {
442+
return getImpl()->getQuantiles();
443+
}
444+
445+
QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get(
446+
unsigned flags, Type storageType, Type expressedType,
447+
ArrayRef<double> quantiles, ArrayRef<double> scales,
448+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
449+
int64_t storageTypeMin, int64_t storageTypeMax) {
450+
return Base::get(storageType.getContext(), flags, storageType, expressedType,
451+
quantiles, scales, zeroPoints, quantizedDimension,
452+
storageTypeMin, storageTypeMax);
453+
}
454+
455+
QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked(
456+
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
457+
Type storageType, Type expressedType, ArrayRef<double> quantiles,
458+
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
459+
int32_t quantizedDimension, int64_t storageTypeMin,
460+
int64_t storageTypeMax) {
461+
return Base::getChecked(emitError, storageType.getContext(), flags,
462+
storageType, expressedType, quantiles, scales,
463+
zeroPoints, quantizedDimension, storageTypeMin,
464+
storageTypeMax);
465+
}
466+
467+
LogicalResult QuantileQuantizedPerAxisType::verify(
468+
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
469+
Type storageType, Type expressedType, ArrayRef<double> quantiles,
470+
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
471+
int32_t quantizedDimension, int64_t storageTypeMin,
472+
int64_t storageTypeMax) {
473+
if (failed(UniformQuantizedPerAxisType::verify(
474+
emitError, flags, storageType, expressedType, scales, zeroPoints,
475+
quantizedDimension, storageTypeMin, storageTypeMax))) {
476+
return failure();
477+
}
478+
479+
const auto quantileArraySize = quantiles.size();
480+
unsigned typeWidth{};
481+
if (storageType.isa<IntegerType>()) {
482+
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
483+
} else if (storageType.isa<Float8E5M2Type>() ||
484+
storageType.isa<Float8E4M3FNType>()) {
485+
// Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
486+
typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth();
487+
} else {
488+
return emitError() << "illegal storage type, supported types are: integral "
489+
"types, Float8E4M3FNType and Float8E5M2Type ";
490+
}
491+
492+
const size_t expectedSize = 1 << typeWidth;
493+
if (quantileArraySize != expectedSize) {
494+
return emitError() << "quantiles array size needs to be equal to "
495+
"2^(bit_size(storageType)), expected: "
496+
<< expectedSize << ", found: " << quantileArraySize;
497+
}
498+
499+
// Verify quantiles
500+
for (double quantile : quantiles) {
501+
if (std::isinf(quantile) || std::isnan(quantile)) {
502+
return emitError() << "illegal quantile value: " << quantile;
503+
}
504+
}
505+
506+
return success();
507+
}
508+
509+
ArrayRef<double> QuantileQuantizedPerAxisType::getQuantiles() const {
510+
return getImpl()->getQuantiles();
511+
}
512+
381513
CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
382514
double min, double max) {
383515
return Base::get(expressedType.getContext(), expressedType, min, max);

0 commit comments

Comments
 (0)