55
66import torch
77import torch .ao .quantization
8+ import torch .ao .quantization .quantize_fx
89import torch .utils .data
910import torchvision
1011import utils
1112from torch import nn
13+ from torchvision .models .quantization .utils import QuantizationWorkflowType
1214from train import train_one_epoch , evaluate , load_data
1315
1416
@@ -22,6 +24,15 @@ def main(args):
2224 if args .post_training_quantize and args .distributed :
2325 raise RuntimeError ("Post training quantization example should not be performed on distributed mode" )
2426
27+ # Validate quantization workflow type
28+ all_quantization_workflow_types = [t .value for t in QuantizationWorkflowType ]
29+ if args .quantization_workflow_type not in all_quantization_workflow_types :
30+ raise RuntimeError (
31+ "Unknown quantization workflow type '%s', must be one of: %s"
32+ % (args .quantization_workflow_type , all_quantization_workflow_types )
33+ )
34+ quantization_workflow_type = QuantizationWorkflowType (args .quantization_workflow_type )
35+
2536 # Set backend engine to ensure that quantized model runs on the correct kernels
2637 if args .backend not in torch .backends .quantized .supported_engines :
2738 raise RuntimeError ("Quantized backend not supported: " + str (args .backend ))
@@ -46,13 +57,21 @@ def main(args):
4657
4758 print ("Creating model" , args .model )
4859 # when training quantized models, we always start from a pre-trained fp32 reference model
49- model = torchvision .models .quantization .__dict__ [args .model ](weights = args .weights , quantize = args .test_only )
60+ model = torchvision .models .quantization .__dict__ [args .model ](
61+ weights = args .weights ,
62+ quantize = args .test_only ,
63+ quantization_workflow_type = quantization_workflow_type ,
64+ )
5065 model .to (device )
5166
5267 if not (args .test_only or args .post_training_quantize ):
53- model .fuse_model (is_qat = True )
54- model .qconfig = torch .ao .quantization .get_default_qat_qconfig (args .backend )
55- torch .ao .quantization .prepare_qat (model , inplace = True )
68+ if quantization_workflow_type == QuantizationWorkflowType .FX_GRAPH_MODE :
69+ qconfig_dict = torch .ao .quantization .get_default_qat_qconfig_dict (args .backend )
70+ model = torch .ao .quantization .quantize_fx .prepare_qat_fx (model , qconfig_dict )
71+ else :
72+ model .fuse_model (is_qat = True )
73+ model .qconfig = torch .ao .quantization .get_default_qat_qconfig (args .backend )
74+ torch .ao .quantization .prepare_qat (model , inplace = True )
5675
5776 if args .distributed and args .sync_bn :
5877 model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (model )
@@ -84,13 +103,20 @@ def main(args):
84103 ds , batch_size = args .batch_size , shuffle = False , num_workers = args .workers , pin_memory = True
85104 )
86105 model .eval ()
87- model .fuse_model (is_qat = False )
88- model .qconfig = torch .ao .quantization .get_default_qconfig (args .backend )
89- torch .ao .quantization .prepare (model , inplace = True )
106+ if quantization_workflow_type == QuantizationWorkflowType .FX_GRAPH_MODE :
107+ qconfig_dict = torch .ao .quantization .get_default_qconfig_dict (args .backend )
108+ model = torch .ao .quantization .quantize_fx .prepare_fx (model , qconfig_dict )
109+ else :
110+ model .fuse_model (is_qat = False )
111+ model .qconfig = torch .ao .quantization .get_default_qconfig (args .backend )
112+ torch .ao .quantization .prepare (model , inplace = True )
90113 # Calibrate first
91114 print ("Calibrating" )
92115 evaluate (model , criterion , data_loader_calibration , device = device , print_freq = 1 )
93- torch .ao .quantization .convert (model , inplace = True )
116+ if quantization_workflow_type == QuantizationWorkflowType .FX_GRAPH_MODE :
117+ model = torch .ao .quantization .quantize_fx .convert_fx (model )
118+ else :
119+ torch .ao .quantization .convert (model , inplace = True )
94120 if args .output_dir :
95121 print ("Saving quantized model" )
96122 if utils .is_main_process ():
@@ -125,7 +151,10 @@ def main(args):
125151 quantized_eval_model = copy .deepcopy (model_without_ddp )
126152 quantized_eval_model .eval ()
127153 quantized_eval_model .to (torch .device ("cpu" ))
128- torch .ao .quantization .convert (quantized_eval_model , inplace = True )
154+ if quantization_workflow_type == QuantizationWorkflowType .FX_GRAPH_MODE :
155+ quantized_eval_model = torch .ao .quantization .quantize_fx .convert_fx (quantized_eval_model )
156+ else :
157+ torch .ao .quantization .convert (quantized_eval_model , inplace = True )
129158
130159 print ("Evaluate Quantized model" )
131160 evaluate (quantized_eval_model , criterion , data_loader_test , device = torch .device ("cpu" ))
@@ -233,6 +262,12 @@ def get_args_parser(add_help=True):
233262 help = "Post training quantize the model" ,
234263 action = "store_true" ,
235264 )
265+ parser .add_argument (
266+ "--quantization-workflow-type" ,
267+ default = "eager_mode" ,
268+ type = str ,
269+ help = "The quantization workflow type to use, either 'eager_mode' (default) or 'fx_graph_mode'" ,
270+ )
236271
237272 # distributed training parameters
238273 parser .add_argument ("--world-size" , default = 1 , type = int , help = "number of distributed processes" )
0 commit comments