diff --git a/testdata/dnn/onnx/data/input_split_0.npy b/testdata/dnn/onnx/data/input_split_0.npy new file mode 100644 index 000000000..5576f4a92 Binary files /dev/null and b/testdata/dnn/onnx/data/input_split_0.npy differ diff --git a/testdata/dnn/onnx/data/output_split_0.npy b/testdata/dnn/onnx/data/output_split_0.npy new file mode 100644 index 000000000..e014b04ca Binary files /dev/null and b/testdata/dnn/onnx/data/output_split_0.npy differ diff --git a/testdata/dnn/onnx/generate_onnx_models.py b/testdata/dnn/onnx/generate_onnx_models.py index bec86026f..10704ee46 100644 --- a/testdata/dnn/onnx/generate_onnx_models.py +++ b/testdata/dnn/onnx/generate_onnx_models.py @@ -868,6 +868,15 @@ def forward(self, x): tup = torch.split(x, self.split_size_sections, self.dim) return torch.cat(tup) +class SimpleSplit(nn.Module): + def forward(self, image): + return torch.cat([img for img in image]) + + +model = SimpleSplit() +input = torch.ones((1, 3, 2, 2)) +save_data_and_model("split_0", input, model, version=11) + model = Split() input = Variable(torch.tensor([1., 2.], dtype=torch.float32)) save_data_and_model("split_1", input, model) @@ -888,6 +897,8 @@ def forward(self, x): model = Split(dim=-1, split_size_sections=[1, 2]) save_data_and_model("split_6", input2, model, version=13) + + class SplitSizes(nn.Module): def __init__(self, *args, **kwargs): super(SplitSizes, self).__init__() diff --git a/testdata/dnn/onnx/models/split_0.onnx b/testdata/dnn/onnx/models/split_0.onnx new file mode 100644 index 000000000..1be2c819c Binary files /dev/null and b/testdata/dnn/onnx/models/split_0.onnx differ