@@ -44,6 +44,125 @@ def contiguous_hook(module, input):
4444 child .register_forward_pre_hook (contiguous_hook )
4545
4646
47+ def is_fused_module (module ):
48+ """This is a helper function for `_propagate_qconfig_helper` to detecte
49+ if this module is fused.
50+
51+ Args:
52+ module (object): input module
53+
54+ Returns:
55+ (bool): is fused or not
56+ """
57+ op_type = str (type (module ))
58+ if 'fused' in op_type :
59+ return True
60+ else :
61+ return False
62+
63+
64+ def _set_input_scale_hook (model , op_cfgs ):
65+ """Insert hooks to observer input scale and zeropoint.
66+
67+ Args:
68+ model (object): input model
69+ op_cfgs (dict): dictionary of quantization configure for each op
70+
71+ Returns:
72+ hook_list (list): input observer hooks
73+ """
74+ def input_scale_hook (module , input ):
75+ module .input_observer = module .qconfig .activation ()
76+ module .input_observer (input [0 ])
77+ return input
78+
79+ def output_scale_hook (module , input , output ):
80+ module .output_observer = module .qconfig .activation ()
81+ module .output_observer (output )
82+ return output
83+
84+ def ConvReLU2d_scale_hook (module , input ):
85+ module .input_observer = module .qconfig .activation ()
86+ module .input_observer (input [0 ])
87+ output = module ._conv_forward (input [0 ], module .weight_fake_quant (module .weight ), module .bias )
88+ module .output_observer = module .qconfig .activation ()
89+ module .output_observer (output )
90+ return input
91+
92+ def LinearReLU_scale_hook (module , input ):
93+ import torch .nn .functional as F
94+ module .input_observer = module .qconfig .activation ()
95+ module .input_observer (input [0 ])
96+ output = F .linear (input [0 ], module .weight_fake_quant (module .weight ), module .bias )
97+ module .output_observer = module .qconfig .activation ()
98+ module .output_observer (output )
99+ return input
100+
101+ hook_list = []
102+ for name , module in model .named_modules ():
103+ if 'Conv' in str (module .__class__ .__name__ ) or \
104+ 'Linear' in str (module .__class__ .__name__ ):
105+ if not hasattr (module , 'qconfig' ) or not module .qconfig :
106+ continue
107+ from torch .nn .intrinsic .qat import ConvBn2d , ConvReLU2d , ConvBnReLU2d , LinearReLU
108+ if type (module ) in [ConvBn2d , ConvBnReLU2d ]:
109+ handle_in = module .register_forward_pre_hook (input_scale_hook )
110+ # module[0] == torch.nn.BatchNorm2d
111+ module [0 ].qconfig = module .qconfig
112+ handle_out = module [0 ].register_forward_hook (output_scale_hook )
113+ hook_list .extend ([handle_in , handle_out ])
114+ elif type (module ) in [ConvReLU2d ]:
115+ handle_in_out = module .register_forward_pre_hook (ConvReLU2d_scale_hook )
116+ hook_list .extend ([handle_in_out ])
117+ elif type (module ) in [LinearReLU ]:
118+ handle_in_out = module .register_forward_pre_hook (LinearReLU_scale_hook )
119+ hook_list .extend ([handle_in_out ])
120+ else :
121+ if is_fused_module (module ):
122+ continue
123+ handle_in = module .register_forward_pre_hook (input_scale_hook )
124+ handle_out = module .register_forward_hook (output_scale_hook )
125+ hook_list .extend ([handle_in , handle_out ])
126+ return hook_list
127+
128+
129+ def _get_input_scale (model , hook_list ):
130+ """Fetch input scale and zeropoint from observer.
131+
132+ Args:
133+ model (object): input model
134+ hook_list (list): input observer hooks
135+
136+ Returns:
137+ input_scale_info (dict): input scale and zero_point of each modules
138+ """
139+ scale_info = {}
140+ for name , module in model .named_modules ():
141+ from torch .nn .intrinsic .qat import ConvBn2d , ConvBnReLU2d
142+ if type (module ) in [ConvBn2d , ConvBnReLU2d ]:
143+ if hasattr (module , "input_observer" ) and hasattr (module [0 ], "output_observer" ):
144+ scale_in , zero_point_in = module .input_observer .calculate_qparams ()
145+ scale_out , zero_point_out = module [0 ].output_observer .calculate_qparams ()
146+ scale_info [name ] = {
147+ 'input_scale' : float (scale_in ),
148+ 'input_zeropoint' : int (zero_point_in ),
149+ 'output_scale' : float (scale_out ),
150+ 'output_zeropoint' : int (zero_point_out )
151+ }
152+ elif hasattr (module , "input_observer" ) and hasattr (module , "output_observer" ):
153+ scale_in , zero_point_in = module .input_observer .calculate_qparams ()
154+ scale_out , zero_point_out = module .output_observer .calculate_qparams ()
155+ scale_info [name ] = {
156+ 'input_scale' : float (scale_in ),
157+ 'input_zeropoint' : int (zero_point_in ),
158+ 'output_scale' : float (scale_out ),
159+ 'output_zeropoint' : int (zero_point_out )
160+ }
161+ for h in hook_list :
162+ h .remove ()
163+ return scale_info
164+
165+
47166def collate_torch_preds (results ):
48167 batch = results [0 ]
49168 if isinstance (batch , list ):
0 commit comments