Skip to content

Commit eb1569e

Browse files
Yantom1xinhe3
authored andcommitted
[SW-230641] Remove smoothquant related scale methods (#258)
1 parent 4bd3385 commit eb1569e

File tree

10 files changed

+7
-119
lines changed

10 files changed

+7
-119
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_config.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,10 @@ class ScaleMethodString(Enum):
2424
HW_ALIGNED_SINGLE_SCALE = auto()
2525
MAXABS_HW = auto()
2626
MAXABS_POW2 = auto()
27-
SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2 = auto()
28-
WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2 = auto()
2927
ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2 = auto()
3028
ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2 = auto()
3129
ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2 = auto()
3230
ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2 = auto()
33-
SMOOTHQUANT_OPT = auto()
3431
MAXABS_HW_OPT_WEIGHT = auto()
3532
MAXABS_POW2_OPT_WEIGHT = auto()
3633
MAXABS_ARBITRARY = auto()
@@ -44,9 +41,6 @@ class ScaleValueType(Enum):
4441
MAXABS = auto()
4542
FIXED_VALUE = auto()
4643
OPT = auto()
47-
SMOOTHQUANT_MAXABS = auto()
48-
SMOOTHQUANT_OPT = auto()
49-
SMOOTHQUANT_WEAK = auto()
5044
DUMMY_SCALES = auto()
5145

5246
class ScaleRoundMethod(Enum):
@@ -150,16 +144,6 @@ def __eq__(self, other):
150144
CfgStr.WEIGHT: ScaleMethodConfig(granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.5),
151145
CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method= ScaleRoundMethod.HW_ALIGNED, backoff= 0.25)
152146
},
153-
ScaleMethodString.SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2:
154-
{
155-
CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type = ScaleValueType.SMOOTHQUANT_MAXABS, granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.5),
156-
CfgStr.ACTIVATION: ScaleMethodConfig(scale_value_type = ScaleValueType.SMOOTHQUANT_MAXABS, granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.25, params={"alpha": 0.5})
157-
},
158-
ScaleMethodString.WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2:
159-
{
160-
CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type = ScaleValueType.SMOOTHQUANT_WEAK, granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.5),
161-
CfgStr.ACTIVATION: ScaleMethodConfig(scale_value_type = ScaleValueType.SMOOTHQUANT_WEAK, granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.25, params={"alpha": 0.5})
162-
},
163147
ScaleMethodString.ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2:
164148
{
165149
CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type = ScaleValueType.OPT, granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.5, params={"weight_scales": [2.0**s for s in range(-3, 5)]}),
@@ -175,11 +159,6 @@ def __eq__(self, other):
175159
CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type = ScaleValueType.OPT, granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.5, params={"weight_scales": [2.0**s for s in range(-3, 5)]}),
176160
CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method= ScaleRoundMethod.POW2, backoff= 0.25)
177161
},
178-
ScaleMethodString.SMOOTHQUANT_OPT:
179-
{
180-
CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type = ScaleValueType.SMOOTHQUANT_OPT, granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.5, params={"transformed_weight_scales": [2.0**s for s in range(-3, 5)]}),
181-
CfgStr.ACTIVATION: ScaleMethodConfig(scale_value_type = ScaleValueType.SMOOTHQUANT_OPT, granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.25, params={"alpha": 0.5})
182-
},
183162
}
184163

