@@ -378,6 +378,138 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
378378 return getImpl ()->quantizedDimension ;
379379}
380380
381+ QuantileQuantizedType
382+ QuantileQuantizedType::get (unsigned flags, Type storageType, Type expressedType,
383+ ArrayRef<double > quantiles, double scale,
384+ int64_t zeroPoint, int64_t storageTypeMin,
385+ int64_t storageTypeMax) {
386+ return Base::get (storageType.getContext (), flags, storageType, expressedType,
387+ quantiles, scale, zeroPoint, storageTypeMin, storageTypeMax);
388+ }
389+
390+ QuantileQuantizedType QuantileQuantizedType::getChecked (
391+ function_ref<InFlightDiagnostic()> emitError, unsigned flags,
392+ Type storageType, Type expressedType, ArrayRef<double> quantiles,
393+ double scale, int64_t zeroPoint, int64_t storageTypeMin,
394+ int64_t storageTypeMax) {
395+ return Base::getChecked (emitError, storageType.getContext (), flags,
396+ storageType, expressedType, quantiles, scale,
397+ zeroPoint, storageTypeMin, storageTypeMax);
398+ }
399+ LogicalResult
400+ QuantileQuantizedType::verify (function_ref<InFlightDiagnostic()> emitError,
401+ unsigned flags, Type storageType,
402+ Type expressedType, ArrayRef<double> quantiles,
403+ double scale, int64_t zeroPoint,
404+ int64_t storageTypeMin, int64_t storageTypeMax) {
405+ if (failed (UniformQuantizedType::verify (emitError, flags, storageType,
406+ expressedType, scale, zeroPoint,
407+ storageTypeMin, storageTypeMax))) {
408+ return failure ();
409+ }
410+
411+ const auto quantileArraySize = quantiles.size ();
412+ unsigned typeWidth{};
413+ if (storageType.isa <IntegerType>()) {
414+ typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth ();
415+ } else if (storageType.isa <Float8E5M2Type>() ||
416+ storageType.isa <Float8E4M3FNType>()) {
417+ // Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
418+ typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth ();
419+ } else {
420+ return emitError () << " illegal storage type, supported types are: integral "
421+ " types, Float8E4M3FNType and Float8E5M2Type " ;
422+ }
423+
424+ const size_t expectedSize = 1 << typeWidth;
425+ if (quantileArraySize != expectedSize) {
426+ return emitError () << " quantiles array size needs to be equal to "
427+ " 2^(bit_size(storageType)), expected: "
428+ << expectedSize << " , found: " << quantileArraySize;
429+ }
430+
431+ // Verify quantiles
432+ for (double quantile : quantiles) {
433+ if (std::isinf (quantile) || std::isnan (quantile)) {
434+ return emitError () << " illegal quantile value: " << quantile;
435+ }
436+ }
437+
438+ return success ();
439+ }
440+
441+ ArrayRef<double > QuantileQuantizedType::getQuantiles () const {
442+ return getImpl ()->getQuantiles ();
443+ }
444+
445+ QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get (
446+ unsigned flags, Type storageType, Type expressedType,
447+ ArrayRef<double > quantiles, ArrayRef<double > scales,
448+ ArrayRef<int64_t > zeroPoints, int32_t quantizedDimension,
449+ int64_t storageTypeMin, int64_t storageTypeMax) {
450+ return Base::get (storageType.getContext (), flags, storageType, expressedType,
451+ quantiles, scales, zeroPoints, quantizedDimension,
452+ storageTypeMin, storageTypeMax);
453+ }
454+
455+ QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked (
456+ function_ref<InFlightDiagnostic()> emitError, unsigned flags,
457+ Type storageType, Type expressedType, ArrayRef<double> quantiles,
458+ ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
459+ int32_t quantizedDimension, int64_t storageTypeMin,
460+ int64_t storageTypeMax) {
461+ return Base::getChecked (emitError, storageType.getContext (), flags,
462+ storageType, expressedType, quantiles, scales,
463+ zeroPoints, quantizedDimension, storageTypeMin,
464+ storageTypeMax);
465+ }
466+
467+ LogicalResult QuantileQuantizedPerAxisType::verify (
468+ function_ref<InFlightDiagnostic()> emitError, unsigned flags,
469+ Type storageType, Type expressedType, ArrayRef<double> quantiles,
470+ ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
471+ int32_t quantizedDimension, int64_t storageTypeMin,
472+ int64_t storageTypeMax) {
473+ if (failed (UniformQuantizedPerAxisType::verify (
474+ emitError, flags, storageType, expressedType, scales, zeroPoints,
475+ quantizedDimension, storageTypeMin, storageTypeMax))) {
476+ return failure ();
477+ }
478+
479+ const auto quantileArraySize = quantiles.size ();
480+ unsigned typeWidth{};
481+ if (storageType.isa <IntegerType>()) {
482+ typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth ();
483+ } else if (storageType.isa <Float8E5M2Type>() ||
484+ storageType.isa <Float8E4M3FNType>()) {
485+ // Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
486+ typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth ();
487+ } else {
488+ return emitError () << " illegal storage type, supported types are: integral "
489+ " types, Float8E4M3FNType and Float8E5M2Type " ;
490+ }
491+
492+ const size_t expectedSize = 1 << typeWidth;
493+ if (quantileArraySize != expectedSize) {
494+ return emitError () << " quantiles array size needs to be equal to "
495+ " 2^(bit_size(storageType)), expected: "
496+ << expectedSize << " , found: " << quantileArraySize;
497+ }
498+
499+ // Verify quantiles
500+ for (double quantile : quantiles) {
501+ if (std::isinf (quantile) || std::isnan (quantile)) {
502+ return emitError () << " illegal quantile value: " << quantile;
503+ }
504+ }
505+
506+ return success ();
507+ }
508+
509+ ArrayRef<double > QuantileQuantizedPerAxisType::getQuantiles () const {
510+ return getImpl ()->getQuantiles ();
511+ }
512+
381513CalibratedQuantizedType CalibratedQuantizedType::get (Type expressedType,
382514 double min, double max) {
383515 return Base::get (expressedType.getContext (), expressedType, min, max);
0 commit comments