diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 60cf897b00de3..50593b08ad74b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1656,8 +1656,8 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, } void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) { - // TODO: Add and test patterns for tensor.unpack patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) { diff --git a/mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir index 6d9709caf7093..0dbdf470bbfc9 100644 --- a/mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir +++ b/mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir @@ -1,4 +1,7 @@ -// RUN: mlir-opt -split-input-file --transform-interpreter --canonicalize --test-linalg-transform-patterns="test-decompose-tensor-unpack" %s | FileCheck %s +// RUN: mlir-opt -split-input-file -transform-interpreter --canonicalize \ +// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-unpack.mlir' \ +// RUN: -transform-interpreter=entry-point=decompose_unpack \ +// RUN: -transform-interpreter %s | FileCheck %s func.func @KCRSsr_to_KCRS(%arg0: tensor<1x1x4x8x8x32xf32>, %arg1: tensor<1x1x128x64xf32>) -> tensor<1x1x128x64xf32> { %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x4x8x8x32xf32> -> tensor<1x1x128x64xf32> diff --git a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir index bd60504f53345..ba1f214952562 100644 --- a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-tensor-unpack" %s | FileCheck %s +// RUN: mlir-opt -split-input-file \ +// RUN: -transform-preload-library='transform-library-paths=%p/td/decompose-unpack.mlir' \ +// RUN: -transform-interpreter=entry-point=decompose_unpack %s | FileCheck %s func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32> { %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x1x1x8x32xf32> -> tensor<1x1x32x8xf32> diff --git a/mlir/test/Dialect/Linalg/td/decompose-unpack.mlir b/mlir/test/Dialect/Linalg/td/decompose-unpack.mlir new file mode 100644 index 0000000000000..11243634262e0 --- /dev/null +++ b/mlir/test/Dialect/Linalg/td/decompose-unpack.mlir @@ -0,0 +1,12 @@ +module @transforms attributes { transform.with_named_sequence } { + transform.named_sequence @decompose_unpack(%module: !transform.any_op {transform.readonly}) { + %pack = transform.structured.match ops{["tensor.unpack"]} in %module : (!transform.any_op) -> !transform.any_op + + %1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %1 { + transform.apply_patterns.linalg.decompose_pack_unpack + } : !transform.any_op + + transform.yield + } +}