diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h index 10491f65d37af..e088eb31338dc 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h @@ -50,28 +50,66 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context); /// returned by getDefaultTargetEnv() if not provided. TargetEnvAttr lookupTargetEnvOrDefault(Operation *op); +/// A thin wrapper around the SpecificationVersion enum to represent +/// and provide utilities around the TOSA specification version. +class TosaSpecificationVersion { +public: + TosaSpecificationVersion() = default; + + TosaSpecificationVersion(uint32_t major, uint32_t minor) + : majorVersion(major), minorVersion(minor) {} + TosaSpecificationVersion(SpecificationVersion version) + : TosaSpecificationVersion(fromVersionEnum(version)) {} + + bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const { + return this->majorVersion == baseVersion.majorVersion && + this->minorVersion >= baseVersion.minorVersion; + } + + uint32_t getMajor() const { return majorVersion; } + uint32_t getMinor() const { return minorVersion; } + +private: + uint32_t majorVersion = 0; + uint32_t minorVersion = 0; + + static TosaSpecificationVersion + fromVersionEnum(SpecificationVersion version) { + switch (version) { + case SpecificationVersion::V_1_0: + return TosaSpecificationVersion(1, 0); + case SpecificationVersion::V_1_1_DRAFT: + return TosaSpecificationVersion(1, 1); + } + llvm_unreachable("Unknown TOSA version"); + } +}; + +TosaSpecificationVersion getMinVersion(const Profile &profile); +TosaSpecificationVersion getMinVersion(const Extension &extension); +TosaSpecificationVersion getMinVersion(const Level &level); + +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version); + /// This class represents the capability enabled in the target implementation /// such as profile, extension, and level. It's a wrapper class around /// tosa::TargetEnvAttr. class TargetEnv { public: TargetEnv() {} - explicit TargetEnv(Level level, const ArrayRef &profiles, - const ArrayRef &extensions) - : level(level) { - enabledProfiles.insert_range(profiles); - enabledExtensions.insert_range(extensions); - } - explicit TargetEnv(TargetEnvAttr targetAttr) - : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(), - targetAttr.getExtensions()) {} + static FailureOr + createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc); + + static LogicalResult verifyTargetInformation(TargetEnvAttr targetAttr, + Location targetAttrLoc); void addProfile(Profile p) { enabledProfiles.insert(p); } void addExtension(Extension e) { enabledExtensions.insert(e); } - // TODO implement the following utilities. - // Version getSpecVersion() const; + TosaSpecificationVersion getSpecVersion() const { + return specificationVersion; + } TosaLevel getLevel() const { if (level == Level::eightK) @@ -105,6 +143,17 @@ class TargetEnv { } private: + // Require target information is verified before constructing, via the use of + // `createTargetEnvFromAttr`. + explicit TargetEnv(SpecificationVersion specificationVersion, Level level, + const ArrayRef &profiles, + const ArrayRef &extensions) + : specificationVersion(specificationVersion), level(level) { + enabledProfiles.insert_range(profiles); + enabledExtensions.insert_range(extensions); + } + + TosaSpecificationVersion specificationVersion; Level level; llvm::SmallSet enabledProfiles; llvm::SmallSet enabledExtensions; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index 1f718accabd15..9eaf0847802cb 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -2,441 +2,812 @@ // `tools/genspec.py` in https://git.mlplatform.org/tosa/specification.git profileComplianceMap = { {"tosa.argmax", - {{{Profile::pro_int}, {{i8T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}}, + {{{Profile::pro_int}, {{{i8T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, i32T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.avg_pool2d", - {{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i8T, i32T, i8T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.conv2d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.conv3d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.depthwise_conv2d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.matmul", - {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i8T, i8T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp32T}, - {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp32T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.max_pool2d", - {{{Profile::pro_int}, {{i8T, i8T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose_conv2d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.clamp", - {{{Profile::pro_int}, {{i8T, i8T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.erf", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sigmoid", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.tanh", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.add", - {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.arithmetic_right_shift", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_and", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_or", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_xor", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.intdiv", - {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.logical_and", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.logical_left_shift", {{{Profile::pro_int, Profile::pro_fp}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, anyOf}}}, {"tosa.logical_right_shift", {{{Profile::pro_int, Profile::pro_fp}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, anyOf}}}, {"tosa.logical_or", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.logical_xor", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.maximum", - {{{Profile::pro_int}, {{i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.minimum", - {{{Profile::pro_int}, {{i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.mul", - {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}}, - {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.pow", - {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.sub", - {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, - {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.table", + {{{Profile::pro_int}, {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}}}}}, {"tosa.abs", - {{{Profile::pro_int}, {{i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_not", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}}, - {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}}, - {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.ceil", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.clz", + {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.cos", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.exp", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.floor", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.log", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.logical_not", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.negate", {{{Profile::pro_int}, - {{i8T, i8T, i8T, i8T}, - {i16T, i16T, i16T, i16T}, - {i32T, i32T, i32T, i32T}}}, + {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reciprocal", - {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.rsqrt", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sin", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.select", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, {{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.equal", - {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {{{Profile::pro_int}, + {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}}, {"tosa.greater", - {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {{{Profile::pro_int}, + {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}}, {"tosa.greater_equal", - {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {{{Profile::pro_int}, + {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_all", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.reduce_any", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.reduce_max", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_min", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_product", - {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_sum", - {{{Profile::pro_int}, {{i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.concat", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.pad", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, {{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reshape", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reverse", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.slice", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.tile", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.gather", {{{Profile::pro_int}, - {{i8T, i32T, i8T}, {i16T, i32T, i16T}, {i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, i32T, fp16T}, {fp32T, i32T, fp32T}}}}}, + {{{i8T, i32T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i32T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, i32T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.scatter", {{{Profile::pro_int}, - {{i8T, i32T, i8T, i8T}, - {i16T, i32T, i16T, i16T}, - {i32T, i32T, i32T, i32T}}}, + {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i32T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, i32T, fp16T, fp16T}, {fp32T, i32T, fp32T, fp32T}}}}}, + {{{fp16T, i32T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.resize", - {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i32T}, SpecificationVersion::V_1_0}, + {{i8T, i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.cast", {{{Profile::pro_int}, - {{boolT, i8T}, - {boolT, i16T}, - {boolT, i32T}, - {i8T, boolT}, - {i8T, i16T}, - {i8T, i32T}, - {i16T, boolT}, - {i16T, i8T}, - {i16T, i32T}, - {i32T, boolT}, - {i32T, i8T}, - {i32T, i16T}}}, - {{Profile::pro_fp}, - {{i8T, fp16T}, - {i8T, fp32T}, - {i16T, fp16T}, - {i16T, fp32T}, - {i32T, fp16T}, - {i32T, fp32T}, - {fp16T, i8T}, - {fp16T, i16T}, - {fp16T, i32T}, - {fp16T, fp32T}, - {fp32T, i8T}, - {fp32T, i16T}, - {fp32T, i32T}, - {fp32T, fp16T}}}}}, + {{{boolT, i8T}, SpecificationVersion::V_1_0}, + {{boolT, i16T}, SpecificationVersion::V_1_0}, + {{boolT, i32T}, SpecificationVersion::V_1_0}, + {{i8T, boolT}, SpecificationVersion::V_1_0}, + {{i8T, i16T}, SpecificationVersion::V_1_0}, + {{i8T, i32T}, SpecificationVersion::V_1_0}, + {{i16T, boolT}, SpecificationVersion::V_1_0}, + {{i16T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i32T}, SpecificationVersion::V_1_0}, + {{i32T, boolT}, SpecificationVersion::V_1_0}, + {{i32T, i8T}, SpecificationVersion::V_1_0}, + {{i32T, i16T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{i8T, fp16T}, SpecificationVersion::V_1_0}, + {{i8T, fp32T}, SpecificationVersion::V_1_0}, + {{i16T, fp16T}, SpecificationVersion::V_1_0}, + {{i16T, fp32T}, SpecificationVersion::V_1_0}, + {{i32T, fp16T}, SpecificationVersion::V_1_0}, + {{i32T, fp32T}, SpecificationVersion::V_1_0}, + {{fp16T, i8T}, SpecificationVersion::V_1_0}, + {{fp16T, i16T}, SpecificationVersion::V_1_0}, + {{fp16T, i32T}, SpecificationVersion::V_1_0}, + {{fp16T, fp32T}, SpecificationVersion::V_1_0}, + {{fp32T, i8T}, SpecificationVersion::V_1_0}, + {{fp32T, i16T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T}, SpecificationVersion::V_1_0}, + {{fp32T, fp16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.rescale", {{{Profile::pro_int}, - {{i8T, i8T, i8T, i8T}, - {i8T, i8T, i16T, i16T}, - {i8T, i8T, i32T, i32T}, - {i16T, i16T, i8T, i8T}, - {i16T, i16T, i16T, i16T}, - {i16T, i16T, i32T, i32T}, - {i32T, i32T, i8T, i8T}, - {i32T, i32T, i16T, i16T}, - {i32T, i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i8T, i8T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i32T, i32T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.const", {{{Profile::pro_int, Profile::pro_fp}, - {{boolT}, {i8T}, {i16T}, {i32T}}, + {{{boolT}, SpecificationVersion::V_1_0}, + {{i8T}, SpecificationVersion::V_1_0}, + {{i16T}, SpecificationVersion::V_1_0}, + {{i32T}, SpecificationVersion::V_1_0}}, anyOf}, - {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.identity", {{{Profile::pro_int, Profile::pro_fp}, - {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}, + {{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable", - {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_write", - {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_read", - {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, }; extensionComplianceMap = { {"tosa.argmax", - {{{Extension::int16}, {{i16T, i32T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}}, - {{Extension::bf16}, {{bf16T, i32T}}}}}, + {{{Extension::int16}, {{{i16T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.avg_pool2d", - {{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{Extension::int16}, + {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}, + SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}, + SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.conv2d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.conv3d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.depthwise_conv2d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, - {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, + {"tosa.fft2d", + {{{Extension::fft}, + {{{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.matmul", - {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}}, + {{{Extension::int16}, + {{{i16T, i16T, i16T, i16T, i48T}, SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}, - {fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}}}, + {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}, - {fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}}}, + {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}}, {{Extension::fp8e4m3, Extension::fp8e5m2}, - {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T}, - {fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T}, - {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T}, - {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}}, + {{{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}, allOf}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}}, + {{Extension::bf16}, + {{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.matmul_t_block_scaled", + {{{Extension::mxfp}, + {{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}}, {"tosa.max_pool2d", - {{{Extension::int16}, {{i16T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}}, + {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.rfft2d", + {{{Extension::fft}, + {{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose_conv2d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.clamp", - {{{Extension::int16}, {{i16T, i16T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}}, - {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}}, - {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, - {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, - {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, - {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.reduce_product", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.reduce_sum", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.erf", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sigmoid", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.tanh", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.add", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.maximum", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.minimum", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.mul", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.pow", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sub", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.table", + {{{Extension::int16}, + {{{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.abs", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.ceil", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.cos", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.exp", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.floor", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.log", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.negate", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reciprocal", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.rsqrt", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sin", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.select", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.equal", + {{{Extension::bf16}, + {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}}, + {"tosa.greater", + {{{Extension::bf16}, + {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}}, + {"tosa.greater_equal", + {{{Extension::bf16}, + {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_max", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_min", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_product", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_sum", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.concat", - {{{Extension::int16}, {{i16T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.pad", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reshape", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reverse", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.slice", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.tile", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.gather", - {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, i32T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, i32T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, i32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, i32T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.scatter", - {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, i32T, bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, i32T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.resize", - {{{Extension::int16}, {{i16T, i48T}, {i16T, i16T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::int16}, + {{{i16T, i48T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.cast", {{{Extension::bf16}, - {{i8T, bf16T}, - {i16T, bf16T}, - {i32T, bf16T}, - {bf16T, i8T}, - {bf16T, i16T}, - {bf16T, i32T}, - {bf16T, fp32T}, - {fp32T, bf16T}}}, + {{{i8T, bf16T}, SpecificationVersion::V_1_0}, + {{i16T, bf16T}, SpecificationVersion::V_1_0}, + {{i32T, bf16T}, SpecificationVersion::V_1_0}, + {{bf16T, i8T}, SpecificationVersion::V_1_0}, + {{bf16T, i16T}, SpecificationVersion::V_1_0}, + {{bf16T, i32T}, SpecificationVersion::V_1_0}, + {{bf16T, fp32T}, SpecificationVersion::V_1_0}, + {{fp32T, bf16T}, SpecificationVersion::V_1_0}}}, {{Extension::bf16, Extension::fp8e4m3}, - {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}}, + {{{bf16T, fp8e4m3T}, SpecificationVersion::V_1_0}, + {{fp8e4m3T, bf16T}, SpecificationVersion::V_1_0}}, allOf}, {{Extension::bf16, Extension::fp8e5m2}, - {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}}, + {{{bf16T, fp8e5m2T}, SpecificationVersion::V_1_0}, + {{fp8e5m2T, bf16T}, SpecificationVersion::V_1_0}}, allOf}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp16T}, - {fp8e4m3T, fp32T}, - {fp16T, fp8e4m3T}, - {fp32T, fp8e4m3T}}}, + {{{fp8e4m3T, fp16T}, SpecificationVersion::V_1_0}, + {{fp8e4m3T, fp32T}, SpecificationVersion::V_1_0}, + {{fp16T, fp8e4m3T}, SpecificationVersion::V_1_0}, + {{fp32T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp16T}, - {fp8e5m2T, fp32T}, - {fp16T, fp8e5m2T}, - {fp32T, fp8e5m2T}}}}}, + {{{fp8e5m2T, fp16T}, SpecificationVersion::V_1_0}, + {{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0}, + {{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0}, + {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.cast_from_block_scaled", + {{{Extension::bf16, Extension::mxfp}, + {{{fp4e2m1T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}, + {{Extension::mxfp}, + {{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}}, + {"tosa.cast_to_block_scaled", + {{{Extension::mxfp}, + {{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp32T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}}, + {{Extension::bf16, Extension::mxfp}, + {{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}}}, {"tosa.rescale", {{{Extension::int16}, - {{i48T, i48T, i8T, i8T}, - {i48T, i48T, i16T, i16T}, - {i48T, i48T, i32T, i32T}}}}}, + {{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i48T, i48T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i48T, i48T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.const", - {{{Extension::int4}, {{i4T}}}, - {{Extension::int16}, {{i48T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T}}}}}, + {{{Extension::int4}, {{{i4T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, {{{i48T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, {{{fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, {{{fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.identity", - {{{Extension::int4}, {{i4T, i4T}}}, - {{Extension::int16}, {{i48T, i48T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.variable", {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}}, + {{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.variable", + {{{Extension::variable}, + {{{i8T}, SpecificationVersion::V_1_0}, + {{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_write", - {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}}, + {{{Extension::variable}, + {{{i8T}, SpecificationVersion::V_1_0}, + {{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_read", - {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}}, + {{{Extension::variable}, + {{{i8T}, SpecificationVersion::V_1_0}, + {{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, }; + // End of auto-generated metadata diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 38cb2936ad8d9..8b6edef08db20 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -221,7 +221,7 @@ class Tosa_I32EnumAttr; def Tosa_EXT_DOUBLEROUND : I32EnumAttrCase<"doubleround", 9>; def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>; def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>; +def Tosa_EXT_MXFP : I32EnumAttrCase<"mxfp", 12>; def Tosa_ExtensionAttr : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [ Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, - Tosa_EXT_DYNAMIC + Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP ]> { let extraClassDeclaration = [{ static llvm::SmallVector getAllValues() { @@ -284,7 +285,7 @@ def Tosa_ExtensionAttr Extension::int16, Extension::int4, Extension::bf16, Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft, Extension::variable, Extension::controlflow, Extension::doubleround, - Extension::inexactround, Extension::dynamic + Extension::inexactround, Extension::dynamic, Extension::mxfp }; } }]; @@ -293,12 +294,6 @@ def Tosa_ExtensionAttr def Tosa_ExtensionArrayAttr : TypedArrayAttrBase; -def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>; -def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">; - -def Tosa_LevelAttr - : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>; - // The base class for defining op availability dimensions. class Availability { // The following are fields for controlling the generated C++ OpInterface. @@ -404,23 +399,46 @@ class Extension extensions> : Availability { let instance = "ref"; } +//===----------------------------------------------------------------------===// +// TOSA Levels +//===----------------------------------------------------------------------===// + +def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>; +def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">; + +def Tosa_LevelAttr + : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>; + +//===----------------------------------------------------------------------===// +// TOSA Specification versions +//===----------------------------------------------------------------------===// + +def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">; +def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">; + +def Tosa_SpecificationVersion : Tosa_I32EnumAttr< + "SpecificationVersion", "TOSA specification version", "specification_version", + [Tosa_V_1_0, Tosa_V_1_1_DRAFT]>; + //===----------------------------------------------------------------------===// // TOSA target environment. //===----------------------------------------------------------------------===// def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> { let summary = "Target environment information."; let parameters = ( ins + "SpecificationVersion": $specification_version, "Level": $level, ArrayRefParameter<"Profile">: $profiles, ArrayRefParameter<"Extension">: $extensions ); - let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` " + let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` " + "`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` " "`extensions` `=` `[` $extensions `]` `>`"; } //===----------------------------------------------------------------------===// -// Iterable attributes. +// Enum attributes. //===----------------------------------------------------------------------===// // Defined in `section 3. Enumerations` of the TOSA specification. @@ -446,6 +464,21 @@ def Tosa_RoundingModeAttr : Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode", [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>; +def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 1>; + +def Tosa_BlockSizeAttr + : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size", + [Tosa_BLOCK_SIZE_32]> { + let extraClassDeclaration = [{ + static unsigned int getBlockSizeValue(BlockSize blockSize) { + switch (blockSize) { + case BlockSize::BLOCK_SIZE_32: + return 32; + } + } + }]; +} + //===----------------------------------------------------------------------===// // TOSA Interfaces. diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 48759f2a3c9e8..2f5fe6b347e33 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -348,6 +348,40 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> { "operands attr-dict `:` functional-type(operands, results)"; } +//===----------------------------------------------------------------------===// +// Operator: matmul_t_block_scaled +//===----------------------------------------------------------------------===// +def Tosa_MatmulTBlockScaledOp : Tosa_InferShapedTypeOp<"matmul_t_block_scaled"> { + let summary = "Performs two dimensional matrix multiplications using block scaled tensors."; + + let description = [{ + Performs two dimensional matrix multiplications using block scaled tensors. The block + dimension is always the the last dimension of the tensor, so the result is effectively + a matrix multiply of A by the transposed B matrix. If the N dimension of input B is of + size 1, the B matrix will be broadcast. + }]; + + let arguments = (ins + Tosa_MXFPDataTensor3D:$a_data, + Tosa_MXFPScaleTensor3D:$a_scale, + Tosa_MXFPDataTensor3D:$b_data, + Tosa_MXFPScaleTensor3D:$b_scale, + Tosa_BlockSizeAttr:$block_size + ); + + let results = (outs + Tosa_Tensor3D:$output_data + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_MXFP]> + ]; +} + //===----------------------------------------------------------------------===// // Operator: max_pool2d //===----------------------------------------------------------------------===// @@ -2438,6 +2472,69 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape, let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// Operator: cast_from_block_scaled +//===----------------------------------------------------------------------===// +def Tosa_CastFromBlockScaledOp: Tosa_InferShapedTypeOp<"cast_from_block_scaled"> { + let summary = "Apply scales from a scale tensor to the values in a value tensor"; + + let description = [{ + Apply the scales from a scale tensor to the values in a value tensor, casting + the result to the output type. The block dimension must be the last dimension + of the tensor. + }]; + + let arguments = (ins + Tosa_MXFPDataTensorAtLeast1D:$input_data, + Tosa_MXFPScaleTensorAtLeast1D:$input_scale, + Tosa_BlockSizeAttr:$block_size + ); + + let results = (outs + Tosa_TensorAtLeast1D: $output_data + ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>, + ]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// Operator: cast_to_block_scaled +//===----------------------------------------------------------------------===// +def Tosa_CastToBlockScaledOp : Tosa_InferShapedTypeOp<"cast_to_block_scaled"> { + let summary = "Calculate scale tensor values per block, output to separate scale and data tensors."; + + let description = [{ + Calculate a scale value per block of input values and use that to calculate + scaled data values from an input tensor. The output tensors are cast to the + specified scale and value types. The block dimension will be the last dimension + of the tensor. + }]; + + let arguments = (ins + Tosa_TensorAtLeast1D:$input_data, + Tosa_BlockSizeAttr:$block_size + ); + + let results = (outs + Tosa_MXFPDataTensorAtLeast1D:$output_data, + Tosa_MXFPScaleTensorAtLeast1D:$output_scale + ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]> + ]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Operator: rescale //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h index 8f5c72bc5f7a9..4a899e3c787e6 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -36,12 +36,15 @@ enum CheckCondition { allOf }; +using VersionedTypeInfo = + std::pair, SpecificationVersion>; + template struct OpComplianceInfo { // Certain operations require multiple modes enabled. // e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3. SmallVector mode; - SmallVector> operandTypeInfoSet; + SmallVector operandTypeInfoSet; CheckCondition condition = CheckCondition::anyOf; }; @@ -76,7 +79,7 @@ class ProfileInfoDepot { LogicalResult populatationDispatch(Operation *op); - LogicalResult populateProfileInfo(ValueRange operands, Value output); + LogicalResult populateProfileInfo(ValueRange operands, ValueRange output); // Base template @@ -130,9 +133,8 @@ class TosaProfileCompliance { // Find the required profiles or extensions from the compliance info according // to the operand type combination. template - SmallVector findMatchedProfile(Operation *op, - SmallVector> compInfo, - CheckCondition &condition); + OpComplianceInfo + findMatchedEntry(Operation *op, SmallVector> compInfo); SmallVector getCooperativeProfiles(Extension ext) { switch (ext) { @@ -145,6 +147,7 @@ class TosaProfileCompliance { case Extension::fp8e4m3: case Extension::fp8e5m2: case Extension::fft: + case Extension::mxfp: return {Profile::pro_fp}; case Extension::variable: case Extension::controlflow: @@ -168,8 +171,7 @@ class TosaProfileCompliance { private: template - FailureOr> getOperatorDefinition(Operation *op, - CheckCondition &condition); + FailureOr> getOperatorDefinition(Operation *op); OperationProfileComplianceMap profileComplianceMap; OperationExtensionComplianceMap extensionComplianceMap; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 93ab120339d55..93843e86fd378 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -84,6 +84,10 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>, def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat], "number">; +def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN], + "micro-scaling format number">; +def Tosa_MXFPScaleNumber : AnyTypeOf<[F8E8M0FNU], "micro-scaling format scale number">; + //===----------------------------------------------------------------------===// // TOSA Tensor Conformance //===----------------------------------------------------------------------===// @@ -187,6 +191,25 @@ def Tosa_Int32Tensor2D : AnyTypeOf<[ def Tosa_TensorAtLeast1D : AnyTypeOf<[ Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">; +def Tosa_MXFPDataTensor3D : AnyTypeOf<[ + TosaUnrankedTensorOf<[Tosa_MXFPNumber]>, + TosaTensorRankOf<[Tosa_MXFPNumber], [3]> +]>; +def Tosa_MXFPScaleTensor3D : AnyTypeOf<[ + TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>, + TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]> +]>; +def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[ + TosaUnrankedTensorOf<[Tosa_MXFPNumber]>, + TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>], + "tosa-conformant tensor of at least rank 1", "::mlir::TensorType" +>; +def Tosa_MXFPScaleTensorAtLeast1D : AnyTypeOf<[ + TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>, + TosaRankedTensorOf<[Tosa_MXFPScaleNumber], [AtLeastRankOne]>], + "tosa-conformant tensor of at least rank 1", "::mlir::TensorType" +>; + //===----------------------------------------------------------------------===// // Generic scalar, vector, or tensor of a particular type. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 6ae19d81e0820..14b00b04ccc18 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> { ]; let options = [ + Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion", + /*default=*/"mlir::tosa::SpecificationVersion::V_1_0", + "The specification version that TOSA operators should conform to.", + [{::llvm::cl::values( + clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"), + clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft") + )}]>, Option<"level", "level", "mlir::tosa::Level", /*default=*/"mlir::tosa::Level::eightK", "The TOSA level that operators should conform to. A TOSA level defines " diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index 5aad67173cc61..32eb286531d28 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -7,10 +7,101 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "llvm/Support/FormatVariadic.h" namespace mlir { namespace tosa { +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { + return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); +} + +TosaSpecificationVersion getMinVersion(const Profile &profile) { + switch (profile) { + case Profile::pro_int: + case Profile::pro_fp: + return TosaSpecificationVersion(1, 0); + case Profile::none: + return TosaSpecificationVersion(0, 0); + } + llvm_unreachable("Unknown TOSA profile"); +} + +TosaSpecificationVersion getMinVersion(const Extension &extension) { + switch (extension) { + case Extension::int16: + case Extension::int4: + case Extension::bf16: + case Extension::fp8e4m3: + case Extension::fp8e5m2: + case Extension::fft: + case Extension::variable: + case Extension::controlflow: + case Extension::doubleround: + case Extension::inexactround: + case Extension::dynamic: + return TosaSpecificationVersion(1, 0); + case Extension::mxfp: + return TosaSpecificationVersion(1, 1); + case Extension::none: + return TosaSpecificationVersion(0, 0); + } + llvm_unreachable("Unknown TOSA extension"); +} + +TosaSpecificationVersion getMinVersion(const Level &level) { + switch (level) { + case Level::eightK: + case Level::none: + return TosaSpecificationVersion(1, 0); + } + llvm_unreachable("Unknown TOSA level"); +} + +FailureOr +TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr, + Location targetEnvAttrLoc) { + if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc))) + return failure(); + + return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(), + targetAttr.getProfiles(), targetAttr.getExtensions()); +} + +LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr, + Location targetAttrLoc) { + TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion()); + + const auto isCompatibleWithTargetVersion = + [&](const auto &targetEnum, Location targetAttrLoc, + StringRef enumName) -> LogicalResult { + const TosaSpecificationVersion minRequiredVersion = + getMinVersion(targetEnum); + if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion)) + return emitError(targetAttrLoc, enumName) + << " '" << stringifyEnum(targetEnum) + << "' is not compatible with the target version " + << stringifyVersion(targetVersion) + << ", minimum required version is " + << stringifyVersion(minRequiredVersion); + return success(); + }; + + for (const auto &profile : targetAttr.getProfiles()) + if (failed( + isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile"))) + return failure(); + for (const auto &extension : targetAttr.getExtensions()) + if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc, + "extension"))) + return failure(); + if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc, + "level"))) + return failure(); + + return success(); +} + TargetEnvAttr lookupTargetEnv(Operation *op) { while (op) { op = SymbolTable::getNearestSymbolTable(op); @@ -27,7 +118,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) { } TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) { - return TargetEnvAttr::get(context, Level::eightK, + return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK, {Profile::pro_int, Profile::pro_fp}, {}); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index c51b5e9cbfc78..2e97f6f2989e3 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser, } } + // special handling: block_size accepts a *bare* BlockSizeMode enum + if constexpr (std::is_same_v) { + if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeBlockSize(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid block_size value: " << kw; + auto attr = BlockSizeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + // Default path: parse any normal attribute literal, including fully qualified // enum keyword Attribute attr; @@ -357,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) { result.operands))) return failure(); - result.addTypes(fnTy.getResult(0)); + result.addTypes(fnTy.getResults()); result.addAttributes(attrs); return success(); @@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) { } else if (auto nanPropagationModeAttr = dyn_cast(attr)) { parser << nanPropagationModeAttr.getValue(); + } else if (auto blockSizeAttr = dyn_cast(attr)) { + parser << blockSizeAttr.getValue(); } else { parser.printAttribute(attr); } @@ -508,6 +523,33 @@ void ReduceMinOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } +ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling(parser, result); +} + +void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling(parser, result); +} + +void CastFromBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling(parser, result); +} + +void CastToBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// @@ -933,32 +975,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { // verify that inType and outType have same element types template -static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { - auto inputType = llvm::dyn_cast(inType); - auto outputType = llvm::dyn_cast(outType); - if (!inputType) { - op.emitOpError("expect shaped tensor for input, got ") << inType; +static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, + StringRef aName = "input", + StringRef bName = "output") { + auto aTType = llvm::dyn_cast(aType); + auto bTType = llvm::dyn_cast(bType); + if (!aTType) { + op.emitOpError("expect shaped tensor for") << aName << ", got " << aType; return failure(); } - if (!outputType) { - op.emitOpError("expect shaped tensor for output, got ") << outType; + if (!bTType) { + op.emitOpError("expect shaped tensor for") << bName << ", got" << bType; return failure(); } - auto inputElementType = inputType.getElementType(); - auto outputElementType = outputType.getElementType(); - auto inputQuantType = - llvm::dyn_cast(inputElementType); - auto outputQuantType = - llvm::dyn_cast(outputElementType); - if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) && - (outputElementType.isIntOrIndexOrFloat() || outputQuantType) && - inputElementType != outputElementType) { + auto aElementType = aTType.getElementType(); + auto bElementType = bTType.getElementType(); + auto aQuantType = + llvm::dyn_cast(aElementType); + auto bQuantType = + llvm::dyn_cast(bElementType); + if ((aElementType.isIntOrIndexOrFloat() || aQuantType) && + (bElementType.isIntOrIndexOrFloat() || bQuantType) && + aElementType != bElementType) { // only check if both element types are int/index/float/UniformQuantized // eg, not sure how to check quant::QuantizedType // this happens in test_conv2d_q_grouped_convolution in // tfl-to-tosa-pipeline.mlir - op.emitOpError("expect input and output to have same element type, got ") - << inputElementType << " and " << outputElementType; + op.emitOpError("expect ") + << aName << " and " << bName << " to have same element type, got " + << aElementType << " and " << bElementType; return failure(); } return success(); @@ -1846,6 +1891,162 @@ LogicalResult MatMulOp::verify() { return success(); } +LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + MatmulTBlockScaledOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + SmallVector outShape(3, ShapedType::kDynamic); + + const auto aDataShape = cast(adaptor.getAData().getType()); + if (aDataShape.hasRank()) { + outShape[0] = aDataShape.getDimSize(0); + outShape[1] = aDataShape.getDimSize(1); + } + + const auto aScaleShape = cast(adaptor.getAScale().getType()); + if (aScaleShape.hasRank()) { + outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0) + : outShape[0]; + outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1) + : outShape[1]; + } + + // If B batch size is 1, it is broadcast across A's batch size + const auto bDataShape = cast(adaptor.getBData().getType()); + if (bDataShape.hasRank()) { + const int64_t bDataBatchSize = bDataShape.getDimSize(0); + if (bDataBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0]; + outShape[2] = bDataShape.getDimSize(1); + } + + const auto bScaleShape = cast(adaptor.getBScale().getType()); + if (bScaleShape.hasRank()) { + const int64_t bScaleBatchSize = bScaleShape.getDimSize(0); + if (bScaleBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0]; + outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1) + : outShape[2]; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); +} + +LogicalResult MatmulTBlockScaledOp::verify() { + // Verify same input data types + const Type aDataType = getAData().getType(); + const Type bDataType = getBData().getType(); + if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data", + "B_data"))) + return failure(); + + auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim, + const StringRef operandName, + const StringRef dimName) -> LogicalResult { + if (ShapedType::isDynamic(currDim)) { + currDim = newDim; + return success(); + } else if (ShapedType::isStatic(newDim) && currDim != newDim) { + return emitOpError("expected ") + << dimName << " of " << operandName << " to match size " << currDim + << ", got " << newDim; + } + return success(); + }; + + // Verify input shape compatibility + int64_t N = ShapedType::kDynamic; + int64_t D = ShapedType::kDynamic; + int64_t H = ShapedType::kDynamic; + int64_t W = ShapedType::kDynamic; + int64_t C = ShapedType::kDynamic; + int64_t multiplesOfC = ShapedType::kDynamic; + + const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType); + if (aDataShape.hasRank()) { + N = aDataShape.getDimSize(0); + H = aDataShape.getDimSize(1); + C = aDataShape.getDimSize(2); + } + + const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType()); + if (aScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale", + "batch")) || + failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale", + "height"))) + return failure(); + multiplesOfC = aScaleShape.getDimSize(2); + } + + const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType); + if (bDataShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data", + "batch")) || + failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data", + "channels"))) + return failure(); + W = bDataShape.getDimSize(1); + } + + const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType()); + if (bScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale", + "batch")) || + failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale", + "width")) || + failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2), + "b_scale", "C/block_size"))) + return failure(); + } + + // Verify batch size is broadcast compatible + if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1) + return emitOpError("expect B matrix batch size to be broadcast compatible " + "with A, got D=") + << D << " vs N=" << N; + + // Verify C is a multiple of block size + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + if (ShapedType::isStatic(C) && C % blockSize != 0) + return emitOpError("expect C to be a multiple of block size, got C=") + << C << ", block_size=" << blockSize; + + // Verify multiplesOfC is C / block size + if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) && + multiplesOfC != C / blockSize) + return emitOpError( + "expect scale operands dimension 2 to equal C/block_size (") + << C << "/" << blockSize << ")" + << ", got " << multiplesOfC; + + // Verify output shape + N = ShapedType::isDynamic(N) ? D : N; + const SmallVector expectedOutputShape = {N, H, W}; + const auto outputType = cast(getResult().getType()); + if (outputType.hasRank() && + failed( + verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) { + InFlightDiagnostic opError = emitOpError("expected output shape "); + auto stringifyDim = [&](int64_t d) { + if (ShapedType::isDynamic(d)) + opError << "?"; + else + opError << d; + }; + llvm::interleaveComma(outputType.getShape(), opError, stringifyDim); + opError << " to be compatible with expected output shape "; + llvm::interleaveComma(expectedOutputShape, opError, stringifyDim); + return opError; + } + + return success(); +} + LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, PadOp::Adaptor adaptor, @@ -3761,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents( return success(); } +LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + CastFromBlockScaledOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + const ShapeAdaptor inputShape(adaptor.getInputData().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + return success(); +} + +LogicalResult CastFromBlockScaledOp::verify() { + const Type inputDataType = getInputData().getType(); + const Type outputDataType = getResult().getType(); + if (failed(verifyCompatibleShape(inputDataType, outputDataType))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "output_data (" << outputDataType << ")"; + + const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); + + if (inputDataShape.hasRank()) { + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + const int64_t inputDataLastDim = + inputDataShape.getDimSize(inputDataShape.getRank() - 1); + if (inputDataLastDim % blockSize != 0) + return emitOpError() << "expect last dimension of input_data (" + << inputDataLastDim + << ") to be divisible by block_size (" << blockSize + << ")"; + + const Type inputScaleType = getInputScale().getType(); + const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType); + + if (inputScaleShape.hasRank()) { + SmallVector inputDataDims, inputScaleDims; + inputDataShape.getDims(inputDataDims); + inputScaleShape.getDims(inputScaleDims); + + if (inputDataDims.size() != inputScaleDims.size() || + failed(verifyCompatibleShape( + ArrayRef(inputDataDims).drop_back(1), + ArrayRef(inputScaleDims).drop_back(1)))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "input_scale (" << inputScaleType + << ") except for the last dimension"; + + const SmallVector dimsToCheck{inputDataLastDim / blockSize, + inputScaleDims.back()}; + if (ShapedType::isStatic(inputDataLastDim) && + failed(verifyCompatibleDims(dimsToCheck))) + return emitOpError() + << "expect last dimension of input_scale (" + << inputScaleDims.back() + << ") to be equal to last dimension of input_data / block_size (" + << inputDataDims.back() / blockSize << ")"; + } + } + + return success(); +} + +LogicalResult CastToBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + CastToBlockScaledOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + const ShapeAdaptor inputShape(adaptor.getInputData().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + if (!inputShape.hasRank()) + return success(); + + // Calculate output_scale shape if ranked input provided + SmallVector outputScaleShape; + inputShape.getDims(outputScaleShape); + const int64_t lastDimLoc = inputShape.getRank() - 1; + const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc); + if (ShapedType::isStatic(lastDimSize)) { + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize()); + outputScaleShape[lastDimLoc] = lastDimSize / blockSize; + } + inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape)); + return success(); +} + +LogicalResult CastToBlockScaledOp::verify() { + const Type inputDataType = getInputData().getType(); + const Type outputDataType = getResult(0).getType(); + if (failed(verifyCompatibleShape(inputDataType, outputDataType))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "output_data (" << outputDataType << ")"; + + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); + if (inputDataShape.hasRank()) { + const int64_t inputDataLastDim = + inputDataShape.getDimSize(inputDataShape.getRank() - 1); + if (ShapedType::isStatic(inputDataLastDim) && + inputDataLastDim % blockSize != 0) + return emitOpError() << "expect last dimension of input_data (" + << inputDataLastDim + << ") to be divisible by block_size (" << blockSize + << ")"; + } + + const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType); + const Type outputScaleType = getResult(1).getType(); + const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType); + if (outputDataShape.hasRank() && outputScaleShape.hasRank()) { + SmallVector outputDataDims, outputScaleDims; + outputDataShape.getDims(outputDataDims); + outputScaleShape.getDims(outputScaleDims); + + if (outputDataDims.size() != outputScaleDims.size() || + failed(verifyCompatibleShape( + ArrayRef(outputDataDims).drop_back(1), + ArrayRef(outputScaleDims).drop_back(1)))) + return emitOpError() << "require compatible shapes for output_data (" + << outputDataType << ") and " + << "output_scale (" << outputScaleType + << ") except for the last dimension"; + + const int64_t outputDataLastDim = outputDataDims.back(); + const SmallVector dimsToCheck{outputDataLastDim / blockSize, + outputScaleDims.back()}; + if (ShapedType::isStatic(outputDataLastDim) && + failed(verifyCompatibleDims(dimsToCheck))) + return emitOpError() + << "expect last dimension of output_scale (" + << outputScaleDims.back() + << ") to be equal to last dimension of output_data / block_size (" + << outputDataDims.back() / blockSize << ")"; + } + + return success(); +} + LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, IfOp::Adaptor adaptor, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp index bcb880a808b36..a0661e4ee0bd2 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -61,8 +61,8 @@ class TosaAttachTarget ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); - const auto targetEnvAttr = - TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions); + const auto targetEnvAttr = TargetEnvAttr::get( + ctx, specificationVersion, level, selectedProfiles, selectedExtensions); mod->setAttr(TargetEnvAttr::name, targetEnvAttr); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 20f9333e7c951..92d5bac9c2653 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -25,6 +25,12 @@ TosaProfileCompliance::TosaProfileCompliance() { const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8}; const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8}; + // micro-scaling formats + const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6}; + const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6}; + const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4}; + const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8}; + // The profile-based compliance content below is auto-generated by a script // in https://git.mlplatform.org/tosa/specification.git #include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc" @@ -44,10 +50,11 @@ TosaProfileCompliance::getProfileComplianceMap() { // Base populating function LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands, - Value output) { - for (auto operand : operands) + ValueRange outputs) { + for (const auto &operand : operands) addValue(operand); - addValue(output); + for (const auto &output : outputs) + addValue(output); return success(); } @@ -169,23 +176,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) { return success(); } -template <> -LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) { - addValue(op.getInputReal()); - addValue(op.getInputImag()); - addValue(op.getOutputReal()); - addValue(op.getOutputImag()); - return success(); -} - -template <> -LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) { - addValue(op.getInputReal()); - addValue(op.getOutputReal()); - addValue(op.getOutputImag()); - return success(); -} - template <> LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) { addValue(op.getOnTrue()); @@ -239,7 +229,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // This helper function populates the info for all operands. #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \ if (isa(op)) { \ - return populateProfileInfo(op->getOperands(), op->getResult(0)); \ + return populateProfileInfo(op->getOperands(), op->getResults()); \ } // Skip irrelevant operands when they are independent and not tied to any @@ -250,8 +240,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { POPULATE_PROFILE_INFO_CUSTOM(Conv3D) POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D) POPULATE_PROFILE_INFO_CUSTOM(Mul) - POPULATE_PROFILE_INFO_CUSTOM(FFT2d) - POPULATE_PROFILE_INFO_CUSTOM(RFFT2d) POPULATE_PROFILE_INFO_CUSTOM(Concat) POPULATE_PROFILE_INFO_CUSTOM(Pad) POPULATE_PROFILE_INFO_CUSTOM(Reshape) @@ -269,7 +257,12 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // For the most of tosa operators, all operands are profile/extension related // and hence are all considered in this profile-based compilance check. + POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled) + POPULATE_PROFILE_INFO_COMMON(FFT2d) + POPULATE_PROFILE_INFO_COMMON(RFFT2d) POPULATE_PROFILE_INFO_COMMON(Cast) + POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled) + POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled) POPULATE_PROFILE_INFO_COMMON(Const) POPULATE_PROFILE_INFO_COMMON(ArgMax) POPULATE_PROFILE_INFO_COMMON(Sub) @@ -335,16 +328,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { //===----------------------------------------------------------------------===// template -FailureOr> -TosaProfileCompliance::getOperatorDefinition(Operation *op, - CheckCondition &condition) { +FailureOr> +TosaProfileCompliance::getOperatorDefinition(Operation *op) { const std::string opName = op->getName().getStringRef().str(); const auto complianceMap = getProfileComplianceMap(); const auto it = complianceMap.find(opName); if (it == complianceMap.end()) return {}; - return findMatchedProfile(op, it->second, condition); + return findMatchedEntry(op, it->second); } template @@ -356,22 +348,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( if (specRequiredModeSet.size() == 0) return success(); - CheckCondition condition = CheckCondition::invalid; - const auto maybeOpRequiredMode = getOperatorDefinition(op, condition); - if (failed(maybeOpRequiredMode)) { + const auto maybeOpDefinition = getOperatorDefinition(op); + if (failed(maybeOpDefinition)) { // Operators such as control-flow and shape ops do not have an operand type // restriction. When the profile compliance information of operation is not // found, confirm if the target have enabled the profile required from the // specification. - int mode_count = 0; + int modeCount = 0; for (const auto &cands : specRequiredModeSet) { if (targetEnv.allowsAnyOf(cands)) return success(); - mode_count += cands.size(); + modeCount += cands.size(); } op->emitOpError() << "illegal: requires" - << (mode_count > 1 ? " any of " : " ") << "[" + << (modeCount > 1 ? " any of " : " ") << "[" << llvm::join(stringifyProfile(specRequiredModeSet), ", ") << "] but not enabled in target\n"; @@ -381,7 +372,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( // Find the required profiles or extensions according to the operand type // combination. - const auto opRequiredMode = maybeOpRequiredMode.value(); + const auto opDefinition = maybeOpDefinition.value(); + const SmallVector opRequiredMode = opDefinition.mode; + const CheckCondition condition = opDefinition.condition; + if (opRequiredMode.size() == 0) { // No matched restriction found. return success(); @@ -437,6 +431,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( } } + // Ensure the matched op compliance version does not exceed the target + // specification version. + const VersionedTypeInfo versionedTypeInfo = + opDefinition.operandTypeInfoSet[0]; + const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second}; + const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()}; + if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) { + op->emitOpError() << "illegal: the target specification version (" + << stringifyVersion(targetVersion) + << ") is not backwards compatible with the op compliance " + "specification version (" + << stringifyVersion(complianceVersion) << ")\n"; + return failure(); + } + return success(); } @@ -461,14 +470,14 @@ TosaProfileCompliance::checkExtension(Operation *op, } LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { - CheckCondition condition = CheckCondition::invalid; - const auto maybeProfDef = getOperatorDefinition(op, condition); - const auto maybeExtDef = getOperatorDefinition(op, condition); + const auto maybeProfDef = getOperatorDefinition(op); + const auto maybeExtDef = getOperatorDefinition(op); if (failed(maybeProfDef) && failed(maybeExtDef)) return success(); - const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) || - (succeeded(maybeExtDef) && !maybeExtDef->empty()); + const bool hasEntry = + (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) || + (succeeded(maybeExtDef) && !maybeExtDef->mode.empty()); if (!hasEntry) { std::string message; llvm::raw_string_ostream os(message); @@ -488,7 +497,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { SmallVector bestTypeInfo; const auto searchBestMatch = [&](auto map) { for (const auto &complianceInfos : map[opName]) { - for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) { + for (const auto &versionedTypeInfos : + complianceInfos.operandTypeInfoSet) { + const SmallVector typeInfos = versionedTypeInfos.first; const int matches = llvm::count_if( llvm::zip_equal(current, typeInfos), [&](const auto zipType) { return isSameTypeInfo(std::get<0>(zipType), @@ -520,9 +531,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { // Find the profiles or extensions requirement according to the signature of // type of the operand list. template -SmallVector TosaProfileCompliance::findMatchedProfile( - Operation *op, SmallVector> compInfo, - CheckCondition &condition) { +OpComplianceInfo TosaProfileCompliance::findMatchedEntry( + Operation *op, SmallVector> compInfo) { assert(compInfo.size() != 0 && "profile-based compliance information is empty"); @@ -533,27 +543,30 @@ SmallVector TosaProfileCompliance::findMatchedProfile( return {}; for (size_t i = 0; i < compInfo.size(); i++) { - SmallVector> sets = compInfo[i].operandTypeInfoSet; - for (SmallVector expected : sets) { + SmallVector sets = compInfo[i].operandTypeInfoSet; + for (const auto &set : sets) { + SmallVector expected = set.first; assert(present.size() == expected.size() && "the entries for profile-based compliance do not match between " "the generated metadata and the type definition retrieved from " " the operation"); - bool is_found = true; + bool isFound = true; // Compare the type signature between the given operation and the // compliance metadata. for (size_t j = 0; j < expected.size(); j++) { if (!isSameTypeInfo(present[j], expected[j])) { // Verify the next mode set from the list. - is_found = false; + isFound = false; break; } } - if (is_found == true) { - condition = compInfo[i].condition; - return compInfo[i].mode; + if (isFound == true) { + SmallVector typeInfoSet{set}; + OpComplianceInfo info{compInfo[i].mode, typeInfoSet, + compInfo[i].condition}; + return info; } } } @@ -603,6 +616,14 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) { return {"fp8e4m3"}; } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) { return {"fp8e5m2"}; + } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) { + return {"fp6e2m3"}; + } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) { + return {"fp6e3m2"}; + } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) { + return {"fp4e2m1"}; + } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) { + return {"fp8e8m0"}; } llvm_unreachable("unknown type"); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 82f2f7eb17af4..a142926bf87e2 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { CHECK_RANKS_AND_SIZES(Transpose); // Type Conversion CHECK_RANKS_AND_SIZES(Cast); + CHECK_RANKS_AND_SIZES(CastFromBlockScaled); + CHECK_RANKS_AND_SIZES(CastToBlockScaled); CHECK_RANKS_AND_SIZES(Rescale); // Control Flow Operators CHECK_RANKS_AND_SIZES(If); @@ -657,6 +659,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { CHECK_SIZES(TransposeConv2D); CHECK_SIZES(FFT2d); CHECK_SIZES(MatMul); + CHECK_SIZES(MatmulTBlockScaled); CHECK_SIZES(MaxPool2d); CHECK_SIZES(RFFT2d); // Scatter/Gather Operators @@ -1192,9 +1195,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { if (isa(type)) { return isa(type); - } - if (auto intTy = dyn_cast(type)) { + Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType, + Float6E3M2FNType, Float8E8M0FNUType>(type); + } else if (auto intTy = dyn_cast(type)) { if (intTy.isSignless()) { switch (intTy.getWidth()) { case 1: @@ -1220,13 +1223,19 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { } void TosaValidation::runOnOperation() { + ModuleOp modOp = getOperation(); + const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp); + const auto maybeTargetEnv = + tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc()); + if (failed(maybeTargetEnv)) + return signalPassFailure(); + targetEnv = *maybeTargetEnv; + TosaDialect *tosaDialect = getContext().getLoadedDialect(); if (!tosaDialect) return; - targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation())); - - getOperation().walk([&](Operation *op) { + modOp.walk([&](Operation *op) { if (op->getDialect() != tosaDialect) return; diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 600c4c717922a..d92d433a7d185 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -696,3 +696,21 @@ func.func @test_const_shape() -> !tosa.shape<4> { return %cst : !tosa.shape<4> } +// ----- +// CHECK-LABEL: test_cast_from_block_scaled +func.func @test_cast_from_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16, mxfp] ] + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- +// CHECK-LABEL: test_cast_to_block_scaled +func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16, mxfp] ] + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = BLOCK_SIZE_32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU> +} + diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index e5c9402caaddc..fff31c294a3f7 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -538,3 +538,27 @@ func.func @test_avg_pool2d_non_const_output_zp(%arg0: tensor<1x32x32x8xf32>, %ou (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32> return %0 : tensor<1x32x32x8xf32> } + +// ----- + +func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op illegal: requires [mxfp] but not enabled in target}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- + +func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op illegal: requires [mxfp] but not enabled in target}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- + +func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op illegal: requires [mxfp] but not enabled in target}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU> +} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 8cc357efa0c77..cd392fcc20ea1 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1622,3 +1622,43 @@ func.func @test_unranked_weight_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<*xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } + +// ----- + +func.func @test_matmul_t_block_scaled_invalid_size(%arg0: tensor<4x8x536870912xf4E2M1FN>, %arg1: tensor<4x8x16777216xf8E8M0FNU>, %arg2: tensor<4x16x536870912xf4E2M1FN>, %arg3: tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x536870912xf4E2M1FN>, tensor<4x8x16777216xf8E8M0FNU>, tensor<4x16x536870912xf4E2M1FN>, tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_cast_from_block_scaled_invalid_size(%arg0: tensor<536870912x32xf6E2M3FN>, %arg1: tensor<536870912x1xf8E8M0FNU>) -> tensor<536870912x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>) -> tensor<536870912x32xf32> + return %0 : tensor<536870912x32xf32> +} + +// ----- + +func.func @test_cast_from_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, %arg1: tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) -> tensor<1x2x3x4x5x6x7x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size} : (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) -> tensor<1x2x3x4x5x6x7x32xf32> + return %0 : tensor<1x2x3x4x5x6x7x32xf32> +} + +// ----- + +func.func @test_cast_to_block_scaled_invalid_size(%arg0: tensor<536870912x32xf32>) -> (tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<536870912x32xf32>) -> (tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<536870912x32xf6E2M3FN>, tensor<536870912x1xf8E8M0FNU> +} + +// ----- + +func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU> +} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 868b7b7a93335..865f712ce1a5a 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -1226,3 +1226,73 @@ func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor< %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> return %0 : tensor<13x29x3xf8E4M3FN> } + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_static +func.func @test_matmul_t_block_scaled_static(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_unranked +func.func @test_matmul_t_block_scaled_unranked(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<*xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e3m2 +func.func @test_matmul_t_block_scaled_fp6e3m2(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E3M2FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3 +func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E2M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp4e2m1 +func.func @test_matmul_t_block_scaled_fp4e2m1(%arg0: tensor<4x8x32xf4E2M1FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf4E2M1FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf4E2M1FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf4E2M1FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_broadcast +func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor, %arg3: tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor, tensor<4x8x1xf8E8M0FNU>, tensor, tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- +// CHECK-LABEL: test_cast_from_block_scaled_static +func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- +// CHECK-LABEL: test_cast_from_block_scaled_unranked +func.func @test_cast_from_block_scaled_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> tensor<*xf32> { + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- +// CHECK-LABEL: test_cast_to_block_scaled_static +func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) { + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU> +} + +// ----- +// CHECK-LABEL: test_cast_to_block_scaled_unranked +func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) { + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) + return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU> +} diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index 7ff8065ee41fd..7de7b85bcaedf 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_f16() -> tensor<3x11x11x3xf16> { @@ -325,3 +325,24 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> { %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32> return %1 : tensor<1x64x64x8xf32> } + +// ----- +func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E3M2FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- +func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- +func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op illegal: requires [pro_fp] but not enabled in target}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU> +} diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir index d6c886c44b013..a0c59c0c4bb3b 100644 --- a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir +++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir @@ -1,12 +1,14 @@ // RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL // RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K // RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="specification_version=1.1.draft" | FileCheck %s --check-prefix=CHECK-VERSION-1P1 // ----- -// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env} -// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env} -// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env} +// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env} +// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env} +// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env} +// CHECK-VERSION-1P1: module attributes {tosa.target_env = #tosa.target_env} // CHECK-LABEL: test_simple func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> { %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 80f06f11fe4ad..54556a0eb08e0 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1574,3 +1574,102 @@ func.func @test_mul_scalar(%arg0: tensor, %arg1: tensor) -> tensor<*xf %0 = tosa.mul %arg0, %arg1, %shift : (tensor, tensor, tensor<1xi8>) -> tensor<*xf32> return %0 : tensor<*xf32> } + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_static +func.func @test_matmul_t_block_scaled_static(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<1x16x32xf8E4M3FN>, %arg3: tensor<1x16x1xf8E8M0FNU>) -> tensor { + // CHECK: -> tensor<4x8x16xf32> + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<1x16x32xf8E4M3FN>, tensor<1x16x1xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_unranked_a_data +func.func @test_matmul_t_block_scaled_unranked_a_data(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor { + // CHECK: -> tensor<4x8x16xf32> + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<*xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_unranked_b_data_and_scale +func.func @test_matmul_t_block_scaled_unranked_b_data_and_scale(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor { + // CHECK: -> tensor<4x8x?xf32> + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_unranked_all +func.func @test_matmul_t_block_scaled_unranked_all(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor { + // CHECK: -> tensor + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_broadcast_b_data +func.func @test_matmul_t_block_scaled_broadcast_b_data(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<1x4x32xf8E4M3FN>, %arg3: tensor<1x4x1xf8E8M0FNU>) -> tensor { + // CHECK: -> tensor + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<1x4x32xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_broadcast_b_scale +func.func @test_matmul_t_block_scaled_broadcast_b_scale(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<1x4x1xf8E8M0FNU>) -> tensor { + // CHECK: -> tensor + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_cast_from_block_scaled_static +func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<*xf32> { + // CHECK: -> tensor<4x32xf32> + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: test_cast_from_block_scaled_unranked_input_scale +func.func @test_cast_from_block_scaled_unranked_input_scale(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> tensor<*xf32> { + // CHECK: -> tensor<4x32xf32> + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: test_cast_to_block_scaled_static +func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) { + // CHECK: -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) + return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU> +} + +// ----- + +// CHECK-LABEL: test_cast_to_block_scaled_unranked +func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) { + // CHECK: -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) + return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU> +} + +// ----- + +// CHECK-LABEL: test_cast_to_block_scaled_dynamic_scales +func.func @test_cast_to_block_scaled_dynamic_scales(%arg0: tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) { + // CHECK: -> (tensor<4x?xf4E2M1FN>, tensor<4x?xf8E8M0FNU>) + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) + return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir new file mode 100644 index 0000000000000..51089df238b84 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.0 profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" + +// ----- + +func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2> + // expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16> + return %0 : tensor<1x14x28xf16> +} + +// ----- + +func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + // expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir new file mode 100644 index 0000000000000..8b6cdc07925f0 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp" -tosa-validate="strict-op-spec-alignment" | FileCheck %s + +// ----- + +func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2> + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16> + return %0 : tensor<1x14x28xf16> +} + +// ----- + +// CHECK-LABEL: test_matmul_fp8_input_fp32_acc_type +func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3 +func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E2M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = BLOCK_SIZE_32} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- + +// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_fp32 +func.func @test_cast_from_block_scaled_fp8e5m2_fp32(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- + +// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_bf16 +func.func @test_cast_from_block_scaled_fp8e5m2_bf16(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xbf16> { + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xbf16> + return %0 : tensor<4x32xbf16> +} + +// ----- + +// CHECK-LABEL: test_cast_to_block_scaled_static +func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) { + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU> +} diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 430b06ad16c39..6cf76cdc7ad8e 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -1033,7 +1033,6 @@ module { // ----- -// CHECK-LABEL: @scatter_invalid_indices_N func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}} %1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32> @@ -1042,7 +1041,6 @@ func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3 // ----- -// CHECK-LABEL: @scatter_invalid_input_N func.func @scatter_invalid_input_N(%arg0 : tensor, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32> @@ -1051,7 +1049,6 @@ func.func @scatter_invalid_input_N(%arg0 : tensor, %arg1 : tensor<2x2 // ----- -// CHECK-LABEL: @scatter_invalid_out_N func.func @scatter_invalid_out_N(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor<2x2x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x2x5xi32>) -> tensor<3x4x5xi32> @@ -1060,7 +1057,6 @@ func.func @scatter_invalid_out_N(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor, %arg2 : tensor<2x2x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x2x5xi32>) -> tensor<2x3x5xi32> @@ -1069,7 +1065,6 @@ func.func @scatter_invalid_out_K(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor, %arg2 : tensor<2x3x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x3x5xi32>) -> tensor<2x4x5xi32> @@ -1078,7 +1073,6 @@ func.func @scatter_invalid_input_W(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor, %arg2 : tensor<2x2x6xi32>) { // expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x2x6xi32>) -> tensor<2x4x5xi32> @@ -1087,7 +1081,6 @@ func.func @scatter_invalid_input_C(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor, %arg2 : tensor<2x2x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x2x5xi32>) -> tensor<2x4x6xi32> @@ -1096,9 +1089,136 @@ func.func @scatter_invalid_out_C(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32> return } + +// ----- + +func.func @test_matmul_t_block_scaled_data_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E5M2>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expect A_data and B_data to have same element type, got 'f8E4M3FN' and 'f8E5M2'}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E5M2>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- + +func.func @test_matmul_t_block_scaled_output_batch_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<4x?x?xf8E8M0FNU>) -> tensor<5x?x?xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expected output shape 5, ?, ? to be compatible with expected output shape 4, 8, ?}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<*xf8E4M3FN>, tensor, tensor<*xf8E4M3FN>, tensor<4x?x?xf8E8M0FNU>) -> tensor<5x?x?xf32> + return %0 : tensor<5x?x?xf32> +} + +// ----- + +func.func @test_matmul_t_block_scaled_output_height_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<4x?x?xf8E8M0FNU>) -> tensor<4x8x?xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expected output shape 4, 8, ? to be compatible with expected output shape 4, 9, ?}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<*xf8E4M3FN>, tensor, tensor<*xf8E4M3FN>, tensor<4x?x?xf8E8M0FNU>) -> tensor<4x8x?xf32> + return %0 : tensor<4x8x?xf32> +} + +// ----- + +func.func @test_matmul_t_block_scaled_output_width_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor, %arg2: tensor, %arg3: tensor<*xf8E8M0FNU>) -> tensor { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expected output shape ?, ?, 10 to be compatible with expected output shape ?, ?, 1}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<*xf8E4M3FN>, tensor, tensor, tensor<*xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +func.func @test_matmul_t_block_scaled_channel_not_multiple_of_block_size(%arg0: tensor<4x8x55xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expected channels of b_data to match size 55, got 32}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<4x8x55xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- + +func.func @test_matmul_t_block_scaled_batch_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<2x16x32xf8E4M3FN>, %arg3: tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expect B matrix batch size to be broadcast compatible with A, got D=2 vs N=4}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<2x16x32xf8E4M3FN>, tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- + +func.func @cast_from_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and output_data ('tensor<5x32xf32>')}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32> + return %0 : tensor<5x32xf32> +} + +// ----- + +func.func @cast_from_block_scaled_not_scalar(%arg0: tensor, %arg1: tensor) -> tensor { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func.func @cast_from_block_scaled_not_divisible_by_block_size(%arg0: tensor<4x33xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x33xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op expect last dimension of input_data (33) to be divisible by block_size (32)}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x33xf32> + return %0 : tensor<4x33xf32> +} + +// ----- + +func.func @cast_from_block_scaled_data_scale_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<5x1xf8E8M0FNU>) -> tensor<4x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and input_scale ('tensor<5x1xf8E8M0FNU>') except for the last dimension}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- + +func.func @cast_from_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x2xf8E8M0FNU>) -> tensor<4x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op expect last dimension of input_scale (2) to be equal to last dimension of input_data / block_size (1)}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size} : (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- + +func.func @test_cast_to_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf32>') and output_data ('tensor<5x32xf4E2M1FN>')}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU> +} + +// ----- + +func.func @test_cast_to_block_scaled_not_scalar(%arg0: tensor) -> (tensor, tensor) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor) -> (tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} + +// ----- + +func.func @test_cast_to_block_scaled_not_divisible_by_block_size(%arg0: tensor<4x33xf32>) -> (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op expect last dimension of input_data (33) to be divisible by block_size (32)}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x33xf32>) -> (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU> +} + +// ----- + +func.func @test_cast_to_block_scaled_data_scale_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op require compatible shapes for output_data ('tensor<4x32xf4E2M1FN>') and output_scale ('tensor<5x1xf8E8M0FNU>') except for the last dimension}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU> +} + +// ----- + +func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op expect last dimension of output_scale (2) to be equal to last dimension of output_data / block_size (1)}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU> +}