@@ -129,6 +129,26 @@ func.func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
129129
130130// -----
131131
132+ // CHECK-LABEL: @static_rank_reduce
133+ // CHECK-SAME: %[[ARG0:.*]]: tensor<8x16x4xf32>, %[[PADVAL:.*]]: f32
134+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 14, 4] [1, 1, 1] : tensor<8x16x4xf32> to tensor<1x14x4xf32>
135+ // CHECK: %[[PADDED:.*]] = tensor.pad %[[SLICE]] low[0, 2, 0] high[0, 0, 0] {
136+ // CHECK: } : tensor<1x14x4xf32> to tensor<1x16x4xf32>
137+ // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[PADDED]][0, 0, 0] [1, 16, 4] [1, 1, 1] : tensor<1x16x4xf32> to tensor<16x4xf32>
138+ // CHECK: return %[[RESULT]]
139+ func.func @static_rank_reduce (%arg0: tensor <8 x16 x4 xf32 >, %pad: f32 )
140+ -> tensor <16 x4 xf32 > {
141+ %0 = tensor.pad %arg0 low [0 , 2 , 0 ] high [0 , 0 , 0 ] {
142+ ^bb0 (%i: index , %j: index , %k: index ):
143+ tensor.yield %pad : f32
144+ } : tensor <8 x16 x4 xf32 > to tensor <8 x18 x4 xf32 >
145+ %1 = tensor.extract_slice %0 [0 , 0 , 0 ] [1 , 16 , 4 ] [1 , 1 , 1 ]
146+ : tensor <8 x18 x4 xf32 > to tensor <16 x4 xf32 >
147+ return %1 : tensor <16 x4 xf32 >
148+ }
149+
150+ // -----
151+
132152// CHECK-LABEL: @dynamic_high_pad
133153// CHECK-SAME: %[[ARG0:.*]]: tensor<?x5xf32>
134154// CHECK-NOT: tensor.pad
@@ -217,6 +237,27 @@ func.func @dynamic_zero_high_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
217237 return %1 : tensor <?x?xf32 >
218238}
219239
240+ // -----
241+
242+ // CHECK-LABEL: @dynamic_rank_reduce
243+ // CHECK: %[[TEMP:.*]] = scf.if %{{.*}} -> (tensor<1x4xf32>) {
244+ // CHECK: tensor.generate
245+ // CHECK: } else {
246+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %{{.*}} : tensor<?x5xf32> to tensor<?x1xf32>
247+ // CHECK: tensor.pad %[[SLICE]] low[0, 0] high[%{{.*}}, 3] {
248+ // CHECK: } : tensor<?x1xf32> to tensor<1x4xf32>
249+ // CHECK: }
250+ // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[TEMP]]{{.*}} : tensor<1x4xf32> to tensor<4xf32>
251+ // CHECK: return %[[RESULT]]
252+ func.func @dynamic_rank_reduce (%arg0 : tensor <?x5 xf32 >, %s1: index , %pad : f32 ) -> tensor <4 xf32 > {
253+ %0 = tensor.pad %arg0 low [0 , 0 ] high [7 , 8 ] {
254+ ^bb0 (%arg1: index , %arg2: index ):
255+ tensor.yield %pad : f32
256+ } : tensor <?x5 xf32 > to tensor <?x13 xf32 >
257+ %1 = tensor.extract_slice %0 [2 , 4 ] [1 , 4 ] [1 , 1 ] : tensor <?x13 xf32 > to tensor <4 xf32 >
258+ return %1 : tensor <4 xf32 >
259+ }
260+
220261// -----
221262// CHECK-LABEL: @nopaddim_with_dynamic_extract(
222263// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4x5xf32>
0 commit comments