diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index 676da6d176497..e30950bbf292d 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -354,6 +354,24 @@ class AffineMap { /// returns the resulting values. `this` must be symbol-less. SmallVector compose(ArrayRef values) const; + /// Returns the number of "zero" results (constant values == 0) in this map. + /// + /// Example: + /// * For `(d0, d1) -> (d0, d1, 0)` returns 1 + /// * For `(d0, d1, d2) -> (d0, d1)` returns 0 + /// * For `(d0, d1, d2) -> (d0, 0, d1, 0, d2)` returns 2 + size_t getNumOfZeroResults() const; + + /// Returns the AffineMap resulting from removing "zero" results (constant + /// values == 0) from this map. + /// + /// Example: + /// * For `(d0, d1) -> (d0, d1, 0)` returns `(d0, d1) -> (d0, d1)` + /// * For `(d0, d1, d2) -> (d0, d1)` returns `(d0, d1, d2) -> (d0, d1)` + /// * For `(d0, d1, d2) -> (d0, 0, d1, 0, d2)` returns + /// `(d0, d1, d2) -> (d0, d1, d2)` + AffineMap dropZeroResults(); + /// Returns true if the AffineMap represents a subset (i.e. a projection) of a /// symbol-less permutation map. `allowZeroInResults` allows projected /// permutation maps with constant zero result expressions. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 63dcda78d0f2b..d712ab8faa6cb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -224,10 +224,10 @@ struct VectorizationState { /// Masks an operation with the canonical vector mask if the operation needs /// masking. Returns the masked operation or the original operation if masking /// is not needed. If provided, the canonical mask for this operation is - /// permuted using `maybeMaskingMap`. + /// permuted using `maybeIndexingMap`. Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, - std::optional maybeMaskingMap = std::nullopt); + std::optional maybeIndexingMap = std::nullopt); private: /// Initializes the iteration space static sizes using the Linalg op @@ -422,16 +422,28 @@ Value VectorizationState::getOrCreateMaskFor( return mask; } -/// Masks an operation with the canonical vector mask if the operation needs -/// masking. Returns the masked operation or the original operation if masking -/// is not needed. If provided, the canonical mask for this operation is -/// permuted using `maybeMaskingMap`. Operation * VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, - std::optional maybeMaskingMap) { + std::optional maybeIndexingMap) { LDBG("Trying to mask: " << *opToMask << "\n"); + std::optional maybeMaskingMap = std::nullopt; + // The Operand indexing map may contain "zero" results, e.g.: + // (d0, d1, d2, d3) -> (d0, d1, d2, 0) + // When applied to canonical vector shapes like these: + // (1, 16, 16, 4) + // we would get: + // (1, 16, 16, 0) + // Instead, we should extract the following map permutation map for masking: + // (d0, d1, d2, d3) -> (d0, d1, d2) + // This way, the corresponding vector/mask type will be: + // vector<1x16x16xty> + // rather than: + // vector<1x16x16x0xty> + if (maybeIndexingMap) + maybeMaskingMap = maybeIndexingMap->dropZeroResults(); + // Create or retrieve mask for this operation. Value mask = getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap); @@ -476,7 +488,8 @@ static AffineMap reindexIndexingMap(AffineMap map) { assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) && "expected projected permutation"); auto res = compressUnusedDims(map); - assert(res.getNumDims() == res.getNumResults() && + assert(res.getNumDims() == + (res.getNumResults() - res.getNumOfZeroResults()) && "expected reindexed map with same number of dims and results"); return res; } @@ -1317,16 +1330,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, // permutation map and masking map. AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); - // Remove zeros from indexing map to use it as masking map. - SmallVector zeroPos; - auto results = indexingMap.getResults(); - for (const auto &result : llvm::enumerate(results)) { - if (isa(result.value())) { - zeroPos.push_back(result.index()); - } - } - AffineMap maskingMap = indexingMap.dropResults(zeroPos); - AffineMap readMap; VectorType readType; Type elemType = getElementTypeOrSelf(opOperand->get()); @@ -1356,7 +1359,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, Operation *read = rewriter.create( loc, readType, opOperand->get(), indices, readMap, ArrayRef(inBounds)); - read = state.maskOperation(rewriter, read, linalgOp, maskingMap); + read = state.maskOperation(rewriter, read, linalgOp, indexingMap); Value readValue = read->getResult(0); // 3.b. If masked, set in-bounds to true. Masking guarantees that the access diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 5cbd0b090492b..ea3c0723b0775 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -592,6 +592,29 @@ SmallVector AffineMap::compose(ArrayRef values) const { return res; } +size_t AffineMap::getNumOfZeroResults() const { + size_t res = 0; + for (auto expr : getResults()) { + auto constExpr = dyn_cast(expr); + if (constExpr && constExpr.getValue() == 0) + res++; + } + + return res; +} + +AffineMap AffineMap::dropZeroResults() { + auto exprs = llvm::to_vector(getResults()); + SmallVector newExprs; + + for (auto expr : getResults()) { + auto constExpr = dyn_cast(expr); + if (!constExpr || constExpr.getValue() != 0) + newExprs.push_back(expr); + } + return AffineMap::get(getNumDims(), getNumSymbols(), newExprs, getContext()); +} + bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const { if (getNumSymbols() > 0) return false; diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir index 3404b73102e6a..9a43d43cd9460 100644 --- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir @@ -1964,3 +1964,43 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 0] : vector<1x4xf32> to vector<4x1xf32> // CHECK: vector.transfer_write %[[VAL_8]], %{{.*}} {in_bounds = [true, true]} : vector<4x1xf32>, tensor<4x1xf32> // CHECK: vector.transfer_write %[[VAL_7]], %{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> + +// ----- + +// Extracted from: https://github.com/llvm/llvm-project/issues/97247 + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)> + +func.func @generic_with_reduction_and_broadcast(%arg0: tensor<1x12x197x197xf32>) -> (tensor<1x12x197x1xf32>) { + %0 = tensor.empty() : tensor<1x12x197x1xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x12x197x197xf32>) outs(%0 : tensor<1x12x197x1xf32>) { + ^bb0(%in: f32, %out: f32): + %818 = arith.addf %in, %out : f32 + linalg.yield %818 : f32 + } -> tensor<1x12x197x1xf32> + return %1 : tensor<1x12x197x1xf32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK: #[[$ATTR_32:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @generic_with_reduction_and_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x197x197xf32>) -> tensor<1x12x197x1xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<1x12x197x1xf32> +// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true, true, true, true]} : tensor<1x12x197x197xf32>, vector<1x12x197x197xf32> +// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_32]]} : tensor<1x12x197x1xf32>, vector<1x12x197xf32> +// CHECK: %[[VAL_6:.*]] = vector.multi_reduction , %[[VAL_4]], %[[VAL_5]] [3] : vector<1x12x197x197xf32> to vector<1x12x197xf32> +// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_6]] : vector<1x12x197xf32> to vector<1x1x12x197xf32> +// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 2, 3, 0] : vector<1x1x12x197xf32> to vector<1x12x197x1xf32> +// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]] {in_bounds = [true, true, true, true]} : vector<1x12x197x1xf32>, tensor<1x12x197x1xf32> +// CHECK: return %[[VAL_9]] : tensor<1x12x197x1xf32> diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index 783149971f0d6..0e2b2458d29cd 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -147,6 +147,51 @@ module attributes {transform.with_named_sequence} { // ----- +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0, 0)> + +func.func @dynamic_generic_with_reduction_and_broadcast(%arg0: tensor, %init: tensor) -> (tensor) { + %0 = linalg.generic { indexing_maps = [#map, #map1], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) + outs(%init : tensor) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0)> + +// CHECK-LABEL: func.func @dynamic_generic_with_reduction_and_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_8:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]] : vector<4x4xi1> +// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_7]] {in_bounds = [true, true]} : tensor, vector<4x4xf32> } : vector<4x4xi1> -> vector<4x4xf32> +// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_3]] : vector<4xi1> +// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_10]] {in_bounds = [true], permutation_map = #[[$MAP]]} : tensor, vector<4xf32> } : vector<4xi1> -> vector<4xf32> +// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.multi_reduction , %[[VAL_9]], %[[VAL_12]] [1] : vector<4x4xf32> to vector<4xf32> } : vector<4x4xi1> -> vector<4xf32> +// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_11]] { vector.transfer_write %[[VAL_13]], %[[VAL_1]]{{\[}}%[[VAL_14]], %[[VAL_14]]] {in_bounds = [true], permutation_map = #[[$MAP]]} : vector<4xf32>, tensor } : vector<4xi1> -> tensor +// CHECK: return %[[VAL_15]] : tensor + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [4, 4] : !transform.any_op + transform.yield + } +} + +// ----- + func.func @vectorize_dynamic_2d_transpose(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor {