Skip to content

Commit b186c70

Browse files
author
Anastasia Murzova
committed
Added Steps support in DNN Slice layer
1 parent 995c6d3 commit b186c70

13 files changed

+46
-1
lines changed
272 Bytes
Binary file not shown.
560 Bytes
Binary file not shown.
560 Bytes
Binary file not shown.
992 Bytes
Binary file not shown.
144 Bytes
Binary file not shown.
200 Bytes
Binary file not shown.
344 Bytes
Binary file not shown.
272 Bytes
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,17 +378,62 @@ def forward(self, x):
378378

379379
class Slice(nn.Module):
380380

381-
def __init__(self):
381+
def __init__(self, custom_sliice=None):
382+
self.custom_sliice=custom_sliice
382383
super(Slice, self).__init__()
383384

384385
def forward(self, x):
386+
if self.custom_sliice:
387+
return x[self.custom_sliice]
388+
385389
return x[..., 1:-1, 0:3]
386390

387391
input = Variable(torch.randn(1, 2, 4, 4))
388392
model = Slice()
389393
save_data_and_model("slice", input, model)
390394
save_data_and_model("slice_opset_11", input, model, version=11)
391395

396+
input_2 = Variable(torch.randn(6, 6))
397+
custom_slice_list = [
398+
slice(1, 3, 1),
399+
slice(0, 3, 2)
400+
]
401+
model_2 = Slice(custom_sliice=custom_slice_list)
402+
save_data_and_model("slice_opset_11_steps_2d", input_2, model_2, version=11)
403+
postprocess_model("models/slice_opset_11_steps_2d.onnx", [['height', 'width']])
404+
405+
input_3 = Variable(torch.randn(3, 6, 6))
406+
custom_slice_list_3 = [
407+
slice(None, None, 2),
408+
slice(None, None, 2),
409+
slice(None, None, 2)
410+
]
411+
model_3 = Slice(custom_sliice=custom_slice_list_3)
412+
save_data_and_model("slice_opset_11_steps_3d", input_3, model_3, version=11)
413+
postprocess_model("models/slice_opset_11_steps_3d.onnx", [[3, 'height', 'width']])
414+
415+
input_4 = Variable(torch.randn(1, 3, 6, 6))
416+
custom_slice_list_4 = [
417+
slice(0, 5, None),
418+
slice(None, None, None),
419+
slice(1, None, 2),
420+
slice(None, None, None)
421+
]
422+
model_4 = Slice(custom_sliice=custom_slice_list_4)
423+
save_data_and_model("slice_opset_11_steps_4d", input_4, model_4, version=11)
424+
postprocess_model("models/slice_opset_11_steps_4d.onnx", [["batch_size", 3, 'height', 'width']])
425+
426+
input_5 = Variable(torch.randn(1, 2, 3, 6, 6))
427+
custom_slice_list_5 = [
428+
slice(None, None, None),
429+
slice(None, None, None),
430+
slice(0, None, 3),
431+
slice(None, None, None),
432+
slice(None, None, 2)
433+
]
434+
model_5 = Slice(custom_sliice=custom_slice_list_5)
435+
save_data_and_model("slice_opset_11_steps_5d", input_5, model_5, version=11)
436+
392437
class Eltwise(nn.Module):
393438

394439
def __init__(self):
608 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)