Skip to content

Commit 9e11d45

Browse files
HolyFalafelXuehaoSun
authored andcommitted
[SW-199696] Implementing dynamic quantization design for linear ops (#188)
* Implementing dynamic quantization design for linear ops * Using copy_ to store scale as a member, added qdq, removed dyn * Added PatchedLinearBase to support all linear modules * Testing dynamic quantization with scale compare * CR comments - calling cguid * Added PatchedLinearBase * Fixed PatchedLinear forward_qdq * Changed quant strategy - scale to fix ci * Renamed QuantStrategy to QuantWrapper * Removed instance member from QuantWrapper * [SW-224403] Added ticket and throwing error when using row_parallel_linear_allreduce_quantization * Changed QuantWrapper to a simple method that stores scale * [SW-224538] Added ticket to TODO comment for init_linear * Pushed requires_grad to the tensor creation * Fixed merge * Fixed load() flow - handling meta tensors with dummy scale * [SW-224609] removed non tested dynamic qdq * Update neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py * Update neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py * Moved copy_scale functions inside PatchedLinearBase * Added and fixed test cases * Increased tolerance for new test cases * Update helper_modules.py * Update helper_modules.py * Some tests/ci fixes * Update neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py * Update helper_modules.py * cr comments + cguid check change * Update helper_modules.py * Update helper_modules.py copy scale * Update neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py * Maxabs design and some structure changes * Merged MaxAbsDynamicPts To base + cguid comments * changed cguid calls to functions * Log changes * Update neural_compressor/torch/algorithms/fp8_quant/model_configs.py * Update neural_compressor/torch/algorithms/fp8_quant/model_configs.py * Re-set self.scale_input as before, value is none in dynamic * Changing back dynamic scale_input to intermediate and not member * Disabling test_linear_dynamic_quantization: not storing scale as member * Reintroduce MaxAbsDynamicPts: in dynamic we don't save scale as a member * weight to hpu comment
1 parent 689575b commit 9e11d45

File tree

19 files changed

+419
-113
lines changed

19 files changed

+419
-113
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,8 @@ def get_device_type_for_scales(mod):
197197
@lru_cache
198198
def is_runtime_scale_patching():
199199
return os.getenv("RUNTIME_SCALE_PATCHING", "False").lower() in ["true", "1"]
200+
201+
#TODO [SW-224612]: Use cguid to calc scales and remove the check
202+
@lru_cache
203+
def is_calc_scale_with_cguid():
204+
return os.getenv("CALC_SCALE_WITH_CGUID", "False").lower() in ["true", "1"]

neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import torch
16+
from enum import Enum
1617
from .common import ModuleConfig
1718
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator, INCAcceleratorType
1819
cur_accelerator = auto_detect_accelerator()
@@ -21,6 +22,27 @@
2122
scale_fcn = lambda x, scale: torch.div(x, scale)
2223
cast_fcn = lambda x, dtype: x.to(dtype=dtype)
2324
cast_to_fp8_fcn = lambda x, dtype, scale_inv=None: torch.ops.hpu.cast_to_fp8_v2(x, scale_inv, False, False, dtype)[0]
25+
def calculate_scale_maxabs(x, maxMode, **kwargs):
26+
return torch.ops.hpu.calculate_scale_for_cast(
27+
x, maxMode.value, ScaleCalculationRoundingMode.NO_SCALE_ROUNDING.value, **kwargs
28+
)
29+
30+
31+
def calculate_scale_rounding(x, scaleMode, **kwargs):
32+
return torch.ops.hpu.calculate_scale_for_cast(
33+
x, ScaleCalculationMaxMode.NO_MAX_CALCULATION.value, scaleMode.value, **kwargs
34+
)
35+
36+
37+
class ScaleCalculationMaxMode(Enum):
38+
NO_MAX_CALCULATION = 0
39+
MAX_ABS_PTS_CALCULATION = 1
40+
MAX_ABS_PCS_CALCULATION = 2
41+
42+
43+
class ScaleCalculationRoundingMode(Enum):
44+
NO_SCALE_ROUNDING = 0
45+
SCALE_TO_POW2_ROUNDING = 1
2446

2547
GAUDI2 = INCAcceleratorType.GAUDI2
2648
GAUDI3 = INCAcceleratorType.GAUDI3

neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
dequantize_per_tensor_from_fp8,
3131
quantize_per_channel_to_fp8,
3232
dequantize_per_channel_from_fp8,
33+
invert_scale,
3334
)
3435
from .scale_handler import create_scale_tensor
3536

@@ -126,14 +127,20 @@ def __init__(self, input_scales_creator, lp_dtype, hp_dtype, *args, **kwargs):
126127

127128
self.cast_to_op = get_quantized_func_wrapper(OP_TYPE.CAST_TO_FP8, self.scale_format)
128129

129-
def forward(self, x):
130+
def calculate_scales(self, x):
130131
scale = self.input_scales_creator.calc_scales(x, QuantTensorType.DYNAMIC)
131-
scale_inv = self.input_scales_creator.calc_invert_scales()
132+
scale_inv = self.input_scales_creator.invert_scales(scale)
133+
return scale, scale_inv
134+
135+
def forward(self, x):
136+
scale, scale_inv = self.calculate_scales(x)
132137

133138
ret = self.cast_to_op(x, scale_inv, False, False, self.lp_dtype)
134139

135140
return ret, scale
136141

142+
#TODO [SW-224609]: implement forward qdq
143+
137144
def extra_repr(self) -> str:
138145
repr = super(QuantDynamicInput, self).extra_repr()
139146
return f"{repr} input_scales_creator={self.input_scales_creator}"

neural_compressor/torch/algorithms/fp8_quant/_core/scale.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def prepare_layer_scales(mod, mod_name, config, mod_type_str, measurement, scale
8686
(ScaleMethod.MAXABS_HW, "maxabs"): "act_maxabs_pts_pow2_hw_weight_maxabs_pts_pow2_hw",
8787
(ScaleMethod.MAXABS_POW2, "maxabs"): "act_maxabs_pts_pow2_weight_maxabs_pts_pow2",
8888
(ScaleMethod.MAXABS_ARBITRARY, "maxabs"): "act_maxabs_pts_weight_maxabs_pts_arbitrary",
89-
(ScaleMethod.MAXABS_POW2_DYNAMIC, "maxabs"): "act_maxabs_pcs_dyn_pow2_weight_maxabs_pts_pow2_hw", # TODO: remove when changing config parsing
89+
(ScaleMethod.ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW, "maxabs"): "act_maxabs_pcs_pow2_weight_maxabs_pts_pow2_hw", # TODO: remove when changing config parsing
9090
(ScaleMethod.MAXABS_HW_OPT_WEIGHT, "maxabs"): "act_maxabs_pts_hw_weight_opt_pts_hw",
9191
(
9292
ScaleMethod.MAXABS_POW2_OPT_WEIGHT,
@@ -138,7 +138,7 @@ def prepare_layer_scales(mod, mod_name, config, mod_type_str, measurement, scale
138138
"input_backoff": 0.25,
139139
"weight_backoff": 0.5,
140140
},
141-
"act_maxabs_pcs_dyn_pow2_weight_maxabs_pts_pow2_hw": {
141+
"act_maxabs_pcs_pow2_weight_maxabs_pts_pow2_hw": {
142142
"input_backoff": 1.0,
143143
"weight_backoff": 0.5,
144144
},

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
from abc import abstractmethod
1515

1616
import torch
17-
from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import get_hqt_config
18-
from .scale_method_factory import ScaleMethodFactory, QuantTensorName
17+
from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import get_hqt_config, is_supported_dynamic_op
18+
from .scale_method_factory import ScaleMethodFactory, QuantTensorName, ScaleValueType
1919
from ..common import ModuleConfig, QuantTensorType
2020
from ..quant_dequant import DequantOutput, QuantDequant, QuantDequantNone, QuantInput, QuantDynamicInput
21+
from ...utils.logger import logger
2122
from neural_compressor.torch.algorithms.fp8_quant._core.common import dequant_original_fp8_weight_if_needed
2223

2324

@@ -31,6 +32,8 @@ def __init__(self, config, mod, measurement, params, op_type):
3132
self.inputs_scales_creators = []
3233
self.output_scales_creators = []
3334
self.params_scales_creators = []
35+
self.is_dynamic = get_hqt_config(self.mod).cfg["dynamic_quantization"] and is_supported_dynamic_op(op_type)
36+
logger.debug("%s %s", self.__class__.__name__, self.__dict__)
3437

3538
def get_module_configuration(self):
3639
scale_format = get_hqt_config(self.mod).cfg["scale_format"]
@@ -60,14 +63,19 @@ def calc_input_scales(self, num_of_inputs):
6063
input_scales = []
6164
for i in range(num_of_inputs):
6265
input_measurement = self.measurement.inputs[i] if self.measurement is not None else []
63-
input_scales.append(
64-
self.inputs_scales_creators[i].calc_scales(input_measurement, QuantTensorType.MEASUREMENTS)
65-
)
66+
input_scale = None
67+
if not self.is_dynamic:
68+
input_scale = self.inputs_scales_creators[i].calc_scales(
69+
input_measurement, QuantTensorType.MEASUREMENTS
70+
)
71+
input_scales.append(input_scale)
6672
return input_scales
6773

6874
def calc_output_scales(self):
6975
output_measurement = self.measurement.outputs[0] if self.measurement is not None else []
70-
output_scales = self.output_scales_creators[0].calc_scales(output_measurement, QuantTensorType.MEASUREMENTS)
76+
output_scales = None
77+
if not self.is_dynamic:
78+
output_scales = self.output_scales_creators[0].calc_scales(output_measurement, QuantTensorType.MEASUREMENTS)
7179
return (output_scales,)
7280

7381
def init_input_config(self, scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant):
@@ -79,7 +87,7 @@ def init_input_config(self, scales_inv, lp_dtype, hp_dtype, scale_format, use_qd
7987
else:
8088
input_config = []
8189
for input_scales_creator, s_inv in zip(self.inputs_scales_creators, scales_inv):
82-
if input_scales_creator.is_dynamic:
90+
if self.is_dynamic:
8391
input_config.append(
8492
QuantDynamicInput(input_scales_creator, lp_dtype, hp_dtype, scale_format=scale_format)
8593
)
@@ -92,29 +100,38 @@ class LinearOpQuantizer(BaseOpQuantizer):
92100

93101
def __init__(self, config, mod, measurement, params, module_type):
94102
super().__init__(config, mod, measurement, params, module_type)
95-
self.inputs_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.INPUT))
103+
self.inputs_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.INPUT, self.is_dynamic))
96104
self.weight_och_scale_calc = self.scales_method_factory.get_scale_method(QuantTensorName.WEIGHT_OUT_CH)
97105
self.weight_ich_scale_calc = self.scales_method_factory.get_scale_method(QuantTensorName.WEIGHT_IN_CH)
98-
self.output_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.OUTPUT))
106+
self.output_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.OUTPUT, self.is_dynamic))
99107

