Skip to content

Commit 9ad9c9f

Browse files
andrey-golubevRoman-Pevnyi
authored andcommitted
[mlir] Allow accessing DialectResourceBlobManager::blobMap (cherry-pick) (#142)
Add a new API to access all blobs that are stored in the blob manager. The main purpose (as of now) is to allow users of dialect resources to iterate over all blobs, especially when the blobs are no longer used in IR (e.g. the operation that uses the blob is deleted) and thus cannot be easily accessed without manual tracking of keys.
1 parent 544ee1a commit 9ad9c9f

File tree

10 files changed

+228
-61
lines changed

10 files changed

+228
-61
lines changed

mlir/cmake/modules/AddMLIR.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,14 @@ function(add_mlir_interface interface)
196196
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
197197
endfunction()
198198

199+
# Declare a dialect in the include directory
200+
function(add_mlir_type_interface interface)
201+
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
202+
mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
203+
mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
204+
add_public_tablegen_target(MLIR${interface}IncGen)
205+
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
206+
endfunction()
199207

200208
# Generate Documentation
201209
function(add_mlir_doc doc_filename output_file output_directory command)

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ class FloatType : public Type {
9191
// Tablegen Type Declarations
9292
//===----------------------------------------------------------------------===//
9393

94+
// Include QuantizationInterface before BuiltinTypes to resolve dependencies
95+
#include "mlir/IR/QuantizationInterface.h"
96+
9497
#define GET_TYPEDEF_CLASSES
9598
#include "mlir/IR/BuiltinTypes.h.inc"
9699

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
include "mlir/IR/AttrTypeBase.td"
1818
include "mlir/IR/BuiltinDialect.td"
1919
include "mlir/IR/BuiltinTypeInterfaces.td"
20+
include "mlir/IR/QuantizationInterface.td"
2021

2122
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
2223
// This is to differentiate the types here with the ones in OpBase.td. We should
@@ -78,8 +79,8 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
7879
//===----------------------------------------------------------------------===//
7980

8081
// Base class for Builtin dialect float types.
81-
class Builtin_FloatType<string name, string mnemonic>
82-
: Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
82+
class Builtin_FloatType<string name, string mnemonic, list<Trait> traits = []>
83+
: Builtin_Type<name, mnemonic, traits, "::mlir::FloatType"> {
8384
let extraClassDeclaration = [{
8485
static }] # name # [{Type get(MLIRContext *context);
8586
}];
@@ -88,7 +89,8 @@ class Builtin_FloatType<string name, string mnemonic>
8889
//===----------------------------------------------------------------------===//
8990
// Float8E5M2Type
9091

91-
def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
92+
def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
93+
[QuantizationInterface]> {
9294
let summary = "8-bit floating point with 2 bit mantissa";
9395
let description = [{
9496
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
@@ -104,6 +106,23 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
104106

105107
Described in: https://arxiv.org/abs/2209.05433
106108
}];
109+
110+
let extraClassDeclaration = [{
111+
static Float8E5M2Type get(MLIRContext *context);
112+
113+
/// QuantizationInterface method implementations
114+
bool isStorageSigned() const { return true; }
115+
unsigned getStorageWidth() const { return 8; }
116+
int64_t getDefaultMaximum([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned integralWidth) const {
117+
return 448;
118+
}
119+
int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const {
120+
return -getDefaultMaximum(isSigned, integralWidth);
121+
}
122+
std::string printStorageType([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned storageWidth) const {
123+
return "f8E5M2";
124+
}
125+
}];
107126
}
108127

109128
//===----------------------------------------------------------------------===//
@@ -128,7 +147,8 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
128147
//===----------------------------------------------------------------------===//
129148
// Float8E4M3FNType
130149

131-
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
150+
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
151+
[QuantizationInterface]> {
132152
let summary = "8-bit floating point with 3 bit mantissa";
133153
let description = [{
134154
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
@@ -145,6 +165,23 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
145165

146166
Described in: https://arxiv.org/abs/2209.05433
147167
}];
168+
169+
let extraClassDeclaration = [{
170+
static Float8E4M3FNType get(MLIRContext *context);
171+
172+
/// QuantizationInterface method implementations
173+
bool isStorageSigned() const { return true; }
174+
unsigned getStorageWidth() const { return 8; }
175+
int64_t getDefaultMaximum([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned integralWidth) const {
176+
return 57344;
177+
}
178+
int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const{
179+
return -getDefaultMaximum(isSigned, integralWidth);
180+
}
181+
std::string printStorageType([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned storageWidth) const {
182+
return "f8E4M3FN";
183+
}
184+
}];
148185
}
149186

150187
//===----------------------------------------------------------------------===//
@@ -358,7 +395,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
358395
// IntegerType
359396
//===----------------------------------------------------------------------===//
360397

361-
def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
398+
def Builtin_Integer : Builtin_Type<"Integer", "integer",
399+
[QuantizationInterface]> {
362400
let summary = "Integer type with arbitrary precision up to a fixed limit";
363401
let description = [{
364402
Syntax:
@@ -415,6 +453,25 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
415453
/// Integer representation maximal bitwidth.
416454
/// Note: This is aligned with the maximum width of llvm::IntegerType.
417455
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
456+
457+
/// QuantizationInterface method implementations
458+
bool isStorageSigned() const { return !isUnsigned(); }
459+
unsigned getStorageWidth() const { return getWidth(); }
460+
int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const {
461+
if (isSigned) {
462+
return llvm::minIntN(integralWidth);
463+
}
464+
return 0;
465+
}
466+
int64_t getDefaultMaximum(bool isSigned, unsigned integralWidth) const {
467+
if (isSigned) {
468+
return llvm::maxIntN(integralWidth);
469+
}
470+
return llvm::maxUIntN(integralWidth);
471+
}
472+
std::string printStorageType(bool isSigned, unsigned storageWidth) const {
473+
return (isSigned ? "i" : "u") + std::to_string(storageWidth);
474+
}
418475
}];
419476
}
420477

mlir/include/mlir/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ add_mlir_interface(OpAsmInterface)
22
add_mlir_interface(SymbolInterfaces)
33
add_mlir_interface(RegionKindInterface)
44

5+
add_mlir_type_interface(QuantizationInterface)
6+
57
set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
68
mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
79
mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- QuantizationInterface.h - Quantile Float Interfaces --------*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef MLIR_IR_QuantizationInterface_H
11+
#define MLIR_IR_QuantizationInterface_H
12+
13+
#include "mlir/IR/Types.h"
14+
15+
// Forward declarations for the types we need in the implementation
16+
namespace mlir {
17+
class IntegerType;
18+
class FloatType;
19+
} // namespace mlir
20+
21+
#include "mlir/IR/QuantizationInterface.h.inc"
22+
23+
#endif // MLIR_IR_QuantizationInterface_H
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#ifndef MLIR_IR_QUANTIZATIONINTERFACE
2+
#define MLIR_IR_QUANTIZATIONINTERFACE
3+
4+
include "mlir/IR/OpBase.td"
5+
6+
def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
7+
let description = [{
8+
Interface for types that can be used as quantile storage types.
9+
This interface provides methods to determine storage characteristics
10+
like width and signedness for quantization purposes.
11+
}];
12+
let cppNamespace = "::mlir";
13+
14+
let methods = [
15+
InterfaceMethod<[{
16+
Get the storage type width in bits.
17+
Returns the number of bits used to store values of this type.
18+
}],
19+
"unsigned", "getStorageWidth", (ins)>,
20+
21+
InterfaceMethod<[{
22+
Check if the storage type is signed.
23+
Returns true if the type represents signed values, false for unsigned.
24+
}],
25+
"bool", "isStorageSigned", (ins)>,
26+
27+
InterfaceMethod<[{
28+
Get the default minimum value for the storage type.
29+
}],
30+
"int64_t", "getDefaultMinimum", (ins "bool":$isSigned, "unsigned":$integralWidth)>,
31+
32+
InterfaceMethod<[{
33+
Get the default maximum value for the storage type.
34+
}],
35+
"int64_t", "getDefaultMaximum", (ins "bool":$isSigned, "unsigned":$integralWidth)>,
36+
37+
InterfaceMethod<[{
38+
Get the name of the storage type.
39+
}],
40+
"std::string", "printStorageType", (ins "bool":$isSigned, "unsigned":$storageWidth)>
41+
];
42+
43+
}
44+
45+
#endif // MLIR_IR_QUANTIZATIONINTERFACE

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/IR/BuiltinTypes.h"
1414
#include "mlir/IR/MLIRContext.h"
15+
#include "mlir/IR/QuantizationInterface.h"
1516
#include "llvm/ADT/StringRef.h"
1617
#include "llvm/ADT/Twine.h"
1718
#include "llvm/Support/MathExtras.h"
@@ -32,7 +33,6 @@ LogicalResult
3233
QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
3334
unsigned flags, Type storageType, Type expressedType,
3435
int64_t storageTypeMin, int64_t storageTypeMax) {
35-
3636
bool isSigned =
3737
(flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
3838

@@ -46,16 +46,11 @@ QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
4646
}
4747

4848
int64_t defaultMin, defaultMax;
49-
if (storageType.isa<IntegerType>()) {
50-
const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth();
51-
defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width);
52-
defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width);
53-
} else if (storageType.isa<Float8E5M2Type>()) {
54-
defaultMin = QuantizedType::getDefaultMinimumForF8E5M2();
55-
defaultMax = QuantizedType::getDefaultMaximumForF8E5M2();
56-
} else if (storageType.isa<Float8E4M3FNType>()) {
57-
defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN();
58-
defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN();
49+
if (auto quantizationInterface =
50+
llvm::dyn_cast<QuantizationInterface>(storageType)) {
51+
const auto width = quantizationInterface.getStorageWidth();
52+
defaultMin = quantizationInterface.getDefaultMinimum(isSigned, width);
53+
defaultMax = quantizationInterface.getDefaultMaximum(isSigned, width);
5954
} else {
6055
return emitError() << "illegal storage type, supported types are: integral "
6156
"types, Float8E4M3FNType and Float8E5M2Type ";
@@ -75,17 +70,42 @@ Type QuantizedType::getStorageType() const {
7570
}
7671

7772
int64_t QuantizedType::getStorageTypeMin() const {
73+
Type storageType = static_cast<ImplType *>(impl)->storageType;
74+
75+
if (auto quantizationInterface =
76+
llvm::dyn_cast<QuantizationInterface>(storageType)) {
77+
unsigned storageWidth = quantizationInterface.getStorageWidth();
78+
bool isSigned = quantizationInterface.isStorageSigned();
79+
return quantizationInterface.getDefaultMinimum(isSigned, storageWidth);
80+
}
81+
7882
return static_cast<ImplType *>(impl)->storageTypeMin;
7983
}
8084

8185
int64_t QuantizedType::getStorageTypeMax() const {
86+
Type storageType = static_cast<ImplType *>(impl)->storageType;
87+
88+
if (auto quantizationInterface =
89+
llvm::dyn_cast<QuantizationInterface>(storageType)) {
90+
unsigned storageWidth = quantizationInterface.getStorageWidth();
91+
bool isSigned = quantizationInterface.isStorageSigned();
92+
return quantizationInterface.getDefaultMaximum(isSigned, storageWidth);
93+
}
94+
8295
return static_cast<ImplType *>(impl)->storageTypeMax;
8396
}
8497

8598
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
8699
// NOTE: If ever supporting non-integral storage types, some other scheme
87100
// for determining the width will be needed.
88-
return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
101+
Type storageType = static_cast<ImplType *>(impl)->storageType;
102+
103+
if (auto quantizationInterface =
104+
llvm::dyn_cast<QuantizationInterface>(storageType)) {
105+
return quantizationInterface.getStorageWidth();
106+
}
107+
108+
return storageType.getIntOrFloatBitWidth();
89109
}
90110

91111
Type QuantizedType::getExpressedType() const {
@@ -282,6 +302,7 @@ LogicalResult UniformQuantizedType::verify(
282302
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
283303
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
284304
int64_t storageTypeMin, int64_t storageTypeMax) {
305+
285306
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
286307
storageTypeMin, storageTypeMax))) {
287308
return failure();

0 commit comments

Comments
 (0)