Skip to content

Commit d2aee97

Browse files
authored
Adding missing check in QuantileQuantizedPerAxisType::verify (#66)
* Adding missing check in QuantileQuantizedPerAxisType::verify * Adding minor test for illegal quantile array size for per axis type
1 parent 22b15c8 commit d2aee97

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,6 @@ LogicalResult QuantileQuantizedPerAxisType::verify(
504504
return failure();
505505
}
506506

507-
const auto quantileArraySize = quantiles.size();
508507
unsigned typeWidth{};
509508
if (storageType.isa<IntegerType>()) {
510509
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
@@ -517,10 +516,17 @@ LogicalResult QuantileQuantizedPerAxisType::verify(
517516
"types, Float8E4M3FNType and Float8E5M2Type ";
518517
}
519518

520-
const size_t expectedSize = 1 << typeWidth;
519+
const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1;
520+
const size_t typeWidthSize = 1 << typeWidth;
521+
const size_t expectedSize =
522+
(storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize;
523+
524+
const auto quantileArraySize = quantiles.size();
521525
if (quantileArraySize != expectedSize) {
522526
return emitError() << "quantiles array size needs to be equal to "
523-
"2^(bit_size(storageType)), expected: "
527+
"2^(bit_size(storageType)), or (storageTypeMax - "
528+
"storageTypeMin + 1) when max and min differ from "
529+
"the type limits; expected: "
524530
<< expectedSize << ", found: " << quantileArraySize;
525531
}
526532

mlir/test/Dialect/Quant/parse-quantile-invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ func.func @parse() -> !qalias {
2727
return %0 : !qalias
2828
}
2929

30+
// -----
31+
// Illegal quantile array size (per axis type)
32+
// expected-error@+1 {{quantiles array size needs to be equal to 2^(bit_size(storageType)), or (storageTypeMax - storageTypeMin + 1) when max and min differ from the type limits; expected: 256, found: 2}}
33+
!qalias = !quant.quantile<i8:f16:f32:1, {-1.0,1.0}:{-2.0e+2,-0.99872:120}>
34+
func.func @parse() -> !qalias {
35+
%0 = "foo"() : () -> !qalias
36+
return %0 : !qalias
37+
}
38+
3039
// -----
3140
// Unrecognized token: trailing
3241
// expected-error@+1 {{expected '>'}}

0 commit comments

Comments
 (0)