Skip to content

Commit ab29a16

Browse files
committed
Merge remote-tracking branch 'upstream/3.4' into merge-3.4
2 parents 2d8d5f9 + 5d47c1d commit ab29a16

File tree

4 files changed

+17
-0
lines changed

4 files changed

+17
-0
lines changed
208 Bytes
Binary file not shown.
208 Bytes
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,23 @@ def forward(self, x):
557557
model = Split(dim=0, split_size_sections=[1, 1])
558558
save_data_and_model("split_4", input, model)
559559

560+
class SplitSizes(nn.Module):
561+
def __init__(self, *args, **kwargs):
562+
super(SplitSizes, self).__init__()
563+
564+
def forward(self, x):
565+
a, b, c, d = torch.split(x, [2, 3, 5, 10], 0)
566+
a = torch.mul(a, 2)
567+
b = torch.mul(b, 3)
568+
c = torch.mul(c, 5)
569+
d = torch.mul(d, 10)
570+
tup = (a, b, c, d)
571+
return torch.cat(tup)
572+
573+
model = SplitSizes()
574+
input_ = Variable(torch.tensor(list(range(20)), dtype=torch.float32))
575+
save_data_and_model("split_sizes", input_, model)
576+
560577
class SplitMax(nn.Module):
561578

562579
def __init__(self):
495 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)