diff --git a/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt b/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt index b5a69eaa938..622b61d1dc0 100644 --- a/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt +++ b/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt @@ -15,6 +15,7 @@ /neural-compressor/neural_compressor/strategy /neural-compressor/neural_compressor/training.py /neural-compressor/neural_compressor/utils +/neural_compressor/torch/algorithms/mx_quant /neural-compressor/neural_compressor/torch/algorithms/static_quant /neural-compressor/neural_compressor/torch/algorithms/smooth_quant /neural_compressor/torch/algorithms/pt2e_quant diff --git a/neural_compressor/torch/algorithms/mx_quant/__init__.py b/neural_compressor/torch/algorithms/mx_quant/__init__.py index e54bfa18052..d85854ffd8f 100644 --- a/neural_compressor/torch/algorithms/mx_quant/__init__.py +++ b/neural_compressor/torch/algorithms/mx_quant/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. # pylint:disable=import-error +"""MX quantization.""" diff --git a/neural_compressor/torch/algorithms/mx_quant/mx.py b/neural_compressor/torch/algorithms/mx_quant/mx.py index 76af3511e20..208f6ad2698 100644 --- a/neural_compressor/torch/algorithms/mx_quant/mx.py +++ b/neural_compressor/torch/algorithms/mx_quant/mx.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""MX quantization.""" from collections import OrderedDict @@ -31,6 +31,8 @@ class MXLinear(torch.nn.Linear): + """Linear for MX data type.""" + def __init__( self, in_features, @@ -39,6 +41,7 @@ def __init__( mx_specs=None, name=None, ): + """Initialization function.""" self.mx_none = mx_specs is None self.name = name @@ -46,6 +49,7 @@ def __init__( super().__init__(in_features, out_features, bias) def apply_mx_specs(self): + """Apply MX data type to weight.""" if self.mx_specs is not None: if self.mx_specs.out_dtype != "float32": self.weight.data = quantize_elemwise_op(self.weight.data, mx_specs=self.mx_specs) @@ -63,6 +67,7 @@ def apply_mx_specs(self): ) def forward(self, input): + """Forward function.""" if self.mx_none: return super().forward(input) @@ -93,6 +98,8 @@ def forward(self, input): class MXQuantizer(Quantizer): + """Quantizer of MX data type.""" + def __init__(self, quant_config: OrderedDict = {}): """Init a MXQuantizer object. diff --git a/neural_compressor/torch/algorithms/mx_quant/utils.py b/neural_compressor/torch/algorithms/mx_quant/utils.py index 2da59c6c700..210e0255cc4 100644 --- a/neural_compressor/torch/algorithms/mx_quant/utils.py +++ b/neural_compressor/torch/algorithms/mx_quant/utils.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""MX quantization utils.""" from enum import Enum, IntEnum @@ -28,6 +28,8 @@ class ElemFormat(Enum): + """Element format.""" + int8 = 1 int4 = 2 int2 = 3 @@ -44,6 +46,7 @@ class ElemFormat(Enum): @staticmethod def from_str(s): + """Get element format with str.""" assert s is not None, "String elem_format == None" s = s.lower() if hasattr(ElemFormat, s): @@ -53,6 +56,7 @@ def from_str(s): @staticmethod def is_bf(s): + """Whether the format is brain floating-point format.""" if isinstance(s, str): assert s is not None, "String elem_format == None" s = s.lower() @@ -65,6 +69,7 @@ def is_bf(s): @staticmethod def is_fp(s): + """Whether the format is floating-point format.""" if isinstance(s, str): assert s is not None, "String elem_format == None" s = s.lower() @@ -77,6 +82,7 @@ def is_fp(s): @staticmethod def is_int(s): + """Whether the format is integer format.""" if isinstance(s, str): assert s is not None, "String elem_format == None" s = s.lower() @@ -89,12 +95,15 @@ def is_int(s): class RoundingMode(IntEnum): + """Rounding mode.""" + nearest = 0 floor = 1 even = 2 @staticmethod def string_enums(): + """Rounding mode names.""" return [s.name for s in list(RoundingMode)] @@ -115,7 +124,9 @@ def _get_max_norm(ebits, mbits): def _get_format_params(fmt): - """Allowed formats: + """Get parameters of the format. + + Allowed formats: - intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation - floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf - bfloatX/bfX: 9 <= X <= 32 @@ -123,6 +134,9 @@ def _get_format_params(fmt): - fp6_e3m2/e2m3, no NaN/Inf - fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior + Args: + fmt (str od ElemFormat): format + Returns: ebits: exponent bits mbits: mantissa bits: includes sign and implicit bits @@ -198,17 +212,19 @@ def _safe_rshift(x, bits, exp): def _round_mantissa(A, bits, round, clamp=False): - """ - Rounds mantissa to nearest bits depending on the rounding method 'round' + """Rounds mantissa to nearest bits depending on the rounding method 'round'. + Args: - A {PyTorch tensor} -- Input tensor - round {str} -- Rounding method - "floor" rounds to the floor - "nearest" rounds to ceil or floor, whichever is nearest + A (torch.Tensor): input tensor + bits (int): bit number of mantissa + round (str): rounding method + "floor" rounds to the floor + "nearest" rounds to ceil or floor, whichever is nearest + clamp (bool, optional): Whether do clip. Defaults to False. + Returns: - A {PyTorch tensor} -- Tensor with mantissas rounded + torch.Tensor: tensor with mantissas rounded """ - if round == "dither": rand_A = torch.rand_like(A, requires_grad=False) A = torch.sign(A) * torch.floor(torch.abs(A) + rand_A) @@ -235,16 +251,18 @@ def _shared_exponents(A, method="max", axes=None, ebits=0): """Get shared exponents for the passed matrix A. Args: - A {PyTorch tensor} -- Input tensor - method {str} -- Exponent selection method. - "max" uses the max absolute value - "none" uses an exponent for each value (i.e., no sharing) - axes {list(int)} -- List of integers which specifies the axes across which - shared exponents are calculated. + A (torch.Tensor): Input tensor + method (str, optional): Exponent selection method. + "max" uses the max absolute value. + "none" uses an exponent for each value (i.e., no sharing) + Defaults to "max". + axes (list(int), optional): list of integers which specifies the axes across which + shared exponents are calculated. Defaults to None. + ebits (int, optional): bit number of the shared exponents. Defaults to 0. + Returns: - shared_exp {PyTorch tensor} -- Tensor of shared exponents + shared_exp (torch.Tensor): Tensor of shared exponents """ - if method == "max": if axes is None: shared_exp = torch.max(torch.abs(A)) @@ -346,21 +364,20 @@ def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes): def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round="nearest", saturate_normals=False, allow_denorm=True): - """Core function used for element-wise quantization - Arguments: - A {PyTorch tensor} -- A tensor to be quantized - bits {int} -- Number of mantissa bits. Includes - sign bit and implicit one for floats - exp_bits {int} -- Number of exponent bits, 0 for ints - max_norm {float} -- Largest representable normal number - round {str} -- Rounding mode: (floor, nearest, even) - saturate_normals {bool} -- If True, normal numbers (i.e., not NaN/Inf) - that exceed max norm are clamped. - Must be True for correct MX conversion. - allow_denorm {bool} -- If False, flush denorm numbers in the - elem_format to zero. + """Core function used for element-wise quantization. + + Args: + A (torch.Tensor): tensor to be quantized + bits (int): number of mantissa bits. Includes sign bit and implicit one for floats + exp_bits (int): number of exponent bits, 0 for ints + max_norm (float): largest representable normal number + round (str, optional): rounding mode: (floor, nearest, even). Defaults to "nearest". + saturate_normals (bool, optional): whether clip normal numbers that exceed max norm. + Must be True for correct MX conversion. Defaults to False. + allow_denorm (bool, optional): if False, flush denorm numbers in the elem_format to zero. Defaults to True. + Returns: - quantized tensor {PyTorch tensor} -- A tensor that has been quantized + torch.Tensor: tensor that has been quantized """ # Flush values < min_norm to zero if denorms are not allowed if not allow_denorm and exp_bits > 0: @@ -401,15 +418,20 @@ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round="nearest", satura def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_denorm=True): - """Quantize values to IEEE fpX format. - - The format defines NaN/Inf - and subnorm numbers in the same way as FP32 and FP16. - Arguments: - exp_bits {int} -- number of bits used to store exponent - mantissa_bits {int} -- number of bits used to store mantissa, not - including sign or implicit 1 - round {str} -- Rounding mode, (floor, nearest, even) + """Quantize values to IEEE fpX format.. + + The format defines NaN/Inf and subnorm numbers in the same way as FP32 and FP16. + + Args: + A (torch.Tensor): a tensor that needs to be quantized + exp_bits (int, optional): number of bits used to store exponent. Defaults to None. + mantissa_bits (int, optional): number of bits used to store mantissa. + Not including sign or implicit 1. Defaults to None. + round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest". + allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True. + + Returns: + torch.Tensor: tensor that has been quantized """ # Shortcut for no quantization if exp_bits is None or mantissa_bits is None: @@ -425,11 +447,17 @@ def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_de def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True): - """Quantize values to bfloatX format - Arguments: - bfloat {int} -- Total number of bits for bfloatX format, - Includes 1 sign, 8 exp bits, and variable - mantissa bits. Must be >= 9. + """Quantize values to bfloatX format. + + Args: + A (torch.Tensor): a tensor that needs to be quantized + bfloat (int): total number of bits for bfloatX format. + Includes 1 sign, 8 exp bits, and variable mantissa bits. Must be >= 9. + round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest". + allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True. + + Returns: + torch.Tensor: tensor that has been quantized """ # Shortcut for no quantization if bfloat == 0 or bfloat == 32: @@ -443,12 +471,14 @@ def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True): def quantize_elemwise_op(A, mx_specs): - """A function used for element-wise quantization with mx_specs - Arguments: - A {PyTorch tensor} -- a tensor that needs to be quantized - mx_specs {dictionary} -- dictionary to specify mx_specs + """A function used for element-wise quantization with mx_specs. + + Args: + A (torch.Tensor): a tensor that needs to be quantized + mx_specs (dict): dictionary to specify mx_specs + Returns: - quantized value {PyTorch tensor} -- a tensor that has been quantized + torch.Tensor: tensor that has been quantized """ if mx_specs is None: return A @@ -530,7 +560,7 @@ def _quantize_mx( def quantize_mx_op( - A, + A: torch.Tensor, elem_format: str, round: str, block_size: int, @@ -538,6 +568,7 @@ def quantize_mx_op( axes=None, expand_and_reshape=False, ): + """Quantize tensor to MX data type.""" if elem_format is None: return A elif type(elem_format) is str: