@@ -1335,26 +1335,30 @@ func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vten
1335
1335
1336
1336
// -----
1337
1337
// CHECK-LABEL: func.func @torch.aten.roll(
1338
- // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int, %[[ARG3:.*]]: !torch.int, %[[ARG4:.*]]: !torch.int ) -> !torch.vtensor<[?,?],f32> {
1338
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> {
1339
1339
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list<int>
1340
- // CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[ARG2]], %[[ARG3]] : (!torch.int, !torch.int) -> !torch.list<int>
1340
+ // CHECK: %[[INT1:.*]] = torch.constant.int 1
1341
+ // CHECK: %[[INT:.*]]-2 = torch.constant.int -2
1342
+ // CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT]]-2 : (!torch.int, !torch.int) -> !torch.list<int>
1341
1343
// CHECK: %[[NONE:.*]] = torch.constant.none
1342
1344
// CHECK: %[[INT0:.*]] = torch.constant.int 0
1343
- // CHECK: %[[INT1 :.*]] = torch.constant.int 1
1345
+ // CHECK: %[[INT1_0 :.*]] = torch.constant.int 1
1344
1346
// CHECK: %[[T2:.*]] = torch.aten.neg.int %[[ARG1]] : !torch.int -> !torch.int
1345
- // CHECK: %[[T3:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[ARG2 ]], %[[T2]], %[[NONE]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
1346
- // CHECK: %[[T4:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[ARG2 ]], %[[INT0]], %[[T2]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
1347
+ // CHECK: %[[T3:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1 ]], %[[T2]], %[[NONE]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
1348
+ // CHECK: %[[T4:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1 ]], %[[INT0]], %[[T2]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
1347
1349
// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T3]], %[[T4]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor<[?,?],f32>>
1348
- // CHECK: %[[T6:.*]] = torch.aten.cat %[[T5]], %[[ARG2 ]] : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
1350
+ // CHECK: %[[T6:.*]] = torch.aten.cat %[[T5]], %[[INT1 ]] : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
1349
1351
// CHECK: %[[T7:.*]] = torch.aten.neg.int %[[ARG2]] : !torch.int -> !torch.int
1350
- // CHECK: %[[T8:.*]] = torch.aten.slice.Tensor %[[T6]], %[[ARG3]] , %[[T7]], %[[NONE]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
1351
- // CHECK: %[[T9:.*]] = torch.aten.slice.Tensor %[[T6]], %[[ARG3]] , %[[INT0]] , %[[T7]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
1352
+ // CHECK: %[[T8:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2 , %[[T7]], %[[NONE]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
1353
+ // CHECK: %[[T9:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2 , %[[INT]]0 , %[[T7]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
1352
1354
// CHECK: %[[T10:.*]] = torch.prim.ListConstruct %[[T8]], %[[T9]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor<[?,?],f32>>
1353
- // CHECK: %[[T11:.*]] = torch.aten.cat %[[T10]], %[[ARG3]] : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
1355
+ // CHECK: %[[T11:.*]] = torch.aten.cat %[[T10]], %[[INT]]-2 : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
1354
1356
// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32>
1355
- func.func @torch.aten.roll (%arg0: !torch.vtensor <[?,?],f32 >, %arg1: !torch.int , %arg2: !torch.int , %arg3: !torch.int , %arg4: !torch.int ) -> !torch.vtensor <[?,?],f32 > {
1357
+ func.func @torch.aten.roll (%arg0: !torch.vtensor <[?,?],f32 >, %arg1: !torch.int , %arg2: !torch.int ) -> !torch.vtensor <[?,?],f32 > {
1356
1358
%0 = torch.prim.ListConstruct %arg1 , %arg2: (!torch.int , !torch.int ) -> !torch.list <int >
1357
- %1 = torch.prim.ListConstruct %arg2 , %arg3: (!torch.int , !torch.int ) -> !torch.list <int >
1359
+ %int1 = torch.constant.int 1
1360
+ %int -2 = torch.constant.int -2
1361
+ %1 = torch.prim.ListConstruct %int1 , %int -2 : (!torch.int , !torch.int ) -> !torch.list <int >
1358
1362
%2 = torch.aten.roll %arg0 , %0 , %1 : !torch.vtensor <[?,?],f32 >, !torch.list <int >, !torch.list <int > -> !torch.vtensor <[?,?],f32 >
1359
1363
return %2 : !torch.vtensor <[?,?],f32 >
1360
1364
}
0 commit comments