@@ -524,6 +524,37 @@ def generate_slice_neg_starts():
524524
525525generate_slice_neg_starts ()
526526
527+ def postprocess_model (model_path , inputs_shapes ):
528+ onnx_model = onnx .load (model_path )
529+
530+ def update_inputs_dims (model , input_dims ):
531+ """
532+ This function updates the sizes of dimensions of the model's inputs to the values
533+ provided in input_dims. if the dim value provided is negative, a unique dim_param
534+ will be set for that dimension.
535+ """
536+ def update_dim (tensor , dim , i , j , dim_param_prefix ):
537+ dim_proto = tensor .type .tensor_type .shape .dim [j ]
538+ if isinstance (dim , int ):
539+ if dim >= 0 :
540+ dim_proto .dim_value = dim
541+ else :
542+ dim_proto .dim_param = dim_param_prefix + str (i ) + '_' + str (j )
543+ elif isinstance (dim , str ):
544+ dim_proto .dim_param = dim
545+ else :
546+ raise ValueError ('Only int or str is accepted as dimension value, incorrect type: {}' .format (type (dim )))
547+
548+ for i , input_dim_arr in enumerate (input_dims ):
549+ for j , dim in enumerate (input_dim_arr ):
550+ update_dim (model .graph .input [i ], dim , i , j , 'in_' )
551+
552+ onnx .checker .check_model (model )
553+ return model
554+
555+ onnx_model = update_inputs_dims (onnx_model , inputs_shapes )
556+ onnx .save (onnx_model , model_path )
557+
527558input_2 = Variable (torch .randn (6 , 6 ))
528559custom_slice_list = [
529560 slice (1 , 3 , 1 ),
@@ -1916,36 +1947,6 @@ def forward(self, x):
19161947model = GatherMultiOutput ()
19171948save_data_and_model ("gather_multi_output" , x , model )
19181949
1919- def postprocess_model (model_path , inputs_shapes ):
1920- onnx_model = onnx .load (model_path )
1921-
1922- def update_inputs_dims (model , input_dims ):
1923- """
1924- This function updates the sizes of dimensions of the model's inputs to the values
1925- provided in input_dims. if the dim value provided is negative, a unique dim_param
1926- will be set for that dimension.
1927- """
1928- def update_dim (tensor , dim , i , j , dim_param_prefix ):
1929- dim_proto = tensor .type .tensor_type .shape .dim [j ]
1930- if isinstance (dim , int ):
1931- if dim >= 0 :
1932- dim_proto .dim_value = dim
1933- else :
1934- dim_proto .dim_param = dim_param_prefix + str (i ) + '_' + str (j )
1935- elif isinstance (dim , str ):
1936- dim_proto .dim_param = dim
1937- else :
1938- raise ValueError ('Only int or str is accepted as dimension value, incorrect type: {}' .format (type (dim )))
1939-
1940- for i , input_dim_arr in enumerate (input_dims ):
1941- for j , dim in enumerate (input_dim_arr ):
1942- update_dim (model .graph .input [i ], dim , i , j , 'in_' )
1943-
1944- onnx .checker .check_model (model )
1945- return model
1946-
1947- onnx_model = update_inputs_dims (onnx_model , inputs_shapes )
1948- onnx .save (onnx_model , model_path )
19491950
19501951class UnsqueezeAndConv (nn .Module ):
19511952 def __init__ (self ):
0 commit comments