Skip to content

Commit 7f86fdf

Browse files
sartilsramasit
authored andcommitted
Extending QuantileQuantizedType with quantileType mlir::Type member (#60)
1 parent be1e643 commit 7f86fdf

File tree

9 files changed

+365
-165
lines changed

9 files changed

+365
-165
lines changed

mlir/include/mlir/Dialect/Quant/IR/QuantBase.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,26 @@ def quant_QuantileQuantizedType :
235235
def quant_QuantileQuantizedValueType :
236236
quant_ScalarOrTensorOf<quant_QuantileQuantizedType>;
237237

238+
// UniformQuantizedPerAxisType
239+
def quant_UniformQuantizedPerAxisType :
240+
DialectType<Quant_Dialect,
241+
CPred<"::llvm::isa<::mlir::quant::UniformQuantizedPerAxisType>($_self)">,
242+
"UniformQuantizedPerAxisType">;
243+
244+
// QuantileQuantizedPerAxisType
245+
def quant_QuantileQuantizedPerAxisType :
246+
DialectType<Quant_Dialect,
247+
CPred<"::llvm::isa<::mlir::quant::QuantileQuantizedPerAxisType>($_self)">,
248+
"QuantileQuantizedPerAxisType">;
249+
250+
// Predicate for detecting a container or primitive of UniformQuantizedPerAxisType.
251+
def quant_UniformQuantizedPerAxisValueType :
252+
quant_ScalarOrTensorOf<quant_UniformQuantizedPerAxisType>;
253+
254+
// Predicate for detecting a container or primitive of QuantileQuantizedPerAxisType.
255+
def quant_QuantileQuantizedPerAxisValueType :
256+
quant_ScalarOrTensorOf<quant_QuantileQuantizedPerAxisType>;
257+
238258
def quant_ScalarType :
239259
Type<Or<[
240260
AnySignlessInteger.predicate,

mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def UniformQuantizedPerAxisType: DialectType<(type
8484
def QuantileQuantizedType: DialectType<(type
8585
VarInt:$flags,
8686
Type:$storageType,
87+
Type:$quantileType,
8788
Type:$expressedType,
8889
Array<DoubleAPFloatList>:$quantiles,
8990
DoubleAPFloat:$scale,
@@ -95,6 +96,7 @@ def QuantileQuantizedType: DialectType<(type
9596
def QuantileQuantizedPerAxisType: DialectType<(type
9697
VarInt:$flags,
9798
Type:$storageType,
99+
Type:$quantileType,
98100
Type:$expressedType,
99101
VarInt:$quantizedDimension,
100102
SignedVarInt:$storageTypeMin,
@@ -105,7 +107,7 @@ def QuantileQuantizedPerAxisType: DialectType<(type
105107
)> {
106108
// Note: builder order differs from bytecode.
107109
let cBuilder = [{
108-
get<$_resultType>(context, flags, storageType, expressedType, quantiles, scales,
110+
get<$_resultType>(context, flags, storageType, quantileType, expressedType, quantiles, scales,
109111
zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax)
110112
}];
111113
}

mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ class UniformQuantizedType
300300
int64_t zeroPoint, int64_t storageTypeMin,
301301
int64_t storageTypeMax);
302302

303+
static bool classof(mlir::Type type);
304+
303305
/// Gets the scale term. The scale designates the difference between the real
304306
/// values corresponding to consecutive quantized values differing by 1.
305307
double getScale() const;
@@ -363,6 +365,8 @@ class UniformQuantizedPerAxisType
363365
int32_t quantizedDimension, int64_t storageTypeMin,
364366
int64_t storageTypeMax);
365367

368+
static bool classof(mlir::Type type);
369+
366370
/// Gets the quantization scales. The scales designate the difference between
367371
/// the real values corresponding to consecutive quantized values differing
368372
/// by 1. The ith scale corresponds to the ith slice in the
@@ -397,15 +401,17 @@ class UniformQuantizedPerAxisType
397401
};
398402

399403
/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a
400-
/// look up table array of quantile values.
404+
/// look up table array of quantile values. The type of the data in the look up table is determined by
405+
/// the quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64.
401406
///
402407
/// Syntax synopsis:
403408
/// Per-layer, all parameters expressed:
404-
/// !quant<quantile[StorageType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
409+
/// !quant<quantile[StorageType:QuantileType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
405410
/// Per-layer, optional parameters omitted:
406-
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
411+
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
407412
///
408413
/// StorageType: 'i'|'u' NumBits
414+
/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
409415
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
410416
/// Quantiles: Quantile+
411417
/// Quantile: A legal double value
@@ -423,23 +429,32 @@ class QuantileQuantizedType
423429
/// Gets an instance of the type with all parameters specified but not
424430
/// checked.
425431
static QuantileQuantizedType get(unsigned flags, Type storageType,
426-
Type expressedType,
432+
Type quantileType, Type expressedType,
427433
ArrayRef<double> quantiles, double scale,
428434
int64_t zeroPoint, int64_t storageTypeMin,
429435
int64_t storageTypeMax);
430436

431437
static QuantileQuantizedType
432438
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
433-
Type storageType, Type expressedType, ArrayRef<double> quantiles,
434-
double scale, int64_t zeroPoint, int64_t storageTypeMin,
435-
int64_t storageTypeMax);
439+
Type storageType, Type quantileType, Type expressedType,
440+
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
441+
int64_t storageTypeMin, int64_t storageTypeMax);
436442

437443
/// Verifies construction invariants and issues errors/warnings.
438444
static LogicalResult verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
439445
unsigned flags, Type storageType,
440-
Type expressedType, ArrayRef<double> quantiles,
441-
double scale, int64_t zeroPoint,
442-
int64_t storageTypeMin, int64_t storageTypeMax);
446+
Type quantileType, Type expressedType,
447+
ArrayRef<double> quantiles, double scale,
448+
int64_t zeroPoint, int64_t storageTypeMin,
449+
int64_t storageTypeMax);
450+
451+
static bool classof(mlir::Type type);
452+
453+
/// Gets the quantileType
454+
Type getQuantileType() const;
455+
456+
/// Gets the quantileType bit width
457+
unsigned getQuantileTypeIntegralWidth() const;
443458

444459
/// Gets the quantile values
445460
ArrayRef<double> getQuantiles() const;
@@ -453,15 +468,17 @@ class QuantileQuantizedType
453468
};
454469

455470
/// Represents per-axis QuantileQuantizedType (also known as per-channel
456-
/// quantization).
471+
/// quantization). The type of the data in the look up table is determined by the
472+
/// quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64.
457473
///
458474
/// Syntax synopsis:
459475
/// Per-axis, all parameters expressed:
460-
/// !quant<quantile[StorageType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
476+
/// !quant<quantile[StorageType:QuantileType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
461477
/// Per-axis, optional parameters omitted:
462-
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
478+
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
463479
///
464480
/// StorageType: 'i'|'u' NumBits
481+
/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
465482
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
466483
/// QuantizedDim: An integer value
467484
/// Quantiles: Quantile+
@@ -482,7 +499,7 @@ class QuantileQuantizedPerAxisType
482499
/// Gets an instance of the type with all parameters specified but not
483500
/// checked.
484501
static QuantileQuantizedPerAxisType
485-
get(unsigned flags, Type storageType, Type expressedType,
502+
get(unsigned flags, Type storageType, Type quantileType, Type expressedType,
486503
ArrayRef<double> quantiles, ArrayRef<double> scales,
487504
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
488505
int64_t storageTypeMin, int64_t storageTypeMax);
@@ -491,19 +508,26 @@ class QuantileQuantizedPerAxisType
491508
/// Returns a nullptr convertible type on failure.
492509
static QuantileQuantizedPerAxisType
493510
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
494-
Type storageType, Type expressedType, ArrayRef<double> quantiles,
495-
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
496-
int32_t quantizedDimension, int64_t storageTypeMin,
497-
int64_t storageTypeMax);
511+
Type storageType, Type quantileType, Type expressedType,
512+
ArrayRef<double> quantiles, ArrayRef<double> scales,
513+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
514+
int64_t storageTypeMin, int64_t storageTypeMax);
498515

499516
/// Verifies construction invariants and issues errors/warnings.
500-
static LogicalResult verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
501-
unsigned flags, Type storageType,
502-
Type expressedType, ArrayRef<double> quantiles,
503-
ArrayRef<double> scales,
504-
ArrayRef<int64_t> zeroPoints,
505-
int32_t quantizedDimension,
506-
int64_t storageTypeMin, int64_t storageTypeMax);
517+
static LogicalResult
518+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
519+
Type storageType, Type quantileType, Type expressedType,
520+
ArrayRef<double> quantiles, ArrayRef<double> scales,
521+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
522+
int64_t storageTypeMin, int64_t storageTypeMax);
523+
524+
static bool classof(mlir::Type type);
525+
526+
/// Gets the quantileType
527+
Type getQuantileType() const;
528+
529+
/// Gets the quantileType bit width
530+
unsigned getQuantileTypeIntegralWidth() const;
507531

