From 9a9e2cb48bd075b546029892357bfa423a43918e Mon Sep 17 00:00:00 2001 From: Smirnov Egor Date: Mon, 16 Aug 2021 15:45:30 +0300 Subject: [PATCH] add Split partial sum tests --- testdata/dnn/onnx/data/input_split_sizes.npy | Bin 0 -> 208 bytes testdata/dnn/onnx/data/output_split_sizes.npy | Bin 0 -> 208 bytes testdata/dnn/onnx/generate_onnx_models.py | 17 +++++++++++++++++ testdata/dnn/onnx/models/split_sizes.onnx | Bin 0 -> 495 bytes 4 files changed, 17 insertions(+) create mode 100644 testdata/dnn/onnx/data/input_split_sizes.npy create mode 100644 testdata/dnn/onnx/data/output_split_sizes.npy create mode 100644 testdata/dnn/onnx/models/split_sizes.onnx diff --git a/testdata/dnn/onnx/data/input_split_sizes.npy b/testdata/dnn/onnx/data/input_split_sizes.npy new file mode 100644 index 0000000000000000000000000000000000000000..3cffa7442d5a489ab457d574043c1293e849d0bd GIT binary patch literal 208 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$720EHL3bhL411<(IXs`!T4nPFN4M4mAhz|hq10ZH_WMB{gVg(>J0AdFq W4glf=AT9vn1|aSL;t4=J!w~?Zh$o}~ literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/data/output_split_sizes.npy b/testdata/dnn/onnx/data/output_split_sizes.npy new file mode 100644 index 0000000000000000000000000000000000000000..7463c44c5960078907bae4212bde0ae917bb2e9d GIT binary patch literal 208 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$720EHL3bhL411<&#aA06K0K@{03=9rHd;*9+05Ojf1A_t(n*i|%Aie{{ YAAp$287K$DVnD0_#9BaX0>pOC0J6p=>Hq)$ literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/generate_onnx_models.py b/testdata/dnn/onnx/generate_onnx_models.py index 6187bbe54..35261ff2c 100644 --- a/testdata/dnn/onnx/generate_onnx_models.py +++ b/testdata/dnn/onnx/generate_onnx_models.py @@ -555,6 +555,23 @@ def forward(self, x): model = Split(dim=0, split_size_sections=[1, 1]) save_data_and_model("split_4", input, model) +class SplitSizes(nn.Module): + def __init__(self, *args, **kwargs): + super(SplitSizes, self).__init__() + + def forward(self, x): + a, b, c, d = torch.split(x, [2, 3, 5, 10], 0) + a = torch.mul(a, 2) + b = torch.mul(b, 3) + c = torch.mul(c, 5) + d = torch.mul(d, 10) + tup = (a, b, c, d) + return torch.cat(tup) + +model = SplitSizes() +input_ = Variable(torch.tensor(list(range(20)), dtype=torch.float32)) +save_data_and_model("split_sizes", input_, model) + class SplitMax(nn.Module): def __init__(self): diff --git a/testdata/dnn/onnx/models/split_sizes.onnx b/testdata/dnn/onnx/models/split_sizes.onnx new file mode 100644 index 0000000000000000000000000000000000000000..58f86769309d3f757b04bbbc007ae57ea58e281f GIT binary patch literal 495 zcmaLUOKXEb5CCAuN3_$64Cw<;#iNkcM>MggZA&kOo_gy=Y!GS@72TBd-}tlobGj-C z_TpjLVTSo&W{d^t^$)L_-KV8^xAFb2g3q|%S=m+%fk)sI1O#hKf2)g}C$6mxfpd;( zddS;rsRP*|)Y`txWTr5UVFx