1313# limitations under the License.
1414"""Rewrite the FP32 operators to FP16 or BF16 operators."""
1515
16+ from collections import defaultdict
1617from dataclasses import dataclass
1718from functools import partial
1819from typing import Any , Callable , Dict , List , Tuple
@@ -50,25 +51,31 @@ class PatternPair:
5051
5152# key: torch func
5253# value: the tuple of args
53- FuncArgsMappingType : TypeAlias = Dict [TorchFuncType , Tuple [torch .Tensor , ...]]
54+ FuncArgsMappingType : TypeAlias = Dict [TorchFuncType , List [ Tuple [torch .Tensor , ...] ]]
5455
5556
5657# Align with xiq, as it relay on xiq's set_module_xx capability
5758FN_ARGS_MAPPING : FuncArgsMappingType = {
58- torch .nn .functional .linear : (torch .randn (0 , 0 ), torch .randn (0 , 0 )), # linear w/o bias
59- torch .nn .functional .linear : (torch .randn (0 , 0 ), torch .randn (0 , 0 ), torch .randn (0 )), # linear w/ bias
60- torch .nn .functional .conv2d : (torch .randn (1 , 1 , 1 , 1 ), torch .randn (1 , 1 , 1 , 1 )), # conv2d w/o bias
61- torch .nn .functional .conv2d : (torch .randn (1 , 1 , 1 , 1 ), torch .randn (1 , 1 , 1 , 1 ), torch .randn (1 )), # conv2d w/ bias
62- torch .matmul : (torch .randn (0 , 0 ), torch .randn (0 , 0 )), # matmul
63- torch .matmul : (torch .randn (0 , 0 , 0 ), torch .randn (0 , 0 , 0 )), # matmul
64- torch .matmul : (torch .randn (0 , 0 , 0 , 0 ), torch .randn (0 , 0 , 0 , 0 )), # matmul
59+ # Note: ORDER is matter
60+ torch .nn .functional .linear : [
61+ (torch .randn (0 , 0 ), torch .randn (0 , 0 )), # linear w/o bias
62+ (torch .randn (0 , 0 ), torch .randn (0 , 0 ), torch .randn (0 )), # linear w/ bias
63+ ],
64+ torch .nn .functional .conv2d : [
65+ (torch .randn (1 , 1 , 1 , 1 ), torch .randn (1 , 1 , 1 , 1 )), # conv2d w/o bias
66+ (torch .randn (1 , 1 , 1 , 1 ), torch .randn (1 , 1 , 1 , 1 ), torch .randn (1 )), # conv2d w/ bias
67+ ],
68+ torch .matmul : [
69+ (torch .randn (0 , 0 ), torch .randn (0 , 0 )),
70+ (torch .randn (0 , 0 , 0 ), torch .randn (0 , 0 , 0 )),
71+ (torch .randn (0 , 0 , 0 , 0 ), torch .randn (0 , 0 , 0 , 0 )),
72+ ],
6573}
6674
6775# module cls <-> function name
6876NN_MODULES_TO_NN_FN = {
6977 torch .nn .Linear : torch .nn .functional .linear ,
7078 torch .nn .Conv2d : torch .nn .functional .conv2d ,
71- torch .nn .MaxPool2d : torch .nn .functional .max_pool2d ,
7279}
7380
7481# Use the mapping from xiq
@@ -78,7 +85,10 @@ class PatternPair:
7885
7986
8087PatternRegistryType : TypeAlias = Dict [TorchFuncType , PatternPair ]
81- HALF_PRECISION_PATTERN_REGISTRY : Dict [torch .dtype , PatternRegistryType ] = {torch .float16 : {}, torch .bfloat16 : {}}
88+ HALF_PRECISION_PATTERN_REGISTRY : Dict [torch .dtype , PatternRegistryType ] = {
89+ torch .float16 : defaultdict (list ),
90+ torch .bfloat16 : defaultdict (list ),
91+ }
8292
8393# FP16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.float16]
8494# BF16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.bfloat16]
@@ -108,10 +118,11 @@ def replace_fn_wrapper(fn_args, fn):
108118
109119
110120def _register_pattern_pair (dtype : torch .dtype ) -> None :
111- for fn , fn_args in FN_ARGS_MAPPING .items ():
112- logger .debug (f"Registering search and replace patterns for { fn } with args: { fn_args } ." )
113- pattern_pair = pattern_factory (fn , fn_args )
114- HALF_PRECISION_PATTERN_REGISTRY [dtype ][fn ] = pattern_pair
121+ for fn , fn_args_lst in FN_ARGS_MAPPING .items ():
122+ for fn_args in fn_args_lst :
123+ logger .debug (f"Registering search and replace patterns for { fn } with args: { fn_args } ." )
124+ pattern_pair = pattern_factory (fn , fn_args )
125+ HALF_PRECISION_PATTERN_REGISTRY [dtype ][fn ].append (pattern_pair )
115126 utils .logger .debug (
116127 f"Registered { len (HALF_PRECISION_PATTERN_REGISTRY [dtype ])} search and replace patterns for { dtype } ."
117128 )
@@ -194,9 +205,10 @@ def get_unquantized_node_set(gm: torch.fx.GraphModule):
194205
195206def transformation (gm : torch .fx .GraphModule , node_candidate_list : List [str ], target_dtype : torch .dtype = torch .float16 ):
196207 """Convert the nodes in `node_candidate_list` to `target_dtype` if possible."""
197- for pattern_pair in HALF_PRECISION_PATTERN_REGISTRY [target_dtype ].values ():
198- apply_single_pattern_pair (gm , pattern_pair , node_candidate_list )
199- utils .logger .info ("Half precision conversion is done:" )
208+ for pattern_pair_lst in HALF_PRECISION_PATTERN_REGISTRY [target_dtype ].values ():
209+ for pattern_pair in pattern_pair_lst :
210+ apply_single_pattern_pair (gm , pattern_pair , node_candidate_list )
211+ utils .logger .info (f"Half precision conversion({ target_dtype } ) completed." )
200212 if utils .level_name == "DEBUG" : # pragma: no cover
201213 gm .print_readable (True )
202214
@@ -249,5 +261,7 @@ def get_half_precision_node_set(gm, config):
249261 for node in possible_node_set :
250262 if node .target in SUPPORTED_OPERATORS :
251263 half_precision_node_set .add (node )
252- utils .logger .info (f"Found { len (half_precision_node_set )} nodes to convert to half precision." )
264+ utils .logger .info (
265+ f"Found { len (half_precision_node_set )} nodes to convert to half precision: { half_precision_node_set } "
266+ )
253267 return half_precision_node_set
0 commit comments