22import datetime
33import os
44import time
5+ from enum import Enum
56
67import torch
78import torch .ao .quantization
9+ import torch .ao .quantization .quantize_fx
810import torch .utils .data
911import torchvision
1012import utils
1113from torch import nn
1214from train import train_one_epoch , evaluate , load_data
1315
1416
17+ class QuantizationWorkflowType (Enum ):
18+ EAGER_MODE_QUANTIZATION = 1
19+ FX_GRAPH_MODE_QUANTIZATION = 2
20+
21+
1522def main (args ):
1623 if args .output_dir :
1724 utils .mkdir (args .output_dir )
@@ -22,6 +29,15 @@ def main(args):
2229 if args .post_training_quantize and args .distributed :
2330 raise RuntimeError ("Post training quantization example should not be performed on distributed mode" )
2431
32+ # Validate quantization workflow type
33+ quantization_workflow_type = args .quantization_workflow_type .upper ()
34+ if quantization_workflow_type not in QuantizationWorkflowType .__members__ :
35+ raise RuntimeError (
36+ "Unknown workflow type '%s', please choose from: %s"
37+ % (args .quantization_workflow_type , str (tuple ([t .lower () for t in QuantizationWorkflowType .__members__ ])))
38+ )
39+ use_fx_graph_mode_quantization = quantization_workflow_type == QuantizationWorkflowType .FX_GRAPH_MODE_QUANTIZATION
40+
2541 # Set backend engine to ensure that quantized model runs on the correct kernels
2642 if args .backend not in torch .backends .quantized .supported_engines :
2743 raise RuntimeError ("Quantized backend not supported: " + str (args .backend ))
@@ -45,14 +61,23 @@ def main(args):
4561 )
4662
4763 print ("Creating model" , args .model )
64+ if use_fx_graph_mode_quantization :
65+ model_namespace = torchvision .models
66+ else :
67+ model_namespace = torchvision .models .quantization
4868 # when training quantized models, we always start from a pre-trained fp32 reference model
49- model = torchvision . models . quantization .__dict__ [args .model ](weights = args .weights , quantize = args .test_only )
69+ model = model_namespace .__dict__ [args .model ](weights = args .weights , quantize = args .test_only )
5070 model .to (device )
5171
5272 if not (args .test_only or args .post_training_quantize ):
53- model .fuse_model (is_qat = True )
54- model .qconfig = torch .ao .quantization .get_default_qat_qconfig (args .backend )
55- torch .ao .quantization .prepare_qat (model , inplace = True )
73+ qconfig = torch .ao .quantization .get_default_qat_qconfig (args .backend )
74+ if use_fx_graph_mode_quantization :
75+ qconfig_dict = {"" : qconfig }
76+ model = torch .ao .quantization .quantize_fx .prepare_qat_fx (model , qconfig_dict )
77+ else :
78+ model .fuse_model (is_qat = True )
79+ model .qconfig = qconfig
80+ torch .ao .quantization .prepare_qat (model , inplace = True )
5681
5782 if args .distributed and args .sync_bn :
5883 model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (model )
@@ -84,13 +109,21 @@ def main(args):
84109 ds , batch_size = args .batch_size , shuffle = False , num_workers = args .workers , pin_memory = True
85110 )
86111 model .eval ()
87- model .fuse_model (is_qat = False )
88- model .qconfig = torch .ao .quantization .get_default_qconfig (args .backend )
89- torch .ao .quantization .prepare (model , inplace = True )
112+ qconfig = torch .ao .quantization .get_default_qconfig (args .backend )
113+ if use_fx_graph_mode_quantization :
114+ qconfig_dict = {"" : qconfig }
115+ model = torch .ao .quantization .quantize_fx .prepare_fx (model , qconfig_dict )
116+ else :
117+ model .fuse_model (is_qat = False )
118+ model .qconfig = qconfig
119+ torch .ao .quantization .prepare (model , inplace = True )
90120 # Calibrate first
91121 print ("Calibrating" )
92122 evaluate (model , criterion , data_loader_calibration , device = device , print_freq = 1 )
93- torch .ao .quantization .convert (model , inplace = True )
123+ if use_fx_graph_mode_quantization :
124+ model = torch .ao .quantization .quantize_fx .convert_fx (model )
125+ else :
126+ torch .ao .quantization .convert (model , inplace = True )
94127 if args .output_dir :
95128 print ("Saving quantized model" )
96129 if utils .is_main_process ():
@@ -125,7 +158,10 @@ def main(args):
125158 quantized_eval_model = copy .deepcopy (model_without_ddp )
126159 quantized_eval_model .eval ()
127160 quantized_eval_model .to (torch .device ("cpu" ))
128- torch .ao .quantization .convert (quantized_eval_model , inplace = True )
161+ if use_fx_graph_mode_quantization :
162+ quantized_eval_model = torch .ao .quantization .quantize_fx .convert_fx (quantized_eval_model )
163+ else :
164+ torch .ao .quantization .convert (quantized_eval_model , inplace = True )
129165
130166 print ("Evaluate Quantized model" )
131167 evaluate (quantized_eval_model , criterion , data_loader_test , device = torch .device ("cpu" ))
@@ -233,6 +269,12 @@ def get_args_parser(add_help=True):
233269 help = "Post training quantize the model" ,
234270 action = "store_true" ,
235271 )
272+ parser .add_argument (
273+ "--quantization-workflow-type" ,
274+ default = "eager_mode_quantization" ,
275+ type = str ,
276+ help = "The quantization workflow type to use, either 'eager_mode_quantization' (default) or 'fx_graph_mode_quantization'" ,
277+ )
236278
237279 # distributed training parameters
238280 parser .add_argument ("--world-size" , default = 1 , type = int , help = "number of distributed processes" )
0 commit comments