|
24 | 24 | import torch |
25 | 25 |
|
26 | 26 | from neural_compressor.common.base_config import BaseConfig, config_registry, register_config |
27 | | -from neural_compressor.common.utils import DEFAULT_WHITE_LIST, FP8_QUANT, GPTQ, OP_NAME_OR_MODULE_TYPE, RTN |
| 27 | +from neural_compressor.common.utils import ( |
| 28 | + DEFAULT_WHITE_LIST, |
| 29 | + FP8_QUANT, |
| 30 | + GPTQ, |
| 31 | + OP_NAME_OR_MODULE_TYPE, |
| 32 | + RTN, |
| 33 | + SMOOTH_QUANT, |
| 34 | + STATIC_QUANT, |
| 35 | +) |
28 | 36 | from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN |
29 | 37 | from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger |
30 | 38 |
|
@@ -282,6 +290,191 @@ def get_default_gptq_config() -> GPTQConfig: |
282 | 290 | return GPTQConfig() |
283 | 291 |
|
284 | 292 |
|
| 293 | +######################## Static Quant Config ############################### |
| 294 | +@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT) |
| 295 | +class StaticQuantConfig(BaseConfig): |
| 296 | + """Config class for static quantization.""" |
| 297 | + |
| 298 | + name = STATIC_QUANT |
| 299 | + params_list = [ |
| 300 | + "w_dtype", |
| 301 | + "w_sym", |
| 302 | + "w_granularity", |
| 303 | + "w_algo", |
| 304 | + "act_dtype", |
| 305 | + "act_sym", |
| 306 | + "act_granularity", |
| 307 | + "act_algo", |
| 308 | + ] |
| 309 | + supported_configs: List[OperatorConfig] = [] |
| 310 | + |
| 311 | + def __init__( |
| 312 | + self, |
| 313 | + w_dtype: str = "int8", |
| 314 | + w_sym: bool = True, |
| 315 | + w_granularity: str = "per_channel", |
| 316 | + w_algo: str = "minmax", |
| 317 | + act_dtype: str = "uint8", |
| 318 | + act_sym: bool = False, |
| 319 | + act_granularity: str = "per_tensor", |
| 320 | + act_algo: str = "kl", |
| 321 | + white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, |
| 322 | + ): |
| 323 | + """Init Static Quant Configs.""" |
| 324 | + super().__init__(white_list=white_list) |
| 325 | + self.w_dtype = w_dtype |
| 326 | + self.w_sym = w_sym |
| 327 | + self.w_granularity = w_granularity |
| 328 | + self.w_algo = w_algo |
| 329 | + self.act_dtype = act_dtype |
| 330 | + self.act_sym = act_sym |
| 331 | + self.act_granularity = act_granularity |
| 332 | + self.act_algo = act_algo |
| 333 | + self._post_init() |
| 334 | + |
| 335 | + @classmethod |
| 336 | + def register_supported_configs(cls) -> List[OperatorConfig]: |
| 337 | + supported_configs = [] |
| 338 | + # TODO(Yi) |
| 339 | + linear_static_config = StaticQuantConfig() |
| 340 | + operators = [torch.nn.Linear] |
| 341 | + supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators)) |
| 342 | + cls.supported_configs = supported_configs |
| 343 | + |
| 344 | + @staticmethod |
| 345 | + def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: |
| 346 | + white_list = (torch.nn.Linear,) |
| 347 | + filter_result = [] |
| 348 | + for op_name, module in model.named_modules(): |
| 349 | + if isinstance(module, white_list): |
| 350 | + pair = (op_name, type(module).__name__) |
| 351 | + filter_result.append(pair) |
| 352 | + logger.debug(f"Get model info: {filter_result}") |
| 353 | + return filter_result |
| 354 | + |
| 355 | + |
| 356 | +# TODO(Yi) run `register_supported_configs` for all registered config. |
| 357 | +StaticQuantConfig.register_supported_configs() |
| 358 | + |
| 359 | + |
| 360 | +def get_default_static_config() -> StaticQuantConfig: |
| 361 | + """Generate the default static quant config. |
| 362 | +
|
| 363 | + Returns: |
| 364 | + the default static quant config. |
| 365 | + """ |
| 366 | + return StaticQuantConfig() |
| 367 | + |
| 368 | + |
| 369 | +######################## Smooth Quant Config ############################### |
| 370 | +@register_config(framework_name=FRAMEWORK_NAME, algo_name=SMOOTH_QUANT) |
| 371 | +class SmoothQuantConfig(BaseConfig): |
| 372 | + """Config class for smooth quantization.""" |
| 373 | + |
| 374 | + name = SMOOTH_QUANT |
| 375 | + params_list = [ |
| 376 | + "w_dtype", |
| 377 | + "w_sym", |
| 378 | + "w_granularity", |
| 379 | + "w_algo", |
| 380 | + "act_dtype", |
| 381 | + "act_sym", |
| 382 | + "act_granularity", |
| 383 | + "act_algo", |
| 384 | + "alpha", |
| 385 | + "folding", |
| 386 | + "scale_sharing", |
| 387 | + "auto_alpha_args", |
| 388 | + ] |
| 389 | + supported_configs: List[OperatorConfig] = [] |
| 390 | + |
| 391 | + def __init__( |
| 392 | + self, |
| 393 | + w_dtype: str = "int8", |
| 394 | + w_sym: bool = True, |
| 395 | + w_granularity: str = "per_channel", |
| 396 | + w_algo: str = "minmax", |
| 397 | + act_dtype: str = "uint8", |
| 398 | + act_sym: bool = False, |
| 399 | + act_granularity: str = "per_tensor", |
| 400 | + act_algo: str = "kl", |
| 401 | + alpha: float = 0.5, |
| 402 | + folding: bool = False, |
| 403 | + # below for autotune |
| 404 | + scale_sharing: bool = False, |
| 405 | + init_alpha: float = 0.5, |
| 406 | + alpha_min: float = 0.0, |
| 407 | + alpha_max: float = 1.0, |
| 408 | + alpha_step: float = 0.1, |
| 409 | + shared_criterion: str = "max", |
| 410 | + enable_blockwise_loss: bool = False, |
| 411 | + auto_alpha_args: dict = None, |
| 412 | + white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, |
| 413 | + ): |
| 414 | + """Init SmoothQuant Configs.""" |
| 415 | + super().__init__(white_list=white_list) |
| 416 | + self.w_dtype = w_dtype |
| 417 | + self.w_sym = w_sym |
| 418 | + self.w_granularity = w_granularity |
| 419 | + self.w_algo = w_algo |
| 420 | + self.act_dtype = act_dtype |
| 421 | + self.act_sym = act_sym |
| 422 | + self.act_granularity = act_granularity |
| 423 | + self.act_algo = act_algo |
| 424 | + self.alpha = alpha |
| 425 | + self.folding = folding |
| 426 | + # below for autotune |
| 427 | + self.scale_sharing = scale_sharing |
| 428 | + self.init_alpha = init_alpha |
| 429 | + self.alpha_min = alpha_min |
| 430 | + self.alpha_max = alpha_max |
| 431 | + self.alpha_step = alpha_step |
| 432 | + self.shared_criterion = shared_criterion |
| 433 | + self.enable_blockwise_loss = enable_blockwise_loss |
| 434 | + self.auto_alpha_args = { |
| 435 | + "init_alpha": self.init_alpha, |
| 436 | + "alpha_min": self.alpha_min, |
| 437 | + "alpha_max": self.alpha_max, |
| 438 | + "alpha_step": self.alpha_step, |
| 439 | + "shared_criterion": self.shared_criterion, |
| 440 | + "enable_blockwise_loss": self.enable_blockwise_loss, |
| 441 | + } |
| 442 | + self._post_init() |
| 443 | + |
| 444 | + @classmethod |
| 445 | + def register_supported_configs(cls) -> List[OperatorConfig]: |
| 446 | + supported_configs = [] |
| 447 | + # TODO(Yi) |
| 448 | + linear_sq_config = SmoothQuantConfig() |
| 449 | + operators = [torch.nn.Linear] |
| 450 | + supported_configs.append(OperatorConfig(config=linear_sq_config, operators=operators)) |
| 451 | + cls.supported_configs = supported_configs |
| 452 | + |
| 453 | + @staticmethod |
| 454 | + def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: |
| 455 | + white_list = (torch.nn.Linear,) |
| 456 | + filter_result = [] |
| 457 | + for op_name, module in model.named_modules(): |
| 458 | + if isinstance(module, white_list): |
| 459 | + pair = (op_name, type(module).__name__) |
| 460 | + filter_result.append(pair) |
| 461 | + logger.debug(f"Get model info: {filter_result}") |
| 462 | + return filter_result |
| 463 | + |
| 464 | + |
| 465 | +# TODO(Yi) run `register_supported_configs` for all registered config. |
| 466 | +SmoothQuantConfig.register_supported_configs() |
| 467 | + |
| 468 | + |
| 469 | +def get_default_sq_config() -> SmoothQuantConfig: |
| 470 | + """Generate the default smoothquant config. |
| 471 | +
|
| 472 | + Returns: |
| 473 | + the default smoothquant config. |
| 474 | + """ |
| 475 | + return SmoothQuantConfig() |
| 476 | + |
| 477 | + |
285 | 478 | ######################## FP8 Config ############################### |
286 | 479 | if is_hpex_avaliable(): |
287 | 480 |
|
|
0 commit comments