Skip to content

Commit 03bd0f3

Browse files
authored
[mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp from Vector dialect (#144307)
This patch deletes `vector.matrix_multiply` and `vector.flat_transpose`, which are thin wrappers around the corresponding LLVM intrinsics: - `llvm.intr.matrix.multiply` - `llvm.intr.matrix.transpose` These Vector dialect ops did not provide additional semantics or abstraction beyond the LLVM intrinsics. Their removal simplifies the lowering pipeline without losing any functionality. The lowering chains: - `vector.contract` → `vector.matrix_multiply` → `llvm.intr.matrix.multiply` - `vector.transpose` → `vector.flat_transpose` → `llvm.intr.matrix.transpose` are now replaced with: - `vector.contract` → `llvm.intr.matrix.multiply` - `vector.transpose` → `llvm.intr.matrix.transpose` This was accomplished by directly replacing: - `vector::MatrixMultiplyOp` with `LLVM::MatrixMultiplyOp` - `vector::FlatTransposeOp` with `LLVM::MatrixTransposeOp` Note: To avoid a build-time dependency from `Vector` to `LLVM`, relevant transformations are moved from "Vector/Transforms" to `Conversion/VectorToLLVM`.
1 parent b832c49 commit 03bd0f3

File tree

18 files changed

+306
-558
lines changed

18 files changed

+306
-558
lines changed

mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@
1313
namespace mlir {
1414
class LLVMTypeConverter;
1515

16-
/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
17-
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
18-
/// will be needed when invoking LLVM.
19-
void populateVectorToLLVMMatrixConversionPatterns(
20-
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
21-
2216
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
2317
void populateVectorToLLVMConversionPatterns(
2418
const LLVMTypeConverter &converter, RewritePatternSet &patterns,

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 0 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -2837,124 +2837,6 @@ def Vector_PrintOp :
28372837
}];
28382838
}
28392839

2840-
//===----------------------------------------------------------------------===//
2841-
// Ops used for supporting progressive lowering and conversion type changes.
2842-
// The Ops are typically not used directly by higher level dialects, but are
2843-
// used by intra-dialect rewriting rules to bring vector operations closer
2844-
// to the hardware ISA.
2845-
//===----------------------------------------------------------------------===//
2846-
2847-
/// Vector dialect matrix multiplication op that operates on flattened 1-D
2848-
/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
2849-
/// This may seem redundant with vector.contract but it serves the purposes of
2850-
/// more progressive lowering and localized type conversion on the path:
2851-
/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
2852-
def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
2853-
PredOpTrait<"lhs operand and result have same element type",
2854-
TCresVTEtIsSameAsOpBase<0, 0>>,
2855-
PredOpTrait<"rhs operand and result have same element type",
2856-
TCresVTEtIsSameAsOpBase<0, 1>>]>,
2857-
Arguments<(
2858-
// TODO: tighten vector element types that make sense.
2859-
ins FixedVectorOfRankAndType<[1],
2860-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
2861-
FixedVectorOfRankAndType<[1],
2862-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
2863-
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
2864-
Results<(
2865-
outs FixedVectorOfRankAndType<[1],
2866-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
2867-
{
2868-
let summary = "Vector matrix multiplication op that operates on flattened 1-D"
2869-
" MLIR vectors";
2870-
let description = [{
2871-
This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
2872-
purposes of more progressive lowering and localized type conversion.
2873-
Higher levels typically lower matrix multiplications into 'vector.contract'
2874-
operations. Subsequent rewriting rule progressively lower these operations
2875-
into 'vector.matrix_multiply' operations to bring the operations closer
2876-
to the hardware ISA.
2877-
2878-
The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
2879-
and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
2880-
<rhs_columns> and multiplies them. The result matrix is returned embedded in
2881-
the result vector.
2882-
2883-
Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not
2884-
support scalable vectors. Hence, this Op is only available for fixed-width
2885-
vectors. Also see:
2886-
2887-
http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic
2888-
2889-
Example:
2890-
2891-
```mlir
2892-
%C = vector.matrix_multiply %A, %B
2893-
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
2894-
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
2895-
```
2896-
}];
2897-
let builders = [
2898-
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "unsigned":$lhsRows,
2899-
"unsigned":$lhsColumns, "unsigned":$rhsColumns),
2900-
[{
2901-
$_state.addOperands({lhs, rhs});
2902-
$_state.addAttribute("lhs_rows",$_builder.getI32IntegerAttr(lhsRows));
2903-
$_state.addAttribute("lhs_columns",$_builder.getI32IntegerAttr(lhsColumns));
2904-
$_state.addAttribute("rhs_columns",$_builder.getI32IntegerAttr(rhsColumns));
2905-
$_state.addTypes(VectorType::get(lhsRows * rhsColumns,
2906-
::llvm::cast<VectorType>(lhs.getType()).getElementType()));
2907-
}]>,
2908-
];
2909-
let assemblyFormat = "$lhs `,` $rhs attr-dict "
2910-
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
2911-
}
2912-
2913-
/// Vector dialect matrix transposition op that operates on flattened 1-D
2914-
/// MLIR vectors. This is the counterpart of llvm.matrix.transpose in MLIR.
2915-
/// This may seem redundant with vector.transpose but it serves the purposes of
2916-
/// more progressive lowering and localized type conversion on the path:
2917-
/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
2918-
def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
2919-
PredOpTrait<"source operand and result have same element type",
2920-
TCresVTEtIsSameAsOpBase<0, 0>>]>,
2921-
Arguments<(
2922-
// TODO: tighten vector element types that make sense.
2923-
ins FixedVectorOfRankAndType<[1],
2924-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix,
2925-
I32Attr:$rows, I32Attr:$columns)>,
2926-
Results<(
2927-
outs FixedVectorOfRankAndType<[1],
2928-
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> {
2929-
let summary = "Vector matrix transposition on flattened 1-D MLIR vectors";
2930-
let description = [{
2931-
This is the counterpart of llvm.matrix.transpose in MLIR. It serves
2932-
the purposes of more progressive lowering and localized type conversion.
2933-
Higher levels typically lower matrix transpositions into 'vector.transpose'
2934-
operations. Subsequent rewriting rule progressively lower these operations
2935-
into 'vector.flat_transpose' operations to bring the operations closer
2936-
to the hardware ISA.
2937-
2938-
The `vector.flat_transpose` op treats the 1-D input `matrix` as
2939-
a 2-D matrix with <rows> rows and <columns> columns, and returns the
2940-
transposed matrix in flattened form in 'res'.
2941-
2942-
Note, the corresponding LLVM intrinsic, `@llvm.matrix.transpose.*`, does not
2943-
support scalable vectors. Hence, this Op is only available for fixed-width
2944-
vectors. Also see:
2945-
2946-
http://llvm.org/docs/LangRef.html#llvm-matrix-transpose-intrinsic
2947-
2948-
Example:
2949-
2950-
```mlir
2951-
%1 = vector.flat_transpose %0 {columns = 4 : i32, rows = 4 : i32}
2952-
: vector<16xf32> -> vector<16xf32>
2953-
```
2954-
}];
2955-
let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
2956-
}
2957-
29582840
//===----------------------------------------------------------------------===//
29592841
// SplatOp
29602842
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,28 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
303303
void populateVectorToFromElementsToShuffleTreePatterns(
304304
RewritePatternSet &patterns, PatternBenefit benefit = 1);
305305

306+
/// Populate the pattern set with the following patterns:
307+
///
308+
/// [ContractionOpToMatmulOpLowering]
309+
/// Lowers `vector.contract` to `llvm.intr.matrix.multiply`.
310+
///
311+
/// Given the high benefit, this will be prioriotised over other
312+
/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
313+
/// only run this registration conditionally.
314+
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns,
315+
PatternBenefit benefit = 100);
316+
317+
/// Populate the pattern set with the following patterns:
318+
///
319+
/// [TransposeOpLowering]
320+
/// Lowers `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
321+
///
322+
/// Given the high benefit, this will be prioriotised over other
323+
/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
324+
/// only run this registration conditionally.
325+
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
326+
PatternBenefit benefit = 100);
327+
306328
} // namespace vector
307329
} // namespace mlir
308330

0 commit comments

Comments
 (0)