From 0c68ac7bcab973e9e5b4d265379857f13ce49a35 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 6 May 2024 12:06:48 +0200 Subject: [PATCH] [mlir] [TOSA] Allow any floating point type --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 6 +++--- .../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 21 ++++--------------- .../Tosa/Transforms/TosaValidation.cpp | 9 ++++---- mlir/test/Dialect/Tosa/invalid.mlir | 2 +- mlir/test/Dialect/Tosa/level_check.mlir | 8 +++++++ 5 files changed, 20 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 97a36c49d01b3..7871b46724a03 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1857,11 +1857,11 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, }]; let arguments = (ins - Tosa_Tensor_Plus_F64:$input + Tosa_Tensor:$input ); let results = (outs - Tosa_Tensor_Plus_F64:$output + Tosa_Tensor:$output ); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; @@ -1944,7 +1944,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure, ); let results = (outs - TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output + TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output ); let hasFolder = 1; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 3687891fe4b7c..14fc9c7a6730c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -71,28 +71,16 @@ def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>, Tosa_QuantizedType<"int16", [16, 0], 1>, Tosa_QuantizedType<"int32", [32, 0], 1>]>; -//===----------------------------------------------------------------------===// -// Floating-point types. -//===----------------------------------------------------------------------===// -def Tosa_Float : AnyTypeOf<[ - F32, - F16, - BF16]>; - //===----------------------------------------------------------------------===// // Multi-category types. //===----------------------------------------------------------------------===// -def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float], +def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat], "number">; -// Add F64 type support just for tosa::CastOp and tosa::ConstOp -def Tosa_AnyNumber_Plus_F64 : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64], - "number_plus_f64">; - // For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp, // tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8, - Tosa_QuantizedInt, Tosa_Float]>; + Tosa_QuantizedInt, AnyFloat]>; //===----------------------------------------------------------------------===// // Tensor types @@ -101,18 +89,17 @@ def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8, def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>; def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>; -def Tosa_FloatTensor : TensorOf<[Tosa_Float]>; +def Tosa_FloatTensor : TensorOf<[AnyFloat]>; // Either ranked or unranked tensor of TOSA supported element types. def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; -def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>; // Must be ranked but no further constraints def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>; // Any tensor element type allowed in Tosa ops. def Tosa_ElementType : Type, "tosa.dtype">; + AnyFloat.predicate]>, "tosa.dtype">; class Tosa_TensorOfOrNone allowedTypes, string description = ""> : AnyTypeOf<[TensorOf, NoneType], description>; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 539501082fd3f..b78c372af77e6 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -506,11 +506,10 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) { } bool TosaValidation::isValidElementType(Type type) { - if ((profile == TosaProfileEnum::BaseInference) && isa(type)) { - return false; - } - if (type.isF64()) { - return false; + if (isa(type)) { + if (profile == TosaProfileEnum::BaseInference) + return false; + return type.isF32() || type.isF16() || type.isBF16(); } if (auto intTy = dyn_cast(type)) { if (intTy.isUnsigned()) { diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 730ac41dd7a8d..cb38d4d81ca2e 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -20,7 +20,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t // ----- func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { - // expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values, but got 'tensor<*xi8>'}} + // expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}} %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index d8dd878051f18..9b652f2d0bd14 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -131,6 +131,14 @@ func.func @test_const_ui32(%arg0 : tensor<1xui32>) { // ----- +func.func @test_const_f64(%arg0 : tensor<1xf64>) { + // expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'f64' is not legal}} + %0 = "tosa.const"() {value = dense<0.0> : tensor<1xf64>} : () -> tensor<1xf64> + return +} + +// ----- + func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}} %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} :