Skip to content

Commit 7766454

Browse files
authored
pt 3.x config for static quant and smooth quant (#1568)
Signed-off-by: Cheng, Zixuan <[email protected]>
1 parent 52ea445 commit 7766454

File tree

5 files changed

+219
-1
lines changed

5 files changed

+219
-1
lines changed

neural_compressor/common/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
COMPOSABLE_CONFIG = "composable_config"
3030
RTN = "rtn"
3131
STATIC_QUANT = "static_quant"
32+
SMOOTH_QUANT = "smooth_quant"
3233
GPTQ = "gptq"
3334
FP8_QUANT = "fp8_quant"
3435

neural_compressor/torch/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
get_default_rtn_config,
2222
GPTQConfig,
2323
get_default_gptq_config,
24+
StaticQuantConfig,
25+
get_default_static_config,
26+
SmoothQuantConfig,
27+
get_default_sq_config,
2428
)
2529

2630
from neural_compressor.common.base_tuning import TuningConfig

neural_compressor/torch/quantization/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,8 @@
1818
get_default_rtn_config,
1919
GPTQConfig,
2020
get_default_gptq_config,
21+
StaticQuantConfig,
22+
get_default_static_config,
23+
SmoothQuantConfig,
24+
get_default_sq_config,
2125
)

neural_compressor/torch/quantization/config.py

Lines changed: 194 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,15 @@
2424
import torch
2525

2626
from neural_compressor.common.base_config import BaseConfig, config_registry, register_config
27-
from neural_compressor.common.utils import DEFAULT_WHITE_LIST, FP8_QUANT, GPTQ, OP_NAME_OR_MODULE_TYPE, RTN
27+
from neural_compressor.common.utils import (
28+
DEFAULT_WHITE_LIST,
29+
FP8_QUANT,
30+
GPTQ,
31+
OP_NAME_OR_MODULE_TYPE,
32+
RTN,
33+
SMOOTH_QUANT,
34+
STATIC_QUANT,
35+
)
2836
from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN
2937
from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger
3038

@@ -282,6 +290,191 @@ def get_default_gptq_config() -> GPTQConfig:
282290
return GPTQConfig()
283291

284292

