@@ -1115,8 +1115,8 @@ func @fold_div_mixed() -> !shape.size {
11151115// CHECK-LABEL: @fold_index_cast_on_index
11161116func @fold_index_cast_on_index (%arg: index ) -> index {
11171117 // CHECK-NOT: size_to_index
1118- %casted = shape.size_to_index %arg : index
1119- return %casted : index
1118+ %0 = shape.size_to_index %arg : index
1119+ return %0 : index
11201120}
11211121
11221122// -----
@@ -1125,8 +1125,8 @@ func @fold_index_cast_on_index(%arg: index) -> index {
11251125// CHECK-LABEL: @fold_to_extent_tensor_on_tensor
11261126func @fold_to_extent_tensor_on_tensor (%arg: tensor <?xindex >) -> tensor <?xindex > {
11271127 // CHECK-NOT: to_extent_tensor
1128- %casted = shape.to_extent_tensor %arg : tensor <?xindex > -> tensor <?xindex >
1129- return %casted : tensor <?xindex >
1128+ %0 = shape.to_extent_tensor %arg : tensor <?xindex > -> tensor <?xindex >
1129+ return %0 : tensor <?xindex >
11301130}
11311131
11321132// -----
@@ -1264,9 +1264,9 @@ func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape {
12641264
12651265// -----
12661266
1267- // CHECK-LABEL: @casted_extent_tensor
1267+ // CHECK-LABEL: @cast_extent_tensor
12681268// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
1269- func @casted_extent_tensor (%arg : tensor <?x?x?xf32 >) -> tensor <?xindex > {
1269+ func @cast_extent_tensor (%arg : tensor <?x?x?xf32 >) -> tensor <?xindex > {
12701270 // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<?xindex>
12711271 // CHECK: return %[[RESULT]] : tensor<?xindex>
12721272 %0 = shape.shape_of %arg : tensor <?x?x?xf32 > -> tensor <3 xindex >
@@ -1276,9 +1276,9 @@ func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
12761276
12771277// -----
12781278
1279- // CHECK-LABEL: @casted_extent_tensor
1279+ // CHECK-LABEL: @cast_extent_tensor
12801280// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<3xindex>
1281- func @casted_extent_tensor (%arg : tensor <?x?x?xf32 >) -> tensor <3 xindex > {
1281+ func @cast_extent_tensor (%arg : tensor <?x?x?xf32 >) -> tensor <3 xindex > {
12821282 // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<3xindex>
12831283 // CHECK: return %[[RESULT]] : tensor<3xindex>
12841284 %0 = shape.shape_of %arg : tensor <?x?x?xf32 > -> tensor <?xindex >
@@ -1288,8 +1288,8 @@ func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
12881288
12891289// -----
12901290
1291- // CHECK-LABEL: @casted_extent_tensor
1292- func @casted_extent_tensor (%arg : tensor <?x?x?x?xf32 >) -> tensor <3 xindex > {
1291+ // CHECK-LABEL: @cast_extent_tensor
1292+ func @cast_extent_tensor (%arg : tensor <?x?x?x?xf32 >) -> tensor <3 xindex > {
12931293 // CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
12941294 %0 = shape.shape_of %arg : tensor <?x?x?x?xf32 > -> tensor <?xindex >
12951295 %1 = tensor.cast %0 : tensor <?xindex > to tensor <3 xindex >
@@ -1298,8 +1298,8 @@ func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
12981298
12991299// -----
13001300
1301- // CHECK-LABEL: @casted_extent_tensor
1302- func @casted_extent_tensor (%arg : tensor <*xf32 >) -> tensor <3 xindex > {
1301+ // CHECK-LABEL: @cast_extent_tensor
1302+ func @cast_extent_tensor (%arg : tensor <*xf32 >) -> tensor <3 xindex > {
13031303 // CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
13041304 %0 = shape.shape_of %arg : tensor <*xf32 > -> tensor <?xindex >
13051305 %1 = tensor.cast %0 : tensor <?xindex > to tensor <3 xindex >
@@ -1335,3 +1335,21 @@ func @cstr_broadcastable_folding(%arg : tensor<?x4xf32>) {
13351335 %2 = shape.cstr_broadcastable %0 , %1: tensor <2 xindex >, tensor <1 xindex >
13361336 " use" (%2 ) : (!shape.witness ) -> ()
13371337}
1338+
1339+ // -----
1340+
1341+ // CHECK-LABEL: @cast_extent_tensor_operands
1342+ // CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<3xindex>)
1343+ func @cast_extent_tensor_operands (%arg0 : tensor <?xindex >,
1344+ %arg1 : tensor <3 xindex >) -> (!shape.witness , tensor <?xindex >) {
1345+ // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?xindex> to tensor<3xindex>
1346+ // CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
1347+ // CHECK: %[[RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
1348+ // CHECK: return %[[WIT]], %[[RES]]
1349+ %0 = tensor.cast %arg0 : tensor <?xindex > to tensor <3 xindex >
1350+ %1 = tensor.cast %arg1 : tensor <3 xindex > to tensor <?xindex >
1351+ %2 = shape.cstr_broadcastable %0 , %1 : tensor <3 xindex >, tensor <?xindex >
1352+ %3 = shape.broadcast %0 , %1 :tensor <3 xindex >, tensor <?xindex >
1353+ -> tensor <?xindex >
1354+ return %2 , %3 : !shape.witness , tensor <?xindex >
1355+ }
0 commit comments