Skip to content

Commit 890703d

Browse files
authored
Merge pull request #1059 from appsinfinity868:issue_23508_opencv_repo_bug_fix
postprocess_model() shifted above to fix the issue 23508 in opencv/opencv repository
2 parents 8c7c046 + 6e7d3c4 commit 890703d

File tree

1 file changed

+31
-30
lines changed

1 file changed

+31
-30
lines changed

testdata/dnn/onnx/generate_onnx_models.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,37 @@ def generate_slice_neg_starts():
524524

525525
generate_slice_neg_starts()
526526

527+
def postprocess_model(model_path, inputs_shapes):
528+
onnx_model = onnx.load(model_path)
529+
530+
def update_inputs_dims(model, input_dims):
531+
"""
532+
This function updates the sizes of dimensions of the model's inputs to the values
533+
provided in input_dims. if the dim value provided is negative, a unique dim_param
534+
will be set for that dimension.
535+
"""
536+
def update_dim(tensor, dim, i, j, dim_param_prefix):
537+
dim_proto = tensor.type.tensor_type.shape.dim[j]
538+
if isinstance(dim, int):
539+
if dim >= 0:
540+
dim_proto.dim_value = dim
541+
else:
542+
dim_proto.dim_param = dim_param_prefix + str(i) + '_' + str(j)
543+
elif isinstance(dim, str):
544+
dim_proto.dim_param = dim
545+
else:
546+
raise ValueError('Only int or str is accepted as dimension value, incorrect type: {}'.format(type(dim)))
547+
548+
for i, input_dim_arr in enumerate(input_dims):
549+
for j, dim in enumerate(input_dim_arr):
550+
update_dim(model.graph.input[i], dim, i, j, 'in_')
551+
552+
onnx.checker.check_model(model)
553+
return model
554+
555+
onnx_model = update_inputs_dims(onnx_model, inputs_shapes)
556+
onnx.save(onnx_model, model_path)
557+
527558
input_2 = Variable(torch.randn(6, 6))
528559
custom_slice_list = [
529560
slice(1, 3, 1),
@@ -1916,36 +1947,6 @@ def forward(self, x):
19161947
model = GatherMultiOutput()
19171948
save_data_and_model("gather_multi_output", x, model)
19181949

1919-
def postprocess_model(model_path, inputs_shapes):
1920-
onnx_model = onnx.load(model_path)
1921-
1922-
def update_inputs_dims(model, input_dims):
1923-
"""
1924-
This function updates the sizes of dimensions of the model's inputs to the values
1925-
provided in input_dims. if the dim value provided is negative, a unique dim_param
1926-
will be set for that dimension.
1927-
"""
1928-
def update_dim(tensor, dim, i, j, dim_param_prefix):
1929-
dim_proto = tensor.type.tensor_type.shape.dim[j]
1930-
if isinstance(dim, int):
1931-
if dim >= 0:
1932-
dim_proto.dim_value = dim
1933-
else:
1934-
dim_proto.dim_param = dim_param_prefix + str(i) + '_' + str(j)
1935-
elif isinstance(dim, str):
1936-
dim_proto.dim_param = dim
1937-
else:
1938-
raise ValueError('Only int or str is accepted as dimension value, incorrect type: {}'.format(type(dim)))
1939-
1940-
for i, input_dim_arr in enumerate(input_dims):
1941-
for j, dim in enumerate(input_dim_arr):
1942-
update_dim(model.graph.input[i], dim, i, j, 'in_')
1943-
1944-
onnx.checker.check_model(model)
1945-
return model
1946-
1947-
onnx_model = update_inputs_dims(onnx_model, inputs_shapes)
1948-
onnx.save(onnx_model, model_path)
19491950

19501951
class UnsqueezeAndConv(nn.Module):
19511952
def __init__(self):

0 commit comments

Comments
 (0)