Skip to content

Commit 2b02536

Browse files
nirda7XuehaoSun
authored andcommitted
[SW-214855] - Set scale attributes in INC to reduce graph recompilations (#162)
* [SW-219831] - Set scale attributes in INC to reduce grpah recompilation * add scaling methods ids * fix scaling method ids check and set * enable feature also for Load QuantMode * move scale tensors to cpu when feature is enabled * fix scaling methods ids to start at 1 * fix cr comments * remove unnecessary imports * fix cr comments * fix more cr comments * fix cr comments * move scale to float on cpu in scale handler for dynamic scaling * fix cr comments * Add unit test * fix sending scale tensor to bridge and unit-test bug
1 parent b959f1a commit 2b02536

File tree

10 files changed

+147
-6
lines changed

10 files changed

+147
-6
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import torch
2121
from enum import Enum, auto
22+
from functools import lru_cache
2223

2324
from .._quant_common.quant_config import get_hqt_config
2425
from ..utils.logger import logger
@@ -288,3 +289,8 @@ def create_mod_info_recursion(parent):
288289
def get_device_type_for_scales(mod):
289290
config = get_hqt_config(mod).cfg
290291
return config["device_for_scales"]
292+
293+
294+
@lru_cache
295+
def is_runtime_scale_patching():
296+
return os.getenv("RUNTIME_SCALE_PATCHING", "False").lower() in ["true", "1"]

neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def prepare_model_with_dummy_measurement(model, mod_list, scaling_method_name, s
204204
mode_type = config.cfg["mod_dict"][mod_type_str]
205205
mod_info = mod_types[mode_type]
206206

207-
op_obj = ops_quantizer.get_op_quantizer(mode_type, "dummy", mod, None, scale_config)
207+
op_obj = ops_quantizer.get_op_quantizer("dummy", mod, None, scale_config, mode_type)
208208
dummy_mod_scales = op_obj.get_scales_module_config()
209209
dummy_mod_config = op_obj.scales_module_config_to_q_and_dq(dummy_mod_scales)
210210
dummy_mod_extra_config = ModuleExtraConfig(

neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/hpu/hpu_quantized_func_wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from ..quantized_func_wrapper import QuantizedFuncWrapperBase, OP_TYPE, QuantizedFuncWrapperFactory
21+
from ...common import is_runtime_scale_patching
2122
from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import ScaleFormat
2223
try: # backwards compatibility for 1.16
2324
from habana_frameworks.torch.hpex.kernels import fp8_fused_sdpa
@@ -40,6 +41,8 @@ def get_default_quantized_func(self):
4041
raise NotImplementedError()
4142

4243
def get_scalar_quantized_func(self):
44+
if is_runtime_scale_patching():
45+
return self.get_default_quantized_func()
4346
return self.get_default_quantized_func().scalar
4447

4548
def get_quantized_func(self, scale_format):

neural_compressor/torch/algorithms/fp8_quant/_core/scale.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def load_layer_scales(mod, mod_name, config, mod_type_str, measurement, scales,
3030
)
3131
mod_extra_config = None
3232
if mod_name in scales or not config.cfg["use_stats_files"] or mod_name in measurement:
33-
op_for_scale_obj = ops_quantizer.get_op_quantizer(module_type, scaling_method_name, mod,
34-
measurement.get(mod_name, None), scale_config)
33+
op_for_scale_obj = ops_quantizer.get_op_quantizer(scaling_method_name, mod, measurement.get(mod_name, None),
34+
scale_config, module_type)
3535
if mod_name not in scales:
3636
logger.debug("Calculating scales for module %s", mod_name)
3737
# calculates scales for current module according to scalling_methods
@@ -61,7 +61,7 @@ def prepare_layer_scales(mod, mod_name, config, mod_type_str, measurement, scale
6161
module_type,
6262
)
6363
mod_extra_config = None
64-
op_obj = ops_quantizer.get_op_quantizer(module_type, scaling_method_name, mod, None, scale_config)
64+
op_obj = ops_quantizer.get_op_quantizer(scaling_method_name, mod, None, scale_config, module_type)
6565
logger.debug("Preparing dynamic scales for module %s", mod_name)
6666
# calculates scales for current module according to scaling_methods
6767
scales[mod_name] = op_obj.get_scales_module_config() # ModuleConfig of scales

neural_compressor/torch/algorithms/fp8_quant/_core/scale_handler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
import types
1717
from .._quant_common.quant_config import ScaleFormat
18+
from .common import is_runtime_scale_patching
1819

1920

2021
def add_scale_registry(patched_mod):
@@ -34,6 +35,8 @@ def register_scale(patched_mod, name, scale, scale_format):
3435

3536

3637
def create_scale_tensor(orig_tensor, scale_format):
38+
if is_runtime_scale_patching() and scale_format in ScaleFormat.__members__.values():
39+
return orig_tensor.to("cpu").to(torch.float)
3740
if scale_format == ScaleFormat.CONST:
3841
if isinstance(orig_tensor, torch.Tensor):
3942
return torch.nn.Parameter(orig_tensor)

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,5 +380,5 @@ def scales_module_config_to_q_and_dq(self, module):
380380
"row_parallel_linear": RowParallelLinearOpQuantizer
381381
}
382382

383-
def get_op_quantizer(module_type, config, mod, measurement, params):
383+
def get_op_quantizer(config, mod, measurement, params, module_type):
384384
return ops_quantizer_map[module_type](config, mod, measurement, params, module_type)

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scales_method.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(self, round_scale_method, optional_scales_list, params, device_for_
161161

162162
def calc_scales(self, tensor, tensor_type, **additional_kwargs):
163163
self.scale = self.round_scale_method.calc(mmse_scale(tensor, self.optional_scales_list, self.lp_dtype, self.hp_dtype))
164-
return self.scale
164+
return self.scale
165165

166166
class OptScalesPcs(ScalesMethod):
167167
def __init__(self, round_scale_method, optional_scales_list, params, device_for_scales, backoff):

neural_compressor/torch/algorithms/fp8_quant/_core/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
from .measure import prepare_model as prepare_model_for_measure
2121
from .quantize import quantize
2222
from .scale import scale_method_mapping, scaling_params
23+
from .common import is_runtime_scale_patching
24+
25+
import os
26+
import habana_frameworks.torch.utils.experimental as htexp
2327

2428

2529
def update_mod_dict(config):
@@ -75,6 +79,22 @@ def quantize_dynamic_op(config, mod_type):
7579
logger.trace(f"should_quantize {name=} {mod_type=} returning {ret}")
7680
return ret
7781

82+
83+
scaling_methods_list = list(scale_method_mapping.values())
84+
#exlude substrings of scaling methods which are not supported for runtime scale patching mode to reduce graph recompile.
85+
exclude_substrings = ["pcs", "smoothquant"]
86+
runtime_scale_patching_supported_methods_list = [method for method in scaling_methods_list if not any(substr in method for substr in exclude_substrings)]
87+
88+
89+
def set_runtime_scale_patching_mode(scaling_method_name):
90+
if is_runtime_scale_patching():
91+
assert (
92+
scaling_method_name in runtime_scale_patching_supported_methods_list
93+
), f"Scaling method \"{scaling_method_name}\" is not supported for runtime scale patching (graph recompile reduction). Cannot set scaling attributes."
94+
htexp._set_scale_attributes("hw" in scaling_method_name or scaling_method_name == "unit_scale",
95+
scaling_methods_list.index(scaling_method_name) + 1)
96+
97+
7898
def prepare_model(model):
7999
"""Receives the parent module to quantize.
80100
Replaces its submodules with patched submodules that perform calibration and quantization.
@@ -101,4 +121,5 @@ def prepare_model(model):
101121
scaling_method_name = scale_method_mapping[(config.cfg["scale_method"], config.cfg["observer"])]
102122
scaling_params[scaling_method_name].update(config.cfg["scale_params"])
103123
config.cfg["scale_params"] = scaling_params[scaling_method_name]
124+
set_runtime_scale_patching_mode(scaling_method_name)
104125
return quantize(model, mod_list)

test/3x/torch/algorithms/fp8_quant/tester.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@
5454
torch.float32: "FP32",
5555
}
5656

57+
RUNTIME_SCALE_PATCHING_SUPPORTED_METHODS_LIST = [
58+
ScaleMethod.UNIT_SCALE,
59+
ScaleMethod.HW_ALIGNED_SINGLE_SCALE,
60+
ScaleMethod.MAXABS_HW,
61+
ScaleMethod.MAXABS_POW2,
62+
ScaleMethod.MAXABS_HW_OPT_WEIGHT,
63+
ScaleMethod.MAXABS_POW2_OPT_WEIGHT,
64+
ScaleMethod.MAXABS_ARBITRARY
65+
]
66+
5767
# Expects to get an exception. If there's no exception, the test will fail
5868
def run_with_raised_exception(test_to_run, error, error_str):
5969
with pytest_raises(Exception) as exc:
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import os
2+
import pytest
3+
import torch
4+
import shutil
5+
import copy
6+
import habana_frameworks.torch.core as htcore
7+
8+
from ..tester import RUNTIME_SCALE_PATCHING_SUPPORTED_METHODS_LIST, SCALE_METHODS_KEY_ERROR, run_with_raised_exception
9+
from neural_compressor.torch.algorithms.fp8_quant._core.common import is_runtime_scale_patching
10+
from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import ScaleMethod
11+
from neural_compressor.torch.quantization import FP8Config, convert, prepare, finalize_calibration
12+
13+
os.environ["PT_HPU_WEIGHT_SHARING"] = "0"
14+
htcore.hpu_inference_set_env()
15+
16+
17+
class TinyBlock(torch.nn.Module):
18+
19+
def __init__(self):
20+
super(TinyBlock, self).__init__()
21+
self.pre_linear = torch.nn.Linear(2, 1, bias=False)
22+
self.pre_linear.weight = torch.nn.Parameter(torch.ones([1, 2]))
23+
24+
def forward(self, x):
25+
x = self.pre_linear(x)
26+
return x
27+
28+
29+
class TinyModel(torch.nn.Module):
30+
31+
def __init__(self):
32+
super(TinyModel, self).__init__()
33+
self.block = TinyBlock()
34+
35+
def forward(self, x):
36+
x = self.block(x)
37+
return x
38+
39+
40+
@pytest.fixture
41+
def temp_directory():
42+
# Create a temporary directory
43+
temp_dir = "./test_runtime_scale_patching_outputs"
44+
os.makedirs(temp_dir)
45+
# Yield the temporary directory path to the test
46+
yield temp_dir
47+
# Cleanup: Remove the temporary directory after the test ends
48+
shutil.rmtree(temp_dir)
49+
50+
51+
@pytest.mark.parametrize("scale_method", ScaleMethod)
52+
@pytest.mark.parametrize("scale_format", ["SCALAR", "CONST"])
53+
@pytest.mark.parametrize("dynamic_scale_patching", [True, False])
54+
def test_no_assert(scale_method, scale_format,dynamic_scale_patching, temp_directory):
55+
if scale_method in SCALE_METHODS_KEY_ERROR:
56+
pytest.xfail("KeyError")
57+
model = TinyModel()
58+
model.eval()
59+
model = model.to("hpu").to(torch.bfloat16)
60+
inference_model = copy.deepcopy(model)
61+
htcore.hpu_inference_initialize()
62+
63+
measure_config_dict = {
64+
"mode": "MEASURE",
65+
"observer": "maxabs",
66+
"allowlist": {"types": [], "names": []},
67+
"blocklist": {"types": [], "names": []},
68+
"dump_stats_path": f"{temp_directory}/inc_output"
69+
}
70+
quant_config_dict = {
71+
"mode": "QUANTIZE",
72+
"scale_format": scale_format,
73+
"scale_method": scale_method.name,
74+
"allowlist": {"types": [], "names": []},
75+
"blocklist": {"types": [], "names": []},
76+
"dump_stats_path": f"{temp_directory}/inc_output"
77+
}
78+
measure_config = FP8Config.from_dict(measure_config_dict)
79+
quant_config = FP8Config.from_dict(quant_config_dict)
80+
81+
def run_convert():
82+
convert(inference_model, quant_config)
83+
84+
is_runtime_scale_patching.cache_clear()
85+
os.environ["RUNTIME_SCALE_PATCHING"] = "0"
86+
87+
model = prepare(model, measure_config)
88+
input = torch.tensor([1.2,2.1]).to(torch.bfloat16).to("hpu")
89+
model(input)
90+
finalize_calibration(model)
91+
92+
if dynamic_scale_patching:
93+
os.environ["RUNTIME_SCALE_PATCHING"] = "1"
94+
if not scale_method in RUNTIME_SCALE_PATCHING_SUPPORTED_METHODS_LIST:
95+
run_with_raised_exception(run_convert, AssertionError, "Cannot set scaling attributes.")
96+
return
97+
# The following convert should run successfully without any asserts
98+
inference_model = convert(inference_model, quant_config)

0 commit comments

Comments
 (0)