@@ -378,17 +378,62 @@ def forward(self, x):
378378
379379class Slice (nn .Module ):
380380
381- def __init__ (self ):
381+ def __init__ (self , custom_sliice = None ):
382+ self .custom_sliice = custom_sliice
382383 super (Slice , self ).__init__ ()
383384
384385 def forward (self , x ):
386+ if self .custom_sliice :
387+ return x [self .custom_sliice ]
388+
385389 return x [..., 1 :- 1 , 0 :3 ]
386390
387391input = Variable (torch .randn (1 , 2 , 4 , 4 ))
388392model = Slice ()
389393save_data_and_model ("slice" , input , model )
390394save_data_and_model ("slice_opset_11" , input , model , version = 11 )
391395
396+ input_2 = Variable (torch .randn (6 , 6 ))
397+ custom_slice_list = [
398+ slice (1 , 3 , 1 ),
399+ slice (0 , 3 , 2 )
400+ ]
401+ model_2 = Slice (custom_sliice = custom_slice_list )
402+ save_data_and_model ("slice_opset_11_steps_2d" , input_2 , model_2 , version = 11 )
403+ postprocess_model ("models/slice_opset_11_steps_2d.onnx" , [['height' , 'width' ]])
404+
405+ input_3 = Variable (torch .randn (3 , 6 , 6 ))
406+ custom_slice_list_3 = [
407+ slice (None , None , 2 ),
408+ slice (None , None , 2 ),
409+ slice (None , None , 2 )
410+ ]
411+ model_3 = Slice (custom_sliice = custom_slice_list_3 )
412+ save_data_and_model ("slice_opset_11_steps_3d" , input_3 , model_3 , version = 11 )
413+ postprocess_model ("models/slice_opset_11_steps_3d.onnx" , [[3 , 'height' , 'width' ]])
414+
415+ input_4 = Variable (torch .randn (1 , 3 , 6 , 6 ))
416+ custom_slice_list_4 = [
417+ slice (0 , 5 , None ),
418+ slice (None , None , None ),
419+ slice (1 , None , 2 ),
420+ slice (None , None , None )
421+ ]
422+ model_4 = Slice (custom_sliice = custom_slice_list_4 )
423+ save_data_and_model ("slice_opset_11_steps_4d" , input_4 , model_4 , version = 11 )
424+ postprocess_model ("models/slice_opset_11_steps_4d.onnx" , [["batch_size" , 3 , 'height' , 'width' ]])
425+
426+ input_5 = Variable (torch .randn (1 , 2 , 3 , 6 , 6 ))
427+ custom_slice_list_5 = [
428+ slice (None , None , None ),
429+ slice (None , None , None ),
430+ slice (0 , None , 3 ),
431+ slice (None , None , None ),
432+ slice (None , None , 2 )
433+ ]
434+ model_5 = Slice (custom_sliice = custom_slice_list_5 )
435+ save_data_and_model ("slice_opset_11_steps_5d" , input_5 , model_5 , version = 11 )
436+
392437class Eltwise (nn .Module ):
393438
394439 def __init__ (self ):
0 commit comments