@@ -121,7 +121,7 @@ def main(args):
121121 if args .distributed :
122122 train_sampler .set_epoch (epoch )
123123 print ("Starting training for epoch" , epoch )
124- train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args . print_freq )
124+ train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args )
125125 lr_scheduler .step ()
126126 with torch .inference_mode ():
127127 if epoch >= args .num_observer_update_epochs :
@@ -132,7 +132,7 @@ def main(args):
132132 model .apply (torch .nn .intrinsic .qat .freeze_bn_stats )
133133 print ("Evaluate QAT model" )
134134
135- evaluate (model , criterion , data_loader_test , device = device )
135+ evaluate (model , criterion , data_loader_test , device = device , log_suffix = "QAT" )
136136 quantized_eval_model = copy .deepcopy (model_without_ddp )
137137 quantized_eval_model .eval ()
138138 quantized_eval_model .to (torch .device ("cpu" ))
@@ -261,6 +261,7 @@ def get_args_parser(add_help=True):
261261 parser .add_argument (
262262 "--train-crop-size" , default = 224 , type = int , help = "the random crop size used for training (default: 224)"
263263 )
264+ parser .add_argument ("--clip-grad-norm" , default = None , type = float , help = "the maximum gradient norm (default None)" )
264265
265266 # Prototype models only
266267 parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load" )
0 commit comments