Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,35 @@ def UniformQuantizedPerAxisType: DialectType<(type
}];
}

def QuantileQuantizedType: DialectType<(type
VarInt:$flags,
Type:$storageType,
Type:$expressedType,
Array<DoubleAPFloatList>:$quantiles,
DoubleAPFloat:$scale,
SignedVarInt:$zeroPoint,
SignedVarInt:$storageTypeMin,
SignedVarInt:$storageTypeMax
)>;

def QuantileQuantizedPerAxisType: DialectType<(type
VarInt:$flags,
Type:$storageType,
Type:$expressedType,
VarInt:$quantizedDimension,
SignedVarInt:$storageTypeMin,
SignedVarInt:$storageTypeMax,
Array<DoubleAPFloatList>:$quantiles,
Array<DoubleAPFloatList>:$scales,
Array<SignedVarIntList>:$zeroPoints
)> {
// Note: builder order differs from bytecode.
let cBuilder = [{
get<$_resultType>(context, flags, storageType, expressedType, quantiles, scales,
zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax)
}];
}

/// This enum contains marker codes used to indicate which attribute is
/// currently being decoded, and how it should be decoded. The order of these
/// codes should generally be unchanged, as any changes will inevitably break
Expand All @@ -93,7 +122,9 @@ def QuantDialectTypes : DialectTypes<"Quant"> {
AnyQuantizedTypeWithExpressedType,
CalibratedQuantizedType,
UniformQuantizedType,
UniformQuantizedPerAxisType
UniformQuantizedPerAxisType,
QuantileQuantizedType,
QuantileQuantizedPerAxisType
];
}

Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,18 @@ def quant_UniformQuantizedType :
CPred<"::llvm::isa<UniformQuantizedType>($_self)">,
"UniformQuantizedType">;

// An implementation of QuantileQuantizedType.
def quant_QuantileQuantizedType :
DialectType<Quantization_Dialect,
CPred<"::llvm::isa<QuantileQuantizedType>($_self)">,
"QuantileQuantizedType">;

// Predicate for detecting a container or primitive of UniformQuantizedType.
def quant_UniformQuantizedValueType :
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;

// Predicate for detecting a container or primitive of QuantileQuantizedType.
def quant_QuantileQuantizedValueType :
quant_TypedPrimitiveOrContainer<quant_QuantileQuantizedType>;

#endif // DIALECT_QUANT_QUANT_OPS_BASE_
124 changes: 124 additions & 0 deletions mlir/include/mlir/Dialect/Quant/QuantTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ struct QuantizedTypeStorage;
struct AnyQuantizedTypeStorage;
struct UniformQuantizedTypeStorage;
struct UniformQuantizedPerAxisTypeStorage;
struct QuantileQuantizedTypeStorage;
struct QuantileQuantizedPerAxisTypeStorage;
struct CalibratedQuantizedTypeStorage;

} // namespace detail
Expand Down Expand Up @@ -390,6 +392,128 @@ class UniformQuantizedPerAxisType
}
};

/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a
/// look up table array of quantile values.
///
/// Syntax synopsis:
/// Per-layer, all parameters expressed:
/// !quant<quantile[StorageType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
/// Per-layer, optional parameters omitted:
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// Quantiles: Quantile+
/// Quantile: A legal double value
/// Scale: A legal double value
/// ZeroPoint: An integer value
class QuantileQuantizedType
: public Type::TypeBase<QuantileQuantizedType, UniformQuantizedType,
detail::QuantileQuantizedTypeStorage> {
public:
using Base::Base;
using Base::getChecked;

static constexpr StringLiteral name = "quant.quantile";

/// Gets an instance of the type with all parameters specified but not
/// checked.
static QuantileQuantizedType get(unsigned flags, Type storageType,
Type expressedType,
ArrayRef<double> quantiles, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);

static QuantileQuantizedType
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> quantiles,
double scale, int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> quantiles,
double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax);

/// Gets the quantile values
ArrayRef<double> getQuantiles() const;

// Fixed point values are real numbers divided by a scale.
// Currently, only signed storage types are treated as fixed point.
// A fixed point value can be obtained from an affine value by subtracting
// the zeroPoint.
// In the future, this may be explicit versus implied by type and zeroPoint.
bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
};

/// Represents per-axis QuantileQuantizedType (also known as per-channel
/// quantization).
///
/// Syntax synopsis:
/// Per-axis, all parameters expressed:
/// !quant<quantile[StorageType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
/// Per-axis, optional parameters omitted:
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// QuantizedDim: An integer value
/// Quantiles: Quantile+
/// Quantile: A legal double value
/// QuantParams: (Scale ':' ZeroPoint)+
/// Scale: A legal double value
/// ZeroPoint: An integer value
class QuantileQuantizedPerAxisType
: public Type::TypeBase<QuantileQuantizedPerAxisType,
UniformQuantizedPerAxisType,
detail::QuantileQuantizedPerAxisTypeStorage> {
public:
using Base::Base;
using Base::getChecked;

static constexpr StringLiteral name = "quant.quantile_per_axis";

/// Gets an instance of the type with all parameters specified but not
/// checked.
static QuantileQuantizedPerAxisType
get(unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> quantiles, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);

/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static QuantileQuantizedPerAxisType
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> quantiles,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> quantiles,
ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);

