diff --git a/testdata/dnn/onnx/data/input_split_sizes.npy b/testdata/dnn/onnx/data/input_split_sizes.npy new file mode 100644 index 000000000..3cffa7442 Binary files /dev/null and b/testdata/dnn/onnx/data/input_split_sizes.npy differ diff --git a/testdata/dnn/onnx/data/output_split_sizes.npy b/testdata/dnn/onnx/data/output_split_sizes.npy new file mode 100644 index 000000000..7463c44c5 Binary files /dev/null and b/testdata/dnn/onnx/data/output_split_sizes.npy differ diff --git a/testdata/dnn/onnx/generate_onnx_models.py b/testdata/dnn/onnx/generate_onnx_models.py index 6187bbe54..35261ff2c 100644 --- a/testdata/dnn/onnx/generate_onnx_models.py +++ b/testdata/dnn/onnx/generate_onnx_models.py @@ -555,6 +555,23 @@ def forward(self, x): model = Split(dim=0, split_size_sections=[1, 1]) save_data_and_model("split_4", input, model) +class SplitSizes(nn.Module): + def __init__(self, *args, **kwargs): + super(SplitSizes, self).__init__() + + def forward(self, x): + a, b, c, d = torch.split(x, [2, 3, 5, 10], 0) + a = torch.mul(a, 2) + b = torch.mul(b, 3) + c = torch.mul(c, 5) + d = torch.mul(d, 10) + tup = (a, b, c, d) + return torch.cat(tup) + +model = SplitSizes() +input_ = Variable(torch.tensor(list(range(20)), dtype=torch.float32)) +save_data_and_model("split_sizes", input_, model) + class SplitMax(nn.Module): def __init__(self): diff --git a/testdata/dnn/onnx/models/split_sizes.onnx b/testdata/dnn/onnx/models/split_sizes.onnx new file mode 100644 index 000000000..58f867693 Binary files /dev/null and b/testdata/dnn/onnx/models/split_sizes.onnx differ