100108
def get_scales_module_config(self):
101109
input_scales = self.calc_input_scales(num_of_inputs=1)
102110
output_measurement = self.measurement.outputs[0] if self.measurement is not None else []
103111
rescaled_weight = self.mod.weight if hasattr(self.mod, 'weight') else None
112+
if (
113+
self.scales_method_factory.scale_value_type_map[QuantTensorName.WEIGHT_IN_CH]
114+
is not ScaleValueType.DUMMY_SCALES
115+
):
116+
# Calculating weight in hpu to support scale calculation CGUID torch.ops.hpu.calculate_scale_for_cast
117+
rescaled_weight = rescaled_weight.to("hpu")
104118
if rescaled_weight is not None:
105119
rescaled_weight = dequant_original_fp8_weight_if_needed(self.mod, rescaled_weight)
106120
if self.weight_ich_scale_calc is not None:
107121
weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST)
108122
rescaled_weight = torch.div(rescaled_weight, weight_scales_in_ch.reshape([1, -1]))
109123
weights_scales_out_ch = self.weight_och_scale_calc.calc_scales(rescaled_weight, QuantTensorType.CONST)
124+
110125
params_config = (
111126
{"weight": weights_scales_out_ch}
112127
if (self.weight_ich_scale_calc is None)
113128
else {"weight": {0: weights_scales_out_ch, 1: weight_scales_in_ch}}
114129
)
115-
output_scales = self.output_scales_creators[0].calc_scales(
116-
output_measurement, QuantTensorType.MEASUREMENTS, input0=weights_scales_out_ch, input1=input_scales[0]
117-
)
130+
output_scales = None
131+
if not self.is_dynamic:
132+
output_scales = self.output_scales_creators[0].calc_scales(
133+
output_measurement, QuantTensorType.MEASUREMENTS, input0=weights_scales_out_ch, input1=input_scales[0]
134+
)
118135
return ModuleConfig(
119136
input_scales,
120137
(output_scales,),

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/round_scales_function.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,22 @@
1313
# limitations under the License.
1414
import torch
1515

16-
from neural_compressor.torch.algorithms.fp8_quant._core.fp_utils import FP8_143_SCALES, FP8_143_SCALES_TRAITS
16+
from neural_compressor.torch.algorithms.fp8_quant._core.fp_utils import FP8_143_SCALES, FP8_143_SCALES_TRAITS, calculate_scale_rounding, ScaleCalculationRoundingMode
17+
#TODO [SW-224612]: Use cguid to calc scales and remoce check
18+
from ..common import is_calc_scale_with_cguid
1719

1820

1921
class ScaleToPow2:
22+
def __init__(self):
23+
#TODO [SW-224612]: Use cguid to calc scales and remove check
24+
if is_calc_scale_with_cguid():
25+
self.calc = self.calc_with_cguid
26+
27+
#TODO [SW-224612]: Use cguid to calc scales and remove special function
28+
def calc_with_cguid(self, scale):
29+
scale_pow2 = calculate_scale_rounding(scale, ScaleCalculationRoundingMode.SCALE_TO_POW2_ROUNDING)
30+
return scale_pow2
31+
2032
def calc(self, scale):
2133
scale_pow2 = 2.0 ** torch.ceil(torch.log2(scale))
2234
return scale_pow2

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_factory.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from enum import Enum, auto
1515

1616
from .round_scales_function import *
17-
# TODO [SW-217813]: support dynamic quantization in all ops and remove supported_dynamic_ops
18-
from ..._quant_common.quant_config import is_supported_dynamic_op
1917
from ..common import get_device_type_for_scales
2018
from .scales_method import *
2119
from ...utils.logger import logger
@@ -61,6 +59,7 @@ def parse_tensor_granularity(config):
6159
scale_granularity = ScaleGranularity.PTS
6260
if "pcs" in config or "smoothquant" in config:
6361
scale_granularity = ScaleGranularity.PCS
62+
logger.trace("parse_tensor_granularity %s %s", config, scale_granularity)
6463
return scale_granularity
6564

6665
# TODO [SW-217813]: support dynamic quantization in all ops and remove op_type
@@ -78,8 +77,6 @@ def parse_tensor_scale_value_type(config, op_type):
7877
scale_value_type = ScaleValueType.OPT
7978
elif "dummy" in config:
8079
scale_value_type = ScaleValueType.DUMMY_SCALES
81-
elif "dyn" in config and is_supported_dynamic_op(op_type):
82-
scale_value_type = ScaleValueType.DYNAMIC
8380
logger.trace(f"parse_tensor_scale_value_type {config=} {scale_value_type=}")
8481
return scale_value_type
8582

@@ -121,17 +118,26 @@ def __init__(self, config, params, mod, op_type):
121118
QuantTensorName.WEIGHT_IN_CH: self.params.get("weight_backoff", 1.0),
122119
QuantTensorName.WEIGHT_OUT_CH: self.params.get("weight_backoff", 1.0),
123120
QuantTensorName.OUTPUT: self.params.get("output_backoff", self.params.get("input_backoff", 1.0)),} # get output_backoff, if doesn't exists use input_backoff, if doesn't exists use 1
124-
logger.debug("%s %s".format(self.__class__.__name__, self.__dict__))
121+
logger.trace("%s %s", self.__class__.__name__, self.__dict__)
125122

