@@ -305,6 +305,11 @@ LogicalResult UniformQuantizedType::verify(
305305 return success ();
306306}
307307
308+ bool UniformQuantizedType::classof (mlir::Type type) {
309+ return type.getTypeID () == mlir::TypeID::get<UniformQuantizedType>() ||
310+ type.getTypeID () == mlir::TypeID::get<QuantileQuantizedType>();
311+ }
312+
308313double UniformQuantizedType::getScale () const { return getImpl ()->scale ; }
309314
310315int64_t UniformQuantizedType::getZeroPoint () const {
@@ -366,6 +371,11 @@ LogicalResult UniformQuantizedPerAxisType::verify(
366371 return success ();
367372}
368373
374+ bool UniformQuantizedPerAxisType::classof (mlir::Type type) {
375+ return type.getTypeID () == mlir::TypeID::get<UniformQuantizedPerAxisType>() ||
376+ type.getTypeID () == mlir::TypeID::get<QuantileQuantizedPerAxisType>();
377+ }
378+
369379ArrayRef<double > UniformQuantizedPerAxisType::getScales () const {
370380 return getImpl ()->getScales ();
371381}
@@ -379,36 +389,35 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
379389}
380390
381391QuantileQuantizedType
382- QuantileQuantizedType::get (unsigned flags, Type storageType, Type expressedType,
383- ArrayRef<double > quantiles, double scale,
384- int64_t zeroPoint, int64_t storageTypeMin,
385- int64_t storageTypeMax) {
386- return Base::get (storageType.getContext (), flags, storageType, expressedType,
387- quantiles, scale, zeroPoint, storageTypeMin, storageTypeMax);
392+ QuantileQuantizedType::get (unsigned flags, Type storageType, Type quantileType,
393+ Type expressedType, ArrayRef<double > quantiles,
394+ double scale, int64_t zeroPoint,
395+ int64_t storageTypeMin, int64_t storageTypeMax) {
396+ return Base::get (storageType.getContext (), flags, storageType, quantileType,
397+ expressedType, quantiles, scale, zeroPoint, storageTypeMin,
398+ storageTypeMax);
388399}
389400
390401QuantileQuantizedType QuantileQuantizedType::getChecked (
391402 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
392- Type storageType, Type expressedType, ArrayRef<double> quantiles ,
393- double scale, int64_t zeroPoint , int64_t storageTypeMin ,
394- int64_t storageTypeMax) {
403+ Type storageType, Type quantileType, Type expressedType ,
404+ ArrayRef< double> quantiles, double scale , int64_t zeroPoint ,
405+ int64_t storageTypeMin, int64_t storageTypeMax) {
395406 return Base::getChecked (emitError, storageType.getContext (), flags,
396- storageType, expressedType, quantiles, scale ,
397- zeroPoint, storageTypeMin, storageTypeMax);
407+ storageType, quantileType, expressedType, quantiles ,
408+ scale, zeroPoint, storageTypeMin, storageTypeMax);
398409}
399- LogicalResult
400- QuantileQuantizedType::verify (function_ref<InFlightDiagnostic()> emitError,
401- unsigned flags, Type storageType,
402- Type expressedType, ArrayRef<double> quantiles,
403- double scale, int64_t zeroPoint,
404- int64_t storageTypeMin, int64_t storageTypeMax) {
410+ LogicalResult QuantileQuantizedType::verify (
411+ function_ref<InFlightDiagnostic()> emitError, unsigned flags,
412+ Type storageType, Type quantileType, Type expressedType,
413+ ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
414+ int64_t storageTypeMin, int64_t storageTypeMax) {
405415 if (failed (UniformQuantizedType::verify (emitError, flags, storageType,
406416 expressedType, scale, zeroPoint,
407417 storageTypeMin, storageTypeMax))) {
408418 return failure ();
409419 }
410420
411- const auto quantileArraySize = quantiles.size ();
412421 unsigned typeWidth{};
413422 if (storageType.isa <IntegerType>()) {
414423 typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth ();
@@ -421,10 +430,17 @@ QuantileQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
421430 " types, Float8E4M3FNType and Float8E5M2Type " ;
422431 }
423432
424- const size_t expectedSize = 1 << typeWidth;
433+ const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1 ;
434+ const size_t typeWidthSize = 1 << typeWidth;
435+ const size_t expectedSize =
436+ (storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize;
437+
438+ const auto quantileArraySize = quantiles.size ();
425439 if (quantileArraySize != expectedSize) {
426440 return emitError () << " quantiles array size needs to be equal to "
427- " 2^(bit_size(storageType)), expected: "
441+ " 2^(bit_size(storageType)), or (storageTypeMax - "
442+ " storageTypeMin + 1) when max and min differ from "
443+ " the type limits; expected: "
428444 << expectedSize << " , found: " << quantileArraySize;
429445 }
430446
@@ -438,38 +454,50 @@ QuantileQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
438454 return success ();
439455}
440456
457+ bool QuantileQuantizedType::classof (mlir::Type type) {
458+ return type.getTypeID () == mlir::TypeID::get<QuantileQuantizedType>();
459+ }
460+
461+ Type QuantileQuantizedType::getQuantileType () const {
462+ return getImpl ()->quantileType ;
463+ }
464+
465+ unsigned QuantileQuantizedType::getQuantileTypeIntegralWidth () const {
466+ return getImpl ()->getQuantileType ().getIntOrFloatBitWidth ();
467+ }
468+
441469ArrayRef<double > QuantileQuantizedType::getQuantiles () const {
442470 return getImpl ()->getQuantiles ();
443471}
444472
445473QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get (
446- unsigned flags, Type storageType, Type expressedType,
474+ unsigned flags, Type storageType, Type quantileType, Type expressedType,
447475 ArrayRef<double > quantiles, ArrayRef<double > scales,
448476 ArrayRef<int64_t > zeroPoints, int32_t quantizedDimension,
449477 int64_t storageTypeMin, int64_t storageTypeMax) {
450- return Base::get (storageType.getContext (), flags, storageType, expressedType ,
451- quantiles, scales, zeroPoints, quantizedDimension ,
452- storageTypeMin, storageTypeMax);
478+ return Base::get (storageType.getContext (), flags, storageType, quantileType ,
479+ expressedType, quantiles, scales, zeroPoints,
480+ quantizedDimension, storageTypeMin, storageTypeMax);
453481}
454482
455483QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked (
456484 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
457- Type storageType, Type expressedType, ArrayRef<double> quantiles ,
458- ArrayRef<double> scales , ArrayRef<int64_t> zeroPoints ,
459- int32_t quantizedDimension, int64_t storageTypeMin ,
460- int64_t storageTypeMax) {
485+ Type storageType, Type quantileType, Type expressedType ,
486+ ArrayRef<double> quantiles , ArrayRef<double> scales ,
487+ ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension ,
488+ int64_t storageTypeMin, int64_t storageTypeMax) {
461489 return Base::getChecked (emitError, storageType.getContext (), flags,
462- storageType, expressedType, quantiles, scales ,
463- zeroPoints, quantizedDimension, storageTypeMin ,
464- storageTypeMax);
490+ storageType, quantileType, expressedType, quantiles ,
491+ scales, zeroPoints, quantizedDimension ,
492+ storageTypeMin, storageTypeMax);
465493}
466494
467495LogicalResult QuantileQuantizedPerAxisType::verify (
468496 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
469- Type storageType, Type expressedType, ArrayRef<double> quantiles ,
470- ArrayRef<double> scales , ArrayRef<int64_t> zeroPoints ,
471- int32_t quantizedDimension, int64_t storageTypeMin ,
472- int64_t storageTypeMax) {
497+ Type storageType, Type quantileType, Type expressedType ,
498+ ArrayRef<double> quantiles , ArrayRef<double> scales ,
499+ ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension ,
500+ int64_t storageTypeMin, int64_t storageTypeMax) {
473501 if (failed (UniformQuantizedPerAxisType::verify (
474502 emitError, flags, storageType, expressedType, scales, zeroPoints,
475503 quantizedDimension, storageTypeMin, storageTypeMax))) {
@@ -506,6 +534,18 @@ LogicalResult QuantileQuantizedPerAxisType::verify(
506534 return success ();
507535}
508536
537+ bool QuantileQuantizedPerAxisType::classof (mlir::Type type) {
538+ return type.getTypeID () == mlir::TypeID::get<QuantileQuantizedPerAxisType>();
539+ }
540+
541+ Type QuantileQuantizedPerAxisType::getQuantileType () const {
542+ return getImpl ()->quantileType ;
543+ }
544+
545+ unsigned QuantileQuantizedPerAxisType::getQuantileTypeIntegralWidth () const {
546+ return getImpl ()->getQuantileType ().getIntOrFloatBitWidth ();
547+ }
548+
509549ArrayRef<double > QuantileQuantizedPerAxisType::getQuantiles () const {
510550 return getImpl ()->getQuantiles ();
511551}
0 commit comments