22import datetime
33import os
44import time
5+ from enum import Enum
56
67import torch
78import torch .ao .quantization
9+ import torch .ao .quantization .quantize_fx
810import torch .utils .data
911import torchvision
1012import utils
1113from torch import nn
1214from train import train_one_epoch , evaluate , load_data
1315
1416
17+ class QuantizationWorkflowType (Enum ):
18+ EAGER_MODE_QUANTIZATION = 1
19+ FX_GRAPH_MODE_QUANTIZATION = 2
20+
21+
1522def main (args ):
1623 if args .output_dir :
1724 utils .mkdir (args .output_dir )
@@ -22,6 +29,17 @@ def main(args):
2229 if args .post_training_quantize and args .distributed :
2330 raise RuntimeError ("Post training quantization example should not be performed on distributed mode" )
2431
32+ # Validate quantization workflow type
33+ quantization_workflow_type = args .quantization_workflow_type .upper ()
34+ if quantization_workflow_type not in QuantizationWorkflowType .__members__ :
35+ raise RuntimeError (
36+ "Unknown workflow type '%s', please choose from: %s"
37+ % (args .quantization_workflow_type , str (tuple ([t .lower () for t in QuantizationWorkflowType .__members__ ])))
38+ )
39+ use_fx_graph_mode_quantization = (
40+ QuantizationWorkflowType [quantization_workflow_type ] == QuantizationWorkflowType .FX_GRAPH_MODE_QUANTIZATION
41+ )
42+
2543 # Set backend engine to ensure that quantized model runs on the correct kernels
2644 if args .backend not in torch .backends .quantized .supported_engines :
2745 raise RuntimeError ("Quantized backend not supported: " + str (args .backend ))
@@ -46,13 +64,20 @@ def main(args):
4664
4765 print ("Creating model" , args .model )
4866 # when training quantized models, we always start from a pre-trained fp32 reference model
49- model = torchvision .models .quantization .__dict__ [args .model ](weights = args .weights , quantize = args .test_only )
67+ if use_fx_graph_mode_quantization :
68+ model = torchvision .models .__dict__ [args .model ](weights = args .weights )
69+ else :
70+ model = torchvision .models .quantization .__dict__ [args .model ](weights = args .weights , quantize = args .test_only )
5071 model .to (device )
5172
5273 if not (args .test_only or args .post_training_quantize ):
53- model .fuse_model (is_qat = True )
54- model .qconfig = torch .ao .quantization .get_default_qat_qconfig (args .backend )
55- torch .ao .quantization .prepare_qat (model , inplace = True )
74+ if use_fx_graph_mode_quantization :
75+ qconfig_dict = torch .ao .quantization .get_default_qat_qconfig_dict (args .backend )
76+ model = torch .ao .quantization .quantize_fx .prepare_qat_fx (model , qconfig_dict )
77+ else :
78+ model .fuse_model (is_qat = True )
79+ model .qconfig = torch .ao .quantization .get_default_qat_qconfig (args .backend )
80+ torch .ao .quantization .prepare_qat (model , inplace = True )
5681
5782 if args .distributed and args .sync_bn :
5883 model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (model )
@@ -84,13 +109,20 @@ def main(args):
84109 ds , batch_size = args .batch_size , shuffle = False , num_workers = args .workers , pin_memory = True
85110 )
86111 model .eval ()
87- model .fuse_model (is_qat = False )
88- model .qconfig = torch .ao .quantization .get_default_qconfig (args .backend )
89- torch .ao .quantization .prepare (model , inplace = True )
112+ if use_fx_graph_mode_quantization :
113+ qconfig_dict = torch .ao .quantization .get_default_qconfig_dict (args .backend )
114+ model = torch .ao .quantization .quantize_fx .prepare_fx (model , qconfig_dict )
115+ else :
116+ model .fuse_model (is_qat = False )
117+ model .qconfig = torch .ao .quantization .get_default_qconfig (args .backend )
118+ torch .ao .quantization .prepare (model , inplace = True )
90119 # Calibrate first
91120 print ("Calibrating" )
92121 evaluate (model , criterion , data_loader_calibration , device = device , print_freq = 1 )
93- torch .ao .quantization .convert (model , inplace = True )
122+ if use_fx_graph_mode_quantization :
123+ model = torch .ao .quantization .quantize_fx .convert_fx (model )
124+ else :
125+ torch .ao .quantization .convert (model , inplace = True )
94126 if args .output_dir :
95127 print ("Saving quantized model" )
96128 if utils .is_main_process ():
@@ -125,7 +157,10 @@ def main(args):
125157 quantized_eval_model = copy .deepcopy (model_without_ddp )
126158 quantized_eval_model .eval ()
127159 quantized_eval_model .to (torch .device ("cpu" ))
128- torch .ao .quantization .convert (quantized_eval_model , inplace = True )
160+ if use_fx_graph_mode_quantization :
161+ quantized_eval_model = torch .ao .quantization .quantize_fx .convert_fx (quantized_eval_model )
162+ else :
163+ torch .ao .quantization .convert (quantized_eval_model , inplace = True )
129164
130165 print ("Evaluate Quantized model" )
131166 evaluate (quantized_eval_model , criterion , data_loader_test , device = torch .device ("cpu" ))
@@ -233,6 +268,12 @@ def get_args_parser(add_help=True):
233268 help = "Post training quantize the model" ,
234269 action = "store_true" ,
235270 )
271+ parser .add_argument (
272+ "--quantization-workflow-type" ,
273+ default = "eager_mode_quantization" ,
274+ type = str ,
275+ help = "The quantization workflow type to use, either 'eager_mode_quantization' (default) or 'fx_graph_mode_quantization'" ,
276+ )
236277
237278 # distributed training parameters
238279 parser .add_argument ("--world-size" , default = 1 , type = int , help = "number of distributed processes" )
0 commit comments