Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/cmake/modules/AddMLIR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ function(add_mlir_interface interface)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()

# Declare a dialect in the include directory
function(add_mlir_type_interface interface)
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIR${interface}IncGen)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()

# Generate Documentation
function(add_mlir_doc doc_filename output_file output_directory command)
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class FloatType : public Type {
// Tablegen Type Declarations
//===----------------------------------------------------------------------===//

// Include QuantizationInterface before BuiltinTypes to resolve dependencies
#include "mlir/IR/QuantizationInterface.h"

#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"

Expand Down
67 changes: 62 additions & 5 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/QuantizationInterface.td"

// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
// This is to differentiate the types here with the ones in OpBase.td. We should
Expand Down Expand Up @@ -78,8 +79,8 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
//===----------------------------------------------------------------------===//

// Base class for Builtin dialect float types.
class Builtin_FloatType<string name, string mnemonic>
: Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
class Builtin_FloatType<string name, string mnemonic, list<Trait> traits = []>
: Builtin_Type<name, mnemonic, traits, "::mlir::FloatType"> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
Expand All @@ -88,7 +89,8 @@ class Builtin_FloatType<string name, string mnemonic>
//===----------------------------------------------------------------------===//
// Float8E5M2Type

def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
[QuantizationInterface]> {
let summary = "8-bit floating point with 2 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
Expand All @@ -104,6 +106,23 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {

Described in: https://arxiv.org/abs/2209.05433
}];

let extraClassDeclaration = [{
static Float8E5M2Type get(MLIRContext *context);

/// QuantizationInterface method implementations
bool isStorageSigned() const { return true; }
unsigned getStorageWidth() const { return 8; }
int64_t getDefaultMaximum([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned integralWidth) const {
return 448;
}
int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const {
return -getDefaultMaximum(isSigned, integralWidth);
}
std::string printStorageType([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned storageWidth) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we should have print method here. Isn't there more canonical way to stringify type name?
It's more getStorageTypeName

return "f8E5M2";
}
}];
}

//===----------------------------------------------------------------------===//
Expand All @@ -128,7 +147,8 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
//===----------------------------------------------------------------------===//
// Float8E4M3FNType

def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
[QuantizationInterface]> {
let summary = "8-bit floating point with 3 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
Expand All @@ -145,6 +165,23 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {

Described in: https://arxiv.org/abs/2209.05433
}];

let extraClassDeclaration = [{
static Float8E4M3FNType get(MLIRContext *context);

/// QuantizationInterface method implementations
bool isStorageSigned() const { return true; }
unsigned getStorageWidth() const { return 8; }
int64_t getDefaultMaximum([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned integralWidth) const {
return 57344;
}
int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const{
return -getDefaultMaximum(isSigned, integralWidth);
}
std::string printStorageType([[maybe_unused]] bool isSigned, [[maybe_unused]] unsigned storageWidth) const {
return "f8E4M3FN";
}
}];
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -358,7 +395,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
// IntegerType
//===----------------------------------------------------------------------===//

def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
def Builtin_Integer : Builtin_Type<"Integer", "integer",
[QuantizationInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
Expand Down Expand Up @@ -415,6 +453,25 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;

/// QuantizationInterface method implementations
bool isStorageSigned() const { return !isUnsigned(); }
unsigned getStorageWidth() const { return getWidth(); }
int64_t getDefaultMinimum(bool isSigned, unsigned integralWidth) const {
if (isSigned) {
return llvm::minIntN(integralWidth);
}
return 0;
}
int64_t getDefaultMaximum(bool isSigned, unsigned integralWidth) const {
if (isSigned) {
return llvm::maxIntN(integralWidth);
}
return llvm::maxUIntN(integralWidth);
}
std::string printStorageType(bool isSigned, unsigned storageWidth) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to pass these argument from outside?
Isn't there some existing interface we can use here to get these directly from this?

return (isSigned ? "i" : "u") + std::to_string(storageWidth);
}
}];
}

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ add_mlir_interface(OpAsmInterface)
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)

add_mlir_type_interface(QuantizationInterface)

set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
Expand Down
23 changes: 23 additions & 0 deletions mlir/include/mlir/IR/QuantizationInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===- QuantizationInterface.h - Quantile Float Interfaces --------*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_QuantizationInterface_H
#define MLIR_IR_QuantizationInterface_H

#include "mlir/IR/Types.h"

// Forward declarations for the types we need in the implementation
namespace mlir {
class IntegerType;
class FloatType;
} // namespace mlir

