diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h index 4ecf03c34c1a5..e088eb31338dc 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h @@ -54,6 +54,8 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op); /// and provide utilities around the TOSA specification version. class TosaSpecificationVersion { public: + TosaSpecificationVersion() = default; + TosaSpecificationVersion(uint32_t major, uint32_t minor) : majorVersion(major), minorVersion(minor) {} TosaSpecificationVersion(SpecificationVersion version) @@ -83,6 +85,10 @@ class TosaSpecificationVersion { } }; +TosaSpecificationVersion getMinVersion(const Profile &profile); +TosaSpecificationVersion getMinVersion(const Extension &extension); +TosaSpecificationVersion getMinVersion(const Level &level); + llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version); /// This class represents the capability enabled in the target implementation @@ -91,22 +97,19 @@ llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version); class TargetEnv { public: TargetEnv() {} - explicit TargetEnv(SpecificationVersion specificationVersion, Level level, - const ArrayRef &profiles, - const ArrayRef &extensions) - : specificationVersion(specificationVersion), level(level) { - enabledProfiles.insert_range(profiles); - enabledExtensions.insert_range(extensions); - } - explicit TargetEnv(TargetEnvAttr targetAttr) - : TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(), - targetAttr.getProfiles(), targetAttr.getExtensions()) {} + static FailureOr + createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc); + + static LogicalResult verifyTargetInformation(TargetEnvAttr targetAttr, + Location targetAttrLoc); void addProfile(Profile p) { enabledProfiles.insert(p); } void addExtension(Extension e) { enabledExtensions.insert(e); } - SpecificationVersion getSpecVersion() const { return specificationVersion; } + TosaSpecificationVersion getSpecVersion() const { + return specificationVersion; + } TosaLevel getLevel() const { if (level == Level::eightK) @@ -140,7 +143,17 @@ class TargetEnv { } private: - SpecificationVersion specificationVersion; + // Require target information is verified before constructing, via the use of + // `createTargetEnvFromAttr`. + explicit TargetEnv(SpecificationVersion specificationVersion, Level level, + const ArrayRef &profiles, + const ArrayRef &extensions) + : specificationVersion(specificationVersion), level(level) { + enabledProfiles.insert_range(profiles); + enabledExtensions.insert_range(extensions); + } + + TosaSpecificationVersion specificationVersion; Level level; llvm::SmallSet enabledProfiles; llvm::SmallSet enabledExtensions; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index c1b5e785bd739..294fb9d99fdb6 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -554,6 +554,18 @@ extensionComplianceMap = { allOf}, {{Extension::bf16}, {{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.matmul_t_block_scaled", + {{{Extension::mxfp}, + {{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}}}}, {"tosa.max_pool2d", {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 8376a4c87dbf2..48e0073c76ab6 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -270,13 +270,14 @@ def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>; def Tosa_EXT_DOUBLEROUND : I32EnumAttrCase<"doubleround", 9>; def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>; def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>; +def Tosa_EXT_MXFP : I32EnumAttrCase<"mxfp", 12>; def Tosa_ExtensionAttr : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [ Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, - Tosa_EXT_DYNAMIC + Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP ]> { let extraClassDeclaration = [{ static llvm::SmallVector getAllValues() { @@ -284,7 +285,7 @@ def Tosa_ExtensionAttr Extension::int16, Extension::int4, Extension::bf16, Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft, Extension::variable, Extension::controlflow, Extension::doubleround, - Extension::inexactround, Extension::dynamic + Extension::inexactround, Extension::dynamic, Extension::mxfp }; } }]; @@ -437,7 +438,7 @@ def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> { } //===----------------------------------------------------------------------===// -// Iterable attributes. +// Enum attributes. //===----------------------------------------------------------------------===// // Defined in `section 3. Enumerations` of the TOSA specification. @@ -463,6 +464,18 @@ def Tosa_RoundingModeAttr : Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode", [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>; +def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 32>; + +def Tosa_BlockSizeAttr + : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size", + [Tosa_BLOCK_SIZE_32]> { + let extraClassDeclaration = [{ + static uint32_t getBlockSizeValue(BlockSize blockSize) { + return static_cast(blockSize); + } + }]; +} + //===----------------------------------------------------------------------===// // TOSA Interfaces. diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 48759f2a3c9e8..a5251fcada4c9 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -348,6 +348,40 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> { "operands attr-dict `:` functional-type(operands, results)"; } +//===----------------------------------------------------------------------===// +// Operator: matmul_t_block_scaled +//===----------------------------------------------------------------------===// +def Tosa_MatmulTBlockScaledOp : Tosa_InferShapedTypeOp<"matmul_t_block_scaled"> { + let summary = "Performs two dimensional matrix multiplications using block scaled tensors."; + + let description = [{ + Performs two dimensional matrix multiplications using block scaled tensors. The block + dimension is always the the last dimension of the tensor, so the result is effectively + a matrix multiply of A by the transposed B matrix. If the N dimension of input B is of + size 1, the B matrix will be broadcast. + }]; + + let arguments = (ins + Tosa_MXFPDataTensor3D:$a_data, + Tosa_MXFPScaleTensor3D:$a_scale, + Tosa_MXFPDataTensor3D:$b_data, + Tosa_MXFPScaleTensor3D:$b_scale, + Tosa_BlockSizeAttr:$block_size + ); + + let results = (outs + Tosa_Tensor3D:$output_data + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_MXFP]> + ]; +} + //===----------------------------------------------------------------------===// // Operator: max_pool2d //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h index 7b946ad6c6a89..79df1b888b40e 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -147,6 +147,7 @@ class TosaProfileCompliance { case Extension::fp8e4m3: case Extension::fp8e5m2: case Extension::fft: + case Extension::mxfp: return {Profile::pro_fp}; case Extension::variable: case Extension::controlflow: diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 93ab120339d55..20bb961482ad8 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -84,6 +84,10 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>, def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat], "number">; +def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN], + "micro-scaling format number">; +def Tosa_MXFPScaleNumber : AnyTypeOf<[F8E8M0FNU], "micro-scaling format scale number">; + //===----------------------------------------------------------------------===// // TOSA Tensor Conformance //===----------------------------------------------------------------------===// @@ -187,6 +191,15 @@ def Tosa_Int32Tensor2D : AnyTypeOf<[ def Tosa_TensorAtLeast1D : AnyTypeOf<[ Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">; +def Tosa_MXFPDataTensor3D : AnyTypeOf<[ + TosaUnrankedTensorOf<[Tosa_MXFPNumber]>, + TosaTensorRankOf<[Tosa_MXFPNumber], [3]> +]>; +def Tosa_MXFPScaleTensor3D : AnyTypeOf<[ + TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>, + TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]> +]>; + //===----------------------------------------------------------------------===// // Generic scalar, vector, or tensor of a particular type. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index 1cba1bb540c02..32eb286531d28 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -12,6 +12,96 @@ namespace mlir { namespace tosa { +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { + return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); +} + +TosaSpecificationVersion getMinVersion(const Profile &profile) { + switch (profile) { + case Profile::pro_int: + case Profile::pro_fp: + return TosaSpecificationVersion(1, 0); + case Profile::none: + return TosaSpecificationVersion(0, 0); + } + llvm_unreachable("Unknown TOSA profile"); +} + +TosaSpecificationVersion getMinVersion(const Extension &extension) { + switch (extension) { + case Extension::int16: + case Extension::int4: + case Extension::bf16: + case Extension::fp8e4m3: + case Extension::fp8e5m2: + case Extension::fft: + case Extension::variable: + case Extension::controlflow: + case Extension::doubleround: + case Extension::inexactround: + case Extension::dynamic: + return TosaSpecificationVersion(1, 0); + case Extension::mxfp: + return TosaSpecificationVersion(1, 1); + case Extension::none: + return TosaSpecificationVersion(0, 0); + } + llvm_unreachable("Unknown TOSA extension"); +} + +TosaSpecificationVersion getMinVersion(const Level &level) { + switch (level) { + case Level::eightK: + case Level::none: + return TosaSpecificationVersion(1, 0); + } + llvm_unreachable("Unknown TOSA level"); +} + +FailureOr +TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr, + Location targetEnvAttrLoc) { + if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc))) + return failure(); + + return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(), + targetAttr.getProfiles(), targetAttr.getExtensions()); +} + +LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr, + Location targetAttrLoc) { + TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion()); + + const auto isCompatibleWithTargetVersion = + [&](const auto &targetEnum, Location targetAttrLoc, + StringRef enumName) -> LogicalResult { + const TosaSpecificationVersion minRequiredVersion = + getMinVersion(targetEnum); + if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion)) + return emitError(targetAttrLoc, enumName) + << " '" << stringifyEnum(targetEnum) + << "' is not compatible with the target version " + << stringifyVersion(targetVersion) + << ", minimum required version is " + << stringifyVersion(minRequiredVersion); + return success(); + }; + + for (const auto &profile : targetAttr.getProfiles()) + if (failed( + isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile"))) + return failure(); + for (const auto &extension : targetAttr.getExtensions()) + if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc, + "extension"))) + return failure(); + if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc, + "level"))) + return failure(); + + return success(); +} + TargetEnvAttr lookupTargetEnv(Operation *op) { while (op) { op = SymbolTable::getNearestSymbolTable(op); @@ -39,9 +129,5 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { return getDefaultTargetEnv(op->getContext()); } -llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { - return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); -} - } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 00f84bc43f444..6cd0eaea3ce6c 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser, } } + // special handling: block_size accepts a *bare* BlockSizeMode enum + if constexpr (std::is_same_v) { + if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeBlockSize(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid block_size value: " << kw; + auto attr = BlockSizeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + // Default path: parse any normal attribute literal, including fully qualified // enum keyword Attribute attr; @@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) { } else if (auto nanPropagationModeAttr = dyn_cast(attr)) { parser << nanPropagationModeAttr.getValue(); + } else if (auto blockSizeAttr = dyn_cast(attr)) { + parser << blockSizeAttr.getValue(); } else { parser.printAttribute(attr); } @@ -508,6 +523,15 @@ void ReduceMinOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } +ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling(parser, result); +} + +void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// @@ -933,32 +957,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { // verify that inType and outType have same element types template -static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { - auto inputType = llvm::dyn_cast(inType); - auto outputType = llvm::dyn_cast(outType); - if (!inputType) { - op.emitOpError("expect shaped tensor for input, got ") << inType; +static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, + StringRef aName = "input", + StringRef bName = "output") { + auto aTType = llvm::dyn_cast(aType); + auto bTType = llvm::dyn_cast(bType); + if (!aTType) { + op.emitOpError("expect shaped tensor for") << aName << ", got " << aType; return failure(); } - if (!outputType) { - op.emitOpError("expect shaped tensor for output, got ") << outType; + if (!bTType) { + op.emitOpError("expect shaped tensor for") << bName << ", got" << bType; return failure(); } - auto inputElementType = inputType.getElementType(); - auto outputElementType = outputType.getElementType(); - auto inputQuantType = - llvm::dyn_cast(inputElementType); - auto outputQuantType = - llvm::dyn_cast(outputElementType); - if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) && - (outputElementType.isIntOrIndexOrFloat() || outputQuantType) && - inputElementType != outputElementType) { + auto aElementType = aTType.getElementType(); + auto bElementType = bTType.getElementType(); + auto aQuantType = + llvm::dyn_cast(aElementType); + auto bQuantType = + llvm::dyn_cast(bElementType); + if ((aElementType.isIntOrIndexOrFloat() || aQuantType) && + (bElementType.isIntOrIndexOrFloat() || bQuantType) && + aElementType != bElementType) { // only check if both element types are int/index/float/UniformQuantized // eg, not sure how to check quant::QuantizedType // this happens in test_conv2d_q_grouped_convolution in // tfl-to-tosa-pipeline.mlir - op.emitOpError("expect input and output to have same element type, got ") - << inputElementType << " and " << outputElementType; + op.emitOpError("expect ") + << aName << " and " << bName << " to have same element type, got " + << aElementType << " and " << bElementType; return failure(); } return success(); @@ -1846,6 +1873,161 @@ LogicalResult MatMulOp::verify() { return success(); } +LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + MatmulTBlockScaledOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + SmallVector outShape(3, ShapedType::kDynamic); + + const auto aDataShape = cast(adaptor.getAData().getType()); + if (aDataShape.hasRank()) { + outShape[0] = aDataShape.getDimSize(0); + outShape[1] = aDataShape.getDimSize(1); + } + + const auto aScaleShape = cast(adaptor.getAScale().getType()); + if (aScaleShape.hasRank()) { + outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0) + : outShape[0]; + outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1) + : outShape[1]; + } + + // If B batch size is 1, it is broadcast across A's batch size + const auto bDataShape = cast(adaptor.getBData().getType()); + if (bDataShape.hasRank()) { + const int64_t bDataBatchSize = bDataShape.getDimSize(0); + if (bDataBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0]; + outShape[2] = bDataShape.getDimSize(1); + } + + const auto bScaleShape = cast(adaptor.getBScale().getType()); + if (bScaleShape.hasRank()) { + const int64_t bScaleBatchSize = bScaleShape.getDimSize(0); + if (bScaleBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0]; + outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1) + : outShape[2]; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); +} + +LogicalResult MatmulTBlockScaledOp::verify() { + // Verify same input data types + const Type aDataType = getAData().getType(); + const Type bDataType = getBData().getType(); + if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data", + "B_data"))) + return failure(); + + auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim, + const StringRef operandName, + const StringRef dimName) -> LogicalResult { + if (ShapedType::isDynamic(currDim)) { + currDim = newDim; + return success(); + } else if (ShapedType::isStatic(newDim) && currDim != newDim) { + return emitOpError("expected ") + << dimName << " of " << operandName << " to match size " << currDim + << ", got " << newDim; + } + return success(); + }; + + // Verify input shape compatibility + int64_t N = ShapedType::kDynamic; + int64_t D = ShapedType::kDynamic; + int64_t H = ShapedType::kDynamic; + int64_t W = ShapedType::kDynamic; + int64_t C = ShapedType::kDynamic; + int64_t multiplesOfC = ShapedType::kDynamic; + + const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType); + if (aDataShape.hasRank()) { + N = aDataShape.getDimSize(0); + H = aDataShape.getDimSize(1); + C = aDataShape.getDimSize(2); + } + + const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType()); + if (aScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale", + "batch")) || + failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale", + "height"))) + return failure(); + multiplesOfC = aScaleShape.getDimSize(2); + } + + const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType); + if (bDataShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data", + "batch")) || + failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data", + "channels"))) + return failure(); + W = bDataShape.getDimSize(1); + } + + const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType()); + if (bScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale", + "batch")) || + failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale", + "width")) || + failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2), + "b_scale", "C/block_size"))) + return failure(); + } + + // Verify batch size is broadcast compatible + if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1) + return emitOpError("expect B matrix batch size to be broadcast compatible " + "with A, got D=") + << D << " vs N=" << N; + + // Verify C is a multiple of block size + const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize()); + if (ShapedType::isStatic(C) && C % blockSize != 0) + return emitOpError("expect C to be a multiple of block size, got C=") + << C << ", block_size=" << blockSize; + + // Verify multiplesOfC is C / block size + if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) && + multiplesOfC != C / blockSize) + return emitOpError( + "expect scale operands dimension 2 to equal C/block_size (") + << C << "/" << blockSize << ")" + << ", got " << multiplesOfC; + + // Verify output shape + N = ShapedType::isDynamic(N) ? D : N; + const SmallVector expectedOutputShape = {N, H, W}; + const auto outputType = cast(getResult().getType()); + if (outputType.hasRank() && + failed( + verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) { + InFlightDiagnostic opError = emitOpError("expected output shape "); + auto stringifyDim = [&](int64_t d) { + if (ShapedType::isDynamic(d)) + opError << "?"; + else + opError << d; + }; + llvm::interleaveComma(outputType.getShape(), opError, stringifyDim); + opError << " to be compatible with expected output shape "; + llvm::interleaveComma(expectedOutputShape, opError, stringifyDim); + return opError; + } + + return success(); +} + LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, PadOp::Adaptor adaptor, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index f072e3eff1975..e965ae0cf9888 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -25,6 +25,12 @@ TosaProfileCompliance::TosaProfileCompliance() { const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8}; const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8}; + // micro-scaling formats + const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6}; + const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6}; + const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4}; + const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8}; + // The profile-based compliance content below is auto-generated by a script // in https://git.mlplatform.org/tosa/specification.git #include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc" @@ -269,6 +275,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // For the most of tosa operators, all operands are profile/extension related // and hence are all considered in this profile-based compilance check. + POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled) POPULATE_PROFILE_INFO_COMMON(Cast) POPULATE_PROFILE_INFO_COMMON(Const) POPULATE_PROFILE_INFO_COMMON(ArgMax) @@ -623,6 +630,14 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) { return {"fp8e4m3"}; } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) { return {"fp8e5m2"}; + } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) { + return {"fp6e2m3"}; + } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) { + return {"fp6e3m2"}; + } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) { + return {"fp4e2m1"}; + } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) { + return {"fp8e8m0"}; } llvm_unreachable("unknown type"); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 82f2f7eb17af4..3f874d94ab9be 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -657,6 +657,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { CHECK_SIZES(TransposeConv2D); CHECK_SIZES(FFT2d); CHECK_SIZES(MatMul); + CHECK_SIZES(MatmulTBlockScaled); CHECK_SIZES(MaxPool2d); CHECK_SIZES(RFFT2d); // Scatter/Gather Operators @@ -1192,9 +1193,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { if (isa(type)) { return isa(type); - } - if (auto intTy = dyn_cast(type)) { + Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType, + Float6E3M2FNType, Float8E8M0FNUType>(type); + } else if (auto intTy = dyn_cast(type)) { if (intTy.isSignless()) { switch (intTy.getWidth()) { case 1: @@ -1220,13 +1221,19 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { } void TosaValidation::runOnOperation() { + ModuleOp modOp = getOperation(); + const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp); + const auto maybeTargetEnv = + tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc()); + if (failed(maybeTargetEnv)) + return signalPassFailure(); + targetEnv = *maybeTargetEnv; + TosaDialect *tosaDialect = getContext().getLoadedDialect(); if (!tosaDialect) return; - targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation())); - - getOperation().walk([&](Operation *op) { + modOp.walk([&](Operation *op) { if (op->getDialect() != tosaDialect) return; diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index e5c9402caaddc..005601d4017b8 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -538,3 +538,11 @@ func.func @test_avg_pool2d_non_const_output_zp(%arg0: tensor<1x32x32x8xf32>, %ou (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32> return %0 : tensor<1x32x32x8xf32> } + +// ----- + +func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op illegal: requires [mxfp] but not enabled in target}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 8cc357efa0c77..8771e6e2476e4 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1622,3 +1622,12 @@ func.func @test_unranked_weight_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<*xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_invalid_size +func.func @test_matmul_t_block_scaled_invalid_size(%arg0: tensor<4x8x536870912xf4E2M1FN>, %arg1: tensor<4x8x16777216xf8E8M0FNU>, %arg2: tensor<4x16x536870912xf4E2M1FN>, %arg3: tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x536870912xf4E2M1FN>, tensor<4x8x16777216xf8E8M0FNU>, tensor<4x16x536870912xf4E2M1FN>, tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 868b7b7a93335..9bf36b5fd4c7d 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -1226,3 +1226,45 @@ func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor< %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> return %0 : tensor<13x29x3xf8E4M3FN> } + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_static +func.func @test_matmul_t_block_scaled_static(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_unranked +func.func @test_matmul_t_block_scaled_unranked(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor<*xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e3m2 +func.func @test_matmul_t_block_scaled_fp6e3m2(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E3M2FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3 +func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E2M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp4e2m1 +func.func @test_matmul_t_block_scaled_fp4e2m1(%arg0: tensor<4x8x32xf4E2M1FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf4E2M1FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf4E2M1FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf4E2M1FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_broadcast +func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor, %arg3: tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor, tensor<4x8x1xf8E8M0FNU>, tensor, tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index 7ff8065ee41fd..0271d71561a52 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_f16() -> tensor<3x11x11x3xf16> { @@ -325,3 +325,10 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> { %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32> return %1 : tensor<1x64x64x8xf32> } + +// ----- +func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E3M2FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 80f06f11fe4ad..72479fe21ade8 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1574,3 +1574,57 @@ func.func @test_mul_scalar(%arg0: tensor, %arg1: tensor) -> tensor<*xf %0 = tosa.mul %arg0, %arg1, %shift : (tensor, tensor, tensor<1xi8>) -> tensor<*xf32> return %0 : tensor<*xf32> } + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_static +func.func @test_matmul_t_block_scaled_static(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<1x16x32xf8E4M3FN>, %arg3: tensor<1x16x1xf8E8M0FNU>) -> tensor { + // CHECK: -> tensor<4x8x16xf32> + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<1x16x32xf8E4M3FN>, tensor<1x16x1xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_unranked_a_data +func.func @test_matmul_t_block_scaled_unranked_a_data(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor { + // CHECK: -> tensor<4x8x16xf32> + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<*xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_unranked_b_data_and_scale +func.func @test_matmul_t_block_scaled_unranked_b_data_and_scale(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor { + // CHECK: -> tensor<4x8x?xf32> + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_unranked_all +func.func @test_matmul_t_block_scaled_unranked_all(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<*xf8E8M0FNU>) -> tensor { + // CHECK: -> tensor + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_broadcast_b_data +func.func @test_matmul_t_block_scaled_broadcast_b_data(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<1x4x32xf8E4M3FN>, %arg3: tensor<1x4x1xf8E8M0FNU>) -> tensor { + // CHECK: -> tensor + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<1x4x32xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_broadcast_b_scale +func.func @test_matmul_t_block_scaled_broadcast_b_scale(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<1x4x1xf8E8M0FNU>) -> tensor { + // CHECK: -> tensor + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir index 81645092bf195..2040a4bc7e6af 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp" -tosa-validate="strict-op-spec-alignment" | FileCheck %s // ----- @@ -18,3 +18,11 @@ func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32> return %0 : tensor<1x14x28xf32> } + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3 +func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf6E2M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = BLOCK_SIZE_32} : (tensor<4x8x32xf6E2M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E2M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 430b06ad16c39..4be5d725ad612 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -1102,3 +1102,51 @@ func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32 %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32> return } + +// ----- + +func.func @test_matmul_t_block_scaled_data_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E5M2>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expect A_data and B_data to have same element type, got 'f8E4M3FN' and 'f8E5M2'}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E5M2>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- + +func.func @test_matmul_t_block_scaled_output_batch_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<4x?x?xf8E8M0FNU>) -> tensor<5x?x?xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expected output shape 5, ?, ? to be compatible with expected output shape 4, 8, ?}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<*xf8E4M3FN>, tensor, tensor<*xf8E4M3FN>, tensor<4x?x?xf8E8M0FNU>) -> tensor<5x?x?xf32> + return %0 : tensor<5x?x?xf32> +} + +// ----- + +func.func @test_matmul_t_block_scaled_output_height_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor, %arg2: tensor<*xf8E4M3FN>, %arg3: tensor<4x?x?xf8E8M0FNU>) -> tensor<4x8x?xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expected output shape 4, 8, ? to be compatible with expected output shape 4, 9, ?}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<*xf8E4M3FN>, tensor, tensor<*xf8E4M3FN>, tensor<4x?x?xf8E8M0FNU>) -> tensor<4x8x?xf32> + return %0 : tensor<4x8x?xf32> +} + +// ----- + +func.func @test_matmul_t_block_scaled_output_width_mismatch(%arg0: tensor<*xf8E4M3FN>, %arg1: tensor, %arg2: tensor, %arg3: tensor<*xf8E8M0FNU>) -> tensor { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expected output shape ?, ?, 10 to be compatible with expected output shape ?, ?, 1}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<*xf8E4M3FN>, tensor, tensor, tensor<*xf8E8M0FNU>) -> tensor + return %0 : tensor +} + +// ----- + +func.func @test_matmul_t_block_scaled_channel_not_multiple_of_block_size(%arg0: tensor<4x8x55xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E4M3FN>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expected channels of b_data to match size 55, got 32}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<4x8x55xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E4M3FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- + +func.func @test_matmul_t_block_scaled_batch_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<2x16x32xf8E4M3FN>, %arg3: tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expect B matrix batch size to be broadcast compatible with A, got D=2 vs N=4}} + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<2x16x32xf8E4M3FN>, tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +}