1313# limitations under the License.
1414"""Rewrite the FP32 operators to FP16 or BF16 operators."""
1515
16+ from collections import defaultdict
1617from dataclasses import dataclass
1718from functools import partial
1819from typing import Any , Callable , Dict , List , Tuple
2526from torch .fx .subgraph_rewriter import Match
2627from typing_extensions import TypeAlias
2728
28- from neural_compressor .common import utils
29+ from neural_compressor .common import logger , utils
2930
3031# =============================================================================
3132# Search and replace patterns
@@ -50,25 +51,44 @@ class PatternPair:
5051
5152# key: torch func
5253# value: the tuple of args
53- FuncArgsMappingType : TypeAlias = Dict [TorchFuncType , Tuple [torch .Tensor , ...]]
54+ FuncArgsMappingType : TypeAlias = Dict [TorchFuncType , List [ Tuple [torch .Tensor , ...] ]]
5455
5556
56- # Align with https://pytorch.org/docs/stable/amp.html#cpu-ops-that-can-autocast-to-bfloat16
57- # TODO: complete the mapping
57+ # Align with xiq, as it relay on xiq's set_module_xx capability
5858FN_ARGS_MAPPING : FuncArgsMappingType = {
59- torch .nn .functional .linear : (torch .randn (0 , 0 ), torch .randn (0 , 0 )), # linear w/o bias
60- torch .nn .functional .linear : (torch .randn (0 , 0 ), torch .randn (0 , 0 ), torch .randn (0 )), # linear w/ bias
59+ # Note: ORDER is matter
60+ torch .nn .functional .linear : [
61+ (torch .randn (0 , 0 ), torch .randn (0 , 0 )), # linear w/o bias
62+ (torch .randn (0 , 0 ), torch .randn (0 , 0 ), torch .randn (0 )), # linear w/ bias
63+ ],
64+ torch .nn .functional .conv2d : [
65+ (torch .randn (1 , 1 , 1 , 1 ), torch .randn (1 , 1 , 1 , 1 )), # conv2d w/o bias
66+ (torch .randn (1 , 1 , 1 , 1 ), torch .randn (1 , 1 , 1 , 1 ), torch .randn (1 )), # conv2d w/ bias
67+ ],
68+ torch .matmul : [
69+ (torch .randn (0 , 0 ), torch .randn (0 , 0 )),
70+ (torch .randn (0 , 0 , 0 ), torch .randn (0 , 0 , 0 )),
71+ (torch .randn (0 , 0 , 0 , 0 ), torch .randn (0 , 0 , 0 , 0 )),
72+ ],
6173}
62- # TODO: complete the mapping
63- FN_ATEN_OPS_MAPPING = {
64- torch .nn .functional .linear : torch .ops .aten .linear .default ,
74+
75+ # module cls <-> function name
76+ NN_MODULES_TO_NN_FN = {
77+ torch .nn .Linear : torch .nn .functional .linear ,
78+ torch .nn .Conv2d : torch .nn .functional .conv2d ,
6579}
6680
81+ # Use the mapping from xiq
82+ FN_ATEN_OPS_MAPPING = xiq ._map_module_function_to_aten_operator_type ()
83+
6784SUPPORTED_OPERATORS = FN_ATEN_OPS_MAPPING .values ()
6885
6986
7087PatternRegistryType : TypeAlias = Dict [TorchFuncType , PatternPair ]
71- HALF_PRECISION_PATTERN_REGISTRY : Dict [torch .dtype , PatternRegistryType ] = {torch .float16 : {}, torch .bfloat16 : {}}
88+ HALF_PRECISION_PATTERN_REGISTRY : Dict [torch .dtype , PatternRegistryType ] = {
89+ torch .float16 : defaultdict (list ),
90+ torch .bfloat16 : defaultdict (list ),
91+ }
7292
7393# FP16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.float16]
7494# BF16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.bfloat16]
@@ -98,15 +118,18 @@ def replace_fn_wrapper(fn_args, fn):
98118
99119
100120def _register_pattern_pair (dtype : torch .dtype ) -> None :
101- for fn , fn_args in FN_ARGS_MAPPING .items ():
102- pattern_pair = pattern_factory (fn , fn_args )
103- HALF_PRECISION_PATTERN_REGISTRY [dtype ][fn ] = pattern_pair
104- utils .logger .info (
121+ for fn , fn_args_lst in FN_ARGS_MAPPING .items ():
122+ for fn_args in fn_args_lst :
123+ logger .debug (f"Registering search and replace patterns for { fn } with args: { fn_args } ." )
124+ pattern_pair = pattern_factory (fn , fn_args )
125+ HALF_PRECISION_PATTERN_REGISTRY [dtype ][fn ].append (pattern_pair )
126+ utils .logger .debug (
105127 f"Registered { len (HALF_PRECISION_PATTERN_REGISTRY [dtype ])} search and replace patterns for { dtype } ."
106128 )
107129
108130
109131_register_pattern_pair (torch .float16 )
132+ _register_pattern_pair (torch .bfloat16 )
110133
111134
112135def get_filter_fn (node_list , fn ):
@@ -182,9 +205,10 @@ def get_unquantized_node_set(gm: torch.fx.GraphModule):
182205
183206def transformation (gm : torch .fx .GraphModule , node_candidate_list : List [str ], target_dtype : torch .dtype = torch .float16 ):
184207 """Convert the nodes in `node_candidate_list` to `target_dtype` if possible."""
185- for pattern_pair in HALF_PRECISION_PATTERN_REGISTRY [target_dtype ].values ():
186- apply_single_pattern_pair (gm , pattern_pair , node_candidate_list )
187- utils .logger .info ("Half precision conversion is done:" )
208+ for pattern_pair_lst in HALF_PRECISION_PATTERN_REGISTRY [target_dtype ].values ():
209+ for pattern_pair in pattern_pair_lst :
210+ apply_single_pattern_pair (gm , pattern_pair , node_candidate_list )
211+ utils .logger .info (f"Half precision conversion({ target_dtype } ) completed." )
188212 if utils .level_name == "DEBUG" : # pragma: no cover
189213 gm .print_readable (True )
190214
@@ -201,11 +225,11 @@ def _parse_node_candidate_set_from_user_config(config, gm):
201225 op_name_filters = []
202226 for op_type_name , config in op_type_configs .items (): # pragma: no cover
203227 op_type = getattr (torch .nn , op_type_name )
204- if config .act_dtype == "fp16" : # pragma: no cover
228+ if config .act_dtype in [ "fp16" , "bf16" ] : # pragma: no cover
205229 filter = xpq ._get_module_type_filter (op_type )
206230 op_type_filters .append (filter )
207231 for op_name , config in op_name_configs .items ():
208- if config .act_dtype == "fp16" : # pragma: no cover
232+ if config .act_dtype in [ "fp16" , "bf16" ] : # pragma: no cover
209233 filter = xpq ._get_module_name_filter (op_name )
210234 op_name_filters .append (filter )
211235 node_set_from_user_config = set ()
@@ -237,5 +261,7 @@ def get_half_precision_node_set(gm, config):
237261 for node in possible_node_set :
238262 if node .target in SUPPORTED_OPERATORS :
239263 half_precision_node_set .add (node )
240- utils .logger .info (f"Found { len (half_precision_node_set )} nodes to convert to half precision." )
264+ utils .logger .info (
265+ f"Found { len (half_precision_node_set )} nodes to convert to half precision: { half_precision_node_set } "
266+ )
241267 return half_precision_node_set
0 commit comments