Skip to content

Commit 56ba930

Browse files
sartilsramasit
authored andcommitted
Adding missing check in QuantileQuantizedPerAxisType::verify (#66)
* Adding missing check in QuantileQuantizedPerAxisType::verify * Adding minor test for illegal quantile array size for per axis type
1 parent d73f0c6 commit 56ba930

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,6 @@ LogicalResult QuantileQuantizedPerAxisType::verifyInvariants(
540540
return failure();
541541
}
542542

543-
const auto quantileArraySize = quantiles.size();
544543
unsigned typeWidth{};
545544
if (storageType.isa<IntegerType>()) {
546545
typeWidth = llvm::dyn_cast<IntegerType>(storageType).getWidth();
@@ -553,10 +552,17 @@ LogicalResult QuantileQuantizedPerAxisType::verifyInvariants(
553552
"types, Float8E4M3FNType and Float8E5M2Type ";
554553
}
555554

556-
const size_t expectedSize = 1 << typeWidth;
555+
const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1;
556+
const size_t typeWidthSize = 1 << typeWidth;
557+
const size_t expectedSize =
558+
(storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize;
559+
560+
const auto quantileArraySize = quantiles.size();
557561
if (quantileArraySize != expectedSize) {
558562
return emitError() << "quantiles array size needs to be equal to "
559-
"2^(bit_size(storageType)), expected: "
563+
"2^(bit_size(storageType)), or (storageTypeMax - "
564+
"storageTypeMin + 1) when max and min differ from "
565+
"the type limits; expected: "
560566
<< expectedSize << ", found: " << quantileArraySize;
561567
}
562568

mlir/test/Dialect/Quant/parse-quantile-invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ func.func @parse() -> !qalias {
2727
return %0 : !qalias
2828
}
2929

30+
// -----
31+
// Illegal quantile array size (per axis type)
32+
// expected-error@+1 {{quantiles array size needs to be equal to 2^(bit_size(storageType)), or (storageTypeMax - storageTypeMin + 1) when max and min differ from the type limits; expected: 256, found: 2}}
33+
!qalias = !quant.quantile<i8:f16:f32:1, {-1.0,1.0}:{-2.0e+2,-0.99872:120}>
34+
func.func @parse() -> !qalias {
35+
%0 = "foo"() : () -> !qalias
36+
return %0 : !qalias
37+
}
38+
3039
// -----
3140
// Unrecognized token: trailing
3241
// expected-error@+1 {{expected '>'}}

0 commit comments

Comments
 (0)