diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 342634364cc10..2e280dba469c9 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -471,16 +471,33 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { - auto intMin = rewriter.create( + Value intMin = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); - auto intMax = rewriter.create( + Value intMax = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); + // Since F32 constants are created, we may still need to convert them to + // the correct type. + auto convertType = [&](Type ty, Value arg) { + auto argTy = arg.getType(); + bool bitExtend = + argTy.getIntOrFloatBitWidth() < ty.getIntOrFloatBitWidth(); + if (ty != argTy) { + if (!bitExtend) + arg = rewriter.create(loc, ty, arg); + else + arg = rewriter.create(loc, ty, arg); + } + return arg; + }; + intMin = convertType(srcTy, intMin); + intMax = convertType(srcTy, intMax); + auto rounded = rewriter.create(loc, args[0]); auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 6483e29e7a9c2..70d09cde7bc7f 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -270,6 +270,17 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () { // CHECK: arith.extf %0 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xf32> + // CHECK: linalg.generic + // CHECK: %[[C_LOWEST:.+]] = arith.constant -2.14748365E+9 + // CHECK: %[[C_MAX:.+]] = arith.constant 2.14748365E+9 + // CHECK: arith.truncf %[[C_LOWEST]] : f32 to f16 + // CHECK: arith.truncf %[[C_MAX]] : f32 to f16 + // CHECK: math.roundeven + // CHECK: arith.minf + // CHECK: arith.maxf + // CHECK: arith.fptosi + %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32> + return }