From 5b88b095922deb5041b5f79dd3407784fcb719c2 Mon Sep 17 00:00:00 2001 From: Luca Sarti Date: Wed, 9 Oct 2024 15:53:48 +0000 Subject: [PATCH 1/2] Adding missing check in QuantileQuantizedPerAxisType::verify --- mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 0ee888308642..5a3500ec4278 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -504,7 +504,6 @@ LogicalResult QuantileQuantizedPerAxisType::verify( return failure(); } - const auto quantileArraySize = quantiles.size(); unsigned typeWidth{}; if (storageType.isa()) { typeWidth = llvm::dyn_cast(storageType).getWidth(); @@ -517,10 +516,17 @@ LogicalResult QuantileQuantizedPerAxisType::verify( "types, Float8E4M3FNType and Float8E5M2Type "; } - const size_t expectedSize = 1 << typeWidth; + const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1; + const size_t typeWidthSize = 1 << typeWidth; + const size_t expectedSize = + (storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize; + + const auto quantileArraySize = quantiles.size(); if (quantileArraySize != expectedSize) { return emitError() << "quantiles array size needs to be equal to " - "2^(bit_size(storageType)), expected: " + "2^(bit_size(storageType)), or (storageTypeMax - " + "storageTypeMin + 1) when max and min differ from " + "the type limits; expected: " << expectedSize << ", found: " << quantileArraySize; } From f2db4dd8435837bb0ba127ae6c850b1e6c784de6 Mon Sep 17 00:00:00 2001 From: Luca Sarti Date: Fri, 11 Oct 2024 07:49:36 +0000 Subject: [PATCH 2/2] Adding minor test for illegal quantile array size for per axis type --- mlir/test/Dialect/Quant/parse-quantile-invalid.mlir | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir b/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir index 4367173d2dff..005faa60e3cb 100644 --- a/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir +++ b/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir @@ -27,6 +27,15 @@ func.func @parse() -> !qalias { return %0 : !qalias } +// ----- +// Illegal quantile array size (per axis type) +// expected-error@+1 {{quantiles array size needs to be equal to 2^(bit_size(storageType)), or (storageTypeMax - storageTypeMin + 1) when max and min differ from the type limits; expected: 256, found: 2}} +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + // ----- // Unrecognized token: trailing // expected-error@+1 {{expected '>'}}