#include "mlir/IR/QuantizationInterface.h.inc"

#endif // MLIR_IR_QuantizationInterface_H
45 changes: 45 additions & 0 deletions mlir/include/mlir/IR/QuantizationInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef MLIR_IR_QUANTIZATIONINTERFACE
#define MLIR_IR_QUANTIZATIONINTERFACE

include "mlir/IR/OpBase.td"

def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
let description = [{
Interface for types that can be used as quantile storage types.
This interface provides methods to determine storage characteristics
like width and signedness for quantization purposes.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Get the storage type width in bits.
Returns the number of bits used to store values of this type.
}],
"unsigned", "getStorageWidth", (ins)>,

InterfaceMethod<[{
Check if the storage type is signed.
Returns true if the type represents signed values, false for unsigned.
}],
"bool", "isStorageSigned", (ins)>,

InterfaceMethod<[{
Get the default minimum value for the storage type.
}],
"int64_t", "getDefaultMinimum", (ins "bool":$isSigned, "unsigned":$integralWidth)>,

InterfaceMethod<[{
Get the default maximum value for the storage type.
}],
"int64_t", "getDefaultMaximum", (ins "bool":$isSigned, "unsigned":$integralWidth)>,

InterfaceMethod<[{
Get the name of the storage type.
}],
"std::string", "printStorageType", (ins "bool":$isSigned, "unsigned":$storageWidth)>
];

}

#endif // MLIR_IR_QUANTIZATIONINTERFACE
45 changes: 33 additions & 12 deletions mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/QuantizationInterface.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/MathExtras.h"
Expand All @@ -32,7 +33,6 @@ LogicalResult
QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {

bool isSigned =
(flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;

Expand All @@ -46,16 +46,11 @@ QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
}

int64_t defaultMin, defaultMax;
if (storageType.isa<IntegerType>()) {
const auto width = llvm::dyn_cast<IntegerType>(storageType).getWidth();
defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width);
defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width);
} else if (storageType.isa<Float8E5M2Type>()) {
defaultMin = QuantizedType::getDefaultMinimumForF8E5M2();
defaultMax = QuantizedType::getDefaultMaximumForF8E5M2();
} else if (storageType.isa<Float8E4M3FNType>()) {
defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN();
defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN();
if (auto quantizationInterface =
llvm::dyn_cast<QuantizationInterface>(storageType)) {
const auto width = quantizationInterface.getStorageWidth();
defaultMin = quantizationInterface.getDefaultMinimum(isSigned, width);
defaultMax = quantizationInterface.getDefaultMaximum(isSigned, width);
} else {
return emitError() << "illegal storage type, supported types are: integral "
"types, Float8E4M3FNType and Float8E5M2Type ";
Expand All @@ -75,17 +70,42 @@ Type QuantizedType::getStorageType() const {
}

int64_t QuantizedType::getStorageTypeMin() const {
Type storageType = static_cast<ImplType *>(impl)->storageType;

if (auto quantizationInterface =
llvm::dyn_cast<QuantizationInterface>(storageType)) {
unsigned storageWidth = quantizationInterface.getStorageWidth();
bool isSigned = quantizationInterface.isStorageSigned();
return quantizationInterface.getDefaultMinimum(isSigned, storageWidth);
}

return static_cast<ImplType *>(impl)->storageTypeMin;
}

int64_t QuantizedType::getStorageTypeMax() const {
Type storageType = static_cast<ImplType *>(impl)->storageType;

if (auto quantizationInterface =
llvm::dyn_cast<QuantizationInterface>(storageType)) {
unsigned storageWidth = quantizationInterface.getStorageWidth();
bool isSigned = quantizationInterface.isStorageSigned();
return quantizationInterface.getDefaultMaximum(isSigned, storageWidth);
}

return static_cast<ImplType *>(impl)->storageTypeMax;
}

unsigned QuantizedType::getStorageTypeIntegralWidth() const {
// NOTE: If ever supporting non-integral storage types, some other scheme
// for determining the width will be needed.
return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
Type storageType = static_cast<ImplType *>(impl)->storageType;

if (auto quantizationInterface =
llvm::dyn_cast<QuantizationInterface>(storageType)) {
return quantizationInterface.getStorageWidth();
}

return storageType.getIntOrFloatBitWidth();
}

Type QuantizedType::getExpressedType() const {
Expand Down Expand Up @@ -282,6 +302,7 @@ LogicalResult UniformQuantizedType::verify(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {

if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
storageTypeMin, storageTypeMax))) {
return failure();
Expand Down
Loading