@@ -66,7 +66,7 @@ module attributes {transform.with_named_sequence} {
6666// -----
6767
6868#map = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
69- func.func @vectorize_nd_tensor_extract_constant_idx (%arg0: tensor <3 x3 xf32 >, %arg2: tensor <1 x1 x3 xf32 >) -> tensor <1 x1 x3 xf32 > {
69+ func.func @vectorize_nd_tensor_extract_scalar_broadcast (%arg0: tensor <3 x3 xf32 >, %arg2: tensor <1 x1 x3 xf32 >) -> tensor <1 x1 x3 xf32 > {
7070 %c0 = arith.constant 1 : index
7171 %c1 = arith.constant 2 : index
7272 %2 = linalg.generic {
@@ -80,17 +80,17 @@ func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg
8080 return %2 : tensor <1 x1 x3 xf32 >
8181}
8282
83- // CHECK: #[[$MAP:.* ]] = affine_map<(d0, d1) -> (0, 0, 0)>
84- // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx (
83+ // CHECK: #[[$MAP:.+ ]] = affine_map<(d0, d1) -> (0, 0, 0)>
84+ // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_scalar_broadcast (
8585// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x3xf32>,
8686// CHECK-SAME: %[[ARG_1:.*]]: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
8787// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
8888// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
89- // CHECK-DAG: %[[C0_f32_2 :.*]] = arith.constant 0.000000e+00 : f32
90- // CHECK-DAG : %[[C0_f32 :.*]] = arith.constant 0.000000e+00 : f32
91- // CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG_0]][%[[C1]], %[[C2]]], %[[C0_f32]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32>
92- // CHECK: %[[C0_4 :.*]] = arith.constant 0 : index
93- // CHECK: vector.transfer_write %[[READ]], %[[ARG_1]][ %[[C0_4 ]], %[[C0_4 ]], %[[C0_4 ]]] : vector<1x1x3xf32>, tensor<1x1x3xf32>
89+ // CHECK-DAG: %[[C0 :.*]] = arith.constant 0 : index
90+ // CHECK: %[[MASK :.*]] = vector.constant_mask [1] : vector<1xi1>
91+ // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C1]], %[[C2]]], {{.*}} {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32> } : vector<1xi1> -> vector<1x1x3xf32>
92+ // CHECK: %[[C0_2 :.*]] = arith.constant 0 : index
93+ // CHECK: vector.transfer_write %[[READ]], %[[ARG_1]]{{\[}} %[[C0_2 ]], %[[C0_2 ]], %[[C0_2 ]]] : vector<1x1x3xf32>, tensor<1x1x3xf32>
9494
9595module attributes {transform.with_named_sequence } {
9696 transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
@@ -823,7 +823,7 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
823823 return %out:tensor <1 x1 x4 xi32 >
824824}
825825
826- // CHECK: #[[$ATTR_1 :.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
826+ // CHECK: #[[$MAP :.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
827827// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
828828// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
829829// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
@@ -844,12 +844,14 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
844844// CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
845845// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
846846// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
847- // CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
848- // CHECK: %[[VAL_21:.*]] = vector.extract %[[VAL_20]][0] : index from vector<4xindex>
849- // CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32
850- // CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
847+ // CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
848+ // CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_19]][0] : index from vector<4xindex>
849+ // CHECK: %[[VAL_21:.*]] = arith.constant 0 : i32
850+ // CHECK: %[[VAL_22:.*]] = vector.constant_mask [1] : vector<1xi1>
851+ // CHECK: %[[VAL_23:.*]] = vector.mask %[[VAL_22]] { vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_20]], %[[VAL_2]]], %[[VAL_21]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<15x1xi32>, vector<1x1x4xi32> } : vector<1xi1> -> vector<1x1x4xi32>
851852// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
852853// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
854+ // CHECK: return %[[VAL_25]] : tensor<1x1x4xi32>
853855
854856module attributes {transform.with_named_sequence } {
855857 transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
0 commit comments