Skip to content

Commit 3f73725

Browse files
Yantom1xinhe3
authored andcommitted
[SW-197607] INC- change hard coded gaudi 2 scales for optimal weight … (#221)
* [SW-197607] INC- change hard coded gaudi 2 scales for optimal weight quantization * cr fix
1 parent d870a56 commit 3f73725

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,21 +120,28 @@ def get_fullscales_by_expbias_set(dtype, device, expbias_set):
120120
return [get_fullscale(dtype, device, exp_bias=eb) for eb in expbias_set]
121121

122122

123-
def get_fp8_hw_alligned_scales(dtype, device):
123+
def get_fp8_hw_alligned_scales_by_device(dtype, device):
124+
if device not in [GAUDI2, GAUDI3]:
125+
raise ValueError(
126+
f"{device} is not supported"
127+
)
124128
exp_bias_set = EXP_BIAS_SETS.get((device, dtype), None)
125129
return (
126130
None
127131
if exp_bias_set is None
128132
else [x / get_fullscale(dtype, device) for x in get_fullscales_by_expbias_set(dtype, device, exp_bias_set)]
129133
)
130134

135+
def get_fp8_hw_alligned_scales(dtype):
136+
inc_device_type = auto_detect_accelerator().get_inc_accelerator_type()
137+
return get_fp8_hw_alligned_scales_by_device(dtype, inc_device_type)
131138

132139
DEVICES_SCALE_FACTORS = {
133140
INCAcceleratorType.GAUDI2: 4,
134141
INCAcceleratorType.GAUDI3: 1,
135142
}
136143
FP8_143_SCALES = {
137-
device: get_fp8_hw_alligned_scales(torch.float8_e4m3fn, device) for device in DEVICES_SCALE_FACTORS.keys()
144+
device: get_fp8_hw_alligned_scales_by_device(torch.float8_e4m3fn, device) for device in DEVICES_SCALE_FACTORS.keys()
138145
}
139146
FP8_143_SCALES_TRAITS = {
140147
device: (

neural_compressor/torch/algorithms/fp8_quant/_core/scale.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ..model_configs import ModuleConfig, ModuleExtraConfig
1818
from .scale_methods import ops_quantizer
1919
from .._quant_common.quant_config import ScaleMethod
20+
from .fp_utils import get_fp8_hw_alligned_scales
2021
import torch
2122

2223

@@ -77,7 +78,6 @@ def prepare_layer_scales(mod, mod_name, config, mod_type_str, measurement, scale
7778
)
7879
return mod_extra_config, save_file
7980

80-
8181
scale_method_mapping = {
8282
(ScaleMethod.UNIT_SCALE, "maxabs"): "unit_scale",
8383
(ScaleMethod.UNIT_SCALE, "maxabs_per_channel"): "unit_scale",
@@ -158,7 +158,7 @@ def prepare_layer_scales(mod, mod_name, config, mod_type_str, measurement, scale
158158
"act_maxabs_pts_hw_weight_opt_pts_hw": {
159159
"input_backoff": 0.25,
160160
"weight_backoff": 0.5,
161-
"weight_scales": [2.0**s for s in [4, 0, -4, -8]],
161+
"weight_scales": get_fp8_hw_alligned_scales(torch.float8_e4m3fn)
162162
},
163163
"smoothquant_weights_maxabs_pow2": {
164164
"input_backoff": 0.25,

0 commit comments

Comments
 (0)