From 96193efdd1957358f700fde15f4dbc47cde58eaa Mon Sep 17 00:00:00 2001 From: Spenser Bauman Date: Wed, 10 Jul 2024 13:21:45 -0400 Subject: [PATCH] [mlir][tosa] Fix for incorrect cannonicalization of tosa.pad The current fold method for tosa.pad can produce invalid IR by replacing the padded value with the tosa.pad is a noop. When the type of the input value does not match the type of the tosa.pad, the canonicalizer detects the change in types and asserts. --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 2 +- mlir/test/Dialect/Tosa/canonicalize.mlir | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 8687be075ea67..866ab0d2228f7 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -859,7 +859,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { OpFoldResult PadOp::fold(FoldAdaptor adaptor) { // If the pad is all zeros we can fold this operation away. - if (adaptor.getPadding()) { + if (adaptor.getPadding() && getInput1().getType() == getType()) { auto densePad = llvm::cast(adaptor.getPadding()); if (densePad.isSplat() && densePad.getSplatValue().isZero()) { return getInput1(); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index accc792c8f2ac..3bcf58015831b 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -217,6 +217,20 @@ func.func @pad_noop(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: @pad_noop_type_mismatch_nofold +func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor { + // CHECK: %[[PAD:.+]] = tosa.pad + // CHECK: return %[[PAD]] + + %c0_i32 = arith.constant 0 : i32 + %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32> + + %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: @pad_determine_val_i32 func.func @pad_determine_val_i32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor}