diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 0ee888308642..5a3500ec4278 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -504,7 +504,6 @@ LogicalResult QuantileQuantizedPerAxisType::verify( return failure(); } - const auto quantileArraySize = quantiles.size(); unsigned typeWidth{}; if (storageType.isa()) { typeWidth = llvm::dyn_cast(storageType).getWidth(); @@ -517,10 +516,17 @@ LogicalResult QuantileQuantizedPerAxisType::verify( "types, Float8E4M3FNType and Float8E5M2Type "; } - const size_t expectedSize = 1 << typeWidth; + const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1; + const size_t typeWidthSize = 1 << typeWidth; + const size_t expectedSize = + (storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize; + + const auto quantileArraySize = quantiles.size(); if (quantileArraySize != expectedSize) { return emitError() << "quantiles array size needs to be equal to " - "2^(bit_size(storageType)), expected: " + "2^(bit_size(storageType)), or (storageTypeMax - " + "storageTypeMin + 1) when max and min differ from " + "the type limits; expected: " << expectedSize << ", found: " << quantileArraySize; } diff --git a/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir b/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir index 4367173d2dff..005faa60e3cb 100644 --- a/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir +++ b/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir @@ -27,6 +27,15 @@ func.func @parse() -> !qalias { return %0 : !qalias } +// ----- +// Illegal quantile array size (per axis type) +// expected-error@+1 {{quantiles array size needs to be equal to 2^(bit_size(storageType)), or (storageTypeMax - storageTypeMin + 1) when max and min differ from the type limits; expected: 256, found: 2}} +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + // ----- // Unrecognized token: trailing // expected-error@+1 {{expected '>'}}