508532
/// Gets the quantile values
509533
ArrayRef<double> getQuantiles() const;

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

Lines changed: 75 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,11 @@ LogicalResult UniformQuantizedType::verifyInvariants(
339339
return success();
340340
}
341341

342+
bool UniformQuantizedType::classof(mlir::Type type) {
343+
return type.getTypeID() == mlir::TypeID::get<UniformQuantizedType>() ||
344+
type.getTypeID() == mlir::TypeID::get<QuantileQuantizedType>();
345+
}
346+
342347
double UniformQuantizedType::getScale() const { return getImpl()->scale; }
343348

344349
int64_t UniformQuantizedType::getZeroPoint() const {
@@ -408,6 +413,11 @@ LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
408413
return success();
409414
}
410415

416+
bool UniformQuantizedPerAxisType::classof(mlir::Type type) {
417+
return type.getTypeID() == mlir::TypeID::get<UniformQuantizedPerAxisType>() ||
418+
type.getTypeID() == mlir::TypeID::get<QuantileQuantizedPerAxisType>();
419+
}
420+
411421
ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
412422
return getImpl()->getScales();
413423
}
@@ -421,36 +431,35 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
421431
}
422432

423433
QuantileQuantizedType
424-
QuantileQuantizedType::get(unsigned flags, Type storageType, Type expressedType,
425-
ArrayRef<double> quantiles, double scale,
426-
int64_t zeroPoint, int64_t storageTypeMin,
427-
int64_t storageTypeMax) {
428-
return Base::get(storageType.getContext(), flags, storageType, expressedType,
429-
quantiles, scale, zeroPoint, storageTypeMin, storageTypeMax);
434+
QuantileQuantizedType::get(unsigned flags, Type storageType, Type quantileType,
435+
Type expressedType, ArrayRef<double> quantiles,
436+
double scale, int64_t zeroPoint,
437+
int64_t storageTypeMin, int64_t storageTypeMax) {
438+
return Base::get(storageType.getContext(), flags, storageType, quantileType,
439+
expressedType, quantiles, scale, zeroPoint, storageTypeMin,
440+
storageTypeMax);
430441
}
431442

432443
QuantileQuantizedType QuantileQuantizedType::getChecked(
433444
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
434-
Type storageType, Type expressedType, ArrayRef<double> quantiles,
435-
double scale, int64_t zeroPoint, int64_t storageTypeMin,
436-
int64_t storageTypeMax) {
445+
Type storageType, Type quantileType, Type expressedType,
446+
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
447+
int64_t storageTypeMin, int64_t storageTypeMax) {
437448
return Base::getChecked(emitError, storageType.getContext(), flags,
438-
storageType, expressedType, quantiles, scale,
439-
zeroPoint, storageTypeMin, storageTypeMax);
449+
storageType, quantileType, expressedType, quantiles,
450+
scale, zeroPoint, storageTypeMin, storageTypeMax);
440451
}
441-
LogicalResult
442-
QuantileQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
443-
unsigned flags, Type storageType,
444-
Type expressedType, ArrayRef<double> quantiles,
445-
double scale, int64_t zeroPoint,
446-
int64_t storageTypeMin, int64_t storageTypeMax) {
452+
LogicalResult QuantileQuantizedType::verifyInvariants(
453+
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
454+
Type storageType, Type quantileType, Type expressedType,
455+
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
456+
int64_t storageTypeMin, int64_t storageTypeMax) {
447457
if (failed(UniformQuantizedType::verifyInvariants(emitError, flags, storageType,
448458
expressedType, scale, zeroPoint,
449459
storageTypeMin, storageTypeMax))) {
450460
return failure();
451461
}
452462

453-
const auto quantileArraySize = quantiles.size();
454463
unsigned typeWidth{};
455464
if (storageType.isa<IntegerType>()) {
456465
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
@@ -463,10 +472,17 @@ QuantileQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitE
463472
"types, Float8E4M3FNType and Float8E5M2Type ";
464473
}
465474

466-
const size_t expectedSize = 1 << typeWidth;
475+
const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1;
476+
const size_t typeWidthSize = 1 << typeWidth;
477+
const size_t expectedSize =
478+
(storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize;
479+
480+
const auto quantileArraySize = quantiles.size();
467481
if (quantileArraySize != expectedSize) {
468482
return emitError() << "quantiles array size needs to be equal to "
469-
"2^(bit_size(storageType)), expected: "
483+
"2^(bit_size(storageType)), or (storageTypeMax - "
484+
"storageTypeMin + 1) when max and min differ from "
485+
"the type limits; expected: "
470486
<< expectedSize << ", found: " << quantileArraySize;
471487
}
472488

@@ -480,38 +496,50 @@ QuantileQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitE
480496
return success();
481497
}
482498

499+
bool QuantileQuantizedType::classof(mlir::Type type) {
500+
return type.getTypeID() == mlir::TypeID::get<QuantileQuantizedType>();
501+
}
502+
503+
Type QuantileQuantizedType::getQuantileType() const {
504+
return getImpl()->quantileType;
505+
}
506+
507+
unsigned QuantileQuantizedType::getQuantileTypeIntegralWidth() const {
508+
return getImpl()->getQuantileType().getIntOrFloatBitWidth();
509+
}
510+
483511
ArrayRef<double> QuantileQuantizedType::getQuantiles() const {
484512
return getImpl()->getQuantiles();
485513
}
486514

487515
QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get(
488-
unsigned flags, Type storageType, Type expressedType,
516+
unsigned flags, Type storageType, Type quantileType, Type expressedType,
489517
ArrayRef<double> quantiles, ArrayRef<double> scales,
490518
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
491519
int64_t storageTypeMin, int64_t storageTypeMax) {
492-
return Base::get(storageType.getContext(), flags, storageType, expressedType,
493-
quantiles, scales, zeroPoints, quantizedDimension,
494-
storageTypeMin, storageTypeMax);
520+
return Base::get(storageType.getContext(), flags, storageType, quantileType,
521+
expressedType, quantiles, scales, zeroPoints,
522+
quantizedDimension, storageTypeMin, storageTypeMax);
495523
}
496524

497525
QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked(
498526
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
499-
Type storageType, Type expressedType, ArrayRef<double> quantiles,
500-
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
501-
int32_t quantizedDimension, int64_t storageTypeMin,
502-
int64_t storageTypeMax) {
527+
Type storageType, Type quantileType, Type expressedType,
528+
ArrayRef<double> quantiles, ArrayRef<double> scales,
529+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
530+
int64_t storageTypeMin, int64_t storageTypeMax) {
503531
return Base::getChecked(emitError, storageType.getContext(), flags,
504-
storageType, expressedType, quantiles, scales,
505-
zeroPoints, quantizedDimension, storageTypeMin,
506-
storageTypeMax);
532+
storageType, quantileType, expressedType, quantiles,
533+
scales, zeroPoints, quantizedDimension,
534+
storageTypeMin, storageTypeMax);
507535
}
508536

509537
LogicalResult QuantileQuantizedPerAxisType::verifyInvariants(
510538
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
511-
Type storageType, Type expressedType, ArrayRef<double> quantiles,
512-
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
513-
int32_t quantizedDimension, int64_t storageTypeMin,
514-
int64_t storageTypeMax) {
539+
Type storageType, Type quantileType, Type expressedType,
540+
ArrayRef<double> quantiles, ArrayRef<double> scales,
541+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
542+
int64_t storageTypeMin, int64_t storageTypeMax) {
515543
if (failed(UniformQuantizedPerAxisType::verifyInvariants(
516544
emitError, flags, storageType, expressedType, scales, zeroPoints,
517545
quantizedDimension, storageTypeMin, storageTypeMax))) {
@@ -548,6 +576,18 @@ LogicalResult QuantileQuantizedPerAxisType::verifyInvariants(
548576
return success();
549577
}
550578

579+
bool QuantileQuantizedPerAxisType::classof(mlir::Type type) {
580+
return type.getTypeID() == mlir::TypeID::get<QuantileQuantizedPerAxisType>();
581+
}
582+
583+
Type QuantileQuantizedPerAxisType::getQuantileType() const {
584+
return getImpl()->quantileType;
585+
}
586+
587+
unsigned QuantileQuantizedPerAxisType::getQuantileTypeIntegralWidth() const {
588+
return getImpl()->getQuantileType().getIntOrFloatBitWidth();
589+
}
590+
551591
ArrayRef<double> QuantileQuantizedPerAxisType::getQuantiles() const {
552592
return getImpl()->getQuantiles();
553593
}

0 commit comments

Comments
 (0)