@@ -339,6 +339,11 @@ LogicalResult UniformQuantizedType::verifyInvariants(
339339 return success ();
340340}
341341
342+ bool UniformQuantizedType::classof (mlir::Type type) {
343+ return type.getTypeID () == mlir::TypeID::get<UniformQuantizedType>() ||
344+ type.getTypeID () == mlir::TypeID::get<QuantileQuantizedType>();
345+ }
346+
342347double UniformQuantizedType::getScale () const { return getImpl ()->scale ; }
343348
344349int64_t UniformQuantizedType::getZeroPoint () const {
@@ -408,6 +413,11 @@ LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
408413 return success ();
409414}
410415
416+ bool UniformQuantizedPerAxisType::classof (mlir::Type type) {
417+ return type.getTypeID () == mlir::TypeID::get<UniformQuantizedPerAxisType>() ||
418+ type.getTypeID () == mlir::TypeID::get<QuantileQuantizedPerAxisType>();
419+ }
420+
411421ArrayRef<double > UniformQuantizedPerAxisType::getScales () const {
412422 return getImpl ()->getScales ();
413423}
@@ -421,36 +431,35 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
421431}
422432
423433QuantileQuantizedType
424- QuantileQuantizedType::get (unsigned flags, Type storageType, Type expressedType,
425- ArrayRef<double > quantiles, double scale,
426- int64_t zeroPoint, int64_t storageTypeMin,
427- int64_t storageTypeMax) {
428- return Base::get (storageType.getContext (), flags, storageType, expressedType,
429- quantiles, scale, zeroPoint, storageTypeMin, storageTypeMax);
434+ QuantileQuantizedType::get (unsigned flags, Type storageType, Type quantileType,
435+ Type expressedType, ArrayRef<double > quantiles,
436+ double scale, int64_t zeroPoint,
437+ int64_t storageTypeMin, int64_t storageTypeMax) {
438+ return Base::get (storageType.getContext (), flags, storageType, quantileType,
439+ expressedType, quantiles, scale, zeroPoint, storageTypeMin,
440+ storageTypeMax);
430441}
431442
432443QuantileQuantizedType QuantileQuantizedType::getChecked (
433444 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
434- Type storageType, Type expressedType, ArrayRef<double> quantiles ,
435- double scale, int64_t zeroPoint , int64_t storageTypeMin ,
436- int64_t storageTypeMax) {
445+ Type storageType, Type quantileType, Type expressedType ,
446+ ArrayRef< double> quantiles, double scale , int64_t zeroPoint ,
447+ int64_t storageTypeMin, int64_t storageTypeMax) {
437448 return Base::getChecked (emitError, storageType.getContext (), flags,
438- storageType, expressedType, quantiles, scale ,
439- zeroPoint, storageTypeMin, storageTypeMax);
449+ storageType, quantileType, expressedType, quantiles ,
450+ scale, zeroPoint, storageTypeMin, storageTypeMax);
440451}
441- LogicalResult
442- QuantileQuantizedType::verifyInvariants (function_ref<InFlightDiagnostic()> emitError,
443- unsigned flags, Type storageType,
444- Type expressedType, ArrayRef<double> quantiles,
445- double scale, int64_t zeroPoint,
446- int64_t storageTypeMin, int64_t storageTypeMax) {
452+ LogicalResult QuantileQuantizedType::verifyInvariants (
453+ function_ref<InFlightDiagnostic()> emitError, unsigned flags,
454+ Type storageType, Type quantileType, Type expressedType,
455+ ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
456+ int64_t storageTypeMin, int64_t storageTypeMax) {
447457 if (failed (UniformQuantizedType::verifyInvariants (emitError, flags, storageType,
448458 expressedType, scale, zeroPoint,
449459 storageTypeMin, storageTypeMax))) {
450460 return failure ();
451461 }
452462
453- const auto quantileArraySize = quantiles.size ();
454463 unsigned typeWidth{};
455464 if (storageType.isa <IntegerType>()) {
456465 typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth ();
@@ -463,10 +472,17 @@ QuantileQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitE
463472 " types, Float8E4M3FNType and Float8E5M2Type " ;
464473 }
465474
466- const size_t expectedSize = 1 << typeWidth;
475+ const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1 ;
476+ const size_t typeWidthSize = 1 << typeWidth;
477+ const size_t expectedSize =
478+ (storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize;
479+
480+ const auto quantileArraySize = quantiles.size ();
467481 if (quantileArraySize != expectedSize) {
468482 return emitError () << " quantiles array size needs to be equal to "
469- " 2^(bit_size(storageType)), expected: "
483+ " 2^(bit_size(storageType)), or (storageTypeMax - "
484+ " storageTypeMin + 1) when max and min differ from "
485+ " the type limits; expected: "
470486 << expectedSize << " , found: " << quantileArraySize;
471487 }
472488
@@ -480,38 +496,50 @@ QuantileQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitE
480496 return success ();
481497}
482498
499+ bool QuantileQuantizedType::classof (mlir::Type type) {
500+ return type.getTypeID () == mlir::TypeID::get<QuantileQuantizedType>();
501+ }
502+
503+ Type QuantileQuantizedType::getQuantileType () const {
504+ return getImpl ()->quantileType ;
505+ }
506+
507+ unsigned QuantileQuantizedType::getQuantileTypeIntegralWidth () const {
508+ return getImpl ()->getQuantileType ().getIntOrFloatBitWidth ();
509+ }
510+
483511ArrayRef<double > QuantileQuantizedType::getQuantiles () const {
484512 return getImpl ()->getQuantiles ();
485513}
486514
487515QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get (
488- unsigned flags, Type storageType, Type expressedType,
516+ unsigned flags, Type storageType, Type quantileType, Type expressedType,
489517 ArrayRef<double > quantiles, ArrayRef<double > scales,
490518 ArrayRef<int64_t > zeroPoints, int32_t quantizedDimension,
491519 int64_t storageTypeMin, int64_t storageTypeMax) {
492- return Base::get (storageType.getContext (), flags, storageType, expressedType ,
493- quantiles, scales, zeroPoints, quantizedDimension ,
494- storageTypeMin, storageTypeMax);
520+ return Base::get (storageType.getContext (), flags, storageType, quantileType ,
521+ expressedType, quantiles, scales, zeroPoints,
522+ quantizedDimension, storageTypeMin, storageTypeMax);
495523}
496524
497525QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked (
498526 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
499- Type storageType, Type expressedType, ArrayRef<double> quantiles ,
500- ArrayRef<double> scales , ArrayRef<int64_t> zeroPoints ,
501- int32_t quantizedDimension, int64_t storageTypeMin ,
502- int64_t storageTypeMax) {
527+ Type storageType, Type quantileType, Type expressedType ,
528+ ArrayRef<double> quantiles , ArrayRef<double> scales ,
529+ ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension ,
530+ int64_t storageTypeMin, int64_t storageTypeMax) {
503531 return Base::getChecked (emitError, storageType.getContext (), flags,
504- storageType, expressedType, quantiles, scales ,
505- zeroPoints, quantizedDimension, storageTypeMin ,
506- storageTypeMax);
532+ storageType, quantileType, expressedType, quantiles ,
533+ scales, zeroPoints, quantizedDimension ,
534+ storageTypeMin, storageTypeMax);
507535}
508536
509537LogicalResult QuantileQuantizedPerAxisType::verifyInvariants (
510538 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
511- Type storageType, Type expressedType, ArrayRef<double> quantiles ,
512- ArrayRef<double> scales , ArrayRef<int64_t> zeroPoints ,
513- int32_t quantizedDimension, int64_t storageTypeMin ,
514- int64_t storageTypeMax) {
539+ Type storageType, Type quantileType, Type expressedType ,
540+ ArrayRef<double> quantiles , ArrayRef<double> scales ,
541+ ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension ,
542+ int64_t storageTypeMin, int64_t storageTypeMax) {
515543 if (failed (UniformQuantizedPerAxisType::verifyInvariants (
516544 emitError, flags, storageType, expressedType, scales, zeroPoints,
517545 quantizedDimension, storageTypeMin, storageTypeMax))) {
@@ -548,6 +576,18 @@ LogicalResult QuantileQuantizedPerAxisType::verifyInvariants(
548576 return success ();
549577}
550578
579+ bool QuantileQuantizedPerAxisType::classof (mlir::Type type) {
580+ return type.getTypeID () == mlir::TypeID::get<QuantileQuantizedPerAxisType>();
581+ }
582+
583+ Type QuantileQuantizedPerAxisType::getQuantileType () const {
584+ return getImpl ()->quantileType ;
585+ }
586+
587+ unsigned QuantileQuantizedPerAxisType::getQuantileTypeIntegralWidth () const {
588+ return getImpl ()->getQuantileType ().getIntOrFloatBitWidth ();
589+ }
590+
551591ArrayRef<double > QuantileQuantizedPerAxisType::getQuantiles () const {
552592 return getImpl ()->getQuantiles ();
553593}
0 commit comments