Skip to content

Commit 8e9a7ec

Browse files
author
Tanyo Kwok
committed
add e2e unittest
1 parent 6847438 commit 8e9a7ec

File tree

5 files changed

+68
-21
lines changed

5 files changed

+68
-21
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
709709
};
710710
} // namespace
711711

712-
// Decompose aten.roll into aten.expand and aten.slice and aten.cat ops.
712+
// Decompose aten.roll into aten.slice and aten.cat ops.
713713
// https://pytorch.org/docs/stable/generated/torch.roll.html
714714
namespace {
715715
class DecomposeAtenRollOp : public OpRewritePattern<AtenRollOp> {
@@ -736,28 +736,43 @@ class DecomposeAtenRollOp : public OpRewritePattern<AtenRollOp> {
736736
Value constOne = rewriter.create<Torch::ConstantIntOp>(
737737
loc, rewriter.getI64IntegerAttr(1));
738738
auto self = op.self();
739-
Type listType = Torch::ListType::get(self.getType());
739+
auto selfTy = self.getType().cast<BaseTensorType>();
740740
// roll(input, shift, dim) = cat({
741741
// slice(input, dim, -shift, none),
742742
// slice(input, dim, 0, -shift)}, dim)
743-
auto ImitateRoll = [&](Value input, Value shift, Value dim) {
743+
auto imitateRoll = [&](Value input, Value shift, Value dim,
744+
int64_t cstDim) {
744745
Value negShift = rewriter.create<AtenNegIntOp>(loc, shift);
745-
Type sliceType = computeReductionType(
746-
rewriter, op, self.getType().cast<BaseTensorType>(), dim,
747-
/*keepDim=*/true);
746+
ArrayRef<int64_t> inputShape = selfTy.getSizes();
747+
SmallVector<int64_t> sizes;
748+
sizes.append(inputShape.begin(), inputShape.end());
749+
sizes[cstDim] = ShapedType::kDynamicSize;
750+
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
751+
selfTy.getDtype());
748752
Value slice0 = rewriter.create<AtenSliceTensorOp>(
749-
loc, sliceType, input, dim, negShift, constNone, constOne);
753+
loc, sliceTy, input, dim, negShift, constNone, constOne);
750754
Value slice1 = rewriter.create<AtenSliceTensorOp>(
751-
loc, sliceType, input, dim, constZero, negShift, constOne);
755+
loc, sliceTy, input, dim, constZero, negShift, constOne);
752756

757+
Type listType = Torch::ListType::get(sliceTy);
753758
Value slices = rewriter.create<PrimListConstructOp>(
754759
loc, listType, llvm::ArrayRef<Value>{slice0, slice1});
755760
return rewriter.create<AtenCatOp>(loc, self.getType(), slices, dim);
756761
};
757-
auto output = self;
762+
int rank = getTensorRank(self);
763+
if (rank < 0)
764+
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
765+
Value output = self;
758766
auto nShifts = shifts.size();
759767
for (size_t k = 0; k < nShifts; ++k) {
760-
output = ImitateRoll(output, shifts[k], dims[k]);
768+
auto dim = dims[k];
769+
int64_t cstDim = -1;
770+
if (!matchPattern(dim, m_TorchConstantInt(&cstDim)))
771+
return rewriter.notifyMatchFailure(
772+
op, "unimplemented: dim must be constant");
773+
774+
cstDim = toPositiveDim(cstDim, rank);
775+
output = imitateRoll(output, shifts[k], dim, cstDim);
761776
}
762777
rewriter.replaceOp(op, output);
763778
return success();

lib/Dialect/Torch/Transforms/ShapeLibrary.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4213,6 +4213,10 @@ module {
42134213
}
42144214
return %7 : !torch.list<int>
42154215
}
4216+
func.func @__torch_mlir_shape_fn.aten.roll(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
4217+
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
4218+
return %0 : !torch.list<int>
4219+
}
42164220
func.func @__torch__.torch.jit._shape_functions.expand(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
42174221
%int-1 = torch.constant.int -1
42184222
%true = torch.constant.bool true

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,9 @@ def aten〇repeat(self: List[int], repeats: List[int]) -> List[int]:
611611
out.append(self[i] * repeats[i + leading_rank])
612612
return out
613613

614+
def aten〇roll(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]:
615+
return upstream_shape_functions.unary(self)
616+
614617
def aten〇expand(self: List[int], size: List[int], implicit: bool = False) -> List[int]:
615618
return upstream_shape_functions.expand(self, size)
616619

python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,27 @@ def BroadcastToModule_basic(module, tu: TestUtils):
10471047
# ==============================================================================
10481048

10491049

1050+
class RollModule(torch.nn.Module):
1051+
1052+
def __init__(self):
1053+
super().__init__()
1054+
1055+
@export
1056+
@annotate_args([
1057+
None,
1058+
([3, -1, 2], torch.float32, True),
1059+
])
1060+
def forward(self, x):
1061+
return x.roll([2, -1], [0, 2])
1062+
1063+
1064+
@register_test_case(module_factory=lambda: RollModule())
1065+
def RollModule_basic(module, tu: TestUtils):
1066+
module.forward(tu.rand(3, 1, 2))
1067+
1068+
# ==============================================================================
1069+
1070+
10501071
class RepeatModule(torch.nn.Module):
10511072

10521073
def __init__(self):

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,26 +1335,30 @@ func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vten
13351335

13361336
// -----
13371337
// CHECK-LABEL: func.func @torch.aten.roll(
1338-
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int, %[[ARG3:.*]]: !torch.int, %[[ARG4:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> {
1338+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> {
13391339
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list<int>
1340-
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[ARG2]], %[[ARG3]] : (!torch.int, !torch.int) -> !torch.list<int>
1340+
// CHECK: %[[INT1:.*]] = torch.constant.int 1
1341+
// CHECK: %[[INT:.*]]-2 = torch.constant.int -2
1342+
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT]]-2 : (!torch.int, !torch.int) -> !torch.list<int>
13411343
// CHECK: %[[NONE:.*]] = torch.constant.none
13421344
// CHECK: %[[INT0:.*]] = torch.constant.int 0
1343-
// CHECK: %[[INT1:.*]] = torch.constant.int 1
1345+
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
13441346
// CHECK: %[[T2:.*]] = torch.aten.neg.int %[[ARG1]] : !torch.int -> !torch.int
1345-
// CHECK: %[[T3:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[ARG2]], %[[T2]], %[[NONE]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
1346-
// CHECK: %[[T4:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[ARG2]], %[[INT0]], %[[T2]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
1347+
// CHECK: %[[T3:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1]], %[[T2]], %[[NONE]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
1348+
// CHECK: %[[T4:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1]], %[[INT0]], %[[T2]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
13471349
// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T3]], %[[T4]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor<[?,?],f32>>
1348-
// CHECK: %[[T6:.*]] = torch.aten.cat %[[T5]], %[[ARG2]] : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
1350+
// CHECK: %[[T6:.*]] = torch.aten.cat %[[T5]], %[[INT1]] : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
13491351
// CHECK: %[[T7:.*]] = torch.aten.neg.int %[[ARG2]] : !torch.int -> !torch.int
1350-
// CHECK: %[[T8:.*]] = torch.aten.slice.Tensor %[[T6]], %[[ARG3]], %[[T7]], %[[NONE]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
1351-
// CHECK: %[[T9:.*]] = torch.aten.slice.Tensor %[[T6]], %[[ARG3]], %[[INT0]], %[[T7]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
1352+
// CHECK: %[[T8:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2, %[[T7]], %[[NONE]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32>
1353+
// CHECK: %[[T9:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2, %[[INT]]0, %[[T7]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
13521354
// CHECK: %[[T10:.*]] = torch.prim.ListConstruct %[[T8]], %[[T9]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor<[?,?],f32>>
1353-
// CHECK: %[[T11:.*]] = torch.aten.cat %[[T10]], %[[ARG3]] : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
1355+
// CHECK: %[[T11:.*]] = torch.aten.cat %[[T10]], %[[INT]]-2 : !torch.list<vtensor<[?,?],f32>>, !torch.int -> !torch.vtensor<[?,?],f32>
13541356
// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32>
1355-
func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.vtensor<[?,?],f32> {
1357+
func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> {
13561358
%0 = torch.prim.ListConstruct %arg1, %arg2: (!torch.int, !torch.int) -> !torch.list<int>
1357-
%1 = torch.prim.ListConstruct %arg2, %arg3: (!torch.int, !torch.int) -> !torch.list<int>
1359+
%int1 = torch.constant.int 1
1360+
%int-2 = torch.constant.int -2
1361+
%1 = torch.prim.ListConstruct %int1, %int-2: (!torch.int, !torch.int) -> !torch.list<int>
13581362
%2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
13591363
return %2 : !torch.vtensor<[?,?],f32>
13601364
}

0 commit comments

Comments
 (0)