126123
## TODO remove after SW-217369
127124
## config string example: "act_maxabs_pts_weight_opt_pts_hw", round_method = pow2_hw, scale_value_type = maxabs, granularity = pts
128125
# all config strings in scale.py: scale_method_mapping
129126
# returns MaxAbsPts obj with pow2_hw as scale_round_method
130-
def get_scale_method(self, tensor_name):
131-
backoff = self.scale_backoff_map[tensor_name]
127+
def get_scale_method(self, tensor_name, is_dynamic=False):
128+
backoff = 1.0 if is_dynamic else self.scale_backoff_map[tensor_name]
132129
scale_round_method = self.scale_round_method_map[tensor_name]
133130
scale_value_type = self.scale_value_type_map[tensor_name]
134131
scale_granularity = self.scale_granularity_map[tensor_name]
132+
logger.trace(
133+
"get_scale_method backoff=%s scale_round_method=%s scale_value_type=%s scale_granularity=%s op_type=%s is_dynamic=%s",
134+
backoff,
135+
scale_round_method,
136+
scale_value_type,
137+
scale_granularity,
138+
self.op_type,
139+
is_dynamic,
140+
)
135141

136142
match (scale_value_type, scale_granularity, tensor_name, self.op_type):
137143
## dummy
@@ -145,13 +151,13 @@ def get_scale_method(self, tensor_name):
145151
if self.op_type in {"linear", "matmul"}:
146152
if scale_value_type in {ScaleValueType.MAXABS, ScaleValueType.OPT}:
147153
return MulAdditionalScales(scale_round_method, self.params, self.device_for_scales)
148-
if scale_value_type == ScaleValueType.DYNAMIC:
149-
return MulAdditionalDynamicScales(scale_round_method, self.params, self.device_for_scales)
150154
## maxabs/opt in channel PTS
151155
case (_, ScaleGranularity.PTS, QuantTensorName.WEIGHT_IN_CH, _) \
152156
if scale_value_type not in {ScaleValueType.SMOOTHQUANT_OPT, ScaleValueType.SMOOTHQUANT_MAXABS}:
153157
return None
154158
case (ScaleValueType.MAXABS, ScaleGranularity.PTS, _, _):
159+
if is_dynamic:
160+
return MaxAbsDynamicPts(scale_round_method, self.params, self.device_for_scales, backoff)
155161
return MaxAbsPts(scale_round_method, self.params, self.device_for_scales, backoff)
156162
## maxabs/opt in channel PCS
157163
case (_, ScaleGranularity.PCS, QuantTensorName.WEIGHT_IN_CH, _)\
@@ -160,6 +166,8 @@ def get_scale_method(self, tensor_name):
160166
return InputChannelScale(scale_round_method, self.params, self.device_for_scales, in_channel_size)
161167
## maxabs PCS
162168
case (ScaleValueType.MAXABS, ScaleGranularity.PCS, _, _):
169+
if is_dynamic:
170+
return MaxAbsDynamicPcs(scale_round_method, self.params, self.device_for_scales, backoff)
163171
return MaxAbsPcs(scale_round_method, self.params, self.device_for_scales, backoff)
164172
## opt PTS
165173
case (ScaleValueType.OPT, ScaleGranularity.PTS, _, _):
@@ -188,8 +196,6 @@ def get_scale_method(self, tensor_name):
188196
case (ScaleValueType.SMOOTHQUANT_OPT, _, QuantTensorName.INPUT, _):
189197
backoff_weight = self.params.get("weight_backoff", 1)
190198
return InputSmoothQuantOpt(scale_round_method, self.mod.weight, self.params, self.device_for_scales, backoff, backoff_weight)
191-
case (ScaleValueType.DYNAMIC, ScaleGranularity.PCS, QuantTensorName.INPUT, _):
192-
return MaxAbsDynamicPcs(scale_round_method, self.params, self.device_for_scales, backoff)
193199
case _:
194200
raise NotImplementedError("the config: scale_round_method: " + \
195201
str(scale_round_method) +

0 commit comments

Comments
 (0)