@@ -420,6 +420,138 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
420420 return getImpl ()->quantizedDimension ;
421421}
422422
423+ QuantileQuantizedType
424+ QuantileQuantizedType::get (unsigned flags, Type storageType, Type expressedType,
425+ ArrayRef<double > quantiles, double scale,
426+ int64_t zeroPoint, int64_t storageTypeMin,
427+ int64_t storageTypeMax) {
428+ return Base::get (storageType.getContext (), flags, storageType, expressedType,
429+ quantiles, scale, zeroPoint, storageTypeMin, storageTypeMax);
430+ }
431+
432+ QuantileQuantizedType QuantileQuantizedType::getChecked (
433+ function_ref<InFlightDiagnostic()> emitError, unsigned flags,
434+ Type storageType, Type expressedType, ArrayRef<double> quantiles,
435+ double scale, int64_t zeroPoint, int64_t storageTypeMin,
436+ int64_t storageTypeMax) {
437+ return Base::getChecked (emitError, storageType.getContext (), flags,
438+ storageType, expressedType, quantiles, scale,
439+ zeroPoint, storageTypeMin, storageTypeMax);
440+ }
441+ LogicalResult
442+ QuantileQuantizedType::verifyInvariants (function_ref<InFlightDiagnostic()> emitError,
443+ unsigned flags, Type storageType,
444+ Type expressedType, ArrayRef<double> quantiles,
445+ double scale, int64_t zeroPoint,
446+ int64_t storageTypeMin, int64_t storageTypeMax) {
447+ if (failed (UniformQuantizedType::verifyInvariants (emitError, flags, storageType,
448+ expressedType, scale, zeroPoint,
449+ storageTypeMin, storageTypeMax))) {
450+ return failure ();
451+ }
452+
453+ const auto quantileArraySize = quantiles.size ();
454+ unsigned typeWidth{};
455+ if (storageType.isa <IntegerType>()) {
456+ typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth ();
457+ } else if (storageType.isa <Float8E5M2Type>() ||
458+ storageType.isa <Float8E4M3FNType>()) {
459+ // Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
460+ typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth ();
461+ } else {
462+ return emitError () << " illegal storage type, supported types are: integral "
463+ " types, Float8E4M3FNType and Float8E5M2Type " ;
464+ }
465+
466+ const size_t expectedSize = 1 << typeWidth;
467+ if (quantileArraySize != expectedSize) {
468+ return emitError () << " quantiles array size needs to be equal to "
469+ " 2^(bit_size(storageType)), expected: "
470+ << expectedSize << " , found: " << quantileArraySize;
471+ }
472+
473+ // Verify quantiles
474+ for (double quantile : quantiles) {
475+ if (std::isinf (quantile) || std::isnan (quantile)) {
476+ return emitError () << " illegal quantile value: " << quantile;
477+ }
478+ }
479+
480+ return success ();
481+ }
482+
483+ ArrayRef<double > QuantileQuantizedType::getQuantiles () const {
484+ return getImpl ()->getQuantiles ();
485+ }
486+
487+ QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get (
488+ unsigned flags, Type storageType, Type expressedType,
489+ ArrayRef<double > quantiles, ArrayRef<double > scales,
490+ ArrayRef<int64_t > zeroPoints, int32_t quantizedDimension,
491+ int64_t storageTypeMin, int64_t storageTypeMax) {
492+ return Base::get (storageType.getContext (), flags, storageType, expressedType,
493+ quantiles, scales, zeroPoints, quantizedDimension,
494+ storageTypeMin, storageTypeMax);
495+ }
496+
497+ QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked (
498+ function_ref<InFlightDiagnostic()> emitError, unsigned flags,
499+ Type storageType, Type expressedType, ArrayRef<double> quantiles,
500+ ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
501+ int32_t quantizedDimension, int64_t storageTypeMin,
502+ int64_t storageTypeMax) {
503+ return Base::getChecked (emitError, storageType.getContext (), flags,
504+ storageType, expressedType, quantiles, scales,
505+ zeroPoints, quantizedDimension, storageTypeMin,
506+ storageTypeMax);
507+ }
508+
509+ LogicalResult QuantileQuantizedPerAxisType::verifyInvariants (
510+ function_ref<InFlightDiagnostic()> emitError, unsigned flags,
511+ Type storageType, Type expressedType, ArrayRef<double> quantiles,
512+ ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
513+ int32_t quantizedDimension, int64_t storageTypeMin,
514+ int64_t storageTypeMax) {
515+ if (failed (UniformQuantizedPerAxisType::verifyInvariants (
516+ emitError, flags, storageType, expressedType, scales, zeroPoints,
517+ quantizedDimension, storageTypeMin, storageTypeMax))) {
518+ return failure ();
519+ }
520+
521+ const auto quantileArraySize = quantiles.size ();
522+ unsigned typeWidth{};
523+ if (storageType.isa <IntegerType>()) {
524+ typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth ();
525+ } else if (storageType.isa <Float8E5M2Type>() ||
526+ storageType.isa <Float8E4M3FNType>()) {
527+ // Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
528+ typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth ();
529+ } else {
530+ return emitError () << " illegal storage type, supported types are: integral "
531+ " types, Float8E4M3FNType and Float8E5M2Type " ;
532+ }
533+
534+ const size_t expectedSize = 1 << typeWidth;
535+ if (quantileArraySize != expectedSize) {
536+ return emitError () << " quantiles array size needs to be equal to "
537+ " 2^(bit_size(storageType)), expected: "
538+ << expectedSize << " , found: " << quantileArraySize;
539+ }
540+
541+ // Verify quantiles
542+ for (double quantile : quantiles) {
543+ if (std::isinf (quantile) || std::isnan (quantile)) {
544+ return emitError () << " illegal quantile value: " << quantile;
545+ }
546+ }
547+
548+ return success ();
549+ }
550+
551+ ArrayRef<double > QuantileQuantizedPerAxisType::getQuantiles () const {
552+ return getImpl ()->getQuantiles ();
553+ }
554+
423555CalibratedQuantizedType CalibratedQuantizedType::get (Type expressedType,
424556 double min, double max) {
425557 return Base::get (expressedType.getContext (), expressedType, min, max);
0 commit comments