/// Gets the quantile values
ArrayRef<double> getQuantiles() const;

/// Fixed point values are real numbers divided by a scale.
/// Currently, only signed storage types are treated as fixed point.
/// A fixed point value can be obtained from an affine value by subtracting
/// the zeroPoint.
/// In the future, this may be explicit versus implied by type and zeroPoint.
bool isFixedPoint() const {
return isSigned() && !llvm::is_contained(getZeroPoints(), 0);
}
};

/// A quantized type that infers its range from given min/max values.
///
/// Typical syntax:
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Quant/IR/QuantOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ using namespace mlir::quant::detail;

void QuantizationDialect::initialize() {
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
UniformQuantizedPerAxisType>();
UniformQuantizedPerAxisType, QuantileQuantizedType,
QuantileQuantizedPerAxisType>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
Expand Down
132 changes: 132 additions & 0 deletions mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,138 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
return getImpl()->quantizedDimension;
}

QuantileQuantizedType
QuantileQuantizedType::get(unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> quantiles, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax) {
return Base::get(storageType.getContext(), flags, storageType, expressedType,
quantiles, scale, zeroPoint, storageTypeMin, storageTypeMax);
}

QuantileQuantizedType QuantileQuantizedType::getChecked(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> quantiles,
double scale, int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax) {
return Base::getChecked(emitError, storageType.getContext(), flags,
storageType, expressedType, quantiles, scale,
zeroPoint, storageTypeMin, storageTypeMax);
}
LogicalResult
QuantileQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> quantiles,
double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(UniformQuantizedType::verify(emitError, flags, storageType,
expressedType, scale, zeroPoint,
storageTypeMin, storageTypeMax))) {
return failure();
}

const auto quantileArraySize = quantiles.size();
unsigned typeWidth{};
if (storageType.isa<IntegerType>()) {
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
} else if (storageType.isa<Float8E5M2Type>() ||
storageType.isa<Float8E4M3FNType>()) {
// Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth();
} else {
return emitError() << "illegal storage type, supported types are: integral "
"types, Float8E4M3FNType and Float8E5M2Type ";
}

const size_t expectedSize = 1 << typeWidth;
if (quantileArraySize != expectedSize) {
return emitError() << "quantiles array size needs to be equal to "
"2^(bit_size(storageType)), expected: "
<< expectedSize << ", found: " << quantileArraySize;
}

// Verify quantiles
for (double quantile : quantiles) {
if (std::isinf(quantile) || std::isnan(quantile)) {
return emitError() << "illegal quantile value: " << quantile;
}
}

return success();
}

ArrayRef<double> QuantileQuantizedType::getQuantiles() const {
return getImpl()->getQuantiles();
}

QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get(
unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> quantiles, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
return Base::get(storageType.getContext(), flags, storageType, expressedType,
quantiles, scales, zeroPoints, quantizedDimension,
storageTypeMin, storageTypeMax);
}

QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> quantiles,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax) {
return Base::getChecked(emitError, storageType.getContext(), flags,
storageType, expressedType, quantiles, scales,
zeroPoints, quantizedDimension, storageTypeMin,
storageTypeMax);
}

LogicalResult QuantileQuantizedPerAxisType::verify(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> quantiles,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax) {
if (failed(UniformQuantizedPerAxisType::verify(
emitError, flags, storageType, expressedType, scales, zeroPoints,
quantizedDimension, storageTypeMin, storageTypeMax))) {
return failure();
}

const auto quantileArraySize = quantiles.size();
unsigned typeWidth{};
if (storageType.isa<IntegerType>()) {
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
} else if (storageType.isa<Float8E5M2Type>() ||
storageType.isa<Float8E4M3FNType>()) {
// Both Float8E5M2Type and Float8E4M3FNType derive from FloatType.
typeWidth = llvm::dyn_cast<FloatType>(storageType).getWidth();
} else {
return emitError() << "illegal storage type, supported types are: integral "
"types, Float8E4M3FNType and Float8E5M2Type ";
}

const size_t expectedSize = 1 << typeWidth;
if (quantileArraySize != expectedSize) {
return emitError() << "quantiles array size needs to be equal to "
"2^(bit_size(storageType)), expected: "
<< expectedSize << ", found: " << quantileArraySize;
}

// Verify quantiles
for (double quantile : quantiles) {
if (std::isinf(quantile) || std::isnan(quantile)) {
return emitError() << "illegal quantile value: " << quantile;
}
}

return success();
}

ArrayRef<double> QuantileQuantizedPerAxisType::getQuantiles() const {
return getImpl()->getQuantiles();
}

CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
double min, double max) {
return Base::get(expressedType.getContext(), expressedType, min, max);
Expand Down
Loading