Skip to content

Commit eab46f7

Browse files
committed
fix conv bn fusion isssue
Signed-off-by: Xin He <[email protected]>
1 parent 089b7e1 commit eab46f7

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

neural_compressor/experimental/export/torch2onnx.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,15 @@ def check_data(op_type, data, module_dict):
127127
for name, value in module_dict.items():
128128
if value.shape == data.shape:
129129
if (value == data).all():
130+
module_dict.pop(name)
130131
return name
131-
# Convolution weight data mismatch.
132-
elif op_type == 'Conv' and np.allclose(value, data):
133-
return name
132+
elif op_type == 'Conv':
133+
# Convolution weight data have fluction and BN fusion will insert scale.
134+
# We use the weight scale of the first output channel to check.
135+
weight_scale = value[0] / data[0]
136+
if np.allclose(weight_scale - np.mean(weight_scale), 0, atol=1.e-5):
137+
module_dict.pop(name)
138+
return name
134139
return None
135140

136141
module_dict = {}

0 commit comments

Comments
 (0)