@@ -434,16 +434,49 @@ def forward(self, x):
434434save_data_and_model ("slice" , input , model )
435435save_data_and_model ("slice_opset_11" , input , model , version = 11 )
436436
437- class SliceStarts ( nn . Module ):
438- def __init__ ( self , * args , ** kwargs ):
439- super ( SliceStarts , self ). __init__ ()
437+ def generate_slice_neg_starts ( ):
438+ x = np . random . randn ( 2 , 3 , 4 , 3 ). astype ( np . float32 )
439+ y = x [ - 1 : 2 , - 3 : - 1 , 2 : 3 , 1 : - 1 ]
440440
441- def forward (self , x ):
442- return x [- 1 :]
441+ starts = np .array ([- 1 , - 3 , 2 , 1 ], dtype = np .int64 )
442+ starts = onnx .numpy_helper .from_array (starts , name = 'starts' )
443+ ends = np .array ([ 2 , - 1 , 3 , - 1 ], dtype = np .int64 )
444+ ends = onnx .numpy_helper .from_array (ends , name = 'ends' )
443445
444- model = SliceStarts ()
445- input_ = Variable (torch .randn (1 , 10 , dtype = torch .float32 ))
446- save_data_and_model ("slice_neg_starts" , input_ , model )
446+ node = onnx .helper .make_node (
447+ 'Slice' ,
448+ inputs = ['X' , 'starts' , 'ends' ],
449+ outputs = ['Y' ],
450+ )
451+
452+ X = onnx .helper .make_tensor_value_info ('X' , onnx .TensorProto .FLOAT , list (x .shape ))
453+ Y = onnx .helper .make_tensor_value_info ('Y' , onnx .TensorProto .FLOAT , list (y .shape ))
454+
455+ graph = onnx .helper .make_graph (
456+ [node ], # nodes
457+ 'slice_neg_starts' , # name
458+ [X ], # inputs
459+ [Y ], # outputs
460+ )
461+
462+ graph .initializer .append (starts )
463+ graph .initializer .append (ends )
464+
465+ model = onnx .helper .make_model (graph , producer_name = 'onnx' )
466+ onnx .checker .check_model (model )
467+
468+ name = 'slice_neg_starts'
469+
470+ input_files = os .path .join ("data" , "input_" + name )
471+ np .save (input_files , x .data )
472+
473+ output_files = os .path .join ("data" , "output_" + name )
474+ np .save (output_files , np .ascontiguousarray (y .data ))
475+
476+ models_files = os .path .join ("models" , name + ".onnx" )
477+ onnx .save (model , models_files )
478+
479+ generate_slice_neg_starts ()
447480
448481input_2 = Variable (torch .randn (6 , 6 ))
449482custom_slice_list = [
0 commit comments