@@ -52,9 +52,9 @@ def remove(self):
5252class HessianTrace :
5353 """HessianTrace Class.
5454
55- Please refer to Yao, Zhewei, et al. "Pyhessian: Neural networks through the lens of the hessian."
55+ Please refer to Yao, Zhewei, et al. "Pyhessian: Neural networks through the lens of the hessian."
5656 2020 IEEE international conference on big data (Big data). IEEE, 2020.
57- Dong, Zhen, et al. "Hawq-v2: Hessian aware trace-weighted quantization of neural networks."
57+ Dong, Zhen, et al. "Hawq-v2: Hessian aware trace-weighted quantization of neural networks."
5858 Advances in neural information processing systems 33 (2020): 18518-18529.
5959 https://github.com/openvinotoolkit/nncf/blob/develop/nncf/torch/quantization/hessian_trace.py
6060 """
@@ -173,7 +173,7 @@ def act_grad_hook(model, grad_input, grad_output):
173173 def _get_enable_act_grad_hook (self , name ):
174174 def enable_act_grad_hook (model , inputs , outputs ):
175175 input = inputs [0 ]
176- if input .requires_grad is False :
176+ if input .requires_grad is False : #
177177 input .requires_grad = True
178178 self .layer_acts [name ] = input
179179
@@ -251,13 +251,13 @@ def _sample_rademacher(self, params):
251251 r .masked_fill_ (r == 0 , - 1 )
252252 samples .append (r )
253253 return samples
254-
254+
255255 def _sample_rademacher_like_params (self ):
256256 def sample (parameter ):
257257 r = torch .randint_like (parameter , high = 2 , device = self .device )
258258 return r .masked_fill_ (r == 0 , - 1 )
259259 return [sample (p ) for p in self .params ]
260-
260+
261261 def _sample_normal_like_params (self ):
262262 return [torch .randn (p .size (), device = self .device ) for p in self .params ]
263263
@@ -391,7 +391,7 @@ def _insert_hook(self, model, target_module_list):
391391 for layer , module in model .named_modules ():
392392 for target_module in target_module_list :
393393 # print("layer:",layer)
394- # print("target_model:",target_module)
394+ # print("target_model:",target_module)
395395 if layer == target_module :
396396 logging .debug ("Collect: %s" % (module ))
397397 # print("Collect: %s" % (module))
@@ -408,7 +408,7 @@ def _insert_hook_quantize(self, model, target_module_list):
408408 # print("layer:",layer)
409409 length = len ("_model." )
410410 new_key = layer [length :]
411- # print("target_model:",target_module)
411+ # print("target_model:",target_module)
412412 if new_key == target_module :
413413 logging .debug ("Collect: %s" % (module ))
414414 # print("Collect: %s" % (module))
@@ -521,7 +521,7 @@ def compare_weights(
521521 float_dict : Dict [str , Any ], quantized_dict : Dict [str , Any ]
522522) -> Dict [str , Dict [str , torch .Tensor ]]:
523523 r"""Compare the weights of the float module with its corresponding quantized module.
524-
524+
525525 Returns a dict with key corresponding to module names and each entry being
526526 a dictionary with two keys 'float' and 'quantized', containing the float and
527527 quantized weights. This dict can be used to compare and compute the quantization
@@ -608,7 +608,7 @@ def hawq_top(fp32_model, q_model, dataloader, criterion, enable_act):
608608 op_qnt_tensor = weight_quant_loss [key ]['quantized' ].dequantize ()
609609 diff_l2 = (torch .norm (op_float_tensor - op_qnt_tensor , p = 2 ) ** 2 )
610610 pertur_lst [key ] = diff_l2
611-
611+
612612 if enable_act :
613613 act_to_traces = traces ['activation' ]
614614 for trace_i , pertur_i , act_i in zip (op_to_traces .keys (), pertur_lst .keys (), act_to_traces .keys ()):
0 commit comments