From 1e6c8b3085b397fedfdba36e568230ab40bde68f Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Sat, 7 Dec 2024 16:11:55 +0000 Subject: [PATCH] [mlir][linalg] Enable Vectorization of 0-D tensor.extract This patch removes an assert in `vectorizeTensorExtract` that was blocking the vectorization of 0-D tensor.extract operations, e.g.: ```mlir %1 = tensor.extract %src[] : tensor ``` As demonstrated by the included tests, this case is already effectively supported. **Context** The removed assert was introduced in #109580 as a guard, pending proper support and testing for 0-D tensors. This PR addresses that previously undocumented TODO. Apologies for the oversight! **Updates and Tests** * Revised the existing test `@negative_no_loop` to ensure the `vectorize_nd_extract` attribute is included, allowing the vectorizer to process it. The test was renamed and variables updated for clarity. * Added a new test `@extract_scalar_from_0d_into_1d` to cover "mixed" 0-D/1-D tensor extraction, e.g.: ```mlir %res = linalg.generic { indexing_maps = [#map], iterator_types = ["parallel"] } outs(%init : tensor<1xf32>) { ^bb0(%in: f32): %1 = tensor.extract %src[] : tensor linalg.yield %1 : f32 } -> tensor<1xf32> return %res : tensor<1xf32> ``` **Additional updates** I also took the liberty and improved test coverage for 0-D tensor in the vectorizer tests: * Added a specific test for "0D linalg.generic" in "vectorization-with-patterns.mlir". * Renamed several tests in "vectorization-with-patterns.mlir" to clarify that the 0-D case is now covered. --- .../Linalg/Transforms/Vectorization.cpp | 5 -- .../Linalg/vectorization-with-patterns.mlir | 48 ++++++++++++++- .../Linalg/vectorize-tensor-extract.mlir | 58 +++++++++++++++---- 3 files changed, 93 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index e5c96b52acee2..863f2280e46ce 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1115,11 +1115,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, // b. contiguous loads. // Both cases use vector.transfer_read. - assert(llvm::count_if(resultType.getShape(), - [](uint64_t dim) { return dim != 1; }) && - "Contiguous loads and scalar loads + broadcast only support 1-D " - "vectors ATM!"); - // Collect indices for `vector.transfer_read`. At this point, the indices will // either be scalars or would have been broadcast to vectors matching the // result type. For indices that are vectors, there are two options: diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir index 0c996bed996d3..b688a677500c2 100644 --- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir @@ -122,6 +122,48 @@ module attributes {transform.with_named_sequence} { // ----- +#map = affine_map<() -> ()> + +// CHECK-LABEL: func.func @generic_0d( +// CHECK-SAME: %[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor, %[[ARG_2:.*]]: tensor) +func.func @generic_0d(%arg0: tensor, %arg1: tensor, + %arg2: tensor) -> tensor { +// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[READ_0:.*]] = vector.transfer_read %[[ARG_0]][], %[[PAD]] : tensor, vector +// CHECK: %[[ARG_0_AS_SCALAR:.*]] = vector.extract %[[READ_0]][] : f32 from vector +// CHECK: %[[READ_1:.*]] = vector.transfer_read %[[ARG_1]][], %[[PAD]] : tensor, vector +// CHECK: %[[ARG_1_AS_SCALAR:.*]] = vector.extract %[[READ_1]][] : f32 from vector +// CHECK: %[[READ_2:.*]] = vector.transfer_read %[[ARG_2]][], %[[PAD]] : tensor, vector +// CHECK: %[[ARG_2_AS_SCALAR:.*]] = vector.extract %[[READ_2]][] : f32 from vector +// CHECK: %[[MULF:.*]] = arith.mulf %[[ARG_0_AS_SCALAR]], %[[ARG_1_AS_SCALAR]] : f32 +// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_2_AS_SCALAR]], %[[MULF]] : f32 +// CHECK: %[[ADDF_BCAST:.*]] = vector.broadcast %[[ADDF]] : f32 to vector +// CHECK: vector.transfer_write %[[ADDF_BCAST]], %[[ARG_2]][] : vector, tensor + %res = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = [] + } ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) { + ^bb(%a: f32, %b: f32, %c: f32) : + %d = arith.mulf %a, %b: f32 + %e = arith.addf %c, %d: f32 + linalg.yield %e : f32 + } -> tensor + + return %res : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + #matmul_transpose_out_trait = { indexing_maps = [ affine_map<(m, n, k) -> (m, k)>, @@ -372,7 +414,7 @@ module attributes {transform.with_named_sequence} { // ----- // CHECK-LABEL: func @test_vectorize_fill -func.func @test_vectorize_fill_scalar(%A : memref, %arg0 : f32) { +func.func @test_vectorize_fill_0d(%A : memref, %arg0 : f32) { // CHECK-SAME: (%[[M:.*]]: memref, %[[val:.*]]: f32) // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector // CHECK: vector.transfer_write %[[VEC]], %[[M]][] : vector, memref @@ -410,8 +452,8 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @test_vectorize_copy_scalar -func.func @test_vectorize_copy_scalar(%A : memref, %B : memref) { +// CHECK-LABEL: func @test_vectorize_copy_0d +func.func @test_vectorize_copy_0d(%A : memref, %B : memref) { // CHECK-SAME: (%[[A:.*]]: memref, %[[B:.*]]: memref) // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref, vector // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir index 1a93d1cd9b788..775ceed31be04 100644 --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -39,29 +39,67 @@ module attributes {transform.with_named_sequence} { // ----- #map = affine_map<() -> ()> -func.func @negative_no_loops(%arg0: tensor, %arg1: tensor) -> tensor { - %1 = linalg.generic { +func.func @extract_scalar_from_0d_into_0d(%src: tensor, %init: tensor) -> tensor { + %res = linalg.generic { indexing_maps = [#map], iterator_types = [] - } outs(%arg1 : tensor) { - ^bb0(%arg4: f32): - %2 = tensor.extract %arg0[] : tensor - linalg.yield %2 : f32 + } outs(%init : tensor) { + ^bb0(%in: f32): + %1 = tensor.extract %src[] : tensor + linalg.yield %1 : f32 } -> tensor - return %1 : tensor + + return %res : tensor } -// CHECK-LABEL: func.func @negative_no_loops -// CHECK: tensor.extract + +// CHECK-LABEL: func.func @extract_scalar_from_0d_into_0d( +// CHECK-SAME: %[[SRC:.*]]: tensor, +// CHECK-SAME: %[[INIT:.*]]: tensor) -> tensor { +// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]][], %[[PAD]] : tensor, vector +// CHECK: vector.transfer_write %[[READ]], %[[INIT]][] : vector, tensor module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op transform.yield } } +// ----- + +#map = affine_map<(n) -> (n)> +func.func @extract_scalar_from_0d_into_1d(%src: tensor, %init: tensor<1xf32>) -> tensor<1xf32> { + %res = linalg.generic { + indexing_maps = [#map], + iterator_types = ["parallel"] + } outs(%init : tensor<1xf32>) { + ^bb0(%in: f32): + %1 = tensor.extract %src[] : tensor + linalg.yield %1 : f32 + } -> tensor<1xf32> + + return %res : tensor<1xf32> +} +// CHECK-LABEL: func.func @extract_scalar_from_0d_into_1d( +// CHECK-SAME: %[[SRC:.*]]: tensor, +// CHECK-SAME: %[[INIT:.*]]: tensor<1xf32>) -> tensor<1xf32> { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]][], %[[PAD]] : tensor, vector +// CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector to vector<1xf32> +// CHECK: vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]]] {in_bounds = [true]} : vector<1xf32>, tensor<1xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op + transform.yield + } +} // -----