Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def UniformQuantizedPerAxisType: DialectType<(type
def QuantileQuantizedType: DialectType<(type
VarInt:$flags,
Type:$storageType,
Type:$quantileType,
Type:$expressedType,
Array<DoubleAPFloatList>:$quantiles,
DoubleAPFloat:$scale,
Expand All @@ -95,6 +96,7 @@ def QuantileQuantizedType: DialectType<(type
def QuantileQuantizedPerAxisType: DialectType<(type
VarInt:$flags,
Type:$storageType,
Type:$quantileType,
Type:$expressedType,
VarInt:$quantizedDimension,
SignedVarInt:$storageTypeMin,
Expand All @@ -105,7 +107,7 @@ def QuantileQuantizedPerAxisType: DialectType<(type
)> {
// Note: builder order differs from bytecode.
let cBuilder = [{
get<$_resultType>(context, flags, storageType, expressedType, quantiles, scales,
get<$_resultType>(context, flags, storageType, quantileType, expressedType, quantiles, scales,
zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax)
}];
}
Expand Down
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,38 @@ def quant_UniformQuantizedType :
CPred<"::llvm::isa<UniformQuantizedType>($_self)">,
"UniformQuantizedType">;

// An implementation of UniformQuantizedPerAxisType.
def quant_UniformQuantizedPerAxisType :
DialectType<Quantization_Dialect,
CPred<"::llvm::isa<::mlir::quant::UniformQuantizedPerAxisType>($_self)">,
"UniformQuantizedPerAxisType">;

// An implementation of QuantileQuantizedType.
def quant_QuantileQuantizedType :
DialectType<Quantization_Dialect,
CPred<"::llvm::isa<QuantileQuantizedType>($_self)">,
"QuantileQuantizedType">;

// An implementation of QuantileQuantizedPerAxisType.
def quant_QuantileQuantizedPerAxisType :
DialectType<Quantization_Dialect,
CPred<"::llvm::isa<QuantileQuantizedPerAxisType>($_self)">,
"QuantileQuantizedPerAxisType">;

// Predicate for detecting a container or primitive of UniformQuantizedType.
def quant_UniformQuantizedValueType :
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;

// Predicate for detecting a container or primitive of UniformQuantizedPerAxisType.
def quant_UniformQuantizedPerAxisValueType :
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedPerAxisType>;

// Predicate for detecting a container or primitive of QuantileQuantizedType.
def quant_QuantileQuantizedValueType :
quant_TypedPrimitiveOrContainer<quant_QuantileQuantizedType>;

// Predicate for detecting a container or primitive of QuantileQuantizedPerAxisType.
def quant_QuantileQuantizedPerAxisValueType :
quant_TypedPrimitiveOrContainer<quant_QuantileQuantizedPerAxisType>;

#endif // DIALECT_QUANT_QUANT_OPS_BASE_
74 changes: 49 additions & 25 deletions mlir/include/mlir/Dialect/Quant/QuantTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ class UniformQuantizedType
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);

static bool classof(mlir::Type type);

/// Gets the scale term. The scale designates the difference between the real
/// values corresponding to consecutive quantized values differing by 1.
double getScale() const;
Expand Down Expand Up @@ -359,6 +361,8 @@ class UniformQuantizedPerAxisType
int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);

static bool classof(mlir::Type type);

/// Gets the quantization scales. The scales designate the difference between
/// the real values corresponding to consecutive quantized values differing
/// by 1. The ith scale corresponds to the ith slice in the
Expand Down Expand Up @@ -393,15 +397,17 @@ class UniformQuantizedPerAxisType
};

/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a
/// look up table array of quantile values.
/// look up table array of quantile values. The type of the data in the look up table is determined by
/// the quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64.
///
/// Syntax synopsis:
/// Per-layer, all parameters expressed:
/// !quant<quantile[StorageType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
/// !quant<quantile[StorageType:QuantileType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
/// Per-layer, optional parameters omitted:
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// Quantiles: Quantile+
/// Quantile: A legal double value
Expand All @@ -419,23 +425,32 @@ class QuantileQuantizedType
/// Gets an instance of the type with all parameters specified but not
/// checked.
static QuantileQuantizedType get(unsigned flags, Type storageType,
Type expressedType,
Type quantileType, Type expressedType,
ArrayRef<double> quantiles, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);

static QuantileQuantizedType
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> quantiles,
double scale, int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
Type storageType, Type quantileType, Type expressedType,
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> quantiles,
double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax);
Type quantileType, Type expressedType,
ArrayRef<double> quantiles, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);

static bool classof(mlir::Type type);

/// Gets the quantileType
Type getQuantileType() const;

/// Gets the quantileType bit width
unsigned getQuantileTypeIntegralWidth() const;

/// Gets the quantile values
ArrayRef<double> getQuantiles() const;
Expand All @@ -449,15 +464,17 @@ class QuantileQuantizedType
};

/// Represents per-axis QuantileQuantizedType (also known as per-channel
/// quantization).
/// quantization). The type of the data in the look up table is determined by the
/// quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64.
///
/// Syntax synopsis:
/// Per-axis, all parameters expressed:
/// !quant<quantile[StorageType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
/// !quant<quantile[StorageType:QuantileType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
/// Per-axis, optional parameters omitted:
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// QuantizedDim: An integer value
/// Quantiles: Quantile+
Expand All @@ -478,7 +495,7 @@ class QuantileQuantizedPerAxisType
/// Gets an instance of the type with all parameters specified but not
/// checked.
static QuantileQuantizedPerAxisType
get(unsigned flags, Type storageType, Type expressedType,
get(unsigned flags, Type storageType, Type quantileType, Type expressedType,
ArrayRef<double> quantiles, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
Expand All @@ -487,19 +504,26 @@ class QuantileQuantizedPerAxisType
/// Returns a nullptr convertible type on failure.
static QuantileQuantizedPerAxisType
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> quantiles,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax);
Type storageType, Type quantileType, Type expressedType,
ArrayRef<double> quantiles, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);

/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> quantiles,
ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
static LogicalResult
verify(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type quantileType, Type expressedType,
ArrayRef<double> quantiles, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);

static bool classof(mlir::Type type);

/// Gets the quantileType
Type getQuantileType() const;

/// Gets the quantileType bit width
unsigned getQuantileTypeIntegralWidth() const;

/// Gets the quantile values
ArrayRef<double> getQuantiles() const;
Expand Down
110 changes: 75 additions & 35 deletions mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ LogicalResult UniformQuantizedType::verify(
return success();
}

bool UniformQuantizedType::classof(mlir::Type type) {
return type.getTypeID() == mlir::TypeID::get<UniformQuantizedType>() ||
type.getTypeID() == mlir::TypeID::get<QuantileQuantizedType>();
}

double UniformQuantizedType::getScale() const { return getImpl()->scale; }

int64_t UniformQuantizedType::getZeroPoint() const {
Expand Down Expand Up @@ -366,6 +371,11 @@ LogicalResult UniformQuantizedPerAxisType::verify(
return success();
}

bool UniformQuantizedPerAxisType::classof(mlir::Type type) {
return type.getTypeID() == mlir::TypeID::get<UniformQuantizedPerAxisType>() ||
type.getTypeID() == mlir::TypeID::get<QuantileQuantizedPerAxisType>();
}

ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
return getImpl()->getScales();
}
Expand All @@ -379,36 +389,35 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
}

QuantileQuantizedType
QuantileQuantizedType::get(unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> quantiles, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax) {
return Base::get(storageType.getContext(), flags, storageType, expressedType,
quantiles, scale, zeroPoint, storageTypeMin, storageTypeMax);
QuantileQuantizedType::get(unsigned flags, Type storageType, Type quantileType,
Type expressedType, ArrayRef<double> quantiles,
double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
return Base::get(storageType.getContext(), flags, storageType, quantileType,
expressedType, quantiles, scale, zeroPoint, storageTypeMin,
storageTypeMax);
}

QuantileQuantizedType QuantileQuantizedType::getChecked(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> quantiles,
double scale, int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax) {
Type storageType, Type quantileType, Type expressedType,
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
return Base::getChecked(emitError, storageType.getContext(), flags,
storageType, expressedType, quantiles, scale,
zeroPoint, storageTypeMin, storageTypeMax);
storageType, quantileType, expressedType, quantiles,
scale, zeroPoint, storageTypeMin, storageTypeMax);
}
LogicalResult
QuantileQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> quantiles,
double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
LogicalResult QuantileQuantizedType::verify(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type quantileType, Type expressedType,
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(UniformQuantizedType::verify(emitError, flags, storageType,
expressedType, scale, zeroPoint,
storageTypeMin, storageTypeMax))) {
return failure();
}

const auto quantileArraySize = quantiles.size();
unsigned typeWidth{};
if (storageType.isa<IntegerType>()) {
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
Expand All @@ -421,10 +430,17 @@ QuantileQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
"types, Float8E4M3FNType and Float8E5M2Type ";
}

const size_t expectedSize = 1 << typeWidth;
const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1;
const size_t typeWidthSize = 1 << typeWidth;
const size_t expectedSize =
(storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize;

const auto quantileArraySize = quantiles.size();
if (quantileArraySize != expectedSize) {
return emitError() << "quantiles array size needs to be equal to "
"2^(bit_size(storageType)), expected: "
"2^(bit_size(storageType)), or (storageTypeMax - "
"storageTypeMin + 1) when max and min differ from "
"the type limits; expected: "
<< expectedSize << ", found: " << quantileArraySize;
}

Expand All @@ -438,38 +454,50 @@ QuantileQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

bool QuantileQuantizedType::classof(mlir::Type type) {
return type.getTypeID() == mlir::TypeID::get<QuantileQuantizedType>();
}

Type QuantileQuantizedType::getQuantileType() const {
return getImpl()->quantileType;
}

unsigned QuantileQuantizedType::getQuantileTypeIntegralWidth() const {
return getImpl()->getQuantileType().getIntOrFloatBitWidth();
}

ArrayRef<double> QuantileQuantizedType::getQuantiles() const {
return getImpl()->getQuantiles();
}

QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get(
unsigned flags, Type storageType, Type expressedType,
unsigned flags, Type storageType, Type quantileType, Type expressedType,
ArrayRef<double> quantiles, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
return Base::get(storageType.getContext(), flags, storageType, expressedType,
quantiles, scales, zeroPoints, quantizedDimension,
storageTypeMin, storageTypeMax);
return Base::get(storageType.getContext(), flags, storageType, quantileType,
expressedType, quantiles, scales, zeroPoints,
quantizedDimension, storageTypeMin, storageTypeMax);
}

QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> quantiles,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax) {
Type storageType, Type quantileType, Type expressedType,
ArrayRef<double> quantiles, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
return Base::getChecked(emitError, storageType.getContext(), flags,
storageType, expressedType, quantiles, scales,
zeroPoints, quantizedDimension, storageTypeMin,
storageTypeMax);
storageType, quantileType, expressedType, quantiles,
scales, zeroPoints, quantizedDimension,
storageTypeMin, storageTypeMax);
}

LogicalResult QuantileQuantizedPerAxisType::verify(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> quantiles,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax) {
Type storageType, Type quantileType, Type expressedType,
ArrayRef<double> quantiles, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(UniformQuantizedPerAxisType::verify(
emitError, flags, storageType, expressedType, scales, zeroPoints,
quantizedDimension, storageTypeMin, storageTypeMax))) {
Expand Down Expand Up @@ -506,6 +534,18 @@ LogicalResult QuantileQuantizedPerAxisType::verify(
return success();
}

bool QuantileQuantizedPerAxisType::classof(mlir::Type type) {
return type.getTypeID() == mlir::TypeID::get<QuantileQuantizedPerAxisType>();
}

Type QuantileQuantizedPerAxisType::getQuantileType() const {
return getImpl()->quantileType;
}

unsigned QuantileQuantizedPerAxisType::getQuantileTypeIntegralWidth() const {
return getImpl()->getQuantileType().getIntOrFloatBitWidth();
}

ArrayRef<double> QuantileQuantizedPerAxisType::getQuantiles() const {
return getImpl()->getQuantiles();
}
Expand Down
Loading