185164
reverse_scale_method_mapping = {

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_factory.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ def get_scale_method(self, tensor_type, is_dynamic=False):
8787
if scale_value_type in {ScaleValueType.MAXABS, ScaleValueType.OPT}:
8888
return MulAdditionalScales(scale_round_method, self.params, self.device_for_scales)
8989
## maxabs/opt in channel PTS
90-
case (_, ScaleGranularity.PTS, QuantTensorName.WEIGHT_IN_CH, _) \
91-
if scale_value_type not in {ScaleValueType.SMOOTHQUANT_OPT, ScaleValueType.SMOOTHQUANT_MAXABS}:
90+
case (_, ScaleGranularity.PTS, QuantTensorName.WEIGHT_IN_CH, _):
9291
return None
9392
case (ScaleValueType.MAXABS, ScaleGranularity.PTS, _, _):
9493
if is_dynamic:
@@ -111,28 +110,6 @@ def get_scale_method(self, tensor_type, is_dynamic=False):
111110
case (ScaleValueType.OPT, ScaleGranularity.PCS, _, _):
112111
opt_list_of_scales = self.scale_method_config_map[tensor_type].params["weight_scales"]
113112
return OptScalesPcs(scale_round_method, opt_list_of_scales, self.params, self.device_for_scales, backoff)
114-
## smooth quant
115-
case (_, ScaleGranularity.PCS, QuantTensorName.WEIGHT_IN_CH, _) \
116-
if scale_value_type in {ScaleValueType.SMOOTHQUANT_OPT, ScaleValueType.SMOOTHQUANT_MAXABS}:
117-
return WeightIchSmoothQuant(scale_round_method, self.params, self.device_for_scales)
118-
case (_, ScaleGranularity.PCS, QuantTensorName.OUTPUT, _) \
119-
if scale_value_type in {ScaleValueType.SMOOTHQUANT_OPT, ScaleValueType.SMOOTHQUANT_MAXABS} \
120-
and self.op_type in {"linear", "matmul"}:
121-
return UseFirstAdditionalScales(scale_round_method, self.params, self.device_for_scales)
122-
## SMOOTHQUANT_MAXABS input and weight out channel
123-
case (ScaleValueType.SMOOTHQUANT_MAXABS, ScaleGranularity.PCS, QuantTensorName.WEIGHT_OUT_CH, _):
124-
return MaxAbsPcs(scale_round_method, self.params, self.device_for_scales, backoff)
125-
case (ScaleValueType.SMOOTHQUANT_MAXABS, ScaleGranularity.PCS, QuantTensorName.INPUT, _):
126-
alpha = self.scale_method_config_map[QuantTensorName.INPUT].params["alpha"]
127-
return InputSmoothQuantMaxAbs(scale_round_method, self.mod.weight, self.params, self.device_for_scales, backoff, alpha)
128-
## SMOOTHQUANT_OPT input and weight out channel
129-
case (ScaleValueType.SMOOTHQUANT_OPT, _, QuantTensorName.WEIGHT_OUT_CH, _):
130-
opt_list_of_scales = self.scale_method_config_map[tensor_type].params["transformed_weight_scales"]
131-
return OptScalesPcs(scale_round_method, opt_list_of_scales, self.params, self.device_for_scales, backoff)
132-
case (ScaleValueType.SMOOTHQUANT_OPT, _, QuantTensorName.INPUT, _):
133-
backoff_weight = self.scale_method_config_map[QuantTensorName.WEIGHT_OUT_CH].backoff
134-
alpha = self.scale_method_config_map[QuantTensorName.INPUT].params["alpha"]
135-
return InputSmoothQuantOpt(scale_round_method, self.mod.weight, self.params, self.device_for_scales, backoff, backoff_weight, alpha)
136113
case _:
137114
raise NotImplementedError("the config: scale_round_method: " + \
138115
str(scale_round_method) +

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scales_method.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -225,60 +225,6 @@ def calc_scales(self, tensor, tensor_type, **additional_kwargs):
225225
return self.scale
226226

227227

228-
class InputSmoothQuantMaxAbs(ScalesMethod):
229-
def __init__(self, round_scale_method, weight, params, device_for_scales, backoff, alpha):
230-
super().__init__(round_scale_method, params, device_for_scales)
231-
self.round_scale_method = round_scale_method
232-
self.weight = weight
233-
self.alpha = alpha
234-
self.backoff = backoff
235-
self.device_for_scales = device_for_scales
236-
237-
def calc_scales(self, tensor, tensor_type, **additional_kwargs):
238-
weight_scale_in_ch = MaxAbsPcs(ScaleIdentity(), self.params, self.device_for_scales, 1.0, 1.0, dim=0).calc_scales(
239-
self.weight, QuantTensorType.CONST)
240-
input_range = torch.tensor(tensor, dtype=self.hp_dtype, device=self.device)
241-
input_scale = MaxAbsPts(ScaleIdentity(), self.params, self.device_for_scales, 1.0, 1.0).calc_scales(tensor,
242-
QuantTensorType.MEASUREMENTS)
243-
input_scale = (input_scale ** self.alpha) / (weight_scale_in_ch ** (1 - self.alpha))
244-
input_scale = self.round_scale_method.calc(input_scale)
245-
input_range_post = input_range / input_scale
246-
input_scale_post = calc_maxabs_scale(input_range_post.max(), self.fullscale, self.backoff)
247-
input_scale_post = self.round_scale_method.calc(input_scale_post)
248-
input_scale = input_scale * input_scale_post
249-
self.scale = input_scale
250-
return self.scale
251-
252-
class InputSmoothQuantOpt(ScalesMethod):
253-
def __init__(self, round_scale_method, weight, params, device_for_scales, backoff, backoff_weight, alpha):
254-
super().__init__(round_scale_method, params, device_for_scales)
255-
self.round_scale_method = round_scale_method
256-
self.weight = weight
257-
self.alpha = alpha
258-
self.backoff = backoff
259-
self.backoff_weight = backoff_weight
260-
self.device_for_scales = device_for_scales
261-
262-
def calc_scales(self, tensor, tensor_type, **additional_kwargs):
263-
weight_scale_in_ch = MaxAbsPcs(ScaleIdentity(), self.params, self.device_for_scales, self.backoff_weight,
264-
self.fullscale, dim=0).calc_scales(self.weight, QuantTensorType.CONST)
265-
input_scale = MaxAbsPts(ScaleIdentity(), self.params, self.device_for_scales, self.backoff,
266-
self.fullscale).calc_scales(tensor, QuantTensorType.MEASUREMENTS)
267-
input_scale = (input_scale ** self.alpha) / (weight_scale_in_ch ** (1 - self.alpha))
268-
input_scale = self.round_scale_method.calc(input_scale)
269-
self.scale = input_scale
270-
return self.scale
271-
272-
273-
class WeightIchSmoothQuant(ScalesMethod):
274-
def __init__(self, round_scale_method, params, device_for_scales):
275-
super().__init__(round_scale_method, params, device_for_scales)
276-
277-
def calc_scales(self, tensor, tensor_type, **additional_kwargs):
278-
self.scale = 1 / tensor
279-
return self.scale
280-
281-
282228
class MaxAbsDynamicPcs(MaxAbsPcs):
283229

284230
def __init__(self, round_scale_method, params, device_for_scales, backoff, fullscale=None):

test/3x/torch/algorithms/fp8_quant/tester.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,7 @@
4848
ScaleMethodString.MAXABS_HW_OPT_WEIGHT,
4949
ScaleMethodString.MAXABS_POW2_OPT_WEIGHT,
5050
]
51-
SCALE_METHODS_KEY_ERROR = [
52-
ScaleMethodString.SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2,
53-
ScaleMethodString.WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2,
54-
ScaleMethodString.SMOOTHQUANT_OPT,
55-
]
51+
5652
SCALE_METHODS_COMPILATION_ERROR = [
5753
ScaleMethodString.ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2,
5854
ScaleMethodString.ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2,

test/3x/torch/algorithms/fp8_quant/unit_tests/test_functions/test_config_json.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import Matmul
1010
from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import QuantMode
1111
from neural_compressor.torch.algorithms.fp8_quant._core.scale_methods.scale_method_config import ScaleMethodString
12-
from ...tester import run_with_raised_exception, get_internal_config, SCALE_METHODS_QUANT_ONLY, SCALE_METHODS_KEY_ERROR
12+
from ...tester import run_with_raised_exception, get_internal_config, SCALE_METHODS_QUANT_ONLY
1313
from ...test_hpu_utils import *
1414

1515
class Model(torch.nn.Module):
@@ -49,9 +49,8 @@ def run_predefined_config():
4949
prepare_model._prep_model_with_predefined_config(model, config=config)
5050
fp8_quant.finish_measurements(model)
5151

52-
if scale_method in SCALE_METHODS_KEY_ERROR and quant_mode == QuantMode.QUANTIZE:
53-
pytest.xfail("KeyError")
54-
elif scale_method == ScaleMethodString.ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW:
52+
53+
if scale_method == ScaleMethodString.ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW:
5554
return run_with_raised_exception(run_predefined_config, ValueError, "Unsupported config: scale_method")
5655
# This is an expected exception, as test is not measuring before
5756
elif scale_method not in SCALE_METHODS_QUANT_ONLY:

test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_conv2d.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ def test_conv2d_accuracy(hp_dtype: torch.dtype, lp_dtype: torch.dtype, scale_met
2626
# TODO [SW-196641]: fix the following issues:
2727
if scale_method in SCALE_METHODS_SEGFAULT:
2828
pytest.skip("Not supported")
29-
if scale_method in SCALE_METHODS_KEY_ERROR:
30-
pytest.xfail("KeyError")
3129
if scale_method in SCALE_METHODS_COMPILATION_ERROR:
3230
pytest.xfail("Graph compile error")
3331
quant_modes = QUANT_MODES_DEFAULT

test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_linear.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ def get_test_vectors(*, dtype: torch.dtype, N: int, D_in: int, atol: float = 0.0
3939
)
4040

4141
def check_tests_to_skip(scale_method, scale_format, dynamic_quantization, device_type = None):
42-
if scale_method in SCALE_METHODS_KEY_ERROR:
43-
pytest.xfail("KeyError")
4442
# TODO [SW-215692]: Fix segfault
4543
if scale_format == ScaleFormat.CONST or dynamic_quantization:
4644
if scale_method in [ScaleMethodString.MAXABS_HW_OPT_WEIGHT, ScaleMethodString.MAXABS_POW2_OPT_WEIGHT]:

test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_matmul.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ def forward(self, x, y):
5555
@pytest.mark.parametrize("device_type", device_type)
5656
@pytest.mark.parametrize("dynamic_quantization", [True, False], ids=["dynamic_quantization", "static_quantization"])
5757
def test_matmul_accuracy(hp_dtype: torch.dtype, lp_dtype: torch.dtype, scale_method: ScaleMethodString, device_type: str, dynamic_quantization: bool):
58-
# TODO [SW-196641]: fix the following issues:
59-
if scale_method in SCALE_METHODS_KEY_ERROR:
60-
pytest.xfail("KeyError")
6158
quant_modes = QUANT_MODES_DEFAULT
6259
atol = 0.2
6360
if scale_method in SCALE_METHODS_QUANT_ONLY or dynamic_quantization:

test/3x/torch/algorithms/fp8_quant/unit_tests/test_runtime_scale_patching.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import habana_frameworks.torch.core as htcore
77
import habana_frameworks.torch.utils.experimental as htexp
88

9-
from ..tester import RUNTIME_SCALE_PATCHING_SUPPORTED_METHODS_LIST, SCALE_METHODS_KEY_ERROR, run_with_raised_exception
9+
from ..tester import RUNTIME_SCALE_PATCHING_SUPPORTED_METHODS_LIST, run_with_raised_exception
1010
from neural_compressor.torch.algorithms.fp8_quant._core.common import is_runtime_scale_patching
1111
from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import ScaleMethodString
1212
from neural_compressor.torch.quantization import FP8Config, convert, prepare, finalize_calibration
@@ -52,8 +52,6 @@ def temp_directory():
5252
@pytest.mark.parametrize("scale_format", ["SCALAR", "CONST"])
5353
@pytest.mark.parametrize("dynamic_scale_patching", [True, False])
5454
def test_no_assert(scale_method, scale_format,dynamic_scale_patching, temp_directory):
55-
if scale_method in SCALE_METHODS_KEY_ERROR :
56-
pytest.xfail("KeyError")
5755
model = TinyModel()
5856
model.eval()
5957
model = model.to("hpu").to(torch.bfloat16)

test/3x/torch/algorithms/fp8_quant/unit_tests/test_scale_method_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def forward(self, x):
3737
def check_tests_to_skip(scale_method, scale_value_type_weight = None, scale_value_type_activation = None):
3838
if scale_value_type_weight == ScaleValueType.DUMMY_SCALES or scale_value_type_activation == ScaleValueType.DUMMY_SCALES:
3939
pytest.xfail("Dummy scales is not a scale method")
40-
if scale_method in SCALE_METHODS_KEY_ERROR or scale_method in SUPPORTED_DYNAMIC_SCALES:
40+
if scale_method in SUPPORTED_DYNAMIC_SCALES:
4141
pytest.xfail("Key error")
4242

4343
@pytest.mark.parametrize("scale_granularity_weight", ScaleGranularity)

0 commit comments

Comments
 (0)