@@ -807,56 +807,41 @@ module attributes {transform.with_named_sequence} {
807
807
808
808
// -----
809
809
810
- func.func @vectorize_scalar_broadcast_column_tensor ( %in : tensor <1 x1 x4 xi32 >) -> tensor <1 x1 x4 xi32 > {
810
+ func.func @vectorize_scalar_read_with_broadcast_from_column_tensor ( %init : tensor <1 x1 x4 xi32 >) -> tensor <1 x1 x4 xi32 > {
811
811
%c4 = arith.constant 4 : index
812
812
%c0 = arith.constant 0 : index
813
- %cst = arith.constant dense <[[0 ], [1 ], [2 ], [3 ], [4 ], [5 ], [6 ], [7 ], [8 ], [9 ], [10 ], [11 ], [12 ], [13 ], [14 ]]> : tensor <15 x1 xi32 >
814
-
815
- %out = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>], iterator_types = [" parallel" , " parallel" , " parallel" ]} outs (%in : tensor <1 x1 x4 xi32 >) {
816
- ^bb0 (%out: i32 ):
817
- %8 = linalg.index 0 : index
818
- %idx_0 = linalg.index 0 : index
819
- %extracted = tensor.extract %cst [%idx_0 , %c0 ] : tensor <15 x1 xi32 >
820
- linalg.yield %extracted : i32
813
+ %src = arith.constant dense <[[0 ], [1 ], [2 ], [3 ], [4 ], [5 ], [6 ], [7 ], [8 ], [9 ], [10 ], [11 ], [12 ], [13 ], [14 ]]> : tensor <15 x1 xi32 >
814
+
815
+ %res = linalg.generic {
816
+ indexing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>],
817
+ iterator_types = [" parallel" , " parallel" , " parallel" ]}
818
+ outs (%init : tensor <1 x1 x4 xi32 >) {
819
+
820
+ ^bb0 (%out: i32 ):
821
+ %idx = linalg.index 0 : index
822
+ %extracted = tensor.extract %src [%idx , %c0 ] : tensor <15 x1 xi32 >
823
+ linalg.yield %extracted : i32
821
824
} -> tensor <1 x1 x4 xi32 >
822
825
823
- return %out: tensor <1 x1 x4 xi32 >
826
+ return %res : tensor <1 x1 x4 xi32 >
824
827
}
825
828
826
- // CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
827
- // CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
828
- // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
829
- // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
830
- // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
831
- // CHECK: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
832
- // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
833
- // CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
834
- // CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
835
- // CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
836
- // CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
837
- // CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
838
- // CHECK: %[[VAL_10:.*]] = vector.step : vector<1xindex>
839
- // CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
840
- // CHECK: %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
841
- // CHECK: %[[VAL_13:.*]] = vector.step : vector<1xindex>
842
- // CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
843
- // CHECK: %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
844
- // CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
845
- // CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
846
- // CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
847
- // CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
848
- // CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_19]][0] : index from vector<4xindex>
849
- // CHECK: %[[VAL_21:.*]] = arith.constant 0 : i32
850
- // CHECK: %[[VAL_22:.*]] = vector.constant_mask [1] : vector<1xi1>
851
- // CHECK: %[[VAL_23:.*]] = vector.mask %[[VAL_22]] { vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_20]], %[[VAL_2]]], %[[VAL_21]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<15x1xi32>, vector<1x1x4xi32> } : vector<1xi1> -> vector<1x1x4xi32>
852
- // CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
853
- // CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
854
- // CHECK: return %[[VAL_25]] : tensor<1x1x4xi32>
829
+ // CHECK-LABEL: func.func @vectorize_scalar_read_with_broadcast_from_column_tensor(
830
+ // CHECK-SAME: %[[INIT:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
831
+ // CHECK: %[[PAD:.*]] = arith.constant 0 : i32
832
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
833
+ // CHECK: %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
834
+ // CHECK: %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
835
+ // CHECK: %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
836
+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
837
+ // CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
838
+ // CHECK: %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
855
839
856
840
module attributes {transform.with_named_sequence } {
857
841
transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
858
- %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
859
- transform.structured.vectorize %0 vector_sizes [1 , 1 , 4 ]{ vectorize_nd_extract } : !transform.any_op
842
+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
843
+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
844
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op ) -> !transform.any_op
860
845
transform.yield
861
846
}
862
847
}
0 commit comments