Skip to content

Commit b3ccdd1

Browse files
authored
Extending QuantileQuantizedType with quantileType mlir::Type member (#60)
1 parent 8f27842 commit b3ccdd1

File tree

9 files changed

+365
-165
lines changed

9 files changed

+365
-165
lines changed

mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def UniformQuantizedPerAxisType: DialectType<(type
8484
def QuantileQuantizedType: DialectType<(type
8585
VarInt:$flags,
8686
Type:$storageType,
87+
Type:$quantileType,
8788
Type:$expressedType,
8889
Array<DoubleAPFloatList>:$quantiles,
8990
DoubleAPFloat:$scale,
@@ -95,6 +96,7 @@ def QuantileQuantizedType: DialectType<(type
9596
def QuantileQuantizedPerAxisType: DialectType<(type
9697
VarInt:$flags,
9798
Type:$storageType,
99+
Type:$quantileType,
98100
Type:$expressedType,
99101
VarInt:$quantizedDimension,
100102
SignedVarInt:$storageTypeMin,
@@ -105,7 +107,7 @@ def QuantileQuantizedPerAxisType: DialectType<(type
105107
)> {
106108
// Note: builder order differs from bytecode.
107109
let cBuilder = [{
108-
get<$_resultType>(context, flags, storageType, expressedType, quantiles, scales,
110+
get<$_resultType>(context, flags, storageType, quantileType, expressedType, quantiles, scales,
109111
zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax)
110112
}];
111113
}

mlir/include/mlir/Dialect/Quant/QuantOpsBase.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,38 @@ def quant_UniformQuantizedType :
6767
CPred<"::llvm::isa<UniformQuantizedType>($_self)">,
6868
"UniformQuantizedType">;
6969

70+
// An implementation of UniformQuantizedPerAxisType.
71+
def quant_UniformQuantizedPerAxisType :
72+
DialectType<Quantization_Dialect,
73+
CPred<"::llvm::isa<::mlir::quant::UniformQuantizedPerAxisType>($_self)">,
74+
"UniformQuantizedPerAxisType">;
75+
7076
// An implementation of QuantileQuantizedType.
7177
def quant_QuantileQuantizedType :
7278
DialectType<Quantization_Dialect,
7379
CPred<"::llvm::isa<QuantileQuantizedType>($_self)">,
7480
"QuantileQuantizedType">;
7581

82+
// An implementation of QuantileQuantizedPerAxisType.
83+
def quant_QuantileQuantizedPerAxisType :
84+
DialectType<Quantization_Dialect,
85+
CPred<"::llvm::isa<QuantileQuantizedPerAxisType>($_self)">,
86+
"QuantileQuantizedPerAxisType">;
87+
7688
// Predicate for detecting a container or primitive of UniformQuantizedType.
7789
def quant_UniformQuantizedValueType :
7890
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
7991

92+
// Predicate for detecting a container or primitive of UniformQuantizedPerAxisType.
93+
def quant_UniformQuantizedPerAxisValueType :
94+
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedPerAxisType>;
95+
8096
// Predicate for detecting a container or primitive of QuantileQuantizedType.
8197
def quant_QuantileQuantizedValueType :
8298
quant_TypedPrimitiveOrContainer<quant_QuantileQuantizedType>;
8399

100+
// Predicate for detecting a container or primitive of QuantileQuantizedPerAxisType.
101+
def quant_QuantileQuantizedPerAxisValueType :
102+
quant_TypedPrimitiveOrContainer<quant_QuantileQuantizedPerAxisType>;
103+
84104
#endif // DIALECT_QUANT_QUANT_OPS_BASE_

mlir/include/mlir/Dialect/Quant/QuantTypes.h

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ class UniformQuantizedType
296296
int64_t zeroPoint, int64_t storageTypeMin,
297297
int64_t storageTypeMax);
298298

299+
static bool classof(mlir::Type type);
300+
299301
/// Gets the scale term. The scale designates the difference between the real
300302
/// values corresponding to consecutive quantized values differing by 1.
301303
double getScale() const;
@@ -359,6 +361,8 @@ class UniformQuantizedPerAxisType
359361
int32_t quantizedDimension,
360362
int64_t storageTypeMin, int64_t storageTypeMax);
361363

364+
static bool classof(mlir::Type type);
365+
362366
/// Gets the quantization scales. The scales designate the difference between
363367
/// the real values corresponding to consecutive quantized values differing
364368
/// by 1. The ith scale corresponds to the ith slice in the
@@ -393,15 +397,17 @@ class UniformQuantizedPerAxisType
393397
};
394398

395399
/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a
396-
/// look up table array of quantile values.
400+
/// look up table array of quantile values. The type of the data in the look up table is determined by
401+
/// the quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64.
397402
///
398403
/// Syntax synopsis:
399404
/// Per-layer, all parameters expressed:
400-
/// !quant<quantile[StorageType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
405+
/// !quant<quantile[StorageType:QuantileType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
401406
/// Per-layer, optional parameters omitted:
402-
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
407+
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
403408
///
404409
/// StorageType: 'i'|'u' NumBits
410+
/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
405411
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
406412
/// Quantiles: Quantile+
407413
/// Quantile: A legal double value
@@ -419,23 +425,32 @@ class QuantileQuantizedType
419425
/// Gets an instance of the type with all parameters specified but not
420426
/// checked.
421427
static QuantileQuantizedType get(unsigned flags, Type storageType,
422-
Type expressedType,
428+
Type quantileType, Type expressedType,
423429
ArrayRef<double> quantiles, double scale,
424430
int64_t zeroPoint, int64_t storageTypeMin,
425431
int64_t storageTypeMax);
426432

427433
static QuantileQuantizedType
428434
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
429-
Type storageType, Type expressedType, ArrayRef<double> quantiles,
430-
double scale, int64_t zeroPoint, int64_t storageTypeMin,
431-
int64_t storageTypeMax);
435+
Type storageType, Type quantileType, Type expressedType,
436+
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
437+
int64_t storageTypeMin, int64_t storageTypeMax);
432438

433439
/// Verifies construction invariants and issues errors/warnings.
434440
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
435441
unsigned flags, Type storageType,
436-
Type expressedType, ArrayRef<double> quantiles,
437-
double scale, int64_t zeroPoint,
438-
int64_t storageTypeMin, int64_t storageTypeMax);
442+
Type quantileType, Type expressedType,
443+
ArrayRef<double> quantiles, double scale,
444+
int64_t zeroPoint, int64_t storageTypeMin,
445+
int64_t storageTypeMax);
446+
447+
static bool classof(mlir::Type type);
448+
449+
/// Gets the quantileType
450+
Type getQuantileType() const;
451+
452+
/// Gets the quantileType bit width
453+
unsigned getQuantileTypeIntegralWidth() const;
439454

440455
/// Gets the quantile values
441456
ArrayRef<double> getQuantiles() const;
@@ -449,15 +464,17 @@ class QuantileQuantizedType
449464
};
450465

451466
/// Represents per-axis QuantileQuantizedType (also known as per-channel
452-
/// quantization).
467+
/// quantization). The type of the data in the look up table is determined by the
468+
/// quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64.
453469
///
454470
/// Syntax synopsis:
455471
/// Per-axis, all parameters expressed:
456-
/// !quant<quantile[StorageType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
472+
/// !quant<quantile[StorageType:QuantileType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
457473
/// Per-axis, optional parameters omitted:
458-
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
474+
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
459475
///
460476
/// StorageType: 'i'|'u' NumBits
477+
/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
461478
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
462479
/// QuantizedDim: An integer value
463480
/// Quantiles: Quantile+
@@ -478,7 +495,7 @@ class QuantileQuantizedPerAxisType
478495
/// Gets an instance of the type with all parameters specified but not
479496
/// checked.
480497
static QuantileQuantizedPerAxisType
481-
get(unsigned flags, Type storageType, Type expressedType,
498+
get(unsigned flags, Type storageType, Type quantileType, Type expressedType,
482499
ArrayRef<double> quantiles, ArrayRef<double> scales,
483500
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
484501
int64_t storageTypeMin, int64_t storageTypeMax);
@@ -487,19 +504,26 @@ class QuantileQuantizedPerAxisType
487504
/// Returns a nullptr convertible type on failure.
488505
static QuantileQuantizedPerAxisType
489506
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
490-
Type storageType, Type expressedType, ArrayRef<double> quantiles,
491-
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
492-
int32_t quantizedDimension, int64_t storageTypeMin,
493-
int64_t storageTypeMax);
507+
Type storageType, Type quantileType, Type expressedType,
508+
ArrayRef<double> quantiles, ArrayRef<double> scales,
509+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
510+
int64_t storageTypeMin, int64_t storageTypeMax);
494511

495512
/// Verifies construction invariants and issues errors/warnings.
496-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
497-
unsigned flags, Type storageType,
498-
Type expressedType, ArrayRef<double> quantiles,
499-
ArrayRef<double> scales,
500-
ArrayRef<int64_t> zeroPoints,
501-
int32_t quantizedDimension,
502-
int64_t storageTypeMin, int64_t storageTypeMax);
513+
static LogicalResult
514+
verify(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
515+
Type storageType, Type quantileType, Type expressedType,
516+
ArrayRef<double> quantiles, ArrayRef<double> scales,
517+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
518+
int64_t storageTypeMin, int64_t storageTypeMax);
519+
520+
static bool classof(mlir::Type type);
521+
522+
/// Gets the quantileType
523+
Type getQuantileType() const;
524+
525+
/// Gets the quantileType bit width
526+
unsigned getQuantileTypeIntegralWidth() const;
503527

504528
/// Gets the quantile values
505529
ArrayRef<double> getQuantiles() const;

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

Lines changed: 75 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,11 @@ LogicalResult UniformQuantizedType::verify(
305305
return success();
306306
}
307307

308+
bool UniformQuantizedType::classof(mlir::Type type) {
309+
return type.getTypeID() == mlir::TypeID::get<UniformQuantizedType>() ||
310+
type.getTypeID() == mlir::TypeID::get<QuantileQuantizedType>();
311+
}
312+
308313
double UniformQuantizedType::getScale() const { return getImpl()->scale; }
309314

310315
int64_t UniformQuantizedType::getZeroPoint() const {
@@ -366,6 +371,11 @@ LogicalResult UniformQuantizedPerAxisType::verify(
366371
return success();
367372
}
368373

374+
bool UniformQuantizedPerAxisType::classof(mlir::Type type) {
375+
return type.getTypeID() == mlir::TypeID::get<UniformQuantizedPerAxisType>() ||
376+
type.getTypeID() == mlir::TypeID::get<QuantileQuantizedPerAxisType>();
377+
}
378+
369379
ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
370380
return getImpl()->getScales();
371381
}
@@ -379,36 +389,35 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
379389
}
380390

381391
QuantileQuantizedType
382-
QuantileQuantizedType::get(unsigned flags, Type storageType, Type expressedType,
383-
ArrayRef<double> quantiles, double scale,
384-
int64_t zeroPoint, int64_t storageTypeMin,
385-
int64_t storageTypeMax) {
386-
return Base::get(storageType.getContext(), flags, storageType, expressedType,
387-
quantiles, scale, zeroPoint, storageTypeMin, storageTypeMax);
392+
QuantileQuantizedType::get(unsigned flags, Type storageType, Type quantileType,
393+
Type expressedType, ArrayRef<double> quantiles,
394+
double scale, int64_t zeroPoint,
395+
int64_t storageTypeMin, int64_t storageTypeMax) {
396+
return Base::get(storageType.getContext(), flags, storageType, quantileType,
397+
expressedType, quantiles, scale, zeroPoint, storageTypeMin,
398+
storageTypeMax);
388399
}
389400

390401
QuantileQuantizedType QuantileQuantizedType::getChecked(
391402
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
392-
Type storageType, Type expressedType, ArrayRef<double> quantiles,
393-
double scale, int64_t zeroPoint, int64_t storageTypeMin,
394-
int64_t storageTypeMax) {
403+
Type storageType, Type quantileType, Type expressedType,
404+
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
405+
int64_t storageTypeMin, int64_t storageTypeMax) {
395406
return Base::getChecked(emitError, storageType.getContext(), flags,
396-
storageType, expressedType, quantiles, scale,
397-
zeroPoint, storageTypeMin, storageTypeMax);
407+
storageType, quantileType, expressedType, quantiles,
408+
scale, zeroPoint, storageTypeMin, storageTypeMax);
398409
}
399-
LogicalResult
400-
QuantileQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
401-
unsigned flags, Type storageType,
402-
Type expressedType, ArrayRef<double> quantiles,
403-
double scale, int64_t zeroPoint,
404-
int64_t storageTypeMin, int64_t storageTypeMax) {
410+
LogicalResult QuantileQuantizedType::verify(
411+
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
412+
Type storageType, Type quantileType, Type expressedType,
413+
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
414+
int64_t storageTypeMin, int64_t storageTypeMax) {
405415
if (failed(UniformQuantizedType::verify(emitError, flags, storageType,
406416
expressedType, scale, zeroPoint,
407417
storageTypeMin, storageTypeMax))) {
408418
return failure();
409419
}
410420

411-
const auto quantileArraySize = quantiles.size();
412421
unsigned typeWidth{};
413422
if (storageType.isa<IntegerType>()) {
414423
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
@@ -421,10 +430,17 @@ QuantileQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
421430
"types, Float8E4M3FNType and Float8E5M2Type ";
422431
}
423432

424-
const size_t expectedSize = 1 << typeWidth;
433+
const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1;
434+
const size_t typeWidthSize = 1 << typeWidth;
435+
const size_t expectedSize =
436+
(storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize;
437+
438+
const auto quantileArraySize = quantiles.size();
425439
if (quantileArraySize != expectedSize) {
426440
return emitError() << "quantiles array size needs to be equal to "
427-
"2^(bit_size(storageType)), expected: "
441+
"2^(bit_size(storageType)), or (storageTypeMax - "
442+
"storageTypeMin + 1) when max and min differ from "
443+
"the type limits; expected: "
428444
<< expectedSize << ", found: " << quantileArraySize;
429445
}
430446

@@ -438,38 +454,50 @@ QuantileQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
438454
return success();
439455
}
440456

457+
bool QuantileQuantizedType::classof(mlir::Type type) {
458+
return type.getTypeID() == mlir::TypeID::get<QuantileQuantizedType>();
459+
}
460+
461+
Type QuantileQuantizedType::getQuantileType() const {
462+
return getImpl()->quantileType;
463+
}
464+
465+
unsigned QuantileQuantizedType::getQuantileTypeIntegralWidth() const {
466+
return getImpl()->getQuantileType().getIntOrFloatBitWidth();
467+
}
468+
441469
ArrayRef<double> QuantileQuantizedType::getQuantiles() const {
442470
return getImpl()->getQuantiles();
443471
}
444472

445473
QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get(
446-
unsigned flags, Type storageType, Type expressedType,
474+
unsigned flags, Type storageType, Type quantileType, Type expressedType,
447475
ArrayRef<double> quantiles, ArrayRef<double> scales,
448476
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
449477
int64_t storageTypeMin, int64_t storageTypeMax) {
450-
return Base::get(storageType.getContext(), flags, storageType, expressedType,
451-
quantiles, scales, zeroPoints, quantizedDimension,
452-
storageTypeMin, storageTypeMax);
478+
return Base::get(storageType.getContext(), flags, storageType, quantileType,
479+
expressedType, quantiles, scales, zeroPoints,
480+
quantizedDimension, storageTypeMin, storageTypeMax);
453481
}
454482

455483
QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked(
456484
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
457-
Type storageType, Type expressedType, ArrayRef<double> quantiles,
458-
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
459-
int32_t quantizedDimension, int64_t storageTypeMin,
460-
int64_t storageTypeMax) {
485+
Type storageType, Type quantileType, Type expressedType,
486+
ArrayRef<double> quantiles, ArrayRef<double> scales,
487+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
488+
int64_t storageTypeMin, int64_t storageTypeMax) {
461489
return Base::getChecked(emitError, storageType.getContext(), flags,
462-
storageType, expressedType, quantiles, scales,
463-
zeroPoints, quantizedDimension, storageTypeMin,
464-
storageTypeMax);
490+
storageType, quantileType, expressedType, quantiles,
491+
scales, zeroPoints, quantizedDimension,
492+
storageTypeMin, storageTypeMax);
465493
}
466494

467495
LogicalResult QuantileQuantizedPerAxisType::verify(
468496
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
469-
Type storageType, Type expressedType, ArrayRef<double> quantiles,
470-
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
471-
int32_t quantizedDimension, int64_t storageTypeMin,
472-
int64_t storageTypeMax) {
497+
Type storageType, Type quantileType, Type expressedType,
498+
ArrayRef<double> quantiles, ArrayRef<double> scales,
499+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
500+
int64_t storageTypeMin, int64_t storageTypeMax) {
473501
if (failed(UniformQuantizedPerAxisType::verify(
474502
emitError, flags, storageType, expressedType, scales, zeroPoints,
475503
quantizedDimension, storageTypeMin, storageTypeMax))) {
@@ -506,6 +534,18 @@ LogicalResult QuantileQuantizedPerAxisType::verify(
506534
return success();
507535
}
508536

537+
bool QuantileQuantizedPerAxisType::classof(mlir::Type type) {
538+
return type.getTypeID() == mlir::TypeID::get<QuantileQuantizedPerAxisType>();
539+
}
540+
541+
Type QuantileQuantizedPerAxisType::getQuantileType() const {
542+
return getImpl()->quantileType;
543+
}
544+
545+
unsigned QuantileQuantizedPerAxisType::getQuantileTypeIntegralWidth() const {
546+
return getImpl()->getQuantileType().getIntOrFloatBitWidth();
547+
}
548+
509549
ArrayRef<double> QuantileQuantizedPerAxisType::getQuantiles() const {
510550
return getImpl()->getQuantiles();
511551
}

0 commit comments

Comments
 (0)