From 0418e51cf33bc59cc6f19ed00edc8c2d62e4d9df Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 15 Jun 2024 10:46:44 -0500 Subject: [PATCH 01/14] implement canonicalizer for batched linalg operations --- .../Linalg/IR/LinalgNamedStructuredOps.yaml | 49 ++----- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 121 ++++++++++++++++++ .../linalg/opdsl/ops/core_named_ops.py | 5 + 3 files changed, 138 insertions(+), 37 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index fad234a9dcae9..3cbfb58ed8506 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -304,41 +304,6 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: I --- !LinalgOpConfig -metadata: !LinalgOpMetadata - name: reciprocal - cpp_class_name: ReciprocalOp - doc: |- - Applies reciprocal(x) elementwise. - - No numeric casting is performed on the input operand. -structured_op: !LinalgStructuredOpConfig - args: - - !LinalgOperandDefConfig - name: I - kind: input_tensor - type_var: T1 - shape_map: affine_map<() -> ()> - - !LinalgOperandDefConfig - name: O - kind: output_tensor - type_var: T1 - shape_map: affine_map<() -> ()> - indexing_maps: !LinalgIndexingMapsConfig - static_indexing_maps: - - affine_map<() -> ()> - - affine_map<() -> ()> - iterator_types: [] - assignments: - - !ScalarAssign - arg: O - value: !ScalarExpression - scalar_fn: - kind: unary - fn_name: reciprocal - operands: - - !ScalarExpression - scalar_arg: I ---- !LinalgOpConfig metadata: !LinalgOpMetadata name: round cpp_class_name: RoundOp @@ -516,7 +481,7 @@ structured_op: !LinalgStructuredOpConfig --- !LinalgOpConfig metadata: !LinalgOpMetadata name: erf - cpp_class_name: erfOp + cpp_class_name: ErfOp doc: |- Applies erf(x) elementwise. @@ -959,7 +924,7 @@ structured_op: !LinalgStructuredOpConfig --- !LinalgOpConfig metadata: !LinalgOpMetadata name: powf - cpp_class_name: PowFOp + cpp_class_name: PowfOp doc: |- Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`. @@ -1622,6 +1587,8 @@ metadata: !LinalgOpMetadata them to the same data type as the accumulator/output. implements: - LinalgContractionOpInterface + defines: + - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1692,6 +1659,8 @@ metadata: !LinalgOpMetadata them to the same data type as the accumulator/output. implements: - LinalgContractionOpInterface + defines: + - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1762,6 +1731,8 @@ metadata: !LinalgOpMetadata them to the same data type as the accumulator/output. implements: - LinalgContractionOpInterface + defines: + - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -2140,6 +2111,8 @@ metadata: !LinalgOpMetadata them to the same data type as the accumulator/output. implements: - LinalgContractionOpInterface + defines: + - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -2208,6 +2181,8 @@ metadata: !LinalgOpMetadata them to the same data type as the accumulator/output. implements: - LinalgContractionOpInterface + defines: + - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index b79afebfa8158..ecd669165efc7 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -42,6 +43,7 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include #include using namespace mlir; @@ -578,6 +580,125 @@ class RegionBuilderHelper { } // namespace +//===----------------------------------------------------------------------===// +// BatchMatmulOp +//===----------------------------------------------------------------------===// + +namespace { + +template +struct BatchMatmulToMatmul : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp, + PatternRewriter &rewriter) const override { + + auto loc = batchMatmulOp.getLoc(); + auto inputs = batchMatmulOp.getDpsInputs(); + auto inits = batchMatmulOp.getDpsInits(); + if (inputs.size() != 2 || inits.size() != 1) + return rewriter.notifyMatchFailure(batchMatmulOp, + "expected 2 inputs and 1 init"); + auto lhs = inputs[0]; + auto rhs = inputs[1]; + auto init = inits[0]; + + auto lhsType = cast(lhs.getType()); + auto rhsType = cast(rhs.getType()); + auto initType = cast(init.getType()); + if (ShapedType::isDynamic(lhsType.getShape()[0]) || + lhsType.getShape()[0] != rhsType.getShape()[0] || + rhsType.getShape()[0] != initType.getShape()[0]) + return rewriter.notifyMatchFailure( + batchMatmulOp, "expected batch sizes of all operands to be same"); + + auto results = batchMatmulOp.getResults(); + if (results.size() > 1) + return rewriter.notifyMatchFailure(batchMatmulOp, + "expected at most one result"); + + SmallVector resultType; + if (results.size() == 1) { + auto oldResultType = cast(results[0].getType()); + resultType.push_back( + RankedTensorType::get(oldResultType.getShape().drop_front(1), + oldResultType.getElementType())); + } + + auto collapseSingletonDim = [&](Value val) -> Value { + SmallVector reassociation({{0, 1}}); + auto valType = cast(val.getType()); + for (auto i = 2; i < valType.getRank(); i++) + reassociation.push_back({i}); + if (isa(valType)) { + RankedTensorType collapsedType = RankedTensorType::get( + valType.getShape().drop_front(1), valType.getElementType()); + return rewriter.create(loc, collapsedType, val, + reassociation); + } + MemRefType collapsedType = MemRefType::get( + valType.getShape().drop_front(1), valType.getElementType()); + return rewriter.create(loc, collapsedType, val, + reassociation); + }; + + auto collapsedLhs = collapseSingletonDim(lhs); + auto collapsedRhs = collapseSingletonDim(rhs); + auto collapsedInit = collapseSingletonDim(init); + + auto collapsedOp = rewriter.create( + loc, resultType, ValueRange{collapsedLhs, collapsedRhs}, + ValueRange{collapsedInit}); + for (auto attr : batchMatmulOp->getAttrs()) { + if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) + continue; + collapsedOp->setAttr(attr.getName(), attr.getValue()); + } + + if (results.size() < 1) { + rewriter.replaceOp(batchMatmulOp, collapsedOp); + } else { + SmallVector reassociation({{0, 1}}); + auto resultType = cast(results[0].getType()); + for (auto i = 2; i < resultType.getRank(); i++) + reassociation.push_back({i}); + Value expandedResult = rewriter.create( + loc, resultType, collapsedOp.getResultTensors()[0], reassociation); + rewriter.replaceOp(batchMatmulOp, expandedResult); + } + + return success(); + } +}; + +} // namespace + +void BatchMatmulOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + +void BatchMatmulTransposeAOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add>( + context); +} + +void BatchMatmulTransposeBOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add>( + context); +} + +void BatchMatvecOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + +void BatchVecmatOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 43410aaa6af1b..b4b36ba0bfe51 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -518,6 +518,7 @@ def batch_matmul( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + defines(Canonicalizer) domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( @@ -537,6 +538,7 @@ def batch_matmul_transpose_a( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + defines(Canonicalizer) domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed( @@ -556,6 +558,7 @@ def batch_matmul_transpose_b( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + defines(Canonicalizer) domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( @@ -642,6 +645,7 @@ def batch_matvec( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + defines(Canonicalizer) domain(D.b, D.m, D.k) implements(ContractionOpInterface) C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( @@ -660,6 +664,7 @@ def batch_vecmat( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + defines(Canonicalizer) domain(D.b, D.n, D.k) implements(ContractionOpInterface) C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed( From 02b2ca083d145fc88a9498480e4a831affdebf10 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 16 Jun 2024 12:01:33 -0500 Subject: [PATCH 02/14] add tests --- .../Linalg/IR/LinalgNamedStructuredOps.yaml | 34 +++++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 12 +- mlir/test/Dialect/Linalg/canonicalize.mlir | 137 +++++++++++++++++- 3 files changed, 174 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index 3cbfb58ed8506..41f90483c93b3 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -304,6 +304,40 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: I --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: reciprocal + cpp_class_name: ReciprocalOp + doc: |- + Applies reciprocal(x) elementwise. + No numeric casting is performed on the input operand. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: T1 + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: unary + fn_name: reciprocal + operands: + - !ScalarExpression + scalar_arg: I +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: round cpp_class_name: RoundOp diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index ecd669165efc7..4e47b6018c445 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -605,16 +605,12 @@ struct BatchMatmulToMatmul : OpRewritePattern { auto lhsType = cast(lhs.getType()); auto rhsType = cast(rhs.getType()); auto initType = cast(init.getType()); - if (ShapedType::isDynamic(lhsType.getShape()[0]) || - lhsType.getShape()[0] != rhsType.getShape()[0] || - rhsType.getShape()[0] != initType.getShape()[0]) - return rewriter.notifyMatchFailure( - batchMatmulOp, "expected batch sizes of all operands to be same"); + if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 || + initType.getShape()[0] != 1) + return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1"); auto results = batchMatmulOp.getResults(); - if (results.size() > 1) - return rewriter.notifyMatchFailure(batchMatmulOp, - "expected at most one result"); + assert(results.size() < 2 && "expected at most one result"); SmallVector resultType; if (results.size() == 1) { diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 928030a81dc02..8514bcb089891 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) return %0 : tensor<2x3xf32> } -// ---- +// ----- func.func @transpose_1d(%input: tensor<16xf32>, %init: tensor<16xf32>) -> tensor<16xf32> { @@ -1096,3 +1096,138 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>, func.return %transpose2 : tensor<3x4x5xf32> } +// ----- + +func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> { + // CHECK-LABEL: @singleton_batch_matmul_tensor + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] + // CHECK-NEXT: %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]] + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]] + // CHECK-NEXT: return %[[RES]] + %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>) + outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> + return %1 : tensor<1x?x?xf32> +} + +// ----- + +func.func @singletone_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) { + // CHECK-LABEL: @singletone_batch_matmul_memref + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref, memref) outs(%[[COLLAPSED_INIT]] : memref) + // CHECK-NEXT: return + linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>) + outs(%arg2 : memref<1x?x?xf32>) + return +} + +// ----- + +func.func @singletone_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { + // CHECK-LABEL: @singletone_batch_matvec + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] + // CHECK-NEXT: return %[[RES]] + %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>) + outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32> + return %1 : tensor<1x?xf32> +} + +// ----- + +func.func @singletone_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { + // CHECK-LABEL: @singletone_batch_vecmat + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] + // CHECK-NEXT: return %[[RES]] + %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>) + outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32> + return %1 : tensor<1x?xf32> +} + +// ----- + +func.func @singletone_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) { + // CHECK-LABEL: @singletone_batchmatmul_transpose_a + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>) + // CHECK-NEXT: return + linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>) + return +} + +// ----- + +func.func @singletone_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) { + // CHECK-LABEL: @singletone_batchmatmul_transpose_b + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>) + // CHECK-NEXT: return + linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>) + return +} + +// ----- + +func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> { + // CHECK-LABEL: @nonsingleton_batch_matmul + // CHECK-NOT: collapse_shape + // CHECK: linalg.batch_matmul + // CHECK-NOT: expand_shape + %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>) + outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32> + return %1 : tensor<2x?x?xf32> +} + +// ----- + +func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2: tensor) -> tensor { + // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic + // CHECK-NOT: collapse_shape + // CHECK: linalg.batch_matmul + // CHECK-NOT: expand_shape + %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %1 : tensor +} + From 543b0d643506d12c658e2984d943873ed4c8b78b Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 16 Jun 2024 12:12:35 -0500 Subject: [PATCH 03/14] remove unecessary changes --- .../mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml | 1 + mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index 41f90483c93b3..3f0aa33767a75 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -309,6 +309,7 @@ metadata: !LinalgOpMetadata cpp_class_name: ReciprocalOp doc: |- Applies reciprocal(x) elementwise. + No numeric casting is performed on the input operand. structured_op: !LinalgStructuredOpConfig args: diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 4e47b6018c445..8df33a107c2cb 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -43,7 +42,6 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include #include using namespace mlir; From 5732d87375c942306ddf5c7a6661b8123f423b1c Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 18 Jun 2024 20:28:31 -0500 Subject: [PATCH 04/14] Move patterns to a populate function and implement test pass --- .../Linalg/IR/LinalgNamedStructuredOps.yaml | 14 +- .../Dialect/Linalg/Transforms/Transforms.h | 7 + mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 115 --------------- .../Linalg/Transforms/DropUnitDims.cpp | 98 +++++++++++++ .../linalg/opdsl/ops/core_named_ops.py | 5 - mlir/test/Dialect/Linalg/canonicalize.mlir | 137 +----------------- mlir/test/lib/Dialect/Linalg/CMakeLists.txt | 1 + .../TestLinalgRankReduceContractionOps.cpp | 68 +++++++++ mlir/tools/mlir-opt/mlir-opt.cpp | 2 + 9 files changed, 179 insertions(+), 268 deletions(-) create mode 100644 mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index 3f0aa33767a75..fad234a9dcae9 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -516,7 +516,7 @@ structured_op: !LinalgStructuredOpConfig --- !LinalgOpConfig metadata: !LinalgOpMetadata name: erf - cpp_class_name: ErfOp + cpp_class_name: erfOp doc: |- Applies erf(x) elementwise. @@ -959,7 +959,7 @@ structured_op: !LinalgStructuredOpConfig --- !LinalgOpConfig metadata: !LinalgOpMetadata name: powf - cpp_class_name: PowfOp + cpp_class_name: PowFOp doc: |- Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`. @@ -1622,8 +1622,6 @@ metadata: !LinalgOpMetadata them to the same data type as the accumulator/output. implements: - LinalgContractionOpInterface - defines: - - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1694,8 +1692,6 @@ metadata: !LinalgOpMetadata them to the same data type as the accumulator/output. implements: - LinalgContractionOpInterface - defines: - - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -1766,8 +1762,6 @@ metadata: !LinalgOpMetadata them to the same data type as the accumulator/output. implements: - LinalgContractionOpInterface - defines: - - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -2146,8 +2140,6 @@ metadata: !LinalgOpMetadata them to the same data type as the accumulator/output. implements: - LinalgContractionOpInterface - defines: - - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig @@ -2216,8 +2208,6 @@ metadata: !LinalgOpMetadata them to the same data type as the accumulator/output. implements: - LinalgContractionOpInterface - defines: - - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 308ce92e35520..c49383c600a57 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1692,6 +1692,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns, void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn); +/// Adds patterns that that reduce the rank of named contraction ops that have +/// unit dimensions in the operand(s) by converting to a senquence of `collapse_shape`, +/// ``, `expand_shape` (if on tensors). For example a +/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul` +/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`. +void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 8df33a107c2cb..b79afebfa8158 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -578,121 +578,6 @@ class RegionBuilderHelper { } // namespace -//===----------------------------------------------------------------------===// -// BatchMatmulOp -//===----------------------------------------------------------------------===// - -namespace { - -template -struct BatchMatmulToMatmul : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp, - PatternRewriter &rewriter) const override { - - auto loc = batchMatmulOp.getLoc(); - auto inputs = batchMatmulOp.getDpsInputs(); - auto inits = batchMatmulOp.getDpsInits(); - if (inputs.size() != 2 || inits.size() != 1) - return rewriter.notifyMatchFailure(batchMatmulOp, - "expected 2 inputs and 1 init"); - auto lhs = inputs[0]; - auto rhs = inputs[1]; - auto init = inits[0]; - - auto lhsType = cast(lhs.getType()); - auto rhsType = cast(rhs.getType()); - auto initType = cast(init.getType()); - if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 || - initType.getShape()[0] != 1) - return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1"); - - auto results = batchMatmulOp.getResults(); - assert(results.size() < 2 && "expected at most one result"); - - SmallVector resultType; - if (results.size() == 1) { - auto oldResultType = cast(results[0].getType()); - resultType.push_back( - RankedTensorType::get(oldResultType.getShape().drop_front(1), - oldResultType.getElementType())); - } - - auto collapseSingletonDim = [&](Value val) -> Value { - SmallVector reassociation({{0, 1}}); - auto valType = cast(val.getType()); - for (auto i = 2; i < valType.getRank(); i++) - reassociation.push_back({i}); - if (isa(valType)) { - RankedTensorType collapsedType = RankedTensorType::get( - valType.getShape().drop_front(1), valType.getElementType()); - return rewriter.create(loc, collapsedType, val, - reassociation); - } - MemRefType collapsedType = MemRefType::get( - valType.getShape().drop_front(1), valType.getElementType()); - return rewriter.create(loc, collapsedType, val, - reassociation); - }; - - auto collapsedLhs = collapseSingletonDim(lhs); - auto collapsedRhs = collapseSingletonDim(rhs); - auto collapsedInit = collapseSingletonDim(init); - - auto collapsedOp = rewriter.create( - loc, resultType, ValueRange{collapsedLhs, collapsedRhs}, - ValueRange{collapsedInit}); - for (auto attr : batchMatmulOp->getAttrs()) { - if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) - continue; - collapsedOp->setAttr(attr.getName(), attr.getValue()); - } - - if (results.size() < 1) { - rewriter.replaceOp(batchMatmulOp, collapsedOp); - } else { - SmallVector reassociation({{0, 1}}); - auto resultType = cast(results[0].getType()); - for (auto i = 2; i < resultType.getRank(); i++) - reassociation.push_back({i}); - Value expandedResult = rewriter.create( - loc, resultType, collapsedOp.getResultTensors()[0], reassociation); - rewriter.replaceOp(batchMatmulOp, expandedResult); - } - - return success(); - } -}; - -} // namespace - -void BatchMatmulOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add>(context); -} - -void BatchMatmulTransposeAOp::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results.add>( - context); -} - -void BatchMatmulTransposeBOp::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results.add>( - context); -} - -void BatchMatvecOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add>(context); -} - -void BatchVecmatOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add>(context); -} - //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index c0829397f1f85..9248710d5afc9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -812,6 +812,103 @@ void mlir::linalg::populateMoveInitOperandsToInputPattern( patterns.add(patterns.getContext()); } +namespace { + +template +struct BatchMatmulToMatmul : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp, + PatternRewriter &rewriter) const override { + + auto loc = batchMatmulOp.getLoc(); + auto inputs = batchMatmulOp.getDpsInputs(); + auto inits = batchMatmulOp.getDpsInits(); + if (inputs.size() != 2 || inits.size() != 1) + return rewriter.notifyMatchFailure(batchMatmulOp, + "expected 2 inputs and 1 init"); + auto lhs = inputs[0]; + auto rhs = inputs[1]; + auto init = inits[0]; + + auto lhsType = cast(lhs.getType()); + auto rhsType = cast(rhs.getType()); + auto initType = cast(init.getType()); + if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 || + initType.getShape()[0] != 1) + return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1"); + + auto results = batchMatmulOp.getResults(); + assert(results.size() < 2 && "expected at most one result"); + + SmallVector resultType; + if (results.size() == 1) { + auto oldResultType = cast(results[0].getType()); + resultType.push_back( + RankedTensorType::get(oldResultType.getShape().drop_front(1), + oldResultType.getElementType())); + } + + auto collapseSingletonDim = [&](Value val) -> Value { + SmallVector reassociation({{0, 1}}); + auto valType = cast(val.getType()); + for (auto i = 2; i < valType.getRank(); i++) + reassociation.push_back({i}); + if (isa(valType)) { + RankedTensorType collapsedType = RankedTensorType::get( + valType.getShape().drop_front(1), valType.getElementType()); + return rewriter.create(loc, collapsedType, val, + reassociation); + } + MemRefType collapsedType = MemRefType::get( + valType.getShape().drop_front(1), valType.getElementType()); + return rewriter.create(loc, collapsedType, val, + reassociation); + }; + + auto collapsedLhs = collapseSingletonDim(lhs); + auto collapsedRhs = collapseSingletonDim(rhs); + auto collapsedInit = collapseSingletonDim(init); + + auto collapsedOp = rewriter.create( + loc, resultType, ValueRange{collapsedLhs, collapsedRhs}, + ValueRange{collapsedInit}); + for (auto attr : batchMatmulOp->getAttrs()) { + if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) + continue; + collapsedOp->setAttr(attr.getName(), attr.getValue()); + } + + if (results.size() < 1) { + rewriter.replaceOp(batchMatmulOp, collapsedOp); + } else { + SmallVector reassociation({{0, 1}}); + auto resultType = cast(results[0].getType()); + for (auto i = 2; i < resultType.getRank(); i++) + reassociation.push_back({i}); + Value expandedResult = rewriter.create( + loc, resultType, collapsedOp.getResultTensors()[0], reassociation); + rewriter.replaceOp(batchMatmulOp, expandedResult); + } + + return success(); + } +}; +} // namespace + +void mlir::linalg::populateContractionOpRankReducingPatterns( + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add>(context); + patterns + .add>( + context); + patterns + .add>( + context); + patterns.add>(context); + patterns.add>(context); +} + namespace { /// Pass that removes unit-extent dims within generic ops. struct LinalgFoldUnitExtentDimsPass @@ -833,4 +930,5 @@ struct LinalgFoldUnitExtentDimsPass (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; + } // namespace diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index b4b36ba0bfe51..43410aaa6af1b 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -518,7 +518,6 @@ def batch_matmul( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ - defines(Canonicalizer) domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( @@ -538,7 +537,6 @@ def batch_matmul_transpose_a( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ - defines(Canonicalizer) domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed( @@ -558,7 +556,6 @@ def batch_matmul_transpose_b( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ - defines(Canonicalizer) domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( @@ -645,7 +642,6 @@ def batch_matvec( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ - defines(Canonicalizer) domain(D.b, D.m, D.k) implements(ContractionOpInterface) C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( @@ -664,7 +660,6 @@ def batch_vecmat( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ - defines(Canonicalizer) domain(D.b, D.n, D.k) implements(ContractionOpInterface) C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed( diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 8514bcb089891..928030a81dc02 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) return %0 : tensor<2x3xf32> } -// ----- +// ---- func.func @transpose_1d(%input: tensor<16xf32>, %init: tensor<16xf32>) -> tensor<16xf32> { @@ -1096,138 +1096,3 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>, func.return %transpose2 : tensor<3x4x5xf32> } -// ----- - -func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> { - // CHECK-LABEL: @singleton_batch_matmul_tensor - // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> - // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> - // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 - // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) - // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] - // CHECK-NEXT: %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]] - // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]] - // CHECK-NEXT: return %[[RES]] - %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>) - outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> - return %1 : tensor<1x?x?xf32> -} - -// ----- - -func.func @singletone_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) { - // CHECK-LABEL: @singletone_batch_matmul_memref - // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32> - // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32> - // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32> - // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref, memref) outs(%[[COLLAPSED_INIT]] : memref) - // CHECK-NEXT: return - linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>) - outs(%arg2 : memref<1x?x?xf32>) - return -} - -// ----- - -func.func @singletone_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { - // CHECK-LABEL: @singletone_batch_matvec - // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> - // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32> - // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32> - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) - // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] - // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] - // CHECK-NEXT: return %[[RES]] - %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>) - outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32> - return %1 : tensor<1x?xf32> -} - -// ----- - -func.func @singletone_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { - // CHECK-LABEL: @singletone_batch_vecmat - // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32> - // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> - // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32> - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) - // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] - // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] - // CHECK-NEXT: return %[[RES]] - %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>) - outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32> - return %1 : tensor<1x?xf32> -} - -// ----- - -func.func @singletone_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) { - // CHECK-LABEL: @singletone_batchmatmul_transpose_a - // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32> - // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32> - // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32> - // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>) - // CHECK-NEXT: return - linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>) - return -} - -// ----- - -func.func @singletone_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) { - // CHECK-LABEL: @singletone_batchmatmul_transpose_b - // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32> - // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32> - // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32> - // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>) - // CHECK-NEXT: return - linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>) - return -} - -// ----- - -func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> { - // CHECK-LABEL: @nonsingleton_batch_matmul - // CHECK-NOT: collapse_shape - // CHECK: linalg.batch_matmul - // CHECK-NOT: expand_shape - %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>) - outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32> - return %1 : tensor<2x?x?xf32> -} - -// ----- - -func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2: tensor) -> tensor { - // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic - // CHECK-NOT: collapse_shape - // CHECK: linalg.batch_matmul - // CHECK-NOT: expand_shape - %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) -> tensor - return %1 : tensor -} - diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt index b28f2b3564662..283e426b4e594 100644 --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_library(MLIRLinalgTestPasses TestLinalgDropUnitDims.cpp TestLinalgElementwiseFusion.cpp TestLinalgFusionTransforms.cpp + TestLinalgRankReduceContractionOps.cpp TestLinalgTransforms.cpp TestPadFusion.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp new file mode 100644 index 0000000000000..5ca27be30a687 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp @@ -0,0 +1,68 @@ +//===- TestLinalgRankReduceContractionOps.cpp - Test Linalg rank reduce +//contractions ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing rank reduing patterns for named +// contraction ops with unit dims. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +struct TestLinalgRankReduceContractionOps + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestLinalgRankReduceContractionOps) + + TestLinalgRankReduceContractionOps() = default; + TestLinalgRankReduceContractionOps( + const TestLinalgRankReduceContractionOps &pass) + : PassWrapper(pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { + return "test-linalg-rank-reduce-contraction-ops"; + } + StringRef getDescription() const final { + return "Test Linalg rank reduce contraction ops with unit dims"; + } + + void runOnOperation() override { + MLIRContext *context = &this->getContext(); + func::FuncOp funcOp = this->getOperation(); + + RewritePatternSet patterns(context); + linalg::populateContractionOpRankReducingPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(patterns)))) + return signalPassFailure(); + return; + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestLinalgRankReduceContractionOps() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 0e8b161d51345..d4ea7a9cae0d2 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -106,6 +106,7 @@ void registerTestLinalgDecomposeOps(); void registerTestLinalgDropUnitDims(); void registerTestLinalgElementwiseFusion(); void registerTestLinalgGreedyFusion(); +void registerTestLinalgRankReduceContractionOps(); void registerTestLinalgTransforms(); void registerTestLivenessAnalysisPass(); void registerTestLivenessPass(); @@ -235,6 +236,7 @@ void registerTestPasses() { mlir::test::registerTestLinalgDropUnitDims(); mlir::test::registerTestLinalgElementwiseFusion(); mlir::test::registerTestLinalgGreedyFusion(); + mlir::test::registerTestLinalgRankReduceContractionOps(); mlir::test::registerTestLinalgTransforms(); mlir::test::registerTestLivenessAnalysisPass(); mlir::test::registerTestLivenessPass(); From 28078405788b799cf64bcf5a7a4059c0eb739875 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 19 Jun 2024 10:24:39 -0500 Subject: [PATCH 05/14] refactor common logic into abstract base class --- .../Linalg/Transforms/DropUnitDims.cpp | 230 +++++++++++++----- 1 file changed, 166 insertions(+), 64 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 9248710d5afc9..07b0cdea40c92 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -814,10 +814,66 @@ void mlir::linalg::populateMoveInitOperandsToInputPattern( namespace { -template -struct BatchMatmulToMatmul : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(BatchOpTy batchMatmulOp, +static SmallVector +getReassociationsForTrailingDims(int64_t rank) { + SmallVector reassociation(rank - 1, {}); + if (rank > 1) { + reassociation[rank - 2] = + (rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1}; + for (int64_t i = 0; i < rank - 2; i++) + reassociation[i] = {i}; + } + return reassociation; +} + +static SmallVector +getReassociationsForLeadingDims(int64_t rank) { + SmallVector reassociation(rank - 1, {}); + if (rank > 1) { + reassociation[0] = + (rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1}; + for (int64_t i = 1; i < rank - 1; i++) + reassociation[i] = {i + rank - 2}; + } + return reassociation; +} + +static Value collapseLeadingSingletonDim(PatternRewriter &rewriter, Value val) { + auto valType = cast(val.getType()); + return collapseValue( + rewriter, val.getLoc(), val, valType.getShape().drop_front(1), + getReassociationsForLeadingDims(valType.getRank()), + ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape); +} + +static Value collapseTrailingSingletonDim(PatternRewriter &rewriter, + Value val) { + auto valType = cast(val.getType()); + return collapseValue( + rewriter, val.getLoc(), val, valType.getShape().drop_back(1), + getReassociationsForTrailingDims(valType.getRank()), + ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape); +} + +static Value expandLeadingSingletonDim(PatternRewriter &rewriter, Value val, + RankedTensorType expandedType) { + return rewriter.create( + val.getLoc(), expandedType, val, + getReassociationsForLeadingDims(expandedType.getRank())); +} + +static Value expandTrailingSingletonDim(PatternRewriter &rewriter, Value val, + RankedTensorType expandedType) { + return rewriter.create( + val.getLoc(), expandedType, val, + getReassociationsForTrailingDims(expandedType.getRank())); +} + +template +struct RankReduceContractionOps : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FromOpTy batchMatmulOp, PatternRewriter &rewriter) const override { auto loc = batchMatmulOp.getLoc(); @@ -830,47 +886,19 @@ struct BatchMatmulToMatmul : OpRewritePattern { auto rhs = inputs[1]; auto init = inits[0]; - auto lhsType = cast(lhs.getType()); - auto rhsType = cast(rhs.getType()); - auto initType = cast(init.getType()); - if (lhsType.getShape()[0] != 1 || rhsType.getShape()[0] != 1 || - initType.getShape()[0] != 1) - return rewriter.notifyMatchFailure(batchMatmulOp, "batch size is not 1"); - - auto results = batchMatmulOp.getResults(); - assert(results.size() < 2 && "expected at most one result"); - - SmallVector resultType; - if (results.size() == 1) { - auto oldResultType = cast(results[0].getType()); - resultType.push_back( - RankedTensorType::get(oldResultType.getShape().drop_front(1), - oldResultType.getElementType())); - } - - auto collapseSingletonDim = [&](Value val) -> Value { - SmallVector reassociation({{0, 1}}); - auto valType = cast(val.getType()); - for (auto i = 2; i < valType.getRank(); i++) - reassociation.push_back({i}); - if (isa(valType)) { - RankedTensorType collapsedType = RankedTensorType::get( - valType.getShape().drop_front(1), valType.getElementType()); - return rewriter.create(loc, collapsedType, val, - reassociation); - } - MemRefType collapsedType = MemRefType::get( - valType.getShape().drop_front(1), valType.getElementType()); - return rewriter.create(loc, collapsedType, val, - reassociation); - }; - - auto collapsedLhs = collapseSingletonDim(lhs); - auto collapsedRhs = collapseSingletonDim(rhs); - auto collapsedInit = collapseSingletonDim(init); - - auto collapsedOp = rewriter.create( - loc, resultType, ValueRange{collapsedLhs, collapsedRhs}, + if (!checkTypes(lhs, rhs, init)) + return rewriter.notifyMatchFailure(batchMatmulOp, + "no reducable dims found"); + + auto collapsedOperands = collapseOperands(rewriter, lhs, rhs, init); + auto collapsedLhs = collapsedOperands[0]; + auto collapsedRhs = collapsedOperands[1]; + auto collapsedInit = collapsedOperands[2]; + SmallVector collapsedResultTy; + if (isa(collapsedInit.getType())) + collapsedResultTy.push_back(collapsedInit.getType()); + auto collapsedOp = rewriter.create( + loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs}, ValueRange{collapsedInit}); for (auto attr : batchMatmulOp->getAttrs()) { if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) @@ -878,35 +906,109 @@ struct BatchMatmulToMatmul : OpRewritePattern { collapsedOp->setAttr(attr.getName(), attr.getValue()); } - if (results.size() < 1) { + auto results = batchMatmulOp.getResults(); + assert(results.size() < 2 && "expected at most one result"); + if (results.size() < 1) rewriter.replaceOp(batchMatmulOp, collapsedOp); - } else { - SmallVector reassociation({{0, 1}}); - auto resultType = cast(results[0].getType()); - for (auto i = 2; i < resultType.getRank(); i++) - reassociation.push_back({i}); - Value expandedResult = rewriter.create( - loc, resultType, collapsedOp.getResultTensors()[0], reassociation); - rewriter.replaceOp(batchMatmulOp, expandedResult); - } + else + rewriter.replaceOp( + batchMatmulOp, + expandResult(rewriter, collapsedOp.getResultTensors()[0], + cast(results[0].getType()))); return success(); } + + virtual bool checkTypes(Value lhs, Value rhs, Value init) const = 0; + virtual SmallVector collapseOperands(PatternRewriter &rewriter, + Value lhs, Value rhs, + Value init) const = 0; + virtual Value expandResult(PatternRewriter &rewriter, Value result, + RankedTensorType expandedType) const = 0; +}; + +template +struct RankReduceBatched : RankReduceContractionOps { + using RankReduceContractionOps::RankReduceContractionOps; + + bool checkTypes(Value lhs, Value rhs, Value init) const override { + auto lhsType = cast(lhs.getType()); + auto rhsType = cast(rhs.getType()); + auto initType = cast(init.getType()); + return lhsType.getShape()[0] == 1 && rhsType.getShape()[0] == 1 && + initType.getShape()[0] == 1; + } + + SmallVector collapseOperands(PatternRewriter &rewriter, Value lhs, + Value rhs, Value init) const override { + auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs); + auto collapsedRhs = collapseLeadingSingletonDim(rewriter, rhs); + auto collapsedInit = collapseLeadingSingletonDim(rewriter, init); + return SmallVector{collapsedLhs, collapsedRhs, collapsedInit}; + } + Value expandResult(PatternRewriter &rewriter, Value result, + RankedTensorType expandedType) const override { + return expandLeadingSingletonDim(rewriter, result, expandedType); + } +}; + +template +struct RankReduceMatmul : RankReduceContractionOps { + using RankReduceContractionOps::RankReduceContractionOps; + + static bool constexpr reduceLeading = + (std::is_same::value && + std::is_same::value) || + (std::is_same::value && + std::is_same::value); + + bool checkTypes(Value lhs, Value rhs, Value init) const override { + auto lhsType = cast(lhs.getType()); + auto rhsType = cast(rhs.getType()); + auto initType = cast(init.getType()); + if (reduceLeading) + return lhsType.getShape()[0] == 1 && initType.getShape()[0] == 1; + else + return rhsType.getShape().back() == 1 && initType.getShape().back() == 1; + } + + SmallVector collapseOperands(PatternRewriter &rewriter, Value lhs, + Value rhs, Value init) const override { + if (reduceLeading) { + auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs); + auto collapsedInit = collapseLeadingSingletonDim(rewriter, init); + return SmallVector{collapsedLhs, rhs, collapsedInit}; + } else { + auto collapsedRhs = collapseTrailingSingletonDim(rewriter, rhs); + auto collapsedInit = collapseTrailingSingletonDim(rewriter, init); + return SmallVector{lhs, collapsedRhs, collapsedInit}; + } + } + Value expandResult(PatternRewriter &rewriter, Value result, + RankedTensorType expandedType) const override { + if (reduceLeading) + return expandLeadingSingletonDim(rewriter, result, expandedType); + else + return expandTrailingSingletonDim(rewriter, result, expandedType); + } }; + } // namespace void mlir::linalg::populateContractionOpRankReducingPatterns( RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add>(context); - patterns - .add>( - context); - patterns - .add>( - context); - patterns.add>(context); - patterns.add>(context); + patterns.add>(context); + patterns.add>( + context); + patterns.add>( + context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); } namespace { From 679192b56ac231fce54824d0189a52f39ea2a63b Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 19 Jun 2024 17:07:46 -0500 Subject: [PATCH 06/14] add regression test --- .../Linalg/rank-reduce-contraction-ops.mlir | 197 ++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir new file mode 100644 index 0000000000000..279a1d52ae72b --- /dev/null +++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir @@ -0,0 +1,197 @@ +//RUN: mlir-opt -test-linalg-rank-reduce-contraction-ops --canonicalize -split-input-file %s | FileCheck %s + +func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> { + // CHECK-LABEL: @singleton_batch_matmul_tensor + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] + // CHECK-NEXT: %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]] + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]] + // CHECK-NEXT: return %[[RES]] + %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>) + outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> + return %1 : tensor<1x?x?xf32> +} + +// ----- + +func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) { + // CHECK-LABEL: @singleton_batch_matmul_memref + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref, memref) outs(%[[COLLAPSED_INIT]] : memref) + // CHECK-NEXT: return + linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>) + outs(%arg2 : memref<1x?x?xf32>) + return +} + +// ----- + +func.func @singleton_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { + // CHECK-LABEL: @singleton_batch_matvec + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] + // CHECK-NEXT: return %[[RES]] + %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>) + outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32> + return %1 : tensor<1x?xf32> +} + +// ----- + +func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { + // CHECK-LABEL: @singleton_batch_vecmat + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] + // CHECK-NEXT: return %[[RES]] + %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>) + outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32> + return %1 : tensor<1x?xf32> +} + +// ----- + +func.func @singleton_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) { + // CHECK-LABEL: @singleton_batchmatmul_transpose_a + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>) + // CHECK-NEXT: return + linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>) + return +} + +// ----- + +func.func @singleton_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) { + // CHECK-LABEL: @singleton_batchmatmul_transpose_b + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>) + // CHECK-NEXT: return + linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>) + return +} + +// ----- + +func.func @matmul_to_vecmat(%arg0: memref<1x?xf32>, %arg1: memref, %arg2: memref<1x?xf32>) { + // CHECK-LABEL: @matmul_to_vecmat + // CHECK: linalg.vecmat + linalg.matmul ins(%arg0, %arg1: memref<1x?xf32>, memref) outs(%arg2: memref<1x?xf32>) + return +} + +// ----- + +func.func @batch_matmul_to_vecmat(%arg0: memref<1x1x?xf32>, %arg1: memref<1x?x?xf32>, %arg2: memref<1x1x?xf32>) { + // CHECK-LABEL: @batch_matmul_to_vecmat + // CHECK: linalg.vecmat + linalg.batch_matmul ins(%arg0, %arg1: memref<1x1x?xf32>, memref<1x?x?xf32>) outs(%arg2: memref<1x1x?xf32>) + return +} + +// ----- + +func.func @matvec_to_dot(%arg0: memref<1x?xf32>, %arg1: memref, %arg2: memref<1xf32>) { + // CHECK-LABEL: @matvec_to_dot + // CHECK: linalg.dot + linalg.matvec ins(%arg0, %arg1: memref<1x?xf32>, memref) outs(%arg2: memref<1xf32>) + return +} + +// ----- + +func.func @vecmat_to_dot(%arg0: memref, %arg1: memref, %arg2: memref<1xf32>) { + // CHECK-LABEL: @vecmat_to_dot + // CHECK: linalg.dot + linalg.vecmat ins(%arg0, %arg1: memref, memref) outs(%arg2: memref<1xf32>) + return +} + +// ----- + +func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor, %arg2: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-LABEL: @matvec_to_dot_tensor + // CHECK: linalg.dot + %0 = linalg.matvec ins(%arg0, %arg1: tensor<1x?xf32>, tensor) outs(%arg2: tensor<1xf32>) -> tensor<1xf32> + return %0 : tensor<1xf32> +} + +// ----- + +func.func @matmul_to_matvec_tensor(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-LABEL: @matmul_to_matvec_tensor + // CHECK: linalg.matvec + %0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) outs(%arg2: tensor) -> tensor + return %0 : tensor +} + +// ----- + +func.func @matmul_to_matvec(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK-LABEL: @matmul_to_matvec + // CHECK: linalg.matvec + linalg.matmul ins(%arg0, %arg1: memref, memref) outs(%arg2: memref) + return +} + +// ----- + +func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> { + // CHECK-LABEL: @nonsingleton_batch_matmul + // CHECK-NOT: collapse_shape + // CHECK: linalg.batch_matmul + // CHECK-NOT: expand_shape + %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>) + outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32> + return %1 : tensor<2x?x?xf32> +} + +// ----- + +func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2: tensor) -> tensor { + // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic + // CHECK-NOT: collapse_shape + // CHECK: linalg.batch_matmul + // CHECK-NOT: expand_shape + %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %1 : tensor +} From ce02e9d206cf718927854338636b6f357c62101c Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 19 Jun 2024 18:13:39 -0500 Subject: [PATCH 07/14] flesh out some tests --- .../Linalg/rank-reduce-contraction-ops.mlir | 69 ++++++++++++------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir index 279a1d52ae72b..79003670d2726 100644 --- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir +++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir @@ -111,15 +111,51 @@ func.func @singleton_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: me // ----- -func.func @matmul_to_vecmat(%arg0: memref<1x?xf32>, %arg1: memref, %arg2: memref<1x?xf32>) { - // CHECK-LABEL: @matmul_to_vecmat - // CHECK: linalg.vecmat - linalg.matmul ins(%arg0, %arg1: memref<1x?xf32>, memref) outs(%arg2: memref<1x?xf32>) +func.func @matmul_to_matvec_tensor(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-LABEL: @matmul_to_matvec_tensor + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]] + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1] + // CHECK-NEXT: return %[[RES]] + %0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) outs(%arg2: tensor) -> tensor + return %0 : tensor +} + +// ----- + +func.func @matmul_to_matvec(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK-LABEL: @matmul_to_matvec + // CHECK: linalg.matvec + linalg.matmul ins(%arg0, %arg1: memref, memref) outs(%arg2: memref) return } // ----- +func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { + // CHECK-LABEL: @matmul_to_vecmat + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] + // CHECK-NEXT: return %[[RES]] + %0 = linalg.matmul ins(%arg0, %arg1: tensor<1x?xf32>, tensor) outs(%arg2: tensor<1x?xf32>) -> tensor<1x?xf32> + return %0 : tensor<1x?xf32> +} + +// ----- + func.func @batch_matmul_to_vecmat(%arg0: memref<1x1x?xf32>, %arg1: memref<1x?x?xf32>, %arg2: memref<1x1x?xf32>) { // CHECK-LABEL: @batch_matmul_to_vecmat // CHECK: linalg.vecmat @@ -131,7 +167,12 @@ func.func @batch_matmul_to_vecmat(%arg0: memref<1x1x?xf32>, %arg1: memref<1x?x?x func.func @matvec_to_dot(%arg0: memref<1x?xf32>, %arg1: memref, %arg2: memref<1xf32>) { // CHECK-LABEL: @matvec_to_dot - // CHECK: linalg.dot + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] [] + // CHECK-NEXT: linalg.dot ins(%[[COLLAPSED_LHS]], %[[RHS]] : memref, memref) outs(%[[COLLAPSED_INIT]] : memref) linalg.matvec ins(%arg0, %arg1: memref<1x?xf32>, memref) outs(%arg2: memref<1xf32>) return } @@ -156,24 +197,6 @@ func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor, %a // ----- -func.func @matmul_to_matvec_tensor(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-LABEL: @matmul_to_matvec_tensor - // CHECK: linalg.matvec - %0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) outs(%arg2: tensor) -> tensor - return %0 : tensor -} - -// ----- - -func.func @matmul_to_matvec(%arg0: memref, %arg1: memref, %arg2: memref) { - // CHECK-LABEL: @matmul_to_matvec - // CHECK: linalg.matvec - linalg.matmul ins(%arg0, %arg1: memref, memref) outs(%arg2: memref) - return -} - -// ----- - func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> { // CHECK-LABEL: @nonsingleton_batch_matmul // CHECK-NOT: collapse_shape From 9b98efdf911f806f3ae15f2f9e74d99af1d8775e Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 19 Jun 2024 20:13:19 -0500 Subject: [PATCH 08/14] support transpose matmul conversion --- .../Linalg/Transforms/DropUnitDims.cpp | 119 ++++++++++-------- .../Linalg/rank-reduce-contraction-ops.mlir | 18 +++ 2 files changed, 84 insertions(+), 53 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 07b0cdea40c92..d9230c6127e00 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -812,6 +812,30 @@ void mlir::linalg::populateMoveInitOperandsToInputPattern( patterns.add(patterns.getContext()); } +namespace { +/// Pass that removes unit-extent dims within generic ops. +struct LinalgFoldUnitExtentDimsPass + : public impl::LinalgFoldUnitExtentDimsPassBase< + LinalgFoldUnitExtentDimsPass> { + using impl::LinalgFoldUnitExtentDimsPassBase< + LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase; + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + RewritePatternSet patterns(context); + ControlDropUnitDims options; + if (useRankReducingSlices) { + options.rankReductionStrategy = linalg::ControlDropUnitDims:: + RankReductionStrategy::ExtractInsertSlice; + } + linalg::populateFoldUnitExtentDimsPatterns(patterns, options); + populateMoveInitOperandsToInputPattern(patterns); + (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); + } +}; + +} // namespace + namespace { static SmallVector @@ -855,20 +879,6 @@ static Value collapseTrailingSingletonDim(PatternRewriter &rewriter, ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape); } -static Value expandLeadingSingletonDim(PatternRewriter &rewriter, Value val, - RankedTensorType expandedType) { - return rewriter.create( - val.getLoc(), expandedType, val, - getReassociationsForLeadingDims(expandedType.getRank())); -} - -static Value expandTrailingSingletonDim(PatternRewriter &rewriter, Value val, - RankedTensorType expandedType) { - return rewriter.create( - val.getLoc(), expandedType, val, - getReassociationsForTrailingDims(expandedType.getRank())); -} - template struct RankReduceContractionOps : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -948,7 +958,9 @@ struct RankReduceBatched : RankReduceContractionOps { } Value expandResult(PatternRewriter &rewriter, Value result, RankedTensorType expandedType) const override { - return expandLeadingSingletonDim(rewriter, result, expandedType); + return rewriter.create( + result.getLoc(), expandedType, result, + getReassociationsForLeadingDims(expandedType.getRank())); } }; @@ -956,9 +968,15 @@ template struct RankReduceMatmul : RankReduceContractionOps { using RankReduceContractionOps::RankReduceContractionOps; - static bool constexpr reduceLeading = + static bool constexpr isTranspose = + std::is_same::value || + std::is_same::value; + + static bool constexpr reduceLeft = (std::is_same::value && std::is_same::value) || + (std::is_same::value && + std::is_same::value) || (std::is_same::value && std::is_same::value); @@ -966,30 +984,47 @@ struct RankReduceMatmul : RankReduceContractionOps { auto lhsType = cast(lhs.getType()); auto rhsType = cast(rhs.getType()); auto initType = cast(init.getType()); - if (reduceLeading) - return lhsType.getShape()[0] == 1 && initType.getShape()[0] == 1; + int constexpr offset = (int)isTranspose; + if (reduceLeft) + return lhsType.getShape().begin()[offset] == 1 && + initType.getShape().begin()[offset] == 1; else - return rhsType.getShape().back() == 1 && initType.getShape().back() == 1; + return rhsType.getShape().rbegin()[offset] == 1 && + initType.getShape().rbegin()[offset] == 1; } SmallVector collapseOperands(PatternRewriter &rewriter, Value lhs, Value rhs, Value init) const override { - if (reduceLeading) { - auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs); - auto collapsedInit = collapseLeadingSingletonDim(rewriter, init); - return SmallVector{collapsedLhs, rhs, collapsedInit}; + if (reduceLeft) { + if (isTranspose) { + lhs = collapseTrailingSingletonDim(rewriter, lhs); + init = collapseTrailingSingletonDim(rewriter, init); + } else { + lhs = collapseLeadingSingletonDim(rewriter, lhs); + init = collapseLeadingSingletonDim(rewriter, init); + } } else { - auto collapsedRhs = collapseTrailingSingletonDim(rewriter, rhs); - auto collapsedInit = collapseTrailingSingletonDim(rewriter, init); - return SmallVector{lhs, collapsedRhs, collapsedInit}; + if (isTranspose) { + rhs = collapseLeadingSingletonDim(rewriter, rhs); + init = collapseLeadingSingletonDim(rewriter, init); + } else { + rhs = collapseTrailingSingletonDim(rewriter, rhs); + init = collapseTrailingSingletonDim(rewriter, init); + } } + return SmallVector{lhs, rhs, init}; } + Value expandResult(PatternRewriter &rewriter, Value result, RankedTensorType expandedType) const override { - if (reduceLeading) - return expandLeadingSingletonDim(rewriter, result, expandedType); + if (reduceLeft) + return rewriter.create( + result.getLoc(), expandedType, result, + getReassociationsForLeadingDims(expandedType.getRank())); else - return expandTrailingSingletonDim(rewriter, result, expandedType); + return rewriter.create( + result.getLoc(), expandedType, result, + getReassociationsForTrailingDims(expandedType.getRank())); } }; @@ -1007,30 +1042,8 @@ void mlir::linalg::populateContractionOpRankReducingPatterns( patterns.add>(context); patterns.add>(context); patterns.add>(context); + patterns.add>(context); + patterns.add>(context); patterns.add>(context); patterns.add>(context); } - -namespace { -/// Pass that removes unit-extent dims within generic ops. -struct LinalgFoldUnitExtentDimsPass - : public impl::LinalgFoldUnitExtentDimsPassBase< - LinalgFoldUnitExtentDimsPass> { - using impl::LinalgFoldUnitExtentDimsPassBase< - LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase; - void runOnOperation() override { - Operation *op = getOperation(); - MLIRContext *context = op->getContext(); - RewritePatternSet patterns(context); - ControlDropUnitDims options; - if (useRankReducingSlices) { - options.rankReductionStrategy = linalg::ControlDropUnitDims:: - RankReductionStrategy::ExtractInsertSlice; - } - linalg::populateFoldUnitExtentDimsPatterns(patterns, options); - populateMoveInitOperandsToInputPattern(patterns); - (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); - } -}; - -} // namespace diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir index 79003670d2726..0548f3f860a89 100644 --- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir +++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir @@ -197,6 +197,24 @@ func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor, %a // ----- +func.func @matmul_transpose_a_to_vecmat(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK-LABEL: @matmul_transpose_a_to_vecmat + // CHECK: linalg.vecmat + linalg.matmul_transpose_a ins(%arg0, %arg1: memref, memref) outs(%arg2: memref) + return +} + +// ----- + +func.func @matmul_transpose_b_to_matvec(%arg0: memref, %arg1: memref<1x?xf32>, %arg2: memref<1x?xf32>) { + // CHECK-LABEL: @matmul_transpose_b_to_matvec + // CHECK: linalg.matvec + linalg.matmul_transpose_b ins(%arg0, %arg1: memref, memref<1x?xf32>) outs(%arg2: memref<1x?xf32>) + return +} + +// ----- + func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> { // CHECK-LABEL: @nonsingleton_batch_matmul // CHECK-NOT: collapse_shape From cbf8eddc377b7970cd12846a8d95ea740b5e3bf1 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 19 Jun 2024 20:28:48 -0500 Subject: [PATCH 09/14] const conditionals --- mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index d9230c6127e00..327c46b965c87 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -985,7 +985,7 @@ struct RankReduceMatmul : RankReduceContractionOps { auto rhsType = cast(rhs.getType()); auto initType = cast(init.getType()); int constexpr offset = (int)isTranspose; - if (reduceLeft) + if constexpr (reduceLeft) return lhsType.getShape().begin()[offset] == 1 && initType.getShape().begin()[offset] == 1; else @@ -995,8 +995,8 @@ struct RankReduceMatmul : RankReduceContractionOps { SmallVector collapseOperands(PatternRewriter &rewriter, Value lhs, Value rhs, Value init) const override { - if (reduceLeft) { - if (isTranspose) { + if constexpr (reduceLeft) { + if constexpr (isTranspose) { lhs = collapseTrailingSingletonDim(rewriter, lhs); init = collapseTrailingSingletonDim(rewriter, init); } else { @@ -1004,7 +1004,7 @@ struct RankReduceMatmul : RankReduceContractionOps { init = collapseLeadingSingletonDim(rewriter, init); } } else { - if (isTranspose) { + if constexpr (isTranspose) { rhs = collapseLeadingSingletonDim(rewriter, rhs); init = collapseLeadingSingletonDim(rewriter, init); } else { @@ -1017,7 +1017,7 @@ struct RankReduceMatmul : RankReduceContractionOps { Value expandResult(PatternRewriter &rewriter, Value result, RankedTensorType expandedType) const override { - if (reduceLeft) + if constexpr (reduceLeft) return rewriter.create( result.getLoc(), expandedType, result, getReassociationsForLeadingDims(expandedType.getRank())); From 29ac128442379767bca2583409f58bd605bb06d1 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 20 Jun 2024 14:19:38 -0500 Subject: [PATCH 10/14] refactor and add more patterns --- .../Linalg/Transforms/DropUnitDims.cpp | 143 +++++++++++------- 1 file changed, 85 insertions(+), 58 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 327c46b965c87..26a85e8225b24 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -839,43 +839,31 @@ struct LinalgFoldUnitExtentDimsPass namespace { static SmallVector -getReassociationsForTrailingDims(int64_t rank) { - SmallVector reassociation(rank - 1, {}); - if (rank > 1) { - reassociation[rank - 2] = - (rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1}; - for (int64_t i = 0; i < rank - 2; i++) - reassociation[i] = {i}; - } - return reassociation; -} - -static SmallVector -getReassociationsForLeadingDims(int64_t rank) { - SmallVector reassociation(rank - 1, {}); - if (rank > 1) { - reassociation[0] = - (rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1}; - for (int64_t i = 1; i < rank - 1; i++) - reassociation[i] = {i + rank - 2}; +getReassociationForReshapeAtDim(int64_t rank, int64_t pos, + bool fromRight = false) { + SmallVector reassociation(rank - 1, {0, 1}); + if (rank > 2) { + int64_t offsetPos = pos - (int64_t)fromRight; + for (int64_t i = 0; i < rank - 1; i++) { + if (i == offsetPos) + reassociation[i] = ReassociationIndices{i, i + 1}; + else if (i < offsetPos) + reassociation[i] = ReassociationIndices{i}; + else + reassociation[i] = ReassociationIndices{i + offsetPos + 1}; + } } return reassociation; } -static Value collapseLeadingSingletonDim(PatternRewriter &rewriter, Value val) { +static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val, + int64_t pos, bool fromRight = false) { auto valType = cast(val.getType()); + SmallVector collapsedShape(valType.getShape()); + collapsedShape.erase(collapsedShape.begin() + pos); return collapseValue( - rewriter, val.getLoc(), val, valType.getShape().drop_front(1), - getReassociationsForLeadingDims(valType.getRank()), - ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape); -} - -static Value collapseTrailingSingletonDim(PatternRewriter &rewriter, - Value val) { - auto valType = cast(val.getType()); - return collapseValue( - rewriter, val.getLoc(), val, valType.getShape().drop_back(1), - getReassociationsForTrailingDims(valType.getRank()), + rewriter, val.getLoc(), val, collapsedShape, + getReassociationForReshapeAtDim(valType.getRank(), pos, fromRight), ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape); } @@ -951,16 +939,16 @@ struct RankReduceBatched : RankReduceContractionOps { SmallVector collapseOperands(PatternRewriter &rewriter, Value lhs, Value rhs, Value init) const override { - auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs); - auto collapsedRhs = collapseLeadingSingletonDim(rewriter, rhs); - auto collapsedInit = collapseLeadingSingletonDim(rewriter, init); + auto collapsedLhs = collapseSingletonDimAt(rewriter, lhs, 0); + auto collapsedRhs = collapseSingletonDimAt(rewriter, rhs, 0); + auto collapsedInit = collapseSingletonDimAt(rewriter, init, 0); return SmallVector{collapsedLhs, collapsedRhs, collapsedInit}; } Value expandResult(PatternRewriter &rewriter, Value result, RankedTensorType expandedType) const override { - return rewriter.create( - result.getLoc(), expandedType, result, - getReassociationsForLeadingDims(expandedType.getRank())); + return rewriter.create( + result.getLoc(), expandedType, result, + getReassociationForReshapeAtDim(expandedType.getRank(), 0)); } }; @@ -968,11 +956,32 @@ template struct RankReduceMatmul : RankReduceContractionOps { using RankReduceContractionOps::RankReduceContractionOps; + static bool constexpr isBatched = + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value; + + static bool constexpr isLHSTransposed = + std::is_same::value || + std::is_same::value; + + static bool constexpr isRHSTransposed = + std::is_same::value || + std::is_same::value; + static bool constexpr isTranspose = + std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value; static bool constexpr reduceLeft = + (std::is_same::value && + std::is_same::value) || + (std::is_same::value && + std::is_same::value) || (std::is_same::value && std::is_same::value) || (std::is_same::value && @@ -980,37 +989,41 @@ struct RankReduceMatmul : RankReduceContractionOps { (std::is_same::value && std::is_same::value); + static int constexpr lhsTransposeOffset = (int)isLHSTransposed; + static int constexpr rhsTransposeOffset = (int)isRHSTransposed; + static int constexpr batchOffset = (int)isBatched; + bool checkTypes(Value lhs, Value rhs, Value init) const override { auto lhsType = cast(lhs.getType()); auto rhsType = cast(rhs.getType()); auto initType = cast(init.getType()); - int constexpr offset = (int)isTranspose; if constexpr (reduceLeft) - return lhsType.getShape().begin()[offset] == 1 && - initType.getShape().begin()[offset] == 1; + return lhsType.getShape().begin()[lhsTransposeOffset + batchOffset] == + 1 && + initType.getShape().begin()[lhsTransposeOffset + batchOffset] == 1; else - return rhsType.getShape().rbegin()[offset] == 1 && - initType.getShape().rbegin()[offset] == 1; + return rhsType.getShape().rbegin()[rhsTransposeOffset] == 1 && + initType.getShape().rbegin()[rhsTransposeOffset] == 1; } SmallVector collapseOperands(PatternRewriter &rewriter, Value lhs, Value rhs, Value init) const override { + if constexpr (reduceLeft) { - if constexpr (isTranspose) { - lhs = collapseTrailingSingletonDim(rewriter, lhs); - init = collapseTrailingSingletonDim(rewriter, init); - } else { - lhs = collapseLeadingSingletonDim(rewriter, lhs); - init = collapseLeadingSingletonDim(rewriter, init); - } + lhs = collapseSingletonDimAt(rewriter, lhs, + lhsTransposeOffset + batchOffset, + /*fromRight=*/isLHSTransposed); + init = collapseSingletonDimAt(rewriter, init, + lhsTransposeOffset + batchOffset, + /*fromRight*/ isLHSTransposed); } else { - if constexpr (isTranspose) { - rhs = collapseLeadingSingletonDim(rewriter, rhs); - init = collapseLeadingSingletonDim(rewriter, init); - } else { - rhs = collapseTrailingSingletonDim(rewriter, rhs); - init = collapseTrailingSingletonDim(rewriter, init); - } + auto rhsRank = cast(rhs.getType()).getRank(); + auto initRank = cast(init.getType()).getRank(); + rhs = collapseSingletonDimAt( + rewriter, rhs, rhsRank - rhsTransposeOffset - 1, /*fromRight=*/true); + init = collapseSingletonDimAt(rewriter, init, + initRank - rhsTransposeOffset - 1, + /*fromRight=*/true); } return SmallVector{lhs, rhs, init}; } @@ -1020,11 +1033,13 @@ struct RankReduceMatmul : RankReduceContractionOps { if constexpr (reduceLeft) return rewriter.create( result.getLoc(), expandedType, result, - getReassociationsForLeadingDims(expandedType.getRank())); + getReassociationForReshapeAtDim(expandedType.getRank(), 0)); else return rewriter.create( result.getLoc(), expandedType, result, - getReassociationsForTrailingDims(expandedType.getRank())); + getReassociationForReshapeAtDim(expandedType.getRank(), + expandedType.getRank() - 1, + /*fromRight=*/true)); } }; @@ -1033,6 +1048,7 @@ struct RankReduceMatmul : RankReduceContractionOps { void mlir::linalg::populateContractionOpRankReducingPatterns( RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); + // Unbatching patterns for unit batch size patterns.add>(context); patterns.add>( context); @@ -1040,10 +1056,21 @@ void mlir::linalg::populateContractionOpRankReducingPatterns( context); patterns.add>(context); patterns.add>(context); + + // Non-batch rank 1 reducing patterns patterns.add>(context); patterns.add>(context); patterns.add>(context); patterns.add>(context); + // Batch rank 1 reducing patterns + patterns.add>(context); + patterns.add>(context); + patterns.add>( + context); + patterns.add>( + context); + + // Non-batch rank 0 reducing patterns patterns.add>(context); patterns.add>(context); } From 7851ae12b78d9b4fbb947de3a6c9fa1654fe0ab6 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 20 Jun 2024 19:13:36 -0500 Subject: [PATCH 11/14] more refactor --- .../Linalg/Transforms/DropUnitDims.cpp | 229 +++++++++--------- .../Linalg/rank-reduce-contraction-ops.mlir | 55 ++--- 2 files changed, 141 insertions(+), 143 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 26a85e8225b24..771b40d6a2001 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -839,31 +839,32 @@ struct LinalgFoldUnitExtentDimsPass namespace { static SmallVector -getReassociationForReshapeAtDim(int64_t rank, int64_t pos, - bool fromRight = false) { +getReassociationForReshapeAtDim(int64_t rank, int64_t pos) { SmallVector reassociation(rank - 1, {0, 1}); + auto lastDim = pos == rank - 1; if (rank > 2) { - int64_t offsetPos = pos - (int64_t)fromRight; for (int64_t i = 0; i < rank - 1; i++) { - if (i == offsetPos) + if (i == pos || (lastDim && i == pos - 1)) reassociation[i] = ReassociationIndices{i, i + 1}; - else if (i < offsetPos) + else if (i < pos) reassociation[i] = ReassociationIndices{i}; else - reassociation[i] = ReassociationIndices{i + offsetPos + 1}; + reassociation[i] = ReassociationIndices{i + 1}; } } return reassociation; } static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val, - int64_t pos, bool fromRight = false) { + int64_t pos) { + if (pos < 0) + return val; auto valType = cast(val.getType()); SmallVector collapsedShape(valType.getShape()); collapsedShape.erase(collapsedShape.begin() + pos); return collapseValue( rewriter, val.getLoc(), val, collapsedShape, - getReassociationForReshapeAtDim(valType.getRank(), pos, fromRight), + getReassociationForReshapeAtDim(valType.getRank(), pos), ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape); } @@ -871,24 +872,52 @@ template struct RankReduceContractionOps : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(FromOpTy batchMatmulOp, + SmallVector + collapseOperands(PatternRewriter &rewriter, ArrayRef operands, + ArrayRef operandCollapseDims) const { + assert(operandCollapseDims.size() == 3 && operands.size() == 3 && + "expected 3 operands and dims"); + return llvm::to_vector(llvm::map_range( + llvm::zip(operands, operandCollapseDims), [&](auto pair) { + return collapseSingletonDimAt(rewriter, std::get<0>(pair), + std::get<1>(pair)); + })); + } + + Value expandResult(PatternRewriter &rewriter, Value result, + RankedTensorType expandedType, int64_t dim) const { + return rewriter.create( + result.getLoc(), expandedType, result, + getReassociationForReshapeAtDim(expandedType.getRank(), dim)); + } + + LogicalResult matchAndRewrite(FromOpTy contractionOp, PatternRewriter &rewriter) const override { - auto loc = batchMatmulOp.getLoc(); - auto inputs = batchMatmulOp.getDpsInputs(); - auto inits = batchMatmulOp.getDpsInits(); + auto loc = contractionOp.getLoc(); + auto inputs = contractionOp.getDpsInputs(); + auto inits = contractionOp.getDpsInits(); if (inputs.size() != 2 || inits.size() != 1) - return rewriter.notifyMatchFailure(batchMatmulOp, + return rewriter.notifyMatchFailure(contractionOp, "expected 2 inputs and 1 init"); auto lhs = inputs[0]; auto rhs = inputs[1]; auto init = inits[0]; + SmallVector operands{lhs, rhs, init}; - if (!checkTypes(lhs, rhs, init)) - return rewriter.notifyMatchFailure(batchMatmulOp, + auto maybeContractionDims = inferContractionDims(contractionOp); + if (failed(maybeContractionDims)) + return rewriter.notifyMatchFailure(contractionOp, + "could not infer contraction dims"); + + auto contractionDims = maybeContractionDims.value(); + SmallVector operandUnitDims; + if (failed(getOperandUnitDims(contractionOp, operandUnitDims))) + return rewriter.notifyMatchFailure(contractionOp, "no reducable dims found"); - auto collapsedOperands = collapseOperands(rewriter, lhs, rhs, init); + auto collapsedOperands = + collapseOperands(rewriter, operands, operandUnitDims); auto collapsedLhs = collapsedOperands[0]; auto collapsedRhs = collapsedOperands[1]; auto collapsedInit = collapsedOperands[2]; @@ -898,57 +927,63 @@ struct RankReduceContractionOps : OpRewritePattern { auto collapsedOp = rewriter.create( loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs}, ValueRange{collapsedInit}); - for (auto attr : batchMatmulOp->getAttrs()) { + for (auto attr : contractionOp->getAttrs()) { if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) continue; collapsedOp->setAttr(attr.getName(), attr.getValue()); } - auto results = batchMatmulOp.getResults(); + auto results = contractionOp.getResults(); assert(results.size() < 2 && "expected at most one result"); if (results.size() < 1) - rewriter.replaceOp(batchMatmulOp, collapsedOp); + rewriter.replaceOp(contractionOp, collapsedOp); else rewriter.replaceOp( - batchMatmulOp, + contractionOp, expandResult(rewriter, collapsedOp.getResultTensors()[0], - cast(results[0].getType()))); + cast(results[0].getType()), + operandUnitDims[2])); return success(); } - virtual bool checkTypes(Value lhs, Value rhs, Value init) const = 0; - virtual SmallVector collapseOperands(PatternRewriter &rewriter, - Value lhs, Value rhs, - Value init) const = 0; - virtual Value expandResult(PatternRewriter &rewriter, Value result, - RankedTensorType expandedType) const = 0; + virtual LogicalResult + getOperandUnitDims(LinalgOp op, + SmallVectorImpl &operandUnitDindices) const = 0; }; template struct RankReduceBatched : RankReduceContractionOps { using RankReduceContractionOps::RankReduceContractionOps; - bool checkTypes(Value lhs, Value rhs, Value init) const override { - auto lhsType = cast(lhs.getType()); - auto rhsType = cast(rhs.getType()); - auto initType = cast(init.getType()); - return lhsType.getShape()[0] == 1 && rhsType.getShape()[0] == 1 && - initType.getShape()[0] == 1; - } + LogicalResult getOperandUnitDims( + LinalgOp op, + SmallVectorImpl &operandUnitDindices) const override { + auto inputs = op.getDpsInputs(); + auto inits = op.getDpsInits(); + if (inputs.size() != 2 || inits.size() != 1) + return failure(); - SmallVector collapseOperands(PatternRewriter &rewriter, Value lhs, - Value rhs, Value init) const override { - auto collapsedLhs = collapseSingletonDimAt(rewriter, lhs, 0); - auto collapsedRhs = collapseSingletonDimAt(rewriter, rhs, 0); - auto collapsedInit = collapseSingletonDimAt(rewriter, init, 0); - return SmallVector{collapsedLhs, collapsedRhs, collapsedInit}; - } - Value expandResult(PatternRewriter &rewriter, Value result, - RankedTensorType expandedType) const override { - return rewriter.create( - result.getLoc(), expandedType, result, - getReassociationForReshapeAtDim(expandedType.getRank(), 0)); + auto maybeContractionDims = inferContractionDims(op); + if (failed(maybeContractionDims)) + return failure(); + auto contractionDims = maybeContractionDims.value(); + + if (contractionDims.batch.size() != 1) + return failure(); + auto batchDim = contractionDims.batch[0]; + SmallVector, 2> bOperands; + op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands); + if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) { + return cast(std::get<0>(pair).getType()) + .getShape()[std::get<1>(pair)] != 1; + })) + return failure(); + + operandUnitDindices = SmallVector{std::get<1>(bOperands[0]), + std::get<1>(bOperands[1]), + std::get<1>(bOperands[2])}; + return success(); } }; @@ -956,27 +991,6 @@ template struct RankReduceMatmul : RankReduceContractionOps { using RankReduceContractionOps::RankReduceContractionOps; - static bool constexpr isBatched = - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value; - - static bool constexpr isLHSTransposed = - std::is_same::value || - std::is_same::value; - - static bool constexpr isRHSTransposed = - std::is_same::value || - std::is_same::value; - - static bool constexpr isTranspose = - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value; - static bool constexpr reduceLeft = (std::is_same::value && std::is_same::value) || @@ -989,57 +1003,44 @@ struct RankReduceMatmul : RankReduceContractionOps { (std::is_same::value && std::is_same::value); - static int constexpr lhsTransposeOffset = (int)isLHSTransposed; - static int constexpr rhsTransposeOffset = (int)isRHSTransposed; - static int constexpr batchOffset = (int)isBatched; - - bool checkTypes(Value lhs, Value rhs, Value init) const override { - auto lhsType = cast(lhs.getType()); - auto rhsType = cast(rhs.getType()); - auto initType = cast(init.getType()); - if constexpr (reduceLeft) - return lhsType.getShape().begin()[lhsTransposeOffset + batchOffset] == - 1 && - initType.getShape().begin()[lhsTransposeOffset + batchOffset] == 1; - else - return rhsType.getShape().rbegin()[rhsTransposeOffset] == 1 && - initType.getShape().rbegin()[rhsTransposeOffset] == 1; - } - - SmallVector collapseOperands(PatternRewriter &rewriter, Value lhs, - Value rhs, Value init) const override { + LogicalResult getOperandUnitDims( + LinalgOp op, + SmallVectorImpl &operandUnitDindices) const override { + auto maybeContractionDims = inferContractionDims(op); + if (failed(maybeContractionDims)) + return failure(); + auto contractionDims = maybeContractionDims.value(); if constexpr (reduceLeft) { - lhs = collapseSingletonDimAt(rewriter, lhs, - lhsTransposeOffset + batchOffset, - /*fromRight=*/isLHSTransposed); - init = collapseSingletonDimAt(rewriter, init, - lhsTransposeOffset + batchOffset, - /*fromRight*/ isLHSTransposed); + auto m = contractionDims.m[0]; + SmallVector, 2> mOperands; + op.mapIterationSpaceDimToAllOperandDims(m, mOperands); + if (mOperands.size() != 2) + return failure(); + if (llvm::all_of(mOperands, [](auto pair) { + return cast(std::get<0>(pair).getType()) + .getShape()[std::get<1>(pair)] == 1; + })) { + operandUnitDindices = SmallVector{ + std::get<1>(mOperands[0]), -1, std::get<1>(mOperands[1])}; + return success(); + } } else { - auto rhsRank = cast(rhs.getType()).getRank(); - auto initRank = cast(init.getType()).getRank(); - rhs = collapseSingletonDimAt( - rewriter, rhs, rhsRank - rhsTransposeOffset - 1, /*fromRight=*/true); - init = collapseSingletonDimAt(rewriter, init, - initRank - rhsTransposeOffset - 1, - /*fromRight=*/true); + auto n = contractionDims.n[0]; + SmallVector, 2> nOperands; + op.mapIterationSpaceDimToAllOperandDims(n, nOperands); + if (nOperands.size() != 2) + return failure(); + if (llvm::all_of(nOperands, [](auto pair) { + return cast(std::get<0>(pair).getType()) + .getShape()[std::get<1>(pair)] == 1; + })) { + operandUnitDindices = SmallVector{ + -1, std::get<1>(nOperands[0]), std::get<1>(nOperands[1])}; + return success(); + } } - return SmallVector{lhs, rhs, init}; - } - - Value expandResult(PatternRewriter &rewriter, Value result, - RankedTensorType expandedType) const override { - if constexpr (reduceLeft) - return rewriter.create( - result.getLoc(), expandedType, result, - getReassociationForReshapeAtDim(expandedType.getRank(), 0)); - else - return rewriter.create( - result.getLoc(), expandedType, result, - getReassociationForReshapeAtDim(expandedType.getRank(), - expandedType.getRank() - 1, - /*fromRight=*/true)); + return failure(); } }; diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir index 0548f3f860a89..70568be99474e 100644 --- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir +++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir @@ -1,23 +1,19 @@ //RUN: mlir-opt -test-linalg-rank-reduce-contraction-ops --canonicalize -split-input-file %s | FileCheck %s -func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> { +func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512x256xf32>, %arg2: tensor<1x128x256xf32>) -> tensor<1x128x256xf32> { // CHECK-LABEL: @singleton_batch_matmul_tensor - // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> - // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> - // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x512x256xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x128x256xf32> // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] - // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) - // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] - // CHECK-NEXT: %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]] - // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, %[[DIM1]], %[[DIM2]]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512x256xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128x256xf32>) + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, 128, 256] // CHECK-NEXT: return %[[RES]] - %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>) - outs(%arg2 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> - return %1 : tensor<1x?x?xf32> + %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512x256xf32>) + outs(%arg2 : tensor<1x128x256xf32>) -> tensor<1x128x256xf32> + return %1 : tensor<1x128x256xf32> } // ----- @@ -39,22 +35,20 @@ func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memr // ----- -func.func @singleton_batch_matvec(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { +func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512xf32>, %arg2: tensor<1x128xf32>) -> tensor<1x128xf32> { // CHECK-LABEL: @singleton_batch_matvec - // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> - // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?xf32> - // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32> - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x512xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x128xf32> // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]] // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) - // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] - // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>) + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, 128] // CHECK-NEXT: return %[[RES]] - %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?xf32>) - outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32> - return %1 : tensor<1x?xf32> + %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512xf32>) + outs(%arg2 : tensor<1x128xf32>) -> tensor<1x128xf32> + return %1 : tensor<1x128xf32> } // ----- @@ -197,19 +191,22 @@ func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor, %a // ----- -func.func @matmul_transpose_a_to_vecmat(%arg0: memref, %arg1: memref, %arg2: memref) { +func.func @matmul_transpose_a_to_vecmat(%arg0: tensor<256x1xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<1x512xf32>) -> tensor<1x512xf32> { // CHECK-LABEL: @matmul_transpose_a_to_vecmat + // CHECK: collapse_shape {{.*}} into tensor<256xf32> + // CHECK: collapse_shape {{.*}} into tensor<512xf32> // CHECK: linalg.vecmat - linalg.matmul_transpose_a ins(%arg0, %arg1: memref, memref) outs(%arg2: memref) - return + // CHECK: expand_shape {{.*}} into tensor<1x512xf32> + %0 = linalg.matmul_transpose_a ins(%arg0, %arg1: tensor<256x1xf32>, tensor<256x512xf32>) outs(%arg2: tensor<1x512xf32>) -> tensor<1x512xf32> + return %0 : tensor<1x512xf32> } // ----- -func.func @matmul_transpose_b_to_matvec(%arg0: memref, %arg1: memref<1x?xf32>, %arg2: memref<1x?xf32>) { +func.func @matmul_transpose_b_to_matvec(%arg0: memref, %arg1: memref<1x?xf32>, %arg2: memref) { // CHECK-LABEL: @matmul_transpose_b_to_matvec // CHECK: linalg.matvec - linalg.matmul_transpose_b ins(%arg0, %arg1: memref, memref<1x?xf32>) outs(%arg2: memref<1x?xf32>) + linalg.matmul_transpose_b ins(%arg0, %arg1: memref, memref<1x?xf32>) outs(%arg2: memref) return } From acca39bc577ec6077ef71740b4176ccf65826a05 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 21 Jun 2024 11:14:03 -0500 Subject: [PATCH 12/14] address comments and extra cleanup --- .../Dialect/Linalg/Transforms/Transforms.h | 4 +- .../Linalg/Transforms/DropUnitDims.cpp | 126 ++++++++++-------- .../Linalg/rank-reduce-contraction-ops.mlir | 9 ++ .../TestLinalgRankReduceContractionOps.cpp | 3 +- 4 files changed, 84 insertions(+), 58 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index c49383c600a57..3682a68b0e2c8 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1692,8 +1692,8 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns, void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn); -/// Adds patterns that that reduce the rank of named contraction ops that have -/// unit dimensions in the operand(s) by converting to a senquence of `collapse_shape`, +/// Adds patterns that reduce the rank of named contraction ops that have +/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`, /// ``, `expand_shape` (if on tensors). For example a /// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul` /// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 771b40d6a2001..e1daeb3ad666e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -838,10 +838,12 @@ struct LinalgFoldUnitExtentDimsPass namespace { +/// Returns reassociation indices for collapsing/expanding a +/// tensor of rank `rank` at position `pos`. static SmallVector getReassociationForReshapeAtDim(int64_t rank, int64_t pos) { SmallVector reassociation(rank - 1, {0, 1}); - auto lastDim = pos == rank - 1; + bool lastDim = pos == rank - 1; if (rank > 2) { for (int64_t i = 0; i < rank - 1; i++) { if (i == pos || (lastDim && i == pos - 1)) @@ -855,6 +857,8 @@ getReassociationForReshapeAtDim(int64_t rank, int64_t pos) { return reassociation; } +/// Returns a collapsed `val` where the collapsing occurs at dim `pos`. +/// If `pos < 0`, then don't collapse. static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val, int64_t pos) { if (pos < 0) @@ -868,22 +872,30 @@ static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val, ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape); } +/// Base class for all rank reduction patterns for contraction ops +/// with unit dimensions. All patterns should convert one named op +/// to another named op. Intended to reduce only one iteration space dim +/// at a time. +/// Reducing multiple dims will happen with recusive application of +/// pattern rewrites. template struct RankReduceContractionOps : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - SmallVector + /// Collapse all collapsable operands. + SmallVector collapseOperands(PatternRewriter &rewriter, ArrayRef operands, ArrayRef operandCollapseDims) const { assert(operandCollapseDims.size() == 3 && operands.size() == 3 && "expected 3 operands and dims"); - return llvm::to_vector(llvm::map_range( + return llvm::map_to_vector( llvm::zip(operands, operandCollapseDims), [&](auto pair) { return collapseSingletonDimAt(rewriter, std::get<0>(pair), std::get<1>(pair)); - })); + }); } + /// Expand result tensor. Value expandResult(PatternRewriter &rewriter, Value result, RankedTensorType expandedType, int64_t dim) const { return rewriter.create( @@ -905,12 +917,6 @@ struct RankReduceContractionOps : OpRewritePattern { auto init = inits[0]; SmallVector operands{lhs, rhs, init}; - auto maybeContractionDims = inferContractionDims(contractionOp); - if (failed(maybeContractionDims)) - return rewriter.notifyMatchFailure(contractionOp, - "could not infer contraction dims"); - - auto contractionDims = maybeContractionDims.value(); SmallVector operandUnitDims; if (failed(getOperandUnitDims(contractionOp, operandUnitDims))) return rewriter.notifyMatchFailure(contractionOp, @@ -935,80 +941,89 @@ struct RankReduceContractionOps : OpRewritePattern { auto results = contractionOp.getResults(); assert(results.size() < 2 && "expected at most one result"); - if (results.size() < 1) + if (results.empty()) { rewriter.replaceOp(contractionOp, collapsedOp); - else + } else { rewriter.replaceOp( contractionOp, expandResult(rewriter, collapsedOp.getResultTensors()[0], cast(results[0].getType()), operandUnitDims[2])); + } return success(); } + /// Populate `operandUnitDims` with 3 indices indicating the unit dim + /// for each operand that should be collapsed in this pattern. If an + /// operand shouldn't be collapsed, the index should be negative. virtual LogicalResult getOperandUnitDims(LinalgOp op, - SmallVectorImpl &operandUnitDindices) const = 0; + SmallVectorImpl &operandUnitDims) const = 0; }; +/// Patterns for unbatching batched contraction ops template -struct RankReduceBatched : RankReduceContractionOps { +struct RankReduceToUnBatched : RankReduceContractionOps { using RankReduceContractionOps::RankReduceContractionOps; - LogicalResult getOperandUnitDims( - LinalgOp op, - SmallVectorImpl &operandUnitDindices) const override { - auto inputs = op.getDpsInputs(); - auto inits = op.getDpsInits(); - if (inputs.size() != 2 || inits.size() != 1) - return failure(); - + /// Look for unit batch dims to collapse. + LogicalResult + getOperandUnitDims(LinalgOp op, + SmallVectorImpl &operandUnitDims) const override { auto maybeContractionDims = inferContractionDims(op); - if (failed(maybeContractionDims)) + if (failed(maybeContractionDims)) { + LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims"); return failure(); + } auto contractionDims = maybeContractionDims.value(); if (contractionDims.batch.size() != 1) return failure(); auto batchDim = contractionDims.batch[0]; - SmallVector, 2> bOperands; + SmallVector, 3> bOperands; op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands); if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) { return cast(std::get<0>(pair).getType()) .getShape()[std::get<1>(pair)] != 1; - })) + })) { + LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found"); return failure(); + } - operandUnitDindices = SmallVector{std::get<1>(bOperands[0]), - std::get<1>(bOperands[1]), - std::get<1>(bOperands[2])}; + operandUnitDims = SmallVector{std::get<1>(bOperands[0]), + std::get<1>(bOperands[1]), + std::get<1>(bOperands[2])}; return success(); } }; +/// Patterns for reducing non-batch dimensions template struct RankReduceMatmul : RankReduceContractionOps { using RankReduceContractionOps::RankReduceContractionOps; + /// Helper for determining whether the lhs/init or rhs/init are reduced. static bool constexpr reduceLeft = - (std::is_same::value && - std::is_same::value) || - (std::is_same::value && - std::is_same::value) || - (std::is_same::value && - std::is_same::value) || - (std::is_same::value && - std::is_same::value) || - (std::is_same::value && - std::is_same::value); - - LogicalResult getOperandUnitDims( - LinalgOp op, - SmallVectorImpl &operandUnitDindices) const override { + (std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v); + + /// Look for non-batch spatial dims to collapse. + LogicalResult + getOperandUnitDims(LinalgOp op, + SmallVectorImpl &operandUnitDims) const override { auto maybeContractionDims = inferContractionDims(op); - if (failed(maybeContractionDims)) + if (failed(maybeContractionDims)) { + LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims"); return failure(); + } auto contractionDims = maybeContractionDims.value(); if constexpr (reduceLeft) { @@ -1021,8 +1036,8 @@ struct RankReduceMatmul : RankReduceContractionOps { return cast(std::get<0>(pair).getType()) .getShape()[std::get<1>(pair)] == 1; })) { - operandUnitDindices = SmallVector{ - std::get<1>(mOperands[0]), -1, std::get<1>(mOperands[1])}; + operandUnitDims = SmallVector{std::get<1>(mOperands[0]), -1, + std::get<1>(mOperands[1])}; return success(); } } else { @@ -1035,11 +1050,12 @@ struct RankReduceMatmul : RankReduceContractionOps { return cast(std::get<0>(pair).getType()) .getShape()[std::get<1>(pair)] == 1; })) { - operandUnitDindices = SmallVector{ - -1, std::get<1>(nOperands[0]), std::get<1>(nOperands[1])}; + operandUnitDims = SmallVector{-1, std::get<1>(nOperands[0]), + std::get<1>(nOperands[1])}; return success(); } } + LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found"); return failure(); } }; @@ -1050,13 +1066,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns( RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); // Unbatching patterns for unit batch size - patterns.add>(context); - patterns.add>( - context); - patterns.add>( - context); - patterns.add>(context); - patterns.add>(context); + patterns.add>(context); + patterns + .add>( + context); + patterns + .add>( + context); + patterns.add>(context); + patterns.add>(context); // Non-batch rank 1 reducing patterns patterns.add>(context); diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir index 70568be99474e..fd59a4a52e378 100644 --- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir +++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir @@ -212,6 +212,15 @@ func.func @matmul_transpose_b_to_matvec(%arg0: memref, %arg1: memref<1x // ----- +func.func @batchmatmul_transpose_b_to_to_dot(%arg0: tensor<1x1x?xf32>, %arg1: tensor<1x1x?xf32>, %arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> { + // CHECK-LABEL: @batchmatmul_transpose_b_to_to_dot + // CHECK: linalg.dot + %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<1x1x?xf32>, tensor<1x1x?xf32>) outs(%arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> + return %0 : tensor<1x1x1xf32> +} + +// ----- + func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> { // CHECK-LABEL: @nonsingleton_batch_matmul // CHECK-NOT: collapse_shape diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp index 5ca27be30a687..8b455d7d68c30 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp @@ -1,5 +1,4 @@ -//===- TestLinalgRankReduceContractionOps.cpp - Test Linalg rank reduce -//contractions ---===// +//===- TestLinalgRankReduceContractionOps.cpp -----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. From 01197ef9f661bc40348c48fd907aab7f1dc118b3 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 21 Jun 2024 11:29:46 -0500 Subject: [PATCH 13/14] add more tests --- .../Linalg/rank-reduce-contraction-ops.mlir | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir index fd59a4a52e378..c086d0fd7e633 100644 --- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir +++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir @@ -203,6 +203,18 @@ func.func @matmul_transpose_a_to_vecmat(%arg0: tensor<256x1xf32>, %arg1: tensor< // ----- +func.func @batch_matmul_transpose_a_to_batch_vecmat(%arg0: tensor<64x256x1xf32>, %arg1: tensor<64x256x512xf32>, %arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32> { + // CHECK-LABEL: @batch_matmul_transpose_a_to_batch_vecmat + // CHECK: collapse_shape {{.*}} into tensor<64x256xf32> + // CHECK: collapse_shape {{.*}} into tensor<64x512xf32> + // CHECK: linalg.batch_vecmat + // CHECK: expand_shape {{.*}} into tensor<64x1x512xf32> + %0 = linalg.batch_matmul_transpose_a ins(%arg0, %arg1: tensor<64x256x1xf32>, tensor<64x256x512xf32>) outs(%arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32> + return %0 : tensor<64x1x512xf32> +} + +// ----- + func.func @matmul_transpose_b_to_matvec(%arg0: memref, %arg1: memref<1x?xf32>, %arg2: memref) { // CHECK-LABEL: @matmul_transpose_b_to_matvec // CHECK: linalg.matvec @@ -212,6 +224,17 @@ func.func @matmul_transpose_b_to_matvec(%arg0: memref, %arg1: memref<1x // ----- +func.func @batchmatmul_transpose_b_to_batchmatvec_tensor(%arg0: tensor<64x128x256xf32>, %arg1: tensor<64x1x256xf32>, %arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> { + // CHECK: collapse_shape {{.*}} into tensor<64x256xf32> + // CHECK: collapse_shape {{.*}} into tensor<64x128xf32> + // CHECK: linalg.batch_matvec + // CHECK: expand_shape {{.*}} into tensor<64x128x1xf32> + %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<64x128x256xf32>, tensor<64x1x256xf32>) outs(%arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> + return %0 : tensor<64x128x1xf32> +} + +// ----- + func.func @batchmatmul_transpose_b_to_to_dot(%arg0: tensor<1x1x?xf32>, %arg1: tensor<1x1x?xf32>, %arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> { // CHECK-LABEL: @batchmatmul_transpose_b_to_to_dot // CHECK: linalg.dot From 52151f41149e6c80bf8e1e68eb9176ba1f3de341 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 21 Jun 2024 12:30:12 -0500 Subject: [PATCH 14/14] expand more autos --- .../Dialect/Linalg/Transforms/DropUnitDims.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index e1daeb3ad666e..36f8696bf1b27 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -922,11 +922,11 @@ struct RankReduceContractionOps : OpRewritePattern { return rewriter.notifyMatchFailure(contractionOp, "no reducable dims found"); - auto collapsedOperands = + SmallVector collapsedOperands = collapseOperands(rewriter, operands, operandUnitDims); - auto collapsedLhs = collapsedOperands[0]; - auto collapsedRhs = collapsedOperands[1]; - auto collapsedInit = collapsedOperands[2]; + Value collapsedLhs = collapsedOperands[0]; + Value collapsedRhs = collapsedOperands[1]; + Value collapsedInit = collapsedOperands[2]; SmallVector collapsedResultTy; if (isa(collapsedInit.getType())) collapsedResultTy.push_back(collapsedInit.getType()); @@ -971,12 +971,13 @@ struct RankReduceToUnBatched : RankReduceContractionOps { LogicalResult getOperandUnitDims(LinalgOp op, SmallVectorImpl &operandUnitDims) const override { - auto maybeContractionDims = inferContractionDims(op); + FailureOr maybeContractionDims = + inferContractionDims(op); if (failed(maybeContractionDims)) { LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims"); return failure(); } - auto contractionDims = maybeContractionDims.value(); + ContractionDimensions contractionDims = maybeContractionDims.value(); if (contractionDims.batch.size() != 1) return failure(); @@ -1019,12 +1020,13 @@ struct RankReduceMatmul : RankReduceContractionOps { LogicalResult getOperandUnitDims(LinalgOp op, SmallVectorImpl &operandUnitDims) const override { - auto maybeContractionDims = inferContractionDims(op); + FailureOr maybeContractionDims = + inferContractionDims(op); if (failed(maybeContractionDims)) { LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims"); return failure(); } - auto contractionDims = maybeContractionDims.value(); + ContractionDimensions contractionDims = maybeContractionDims.value(); if constexpr (reduceLeft) { auto m = contractionDims.m[0];