Skip to content

Commit 7a9d63e

Browse files
ulivnexinhe3
authored andcommitted
[FSW-13914] Fix gaudi specific code in common location (#224)
Move Gaudi specific code to internal scopes, so it won't be imported in FS/JS env Signed-off-by: Xin He <[email protected]>
1 parent c2f9b67 commit 7a9d63e

File tree

4 files changed

+23
-7
lines changed

4 files changed

+23
-7
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from enum import Enum
1717
from .common import ModuleConfig
1818
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator, INCAcceleratorType
19+
from neural_compressor.torch.utils import logger
20+
1921
cur_accelerator = auto_detect_accelerator()
2022

2123
descale_fcn = lambda x, scale: torch.mul(x, scale)
@@ -122,9 +124,8 @@ def get_fullscales_by_expbias_set(dtype, device, expbias_set):
122124

123125
def get_fp8_hw_alligned_scales_by_device(dtype, device):
124126
if device not in [GAUDI2, GAUDI3]:
125-
raise ValueError(
126-
f"{device} is not supported"
127-
)
127+
logger.warning("hw aligned scales not supported for device {}".format(device))
128+
return None # only Gaudis support hw aligned scales
128129
exp_bias_set = EXP_BIAS_SETS.get((device, dtype), None)
129130
return (
130131
None
@@ -157,6 +158,10 @@ def calc_maxabs_scale(xmaxabs, fullscale, backoff=1):
157158
return scale
158159

159160
def mmse_scale_multi(x, ref_scale, scales, lp_dtype, hp_dtype):
161+
if not scales:
162+
raise ValueError(
163+
"got empty scale list. it is possible that scale method isn't supported by current device."
164+
)
160165
# TODO: SW-176672 move weights to hpu before the scale calculations
161166
x = x.to("hpu")
162167
Nch = x.shape[-1]
@@ -180,6 +185,10 @@ def mmse_scale_multi(x, ref_scale, scales, lp_dtype, hp_dtype):
180185

181186

182187
def mmse_scale(x, scales, lp_dtype, hp_dtype):
188+
if not scales:
189+
raise ValueError(
190+
"got empty scale list. it is possible that scale method isn't supported by current device."
191+
)
183192
# TODO: SW-176672 move weights to hpu before the scale calculations
184193
x = x.to("hpu")
185194
opt_err = torch.ones(1, dtype=hp_dtype, device=x.device) * torch.inf

neural_compressor/torch/algorithms/fp8_quant/_core/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
from .quantize import quantize
2222
from .scale import scale_method_mapping, scaling_params
2323
from .common import is_runtime_scale_patching
24-
24+
from neural_compressor.torch.utils.auto_accelerator import is_any_gaudi_accelerator
2525
import os
2626
import re
27-
import habana_frameworks.torch.utils.experimental as htexp
2827

2928

3029
def update_mod_dict(config):
@@ -91,7 +90,8 @@ def quantize_dynamic_op(config, mod_type):
9190

9291

9392
def set_runtime_scale_patching_mode(scaling_method_name):
94-
if is_runtime_scale_patching() and hasattr(htexp, "_set_scale_attributes"):
93+
import habana_frameworks.torch.utils.experimental as htexp # importing in local scope since it is gaudi specific
94+
if is_runtime_scale_patching():
9595
assert (
9696
scaling_method_name in runtime_scale_patching_supported_methods_list
9797
), f"Scaling method \"{scaling_method_name}\" is not supported for runtime scale patching (graph recompile reduction). Cannot set scaling attributes."
@@ -125,5 +125,7 @@ def prepare_model(model):
125125
scaling_method_name = scale_method_mapping[(config.cfg["scale_method"], config.cfg["observer"])]
126126
scaling_params[scaling_method_name].update(config.cfg["scale_params"])
127127
config.cfg["scale_params"] = scaling_params[scaling_method_name]
128-
set_runtime_scale_patching_mode(scaling_method_name)
128+
129+
if is_any_gaudi_accelerator(config.cfg["device_type"]):
130+
set_runtime_scale_patching_mode(scaling_method_name)
129131
return quantize(model, mod_list)

neural_compressor/torch/utils/auto_accelerator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,3 +462,7 @@ def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator:
462462
# INC_TARGET_DEVICE = "CPU" python ...
463463
# or
464464
# CUDA_VISIBLE_DEVICES="" python ...
465+
466+
467+
def is_any_gaudi_accelerator(acc_type: INCAcceleratorType):
468+
return acc_type.value > INCAcceleratorType.GAUDI_MIN.value
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{}

0 commit comments

Comments
 (0)