From b14e3fe50ef08979ad248382a34b10025649bffd Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Sun, 21 Jul 2024 15:26:32 +0000 Subject: [PATCH 1/7] [mlir][spirv] Fix function signature legalization for n-D vectors --- .../SPIRV/Transforms/SPIRVConversion.cpp | 10 ++++- .../func-signature-vector-unroll.mlir | 44 +++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index bf5044437fd09..c146589612b5e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1098,13 +1098,19 @@ struct ReturnOpVectorUnroll final : OpRewritePattern { // the original operand of illegal type. auto originalShape = llvm::to_vector_of(origVecType.getShape()); - SmallVector strides(targetShape->size(), 1); + SmallVector strides(originalShape.size(), 1); + SmallVector extractShape(originalShape.size(), 1); + extractShape.back() = targetShape->back(); SmallVector newTypes; Value returnValue = returnOp.getOperand(origResultNo); for (SmallVector offsets : StaticTileOffsetRange(originalShape, *targetShape)) { Value result = rewriter.create( - loc, returnValue, offsets, *targetShape, strides); + loc, returnValue, offsets, extractShape, strides); + SmallVector extractIndices(originalShape.size() - 1, 0); + if (originalShape.size() > 1) + result = + rewriter.create(loc, result, extractIndices); newOperands.push_back(result); newTypes.push_back(unrolledType); } diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir index 347d282f9ee0c..c018ccb924983 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir @@ -66,6 +66,28 @@ func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> { // ----- +// CHECK-LABEL: @simple_vector_2d +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>) +func.func @simple_vector_2d(%arg0 : vector<4x4xi32>) -> vector<4x4xi32> { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4x4xi32> + // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32> + // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32> + // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32> + // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32> + // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<4xi32> from vector<1x4xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32> + // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<4xi32> from vector<1x4xi32> + // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32> + // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<4xi32> from vector<1x4xi32> + // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32> + // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<4xi32> from vector<1x4xi32> + // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32> + return %arg0 : vector<4x4xi32> +} + +// ----- + // CHECK-LABEL: @vector_6and8 // CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>) func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) { @@ -113,6 +135,28 @@ func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i // ----- +// CHECK-LABEL: @vector_2dand1d +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>, %[[ARG4:.+]]: vector<4xi32>) +func.func @vector_2dand1d(%arg0 : vector<2x6xi32>, %arg1 : vector<4xi32>) -> (vector<2x6xi32>, vector<4xi32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi32> + // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [0, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32> + // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [1, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32> + // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [1, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32> + // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32> + // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<3xi32> from vector<1x3xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32> + // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<3xi32> from vector<1x3xi32> + // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32> + // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<3xi32> from vector<1x3xi32> + // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32> + // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<3xi32> from vector<1x3xi32> + // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]], %[[ARG4]] : vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<4xi32> + return %arg0, %arg1 : vector<2x6xi32>, vector<4xi32> +} + +// ----- + // CHECK-LABEL: @reduction // CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32) func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) { From c80e380068ccca536ec7d06f9e5062b97379f5df Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Mon, 22 Jul 2024 19:23:05 +0000 Subject: [PATCH 2/7] [mlir][spirv] Initial version of vector unrolling for convert-to-spirv pass --- mlir/include/mlir/Conversion/Passes.td | 5 +- .../SPIRV/Transforms/SPIRVConversion.h | 11 ++ .../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 74 +++++++++++ .../SPIRV/Transforms/SPIRVConversion.cpp | 59 +++++++-- .../test/Conversion/ConvertToSPIRV/arith.mlir | 2 +- .../Conversion/ConvertToSPIRV/combined.mlir | 2 +- .../test/Conversion/ConvertToSPIRV/index.mlir | 2 +- mlir/test/Conversion/ConvertToSPIRV/scf.mlir | 2 +- .../Conversion/ConvertToSPIRV/simple.mlir | 2 +- mlir/test/Conversion/ConvertToSPIRV/ub.mlir | 2 +- .../ConvertToSPIRV/vector-unroll.mlir | 102 +++++++++++++++ .../Conversion/ConvertToSPIRV/vector.mlir | 2 +- .../Conversion/ConvertToSPIRV/CMakeLists.txt | 2 + .../TestSPIRVVectorUnrolling.cpp | 119 ++++++++++++++++++ mlir/tools/mlir-opt/mlir-opt.cpp | 2 + 15 files changed, 370 insertions(+), 18 deletions(-) create mode 100644 mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir create mode 100644 mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 748646e605827..b5bb2f42f2961 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -47,7 +47,10 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> { let options = [ Option<"runSignatureConversion", "run-signature-conversion", "bool", /*default=*/"true", - "Run function signature conversion to convert vector types"> + "Run function signature conversion to convert vector types">, + Option<"runVectorUnrolling", "run-vector-unrolling", "bool", + /*default=*/"true", + "Run vector unrolling to convert vector types in function bodies"> ]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 9ad3d5fc85dd3..195fbd0d0cd58 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -189,6 +189,17 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder); +int getComputeVectorSize(int64_t size); + +// GetNativeVectorShape implementation for reduction ops. +SmallVector getNativeVectorShapeImpl(vector::ReductionOp op); + +// GetNativeVectorShape implementation for transpose ops. +SmallVector getNativeVectorShapeImpl(vector::TransposeOp op); + +// For general ops. +std::optional> getNativeVectorShape(Operation *op); + } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp index 003a5feea9e9b..b82a244cfc973 100644 --- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp @@ -17,6 +17,8 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -56,6 +58,78 @@ struct ConvertToSPIRVPass final return signalPassFailure(); } + if (runVectorUnrolling) { + + // Fold transpose ops if possible as we cannot unroll it later. + { + RewritePatternSet patterns(context); + vector::TransposeOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + // Unroll vectors to native vector size. + { + RewritePatternSet patterns(context); + auto options = vector::UnrollVectorOptions().setNativeShapeFn( + [=](auto op) { return mlir::spirv::getNativeVectorShape(op); }); + populateVectorUnrollPatterns(patterns, options); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } + + // Convert transpose ops into extract and insert pairs, in preparation + // of further transformations to canonicalize/cancel. + { + RewritePatternSet patterns(context); + auto options = + vector::VectorTransformsOptions().setVectorTransposeLowering( + vector::VectorTransposeLowering::EltWise); + vector::populateVectorTransposeLoweringPatterns(patterns, options); + vector::populateVectorShapeCastLoweringPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + // Run canonicalization to cast away leading size-1 dimensions. + { + RewritePatternSet patterns(context); + + // Pull in casting way leading one dims to allow cancelling some + // read/write ops. + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + vector::ReductionOp::getCanonicalizationPatterns(patterns, context); + + // Decompose different rank insert_strided_slice and n-D + // extract_slided_slice. + vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( + patterns); + vector::ExtractOp::getCanonicalizationPatterns(patterns, context); + + // Trimming leading unit dims may generate broadcast/shape_cast ops. + // Clean them up. + vector::BroadcastOp::getCanonicalizationPatterns(patterns, context); + vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context); + + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } + + // Run all sorts of canonicalization patterns to clean up again. + { + RewritePatternSet patterns(context); + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + vector::InsertOp::getCanonicalizationPatterns(patterns, context); + vector::ExtractOp::getCanonicalizationPatterns(patterns, context); + vector::ReductionOp::getCanonicalizationPatterns(patterns, context); + vector::TransposeOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } + } + spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); std::unique_ptr target = SPIRVConversionTarget::get(targetAttr); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index c146589612b5e..8470c7642e716 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -46,14 +46,6 @@ namespace { // Utility functions //===----------------------------------------------------------------------===// -static int getComputeVectorSize(int64_t size) { - for (int i : {4, 3, 2}) { - if (size % i == 0) - return i; - } - return 1; -} - static std::optional> getTargetShape(VectorType vecType) { LLVM_DEBUG(llvm::dbgs() << "Get target shape\n"); if (vecType.isScalable()) { @@ -62,8 +54,8 @@ static std::optional> getTargetShape(VectorType vecType) { return std::nullopt; } SmallVector unrollShape = llvm::to_vector<4>(vecType.getShape()); - std::optional> targetShape = - SmallVector(1, getComputeVectorSize(vecType.getShape().back())); + std::optional> targetShape = SmallVector( + 1, mlir::spirv::getComputeVectorSize(vecType.getShape().back())); if (!targetShape) { LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n"); return std::nullopt; @@ -1291,6 +1283,53 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter, builder); } +//===----------------------------------------------------------------------===// +// Public functions for vector unrolling +//===----------------------------------------------------------------------===// + +int mlir::spirv::getComputeVectorSize(int64_t size) { + for (int i : {4, 3, 2}) { + if (size % i == 0) + return i; + } + return 1; +} + +SmallVector +mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) { + VectorType srcVectorType = op.getSourceVectorType(); + assert(srcVectorType.getRank() == 1); // Guaranteed by semantics + int64_t vectorSize = + mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0)); + return {vectorSize}; +} + +SmallVector +mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) { + VectorType vectorType = op.getResultVectorType(); + SmallVector nativeSize(vectorType.getRank(), 1); + nativeSize.back() = + mlir::spirv::getComputeVectorSize(vectorType.getShape().back()); + return nativeSize; +} + +std::optional> +mlir::spirv::getNativeVectorShape(Operation *op) { + if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) { + if (auto vecType = llvm::dyn_cast(op->getResultTypes()[0])) { + SmallVector nativeSize(vecType.getRank(), 1); + nativeSize.back() = + mlir::spirv::getComputeVectorSize(vecType.getShape().back()); + return nativeSize; + } + } + + return TypeSwitch>>(op) + .Case( + [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); }) + .Default([](Operation *) { return std::nullopt; }); +} + //===----------------------------------------------------------------------===// // SPIR-V TypeConverter //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir index 1a844a7cd018b..6418e931f7460 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" -split-input-file %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s //===----------------------------------------------------------------------===// // arithmetic ops diff --git a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir index 02b938be775a3..311174bef15ed 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s // CHECK-LABEL: @combined // CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32 diff --git a/mlir/test/Conversion/ConvertToSPIRV/index.mlir b/mlir/test/Conversion/ConvertToSPIRV/index.mlir index e1cb18aac5d01..f4b116849fa93 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/index.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/index.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-to-spirv="run-signature-conversion=false" | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s // CHECK-LABEL: @basic func.func @basic(%a: index, %b: index) { diff --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir index 58ec6ac61f6ac..246464928b81c 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s // CHECK-LABEL: @if_yield // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr diff --git a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir index c5e0e6603d94a..00556140c3018 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s // CHECK-LABEL: @return_scalar // CHECK-SAME: %[[ARG0:.*]]: i32 diff --git a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir index a83bfb6f405a0..f34ca01c94f00 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s // CHECK-LABEL: @ub // CHECK: %[[UNDEF:.*]] = spirv.Undef : i32 diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir new file mode 100644 index 0000000000000..54d9875002cb5 --- /dev/null +++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt -test-spirv-vector-unrolling -split-input-file %s | FileCheck %s + +// CHECK-LABEL: @vaddi +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>) +func.func @vaddi(%arg0 : vector<6xi32>, %arg1 : vector<6xi32>) -> (vector<6xi32>) { + // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : vector<3xi32> + // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG3]] : vector<3xi32> + // CHECK: return %[[ADD0]], %[[ADD1]] : vector<3xi32>, vector<3xi32> + %0 = arith.addi %arg0, %arg1 : vector<6xi32> + return %0 : vector<6xi32> +} + +// CHECK-LABEL: @vaddi_2d +// CHECK-SAME: (%[[ARG0:.+]]: vector<2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<2xi32>, %[[ARG3:.+]]: vector<2xi32>) +func.func @vaddi_2d(%arg0 : vector<2x2xi32>, %arg1 : vector<2x2xi32>) -> (vector<2x2xi32>) { + // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : vector<2xi32> + // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG3]] : vector<2xi32> + // CHECK: return %[[ADD0]], %[[ADD1]] : vector<2xi32>, vector<2xi32> + %0 = arith.addi %arg0, %arg1 : vector<2x2xi32> + return %0 : vector<2x2xi32> +} + +// CHECK-LABEL: @vaddi_2d_8 +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: vector<4xi32>, %[[ARG5:.+]]: vector<4xi32>, %[[ARG6:.+]]: vector<4xi32>, %[[ARG7:.+]]: vector<4xi32>) +func.func @vaddi_2d_8(%arg0 : vector<2x8xi32>, %arg1 : vector<2x8xi32>) -> (vector<2x8xi32>) { + // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG4]] : vector<4xi32> + // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG5]] : vector<4xi32> + // CHECK: %[[ADD2:.*]] = arith.addi %[[ARG2]], %[[ARG6]] : vector<4xi32> + // CHECK: %[[ADD3:.*]] = arith.addi %[[ARG3]], %[[ARG7]] : vector<4xi32> + // CHECK: return %[[ADD0]], %[[ADD1]], %[[ADD2]], %[[ADD3]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32> + %0 = arith.addi %arg0, %arg1 : vector<2x8xi32> + return %0 : vector<2x8xi32> +} + +// ----- + +// CHECK-LABEL: @reduction_5 +// CHECK-SAME: (%[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<1xi32>, %[[ARG2:.+]]: vector<1xi32>, %[[ARG3:.+]]: vector<1xi32>, %[[ARG4:.+]]: vector<1xi32>) +func.func @reduction_5(%arg0 : vector<5xi32>) -> (i32) { + // CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<1xi32> + // CHECK: %[[ADD0:.*]] = arith.addi %[[EXTRACT0]], %[[EXTRACT1]] : i32 + // CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG2]][0] : i32 from vector<1xi32> + // CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[EXTRACT2]] : i32 + // CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG3]][0] : i32 from vector<1xi32> + // CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[EXTRACT3]] : i32 + // CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG4]][0] : i32 from vector<1xi32> + // CHECK: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[EXTRACT4]] : i32 + // CHECK: return %[[ADD3]] : i32 + %0 = vector.reduction , %arg0 : vector<5xi32> into i32 + return %0 : i32 +} + +// CHECK-LABEL: @reduction_8 +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>) +func.func @reduction_8(%arg0 : vector<8xi32>) -> (i32) { + // CHECK: %[[REDUCTION0:.*]] = vector.reduction , %[[ARG0]] : vector<4xi32> into i32 + // CHECK: %[[REDUCTION1:.*]] = vector.reduction , %[[ARG1]] : vector<4xi32> into i32 + // CHECK: %[[ADD:.*]] = arith.addi %[[REDUCTION0]], %[[REDUCTION1]] : i32 + // CHECK: return %[[ADD]] : i32 + %0 = vector.reduction , %arg0 : vector<8xi32> into i32 + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: @vaddi_reduction +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>) +func.func @vaddi_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) -> (i32) { + // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : vector<4xi32> + // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG3]] : vector<4xi32> + // CHECK: %[[REDUCTION0:.*]] = vector.reduction , %[[ADD0]] : vector<4xi32> into i32 + // CHECK: %[[REDUCTION1:.*]] = vector.reduction , %[[ADD1]] : vector<4xi32> into i32 + // CHECK: %[[ADD2:.*]] = arith.addi %[[REDUCTION0]], %[[REDUCTION1]] : i32 + // CHECK: return %[[ADD2]] : i32 + %0 = arith.addi %arg0, %arg1 : vector<8xi32> + %1 = vector.reduction , %0 : vector<8xi32> into i32 + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: @transpose +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>) +func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2xi32> + // CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32> + // CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[CST]] [0] : i32 into vector<2xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<3xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert %[[EXTRACT1]], %[[INSERT0]][1] : i32 into vector<2xi32> + // CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1] : i32 from vector<3xi32> + // CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[CST]] [0] : i32 into vector<2xi32> + // CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG1]][1] : i32 from vector<3xi32> + // CHECK: %[[INSERT3:.*]] = vector.insert %[[EXTRACT3]], %[[INSERT2]] [1] : i32 into vector<2xi32> + // CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][2] : i32 from vector<3xi32> + // CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[CST]] [0] : i32 into vector<2xi32> + // CHECK: %[[EXTRACT5:.*]] = vector.extract %[[ARG1]][2] : i32 from vector<3xi32> + // CHECK: %[[INSERT5:.*]] = vector.insert %[[EXTRACT5]], %[[INSERT4]] [1] : i32 into vector<2xi32> + // CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32> + %0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32> + return %0 : vector<3x2xi32> +} \ No newline at end of file diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir index c63dd030f4747..e369eadca5730 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s // CHECK-LABEL: @extract // CHECK-SAME: %[[ARG:.+]]: vector<2xf32> diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt index 69b5787f7e851..aeade52c7ade5 100644 --- a/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt +++ b/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestConvertToSPIRV TestSPIRVFuncSignatureConversion.cpp + TestSPIRVVectorUnrolling.cpp EXCLUDE_FROM_LIBMLIR @@ -13,4 +14,5 @@ add_mlir_library(MLIRTestConvertToSPIRV MLIRTransformUtils MLIRTransforms MLIRVectorDialect + MLIRVectorTransforms ) diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp new file mode 100644 index 0000000000000..fbf7933b244fa --- /dev/null +++ b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp @@ -0,0 +1,119 @@ +//===- TestSPIRVVectorUnrolling.cpp - Test signature conversion -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace { + +struct TestSPIRVVectorUnrolling final + : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVVectorUnrolling) + + StringRef getArgument() const final { return "test-spirv-vector-unrolling"; } + + StringRef getDescription() const final { + return "Test patterns that unroll vectors to types supported by SPIR-V"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + Operation *op = getOperation(); + + // Unroll vectors in function signatures to native vector size. + { + RewritePatternSet patterns(context); + populateFuncOpVectorRewritePatterns(patterns); + populateReturnOpVectorRewritePatterns(patterns); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) + return signalPassFailure(); + } + + // Unroll vectors to native vector size. + { + RewritePatternSet patterns(context); + auto options = vector::UnrollVectorOptions().setNativeShapeFn( + [=](auto op) { return mlir::spirv::getNativeVectorShape(op); }); + populateVectorUnrollPatterns(patterns, options); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } + + // Convert transpose ops into extract and insert pairs, in preparation of + // further transformations to canonicalize/cancel. + { + RewritePatternSet patterns(context); + auto options = + vector::VectorTransformsOptions().setVectorTransposeLowering( + vector::VectorTransposeLowering::EltWise); + vector::populateVectorTransposeLoweringPatterns(patterns, options); + vector::populateVectorShapeCastLoweringPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + // Run canonicalization to cast away leading size-1 dimensions. + { + RewritePatternSet patterns(context); + + // We need to pull in casting way leading one dims. + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + vector::ReductionOp::getCanonicalizationPatterns(patterns, context); + + // Decompose different rank insert_strided_slice and n-D + // extract_slided_slice. + vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( + patterns); + vector::ExtractOp::getCanonicalizationPatterns(patterns, context); + + // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean + // them up. + vector::BroadcastOp::getCanonicalizationPatterns(patterns, context); + vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context); + + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } + + // Run all sorts of canonicalization patterns to clean up again. + { + RewritePatternSet patterns(context); + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + vector::InsertOp::getCanonicalizationPatterns(patterns, context); + vector::ExtractOp::getCanonicalizationPatterns(patterns, context); + vector::ReductionOp::getCanonicalizationPatterns(patterns, context); + vector::TransposeOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } + } +}; + +} // namespace + +namespace test { +void registerTestSPIRVVectorUnrolling() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 149f9d59961b8..0f29963da39bb 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -142,6 +142,7 @@ void registerTestSCFWrapInZeroTripCheckPasses(); void registerTestShapeMappingPass(); void registerTestSliceAnalysisPass(); void registerTestSPIRVFuncSignatureConversion(); +void registerTestSPIRVVectorUnrolling(); void registerTestTensorCopyInsertionPass(); void registerTestTensorTransforms(); void registerTestTopologicalSortAnalysisPass(); @@ -275,6 +276,7 @@ void registerTestPasses() { mlir::test::registerTestShapeMappingPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestSPIRVFuncSignatureConversion(); + mlir::test::registerTestSPIRVVectorUnrolling(); mlir::test::registerTestTensorCopyInsertionPass(); mlir::test::registerTestTensorTransforms(); mlir::test::registerTestTopologicalSortAnalysisPass(); From 8915f1e2954b58d5ce5bc4b6dba9e19fa828f574 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Tue, 23 Jul 2024 18:10:30 +0000 Subject: [PATCH 3/7] Formatting and documentation --- mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h | 2 ++ mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp | 1 - mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 195fbd0d0cd58..1236afe552a9c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -189,6 +189,8 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder); +// Find the largest factor of size among {2,3,4} for the lowest dimension of +// the target shape. int getComputeVectorSize(int64_t size); // GetNativeVectorShape implementation for reduction ops. diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp index b82a244cfc973..c4a11793b1b82 100644 --- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp @@ -59,7 +59,6 @@ struct ConvertToSPIRVPass final } if (runVectorUnrolling) { - // Fold transpose ops if possible as we cannot unroll it later. { RewritePatternSet patterns(context); diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir index 54d9875002cb5..043f9422d8790 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir @@ -99,4 +99,4 @@ func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) { // CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32> %0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32> return %0 : vector<3x2xi32> -} \ No newline at end of file +} From 9c312f1dea3521606b33c2837deaf0ae1120ce1f Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Tue, 23 Jul 2024 19:46:28 +0000 Subject: [PATCH 4/7] Remove redundant patterns and refactor code --- .../SPIRV/Transforms/SPIRVConversion.h | 8 ++ .../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 88 ++----------------- .../SPIRV/Transforms/SPIRVConversion.cpp | 66 ++++++++++++++ .../TestSPIRVFuncSignatureConversion.cpp | 9 +- .../TestSPIRVVectorUnrolling.cpp | 73 +-------------- 5 files changed, 85 insertions(+), 159 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 1236afe552a9c..f54c93a09e727 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -19,8 +19,10 @@ #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/OneToNTypeConversion.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/Support/LogicalResult.h" namespace mlir { @@ -202,6 +204,12 @@ SmallVector getNativeVectorShapeImpl(vector::TransposeOp op); // For general ops. std::optional> getNativeVectorShape(Operation *op); +// Unroll vectors in function signatures to native size. +LogicalResult unrollVectorsInSignatures(Operation *op); + +// Unroll vectors in function bodies to native size. +LogicalResult unrollVectorsInFuncBodies(Operation *op); + } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp index c4a11793b1b82..4694a147e1e94 100644 --- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp @@ -44,90 +44,16 @@ struct ConvertToSPIRVPass final using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase; void runOnOperation() override { - MLIRContext *context = &getContext(); Operation *op = getOperation(); + MLIRContext *context = &getContext(); - if (runSignatureConversion) { - // Unroll vectors in function signatures to native vector size. - RewritePatternSet patterns(context); - populateFuncOpVectorRewritePatterns(patterns); - populateReturnOpVectorRewritePatterns(patterns); - GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::ExistingOps; - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) - return signalPassFailure(); - } - - if (runVectorUnrolling) { - // Fold transpose ops if possible as we cannot unroll it later. - { - RewritePatternSet patterns(context); - vector::TransposeOp::getCanonicalizationPatterns(patterns, context); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { - return signalPassFailure(); - } - } - - // Unroll vectors to native vector size. - { - RewritePatternSet patterns(context); - auto options = vector::UnrollVectorOptions().setNativeShapeFn( - [=](auto op) { return mlir::spirv::getNativeVectorShape(op); }); - populateVectorUnrollPatterns(patterns, options); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) - return signalPassFailure(); - } - - // Convert transpose ops into extract and insert pairs, in preparation - // of further transformations to canonicalize/cancel. - { - RewritePatternSet patterns(context); - auto options = - vector::VectorTransformsOptions().setVectorTransposeLowering( - vector::VectorTransposeLowering::EltWise); - vector::populateVectorTransposeLoweringPatterns(patterns, options); - vector::populateVectorShapeCastLoweringPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { - return signalPassFailure(); - } - } - - // Run canonicalization to cast away leading size-1 dimensions. - { - RewritePatternSet patterns(context); - - // Pull in casting way leading one dims to allow cancelling some - // read/write ops. - vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); - vector::ReductionOp::getCanonicalizationPatterns(patterns, context); - - // Decompose different rank insert_strided_slice and n-D - // extract_slided_slice. - vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( - patterns); - vector::ExtractOp::getCanonicalizationPatterns(patterns, context); - - // Trimming leading unit dims may generate broadcast/shape_cast ops. - // Clean them up. - vector::BroadcastOp::getCanonicalizationPatterns(patterns, context); - vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context); - - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) - return signalPassFailure(); - } + // Unroll vectors in function signatures to native size. + if (runSignatureConversion && failed(spirv::unrollVectorsInSignatures(op))) + return signalPassFailure(); - // Run all sorts of canonicalization patterns to clean up again. - { - RewritePatternSet patterns(context); - vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); - vector::InsertOp::getCanonicalizationPatterns(patterns, context); - vector::ExtractOp::getCanonicalizationPatterns(patterns, context); - vector::ReductionOp::getCanonicalizationPatterns(patterns, context); - vector::TransposeOp::getCanonicalizationPatterns(patterns, context); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) - return signalPassFailure(); - } - } + // Unroll vectors in function bodies to native size. + if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op))) + return signalPassFailure(); spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); std::unique_ptr target = diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 8470c7642e716..9f2cab30d56f3 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -20,17 +20,21 @@ #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/OneToNTypeConversion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include @@ -1330,6 +1334,68 @@ mlir::spirv::getNativeVectorShape(Operation *op) { .Default([](Operation *) { return std::nullopt; }); } +LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) { + MLIRContext *context = op->getContext(); + RewritePatternSet patterns(context); + populateFuncOpVectorRewritePatterns(patterns); + populateReturnOpVectorRewritePatterns(patterns); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + return applyPatternsAndFoldGreedily(op, std::move(patterns), config); +} + +LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { + MLIRContext *context = op->getContext(); + + // Unroll vectors in function bodies to native vector size. + { + RewritePatternSet patterns(context); + auto options = vector::UnrollVectorOptions().setNativeShapeFn( + [](auto op) { return mlir::spirv::getNativeVectorShape(op); }); + populateVectorUnrollPatterns(patterns, options); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return llvm::failure(); + } + + // Convert transpose ops into extract and insert pairs, in preparation of + // further transformations to canonicalize/cancel. + { + RewritePatternSet patterns(context); + auto options = vector::VectorTransformsOptions().setVectorTransposeLowering( + vector::VectorTransposeLowering::EltWise); + vector::populateVectorTransposeLoweringPatterns(patterns, options); + vector::populateVectorShapeCastLoweringPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return llvm::failure(); + } + + // Run canonicalization to cast away leading size-1 dimensions. + { + RewritePatternSet patterns(context); + + // We need to pull in casting way leading one dims. + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + vector::ReductionOp::getCanonicalizationPatterns(patterns, context); + vector::TransposeOp::getCanonicalizationPatterns(patterns, context); + + // Decompose different rank insert_strided_slice and n-D + // extract_slided_slice. + vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( + patterns); + vector::InsertOp::getCanonicalizationPatterns(patterns, context); + vector::ExtractOp::getCanonicalizationPatterns(patterns, context); + + // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean + // them up. + vector::BroadcastOp::getCanonicalizationPatterns(patterns, context); + vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context); + + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return llvm::failure(); + } + return llvm::success(); +} + //===----------------------------------------------------------------------===// // SPIR-V TypeConverter //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp index ec67f85f6f27b..4a792336caba4 100644 --- a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp +++ b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp @@ -37,13 +37,8 @@ struct TestSPIRVFuncSignatureConversion final } void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateFuncOpVectorRewritePatterns(patterns); - populateReturnOpVectorRewritePatterns(patterns); - GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::ExistingOps; - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config); + Operation *op = getOperation(); + (void)spirv::unrollVectorsInSignatures(op); } }; diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp index fbf7933b244fa..0bad43d5214b1 100644 --- a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp +++ b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp @@ -34,78 +34,9 @@ struct TestSPIRVVectorUnrolling final } void runOnOperation() override { - MLIRContext *context = &getContext(); Operation *op = getOperation(); - - // Unroll vectors in function signatures to native vector size. - { - RewritePatternSet patterns(context); - populateFuncOpVectorRewritePatterns(patterns); - populateReturnOpVectorRewritePatterns(patterns); - GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::ExistingOps; - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) - return signalPassFailure(); - } - - // Unroll vectors to native vector size. - { - RewritePatternSet patterns(context); - auto options = vector::UnrollVectorOptions().setNativeShapeFn( - [=](auto op) { return mlir::spirv::getNativeVectorShape(op); }); - populateVectorUnrollPatterns(patterns, options); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) - return signalPassFailure(); - } - - // Convert transpose ops into extract and insert pairs, in preparation of - // further transformations to canonicalize/cancel. - { - RewritePatternSet patterns(context); - auto options = - vector::VectorTransformsOptions().setVectorTransposeLowering( - vector::VectorTransposeLowering::EltWise); - vector::populateVectorTransposeLoweringPatterns(patterns, options); - vector::populateVectorShapeCastLoweringPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { - return signalPassFailure(); - } - } - - // Run canonicalization to cast away leading size-1 dimensions. - { - RewritePatternSet patterns(context); - - // We need to pull in casting way leading one dims. - vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); - vector::ReductionOp::getCanonicalizationPatterns(patterns, context); - - // Decompose different rank insert_strided_slice and n-D - // extract_slided_slice. - vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( - patterns); - vector::ExtractOp::getCanonicalizationPatterns(patterns, context); - - // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean - // them up. - vector::BroadcastOp::getCanonicalizationPatterns(patterns, context); - vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context); - - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) - return signalPassFailure(); - } - - // Run all sorts of canonicalization patterns to clean up again. - { - RewritePatternSet patterns(context); - vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); - vector::InsertOp::getCanonicalizationPatterns(patterns, context); - vector::ExtractOp::getCanonicalizationPatterns(patterns, context); - vector::ReductionOp::getCanonicalizationPatterns(patterns, context); - vector::TransposeOp::getCanonicalizationPatterns(patterns, context); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) - return signalPassFailure(); - } + (void)spirv::unrollVectorsInSignatures(op); + (void)spirv::unrollVectorsInFuncBodies(op); } }; From 56597342b891e4a8b954eba61639a92234f60e89 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Wed, 24 Jul 2024 08:23:51 -0400 Subject: [PATCH 5/7] Update mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp Co-authored-by: Jakub Kuderski --- mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 9f2cab30d56f3..8fa5543ae2348 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1354,7 +1354,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { [](auto op) { return mlir::spirv::getNativeVectorShape(op); }); populateVectorUnrollPatterns(patterns, options); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) - return llvm::failure(); + return failure(); } // Convert transpose ops into extract and insert pairs, in preparation of From b75edd302c711562889482b9832018b38d37af3a Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Wed, 24 Jul 2024 13:01:13 +0000 Subject: [PATCH 6/7] Replace llvm::success and llvm::failure with success and failure --- mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 8fa5543ae2348..2ff4b55a9f1ed 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1366,7 +1366,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { vector::populateVectorTransposeLoweringPatterns(patterns, options); vector::populateVectorShapeCastLoweringPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) - return llvm::failure(); + return failure(); } // Run canonicalization to cast away leading size-1 dimensions. @@ -1391,9 +1391,9 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) - return llvm::failure(); + return failure(); } - return llvm::success(); + return success(); } //===----------------------------------------------------------------------===// From 8ef9ad24867a463e5f786b8f8ca7d3fe1e866200 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Wed, 24 Jul 2024 14:18:30 +0000 Subject: [PATCH 7/7] Refactoring and comment --- mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 2ff4b55a9f1ed..d833ec9309baa 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1103,10 +1103,11 @@ struct ReturnOpVectorUnroll final : OpRewritePattern { StaticTileOffsetRange(originalShape, *targetShape)) { Value result = rewriter.create( loc, returnValue, offsets, extractShape, strides); - SmallVector extractIndices(originalShape.size() - 1, 0); - if (originalShape.size() > 1) + if (originalShape.size() > 1) { + SmallVector extractIndices(originalShape.size() - 1, 0); result = rewriter.create(loc, result, extractIndices); + } newOperands.push_back(result); newTypes.push_back(unrolledType); } @@ -1320,7 +1321,7 @@ mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) { std::optional> mlir::spirv::getNativeVectorShape(Operation *op) { if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) { - if (auto vecType = llvm::dyn_cast(op->getResultTypes()[0])) { + if (auto vecType = dyn_cast(op->getResultTypes()[0])) { SmallVector nativeSize(vecType.getRank(), 1); nativeSize.back() = mlir::spirv::getComputeVectorSize(vecType.getShape().back()); @@ -1339,6 +1340,9 @@ LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) { RewritePatternSet patterns(context); populateFuncOpVectorRewritePatterns(patterns); populateReturnOpVectorRewritePatterns(patterns); + // We only want to apply signature conversion once to the existing func ops. + // Without specifying strictMode, the greedy pattern rewriter will keep + // looking for newly created func ops. GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps; return applyPatternsAndFoldGreedily(op, std::move(patterns), config);