From bb80efa34863bc699d2e1d5ba7922da98b27f0e5 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 15 Sep 2025 21:52:24 +0000 Subject: [PATCH] [mlir][tosa] Add support for cast_from/to_block_scaled This commit adds support for the cast_from/to_block_scaled operations from the ext-mxfp extension. This includes: - Operation definition in TosaOps.td - Micro-scaling supported types definition - Shape inference and verifiers - Validation pass checks to ensure usage is only valid when the target environment includes ext-mxfp and at least v1.1.draft of the specification. Note: currently it excludes support for mxint8. This will be added in a later commit. Note: this commit adds support as defined in the spec in https://review.mlplatform.org/c/tosa/specification/+/15362. EXT_MXFP extension is considered experimental and subject to breaking change. Co-authored-by: Tat Wai Chong Change-Id: I490645ce99b7ccd7021ed06acaf1530b4fbf6dfd --- .../Dialect/Tosa/IR/TosaComplianceData.h.inc | 28 +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 63 +++++++ .../Dialect/Tosa/IR/TosaProfileCompliance.h | 3 +- .../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 10 ++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 159 +++++++++++++++++- .../Tosa/Transforms/TosaProfileCompliance.cpp | 32 +--- .../Tosa/Transforms/TosaValidation.cpp | 2 + mlir/test/Dialect/Tosa/availability.mlir | 18 ++ mlir/test/Dialect/Tosa/invalid_extension.mlir | 17 +- mlir/test/Dialect/Tosa/level_check.mlir | 33 +++- mlir/test/Dialect/Tosa/ops.mlir | 28 +++ .../Tosa/profile_pro_fp_unsupported.mlir | 14 ++ mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 45 +++++ .../tosa-validation-version-1p1-valid.mlir | 24 +++ mlir/test/Dialect/Tosa/verifier.mlir | 88 +++++++++- 15 files changed, 529 insertions(+), 35 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index 6e78b75f37d10..8b5934ff0630e 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -864,6 +864,34 @@ extensionComplianceMap = { {{bf16T, fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT}, {{bf16T, fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}}}, + {"tosa.cast_from_block_scaled", + {{{Extension::bf16, Extension::mxfp}, + {{{fp4e2m1T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}, + allOf}, + {{Extension::mxfp}, + {{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}}, + {"tosa.cast_to_block_scaled", + {{{Extension::mxfp}, + {{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp32T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}}, + {{Extension::bf16, Extension::mxfp}, + {{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}, + allOf}}}, {"tosa.rescale", {{{Extension::int16}, {{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0}, diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 0e3df21f43804..6e1759119a621 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -2470,6 +2470,69 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape, let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// Operator: cast_from_block_scaled +//===----------------------------------------------------------------------===// +def Tosa_CastFromBlockScaledOp: Tosa_InferShapedTypeOp<"cast_from_block_scaled"> { + let summary = "Apply scales from a scale tensor to the values in a value tensor"; + + let description = [{ + Apply the scales from a scale tensor to the values in a value tensor, casting + the result to the output type. The block dimension must be the last dimension + of the tensor. + }]; + + let arguments = (ins + Tosa_MXFPDataTensorAtLeast1D:$input_data, + Tosa_MXFPScaleTensorAtLeast1D:$input_scale, + Tosa_BlockSizeAttr:$block_size + ); + + let results = (outs + Tosa_TensorAtLeast1D: $output_data + ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>, + ]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// Operator: cast_to_block_scaled +//===----------------------------------------------------------------------===// +def Tosa_CastToBlockScaledOp : Tosa_InferShapedTypeOp<"cast_to_block_scaled"> { + let summary = "Calculate scale tensor values per block, output to separate scale and data tensors."; + + let description = [{ + Calculate a scale value per block of input values and use that to calculate + scaled data values from an input tensor. The output tensors are cast to the + specified scale and value types. The block dimension will be the last dimension + of the tensor. + }]; + + let arguments = (ins + Tosa_TensorAtLeast1D:$input_data, + Tosa_BlockSizeAttr:$block_size + ); + + let results = (outs + Tosa_MXFPDataTensorAtLeast1D:$output_data, + Tosa_MXFPScaleTensorAtLeast1D:$output_scale + ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]> + ]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Operator: rescale //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h index b76228fac9f33..45d380c1b2e6c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -79,7 +79,8 @@ class ProfileInfoDepot { LogicalResult populatationDispatch(Operation *op); - LogicalResult populateProfileInfo(ValueRange operands, Value output); + // Add input operands and output results to the profile type info list + LogicalResult populateProfileInfo(ValueRange operands, ValueRange results); // Base template diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 20bb961482ad8..93843e86fd378 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -199,6 +199,16 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[ TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>, TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]> ]>; +def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[ + TosaUnrankedTensorOf<[Tosa_MXFPNumber]>, + TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>], + "tosa-conformant tensor of at least rank 1", "::mlir::TensorType" +>; +def Tosa_MXFPScaleTensorAtLeast1D : AnyTypeOf<[ + TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>, + TosaRankedTensorOf<[Tosa_MXFPScaleNumber], [AtLeastRankOne]>], + "tosa-conformant tensor of at least rank 1", "::mlir::TensorType" +>; //===----------------------------------------------------------------------===// // Generic scalar, vector, or tensor of a particular type. diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 6cd0eaea3ce6c..0aff67f0b5eba 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -370,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) { result.operands))) return failure(); - result.addTypes(fnTy.getResult(0)); + result.addTypes(fnTy.getResults()); result.addAttributes(attrs); return success(); @@ -532,6 +532,24 @@ void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) { printWithEnumHandling(parser, *this); } +ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling(parser, result); +} + +void CastFromBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling(parser, result); +} + +void CastToBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// @@ -3944,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents( return success(); } +LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + CastFromBlockScaledOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + const ShapeAdaptor inputShape(adaptor.getInputData().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + return success(); +} + +LogicalResult CastFromBlockScaledOp::verify() { + const Type inputDataType = getInputData().getType(); + const Type outputDataType = getResult().getType(); + if (failed(verifyCompatibleShape(inputDataType, outputDataType))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "output_data (" << outputDataType << ")"; + + const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); + + if (inputDataShape.hasRank()) { + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + const int64_t inputDataLastDim = + inputDataShape.getDimSize(inputDataShape.getRank() - 1); + if (inputDataLastDim % blockSize != 0) + return emitOpError() << "expect last dimension of input_data (" + << inputDataLastDim + << ") to be divisible by block_size (" << blockSize + << ")"; + + const Type inputScaleType = getInputScale().getType(); + const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType); + + if (inputScaleShape.hasRank()) { + SmallVector inputDataDims, inputScaleDims; + inputDataShape.getDims(inputDataDims); + inputScaleShape.getDims(inputScaleDims); + + if (inputDataDims.size() != inputScaleDims.size() || + failed(verifyCompatibleShape( + ArrayRef(inputDataDims).drop_back(1), + ArrayRef(inputScaleDims).drop_back(1)))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "input_scale (" << inputScaleType + << ") except for the last dimension"; + + const SmallVector dimsToCheck{inputDataLastDim / blockSize, + inputScaleDims.back()}; + if (ShapedType::isStatic(inputDataLastDim) && + failed(verifyCompatibleDims(dimsToCheck))) + return emitOpError() + << "expect last dimension of input_scale (" + << inputScaleDims.back() + << ") to be equal to last dimension of input_data / block_size (" + << inputDataDims.back() / blockSize << ")"; + } + } + + return success(); +} + +LogicalResult CastToBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + CastToBlockScaledOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + const ShapeAdaptor inputShape(adaptor.getInputData().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + if (!inputShape.hasRank()) + return success(); + + // Calculate output_scale shape if ranked input provided + SmallVector outputScaleShape; + inputShape.getDims(outputScaleShape); + const int64_t lastDimLoc = inputShape.getRank() - 1; + const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc); + if (ShapedType::isStatic(lastDimSize)) { + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize()); + outputScaleShape[lastDimLoc] = lastDimSize / blockSize; + } + inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape)); + return success(); +} + +LogicalResult CastToBlockScaledOp::verify() { + const Type inputDataType = getInputData().getType(); + const Type outputDataType = getResult(0).getType(); + if (failed(verifyCompatibleShape(inputDataType, outputDataType))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "output_data (" << outputDataType << ")"; + + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); + if (inputDataShape.hasRank()) { + const int64_t inputDataLastDim = + inputDataShape.getDimSize(inputDataShape.getRank() - 1); + if (ShapedType::isStatic(inputDataLastDim) && + inputDataLastDim % blockSize != 0) + return emitOpError() << "expect last dimension of input_data (" + << inputDataLastDim + << ") to be divisible by block_size (" << blockSize + << ")"; + } + + const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType); + const Type outputScaleType = getResult(1).getType(); + const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType); + if (outputDataShape.hasRank() && outputScaleShape.hasRank()) { + SmallVector outputDataDims, outputScaleDims; + outputDataShape.getDims(outputDataDims); + outputScaleShape.getDims(outputScaleDims); + + if (outputDataDims.size() != outputScaleDims.size() || + failed(verifyCompatibleShape( + ArrayRef(outputDataDims).drop_back(1), + ArrayRef(outputScaleDims).drop_back(1)))) + return emitOpError() << "require compatible shapes for output_data (" + << outputDataType << ") and " + << "output_scale (" << outputScaleType + << ") except for the last dimension"; + + const int64_t outputDataLastDim = outputDataDims.back(); + const SmallVector dimsToCheck{outputDataLastDim / blockSize, + outputScaleDims.back()}; + if (ShapedType::isStatic(outputDataLastDim) && + failed(verifyCompatibleDims(dimsToCheck))) + return emitOpError() + << "expect last dimension of output_scale (" + << outputScaleDims.back() + << ") to be equal to last dimension of output_data / block_size (" + << outputDataDims.back() / blockSize << ")"; + } + + return success(); +} + LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, IfOp::Adaptor adaptor, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 53afc5d9f01a6..ab363ee6b4d2a 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -51,10 +51,11 @@ TosaProfileCompliance::getProfileComplianceMap() { // Base populating function LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands, - Value output) { - for (auto operand : operands) + ValueRange results) { + for (const auto &operand : operands) addValue(operand); - addValue(output); + for (const auto &result : results) + addValue(result); return success(); } @@ -176,23 +177,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) { return success(); } -template <> -LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) { - addValue(op.getInputReal()); - addValue(op.getInputImag()); - addValue(op.getOutputReal()); - addValue(op.getOutputImag()); - return success(); -} - -template <> -LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) { - addValue(op.getInputReal()); - addValue(op.getOutputReal()); - addValue(op.getOutputImag()); - return success(); -} - template <> LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) { addValue(op.getOnTrue()); @@ -246,7 +230,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // This helper function populates the info for all operands. #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \ if (isa(op)) { \ - return populateProfileInfo(op->getOperands(), op->getResult(0)); \ + return populateProfileInfo(op->getOperands(), op->getResults()); \ } // Skip irrelevant operands when they are independent and not tied to any @@ -257,8 +241,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { POPULATE_PROFILE_INFO_CUSTOM(Conv3D) POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D) POPULATE_PROFILE_INFO_CUSTOM(Mul) - POPULATE_PROFILE_INFO_CUSTOM(FFT2d) - POPULATE_PROFILE_INFO_CUSTOM(RFFT2d) POPULATE_PROFILE_INFO_CUSTOM(Concat) POPULATE_PROFILE_INFO_CUSTOM(Pad) POPULATE_PROFILE_INFO_CUSTOM(Reshape) @@ -277,7 +259,11 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // For the most of tosa operators, all operands are profile/extension related // and hence are all considered in this profile-based compilance check. POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled) + POPULATE_PROFILE_INFO_COMMON(FFT2d) + POPULATE_PROFILE_INFO_COMMON(RFFT2d) POPULATE_PROFILE_INFO_COMMON(Cast) + POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled) + POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled) POPULATE_PROFILE_INFO_COMMON(Const) POPULATE_PROFILE_INFO_COMMON(ArgMax) POPULATE_PROFILE_INFO_COMMON(Sub) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index cb544adb02d09..4d0b61acc4ea4 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { CHECK_RANKS_AND_SIZES(Transpose); // Type Conversion CHECK_RANKS_AND_SIZES(Cast); + CHECK_RANKS_AND_SIZES(CastFromBlockScaled); + CHECK_RANKS_AND_SIZES(CastToBlockScaled); CHECK_RANKS_AND_SIZES(Rescale); // Control Flow Operators CHECK_RANKS_AND_SIZES(If); diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index c138ac9bab2c4..a05f42395778a 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -696,3 +696,21 @@ func.func @test_const_shape() -> !tosa.shape<4> { return %cst : !tosa.shape<4> } +// ----- +// CHECK-LABEL: test_cast_from_block_scaled +func.func @test_cast_from_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16, mxfp] ] + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- +// CHECK-LABEL: test_cast_to_block_scaled +func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16, mxfp] ] + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = BLOCK_SIZE_32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU> +} + diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index ab048afd1ca0b..68a95787b81c7 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -549,7 +549,6 @@ func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: ten // ----- -// CHECK-LABEL: test_argmax_int64 func.func @test_argmax_int64(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64> { // expected-error@+1 {{'tosa.argmax' op illegal: requires [int64] but not enabled in target}} %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64> @@ -569,3 +568,19 @@ func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xb %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> return %0 : tensor<13x21x3xbf16> } + +// ----- + +func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op illegal: requires [mxfp] but not enabled in target}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- + +func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op illegal: requires [mxfp] but not enabled in target}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU> +} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 8771e6e2476e4..a7087647e542b 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1625,9 +1625,40 @@ func.func @test_unranked_weight_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor // ----- -// CHECK-LABEL: test_matmul_t_block_scaled_invalid_size func.func @test_matmul_t_block_scaled_invalid_size(%arg0: tensor<4x8x536870912xf4E2M1FN>, %arg1: tensor<4x8x16777216xf8E8M0FNU>, %arg2: tensor<4x16x536870912xf4E2M1FN>, %arg3: tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32> { // expected-error@+1 {{'tosa.matmul_t_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x536870912xf4E2M1FN>, tensor<4x8x16777216xf8E8M0FNU>, tensor<4x16x536870912xf4E2M1FN>, tensor<4x16x16777216xf8E8M0FNU>) -> tensor<*xf32> return %0 : tensor<*xf32> } + +// ----- + +func.func @test_cast_from_block_scaled_invalid_size(%arg0: tensor<67108864x32xf6E2M3FN>, %arg1: tensor<67108864x1xf8E8M0FNU>) -> tensor<67108864x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<67108864x32xf6E2M3FN>, tensor<67108864x1xf8E8M0FNU>) -> tensor<67108864x32xf32> + return %0 : tensor<67108864x32xf32> +} + +// ----- + +func.func @test_cast_from_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, %arg1: tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) -> tensor<1x2x3x4x5x6x7x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size} : (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) -> tensor<1x2x3x4x5x6x7x32xf32> + return %0 : tensor<1x2x3x4x5x6x7x32xf32> +} + +// ----- + +func.func @test_cast_to_block_scaled_invalid_size(%arg0: tensor<67108864x32xf32>) -> (tensor<67108864x32xf6E2M3FN>, tensor<67108864x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<67108864x32xf32>) -> (tensor<67108864x32xf6E2M3FN>, tensor<67108864x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<67108864x32xf6E2M3FN>, tensor<67108864x1xf8E8M0FNU> +} + +// ----- + +func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op failed level check: operand rank(shape) <= MAX_RANK}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU> +} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 9bf36b5fd4c7d..865f712ce1a5a 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -1268,3 +1268,31 @@ func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor, %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor, tensor<4x8x1xf8E8M0FNU>, tensor, tensor<1x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> return %0 : tensor<4x8x16xf32> } + +// ----- +// CHECK-LABEL: test_cast_from_block_scaled_static +func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- +// CHECK-LABEL: test_cast_from_block_scaled_unranked +func.func @test_cast_from_block_scaled_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> tensor<*xf32> { + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- +// CHECK-LABEL: test_cast_to_block_scaled_static +func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) { + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU> +} + +// ----- +// CHECK-LABEL: test_cast_to_block_scaled_unranked +func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) { + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) + return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU> +} diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index 0271d71561a52..7de7b85bcaedf 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -332,3 +332,17 @@ func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: ten %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32xf6E3M2FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf6E3M2FN>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> return %0 : tensor<4x8x16xf32> } + +// ----- +func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- +func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op illegal: requires [pro_fp] but not enabled in target}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU> +} diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 72479fe21ade8..54556a0eb08e0 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1628,3 +1628,48 @@ func.func @test_matmul_t_block_scaled_broadcast_b_scale(%arg0: tensor<*xf8E4M3FN %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<*xf8E4M3FN>, tensor<*xf8E8M0FNU>, tensor<*xf8E4M3FN>, tensor<1x4x1xf8E8M0FNU>) -> tensor return %0 : tensor } + +// ----- + +// CHECK-LABEL: test_cast_from_block_scaled_static +func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<*xf32> { + // CHECK: -> tensor<4x32xf32> + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: test_cast_from_block_scaled_unranked_input_scale +func.func @test_cast_from_block_scaled_unranked_input_scale(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>) -> tensor<*xf32> { + // CHECK: -> tensor<4x32xf32> + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<*xf8E8M0FNU>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: test_cast_to_block_scaled_static +func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) { + // CHECK: -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) + return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU> +} + +// ----- + +// CHECK-LABEL: test_cast_to_block_scaled_unranked +func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) { + // CHECK: -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) + return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU> +} + +// ----- + +// CHECK-LABEL: test_cast_to_block_scaled_dynamic_scales +func.func @test_cast_to_block_scaled_dynamic_scales(%arg0: tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) { + // CHECK: -> (tensor<4x?xf4E2M1FN>, tensor<4x?xf8E8M0FNU>) + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) + return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir index 926e7f2798c23..f3d8dab2f6b0f 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -58,3 +58,27 @@ func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xb %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> return %0 : tensor<13x21x3xbf16> } + +// ----- + +// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_fp32 +func.func @test_cast_from_block_scaled_fp8e5m2_fp32(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- + +// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_bf16 +func.func @test_cast_from_block_scaled_fp8e5m2_bf16(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xbf16> { + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xbf16> + return %0 : tensor<4x32xbf16> +} + +// ----- + +// CHECK-LABEL: test_cast_to_block_scaled_static +func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) { + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU> +} diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 4be5d725ad612..6cf76cdc7ad8e 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -1033,7 +1033,6 @@ module { // ----- -// CHECK-LABEL: @scatter_invalid_indices_N func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}} %1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32> @@ -1042,7 +1041,6 @@ func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3 // ----- -// CHECK-LABEL: @scatter_invalid_input_N func.func @scatter_invalid_input_N(%arg0 : tensor, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32> @@ -1051,7 +1049,6 @@ func.func @scatter_invalid_input_N(%arg0 : tensor, %arg1 : tensor<2x2 // ----- -// CHECK-LABEL: @scatter_invalid_out_N func.func @scatter_invalid_out_N(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor<2x2x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x2x5xi32>) -> tensor<3x4x5xi32> @@ -1060,7 +1057,6 @@ func.func @scatter_invalid_out_N(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor, %arg2 : tensor<2x2x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x2x5xi32>) -> tensor<2x3x5xi32> @@ -1069,7 +1065,6 @@ func.func @scatter_invalid_out_K(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor, %arg2 : tensor<2x3x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x3x5xi32>) -> tensor<2x4x5xi32> @@ -1078,7 +1073,6 @@ func.func @scatter_invalid_input_W(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor, %arg2 : tensor<2x2x6xi32>) { // expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x2x6xi32>) -> tensor<2x4x5xi32> @@ -1087,7 +1081,6 @@ func.func @scatter_invalid_input_C(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor, %arg2 : tensor<2x2x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x2x5xi32>) -> tensor<2x4x6xi32> @@ -1096,7 +1089,6 @@ func.func @scatter_invalid_out_C(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) { // expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}} %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32> @@ -1150,3 +1142,83 @@ func.func @test_matmul_t_block_scaled_batch_mismatch(%arg0: tensor<4x8x32xf8E4M3 %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<2x16x32xf8E4M3FN>, tensor<2x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> return %0 : tensor<4x8x16xf32> } + +// ----- + +func.func @cast_from_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and output_data ('tensor<5x32xf32>')}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<5x32xf32> + return %0 : tensor<5x32xf32> +} + +// ----- + +func.func @cast_from_block_scaled_not_scalar(%arg0: tensor, %arg1: tensor) -> tensor { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func.func @cast_from_block_scaled_not_divisible_by_block_size(%arg0: tensor<4x33xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x33xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op expect last dimension of input_data (33) to be divisible by block_size (32)}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x33xf32> + return %0 : tensor<4x33xf32> +} + +// ----- + +func.func @cast_from_block_scaled_data_scale_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<5x1xf8E8M0FNU>) -> tensor<4x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf4E2M1FN>') and input_scale ('tensor<5x1xf8E8M0FNU>') except for the last dimension}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- + +func.func @cast_from_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x2xf8E8M0FNU>) -> tensor<4x32xf32> { + // expected-error@+1 {{'tosa.cast_from_block_scaled' op expect last dimension of input_scale (2) to be equal to last dimension of input_data / block_size (1)}} + %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size} : (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) -> tensor<4x32xf32> + return %0 : tensor<4x32xf32> +} + +// ----- + +func.func @test_cast_to_block_scaled_incompatible_input_output_shape(%arg0: tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op require compatible shapes for input_data ('tensor<4x32xf32>') and output_data ('tensor<5x32xf4E2M1FN>')}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<5x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU> +} + +// ----- + +func.func @test_cast_to_block_scaled_not_scalar(%arg0: tensor) -> (tensor, tensor) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor) -> (tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} + +// ----- + +func.func @test_cast_to_block_scaled_not_divisible_by_block_size(%arg0: tensor<4x33xf32>) -> (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op expect last dimension of input_data (33) to be divisible by block_size (32)}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x33xf32>) -> (tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x33xf4E2M1FN>, tensor<4x1xf8E8M0FNU> +} + +// ----- + +func.func @test_cast_to_block_scaled_data_scale_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op require compatible shapes for output_data ('tensor<4x32xf4E2M1FN>') and output_scale ('tensor<5x1xf8E8M0FNU>') except for the last dimension}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<5x1xf8E8M0FNU> +} + +// ----- + +func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) { + // expected-error@+1 {{'tosa.cast_to_block_scaled' op expect last dimension of output_scale (2) to be equal to last dimension of output_data / block_size (1)}} + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU> +}