293+
######################## Static Quant Config ###############################
294+
@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT)
295+
class StaticQuantConfig(BaseConfig):
296+
"""Config class for static quantization."""
297+
298+
name = STATIC_QUANT
299+
params_list = [
300+
"w_dtype",
301+
"w_sym",
302+
"w_granularity",
303+
"w_algo",
304+
"act_dtype",
305+
"act_sym",
306+
"act_granularity",
307+
"act_algo",
308+
]
309+
supported_configs: List[OperatorConfig] = []
310+
311+
def __init__(
312+
self,
313+
w_dtype: str = "int8",
314+
w_sym: bool = True,
315+
w_granularity: str = "per_channel",
316+
w_algo: str = "minmax",
317+
act_dtype: str = "uint8",
318+
act_sym: bool = False,
319+
act_granularity: str = "per_tensor",
320+
act_algo: str = "kl",
321+
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
322+
):
323+
"""Init Static Quant Configs."""
324+
super().__init__(white_list=white_list)
325+
self.w_dtype = w_dtype
326+
self.w_sym = w_sym
327+
self.w_granularity = w_granularity
328+
self.w_algo = w_algo
329+
self.act_dtype = act_dtype
330+
self.act_sym = act_sym
331+
self.act_granularity = act_granularity
332+
self.act_algo = act_algo
333+
self._post_init()
334+
335+
@classmethod
336+
def register_supported_configs(cls) -> List[OperatorConfig]:
337+
supported_configs = []
338+
# TODO(Yi)
339+
linear_static_config = StaticQuantConfig()
340+
operators = [torch.nn.Linear]
341+
supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators))
342+
cls.supported_configs = supported_configs
343+
344+
@staticmethod
345+
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
346+
white_list = (torch.nn.Linear,)
347+
filter_result = []
348+
for op_name, module in model.named_modules():
349+
if isinstance(module, white_list):
350+
pair = (op_name, type(module).__name__)
351+
filter_result.append(pair)
352+
logger.debug(f"Get model info: {filter_result}")
353+
return filter_result
354+
355+
356+
# TODO(Yi) run `register_supported_configs` for all registered config.
357+
StaticQuantConfig.register_supported_configs()
358+
359+
360+
def get_default_static_config() -> StaticQuantConfig:
361+
"""Generate the default static quant config.
362+
363+
Returns:
364+
the default static quant config.
365+
"""
366+
return StaticQuantConfig()
367+
368+
369+
######################## Smooth Quant Config ###############################
370+
@register_config(framework_name=FRAMEWORK_NAME, algo_name=SMOOTH_QUANT)
371+
class SmoothQuantConfig(BaseConfig):
372+
"""Config class for smooth quantization."""
373+
374+
name = SMOOTH_QUANT
375+
params_list = [
376+
"w_dtype",
377+
"w_sym",
378+
"w_granularity",
379+
"w_algo",
380+
"act_dtype",
381+
"act_sym",
382+
"act_granularity",
383+
"act_algo",
384+
"alpha",
385+
"folding",
386+
"scale_sharing",
387+
"auto_alpha_args",
388+
]
389+
supported_configs: List[OperatorConfig] = []
390+
391+
def __init__(
392+
self,
393+
w_dtype: str = "int8",
394+
w_sym: bool = True,
395+
w_granularity: str = "per_channel",
396+
w_algo: str = "minmax",
397+
act_dtype: str = "uint8",
398+
act_sym: bool = False,
399+
act_granularity: str = "per_tensor",
400+
act_algo: str = "kl",
401+
alpha: float = 0.5,
402+
folding: bool = False,
403+
# below for autotune
404+
scale_sharing: bool = False,
405+
init_alpha: float = 0.5,
406+
alpha_min: float = 0.0,
407+
alpha_max: float = 1.0,
408+
alpha_step: float = 0.1,
409+
shared_criterion: str = "max",
410+
enable_blockwise_loss: bool = False,
411+
auto_alpha_args: dict = None,
412+
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
413+
):
414+
"""Init SmoothQuant Configs."""
415+
super().__init__(white_list=white_list)
416+
self.w_dtype = w_dtype
417+
self.w_sym = w_sym
418+
self.w_granularity = w_granularity
419+
self.w_algo = w_algo
420+
self.act_dtype = act_dtype
421+
self.act_sym = act_sym
422+
self.act_granularity = act_granularity
423+
self.act_algo = act_algo
424+
self.alpha = alpha
425+
self.folding = folding
426+
# below for autotune
427+
self.scale_sharing = scale_sharing
428+
self.init_alpha = init_alpha
429+
self.alpha_min = alpha_min
430+
self.alpha_max = alpha_max
431+
self.alpha_step = alpha_step
432+
self.shared_criterion = shared_criterion
433+
self.enable_blockwise_loss = enable_blockwise_loss
434+
self.auto_alpha_args = {
435+
"init_alpha": self.init_alpha,
436+
"alpha_min": self.alpha_min,
437+
"alpha_max": self.alpha_max,
438+
"alpha_step": self.alpha_step,
439+
"shared_criterion": self.shared_criterion,
440+
"enable_blockwise_loss": self.enable_blockwise_loss,
441+
}
442+
self._post_init()
443+
444+
@classmethod
445+
def register_supported_configs(cls) -> List[OperatorConfig]:
446+
supported_configs = []
447+
# TODO(Yi)
448+
linear_sq_config = SmoothQuantConfig()
449+
operators = [torch.nn.Linear]
450+
supported_configs.append(OperatorConfig(config=linear_sq_config, operators=operators))
451+
cls.supported_configs = supported_configs
452+
453+
@staticmethod
454+
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
455+
white_list = (torch.nn.Linear,)
456+
filter_result = []
457+
for op_name, module in model.named_modules():
458+
if isinstance(module, white_list):
459+
pair = (op_name, type(module).__name__)
460+
filter_result.append(pair)
461+
logger.debug(f"Get model info: {filter_result}")
462+
return filter_result
463+
464+
465+
# TODO(Yi) run `register_supported_configs` for all registered config.
466+
SmoothQuantConfig.register_supported_configs()
467+
468+
469+
def get_default_sq_config() -> SmoothQuantConfig:
470+
"""Generate the default smoothquant config.
471+
472+
Returns:
473+
the default smoothquant config.
474+
"""
475+
return SmoothQuantConfig()
476+
477+
285478
######################## FP8 Config ###############################
286479
if is_hpex_avaliable():
287480

test/3x/torch/test_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,22 @@ def test_gptq_config(self):
321321
gptq_config2 = GPTQConfig.from_dict(quant_config_dict["gptq"])
322322
self.assertEqual(gptq_config1.to_dict(), gptq_config2.to_dict())
323323

324+
def test_static_quant_config(self):
325+
from neural_compressor.torch.quantization import StaticQuantConfig
326+
327+
static_config1 = StaticQuantConfig(w_dtype="int8", act_sym=True, act_algo="minmax")
328+
quant_config_dict = {"static": {"w_dtype": "int8", "act_sym": True, "act_algo": "minmax"}}
329+
static_config2 = StaticQuantConfig.from_dict(quant_config_dict["static"])
330+
self.assertEqual(static_config1.to_dict(), static_config2.to_dict())
331+
332+
def test_smooth_quant_config(self):
333+
from neural_compressor.torch.quantization import SmoothQuantConfig
334+
335+
sq_config1 = SmoothQuantConfig(alpha=0.8, folding=True)
336+
quant_config_dict = {"sq": {"alpha": 0.8, "folding": True}}
337+
sq_config2 = SmoothQuantConfig.from_dict(quant_config_dict["sq"])
338+
self.assertEqual(sq_config1.to_dict(), sq_config2.to_dict())
339+
324340

325341
class TestQuantConfigForAutotune(unittest.TestCase):
326342
def test_expand_config(self):

0 commit comments

Comments
 (0)