33 _get_reduction_params ,
44 choose_qparams_affine_with_min_max ,
55 MappingType ,
6+ Granularity ,
7+ PerAxis ,
8+ PerRow ,
9+ PerTensor ,
610 ZeroPointDomain ,
711)
812from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
913
1014from abc import ABCMeta , abstractmethod
11- from dataclasses import dataclass
1215from typing import Tuple , Optional , Any
1316from functools import partial
1417import logging
1518
1619logger = logging .getLogger (__name__ )
1720
1821
19- @dataclass (frozen = True )
20- class GranularityType :
21- """
22- Base class for representing the granularity of quantization.
23-
24- This class serves as a parent for specific granularity types used in
25- quantization operations, such as per-tensor or per-axis quantization.
26- """
27- pass
28-
29- @dataclass (frozen = True )
30- class PerTensor (GranularityType ):
31- """
32- Represents per-tensor granularity in quantization.
33-
34- This granularity type calcualtes the quantization parameters
35- based off the entire tensor.
36- """
37- pass
38-
39- @dataclass (frozen = True )
40- class PerAxis (GranularityType ):
41- """
42- Represents per-axis granularity in quantization.
43-
44- This granularity type calcualtes different quantization parameters
45- along a specified axis of the tensor.
46-
47- For example if the input tensor is shape [8, 16] and axis=0, then
48- the quantization parameters are calculated for each row of the tensor.
49- Giving a total of 8 quantization parameters.
50-
51-
52- Attributes:
53- axis (int): The axis along which reduction is performed.
54- """
55- axis : int
56-
57- @dataclass (frozen = True )
58-
59- class PerGroup (GranularityType ):
60- """
61- Represents per-channel group granularity in quantization.
62-
63- This granularity type calcualtes different quantization parameters
64- for each group of <group_size> elements.
65-
66- For example if the input tensor is shape [8, 16], and the group size is 4, then
67- the input tensor is reshaped to [64, 4]
68- quantization parameters are calculated for each group of 4 elements,
69- giving a total of 64 quantization parameters.
70-
71- Attributes:
72- group_size (int): The size of each quantization group
73-
74- """
75- group_size : int
76-
77- class PerRow (GranularityType ):
78- """
79- Represents row-wise granularity in quantization.
80-
81- This is a special case of per-axis quantization and is unique to Float8 matmuls
82- where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
83- is quantized with a block_size of (1, weight.shape[1]).
84- """
85- pass
86-
8722# borrowed from torch.ao.quantization.observer
8823class _PartialWrapper :
8924 def __init__ (self , p ):
@@ -120,23 +55,23 @@ def _with_args(cls_or_self, *args, **kwargs):
12055
12156
12257def get_block_size (
123- input_shape : Tuple [int , ...], granularity_type : GranularityType
58+ input_shape : Tuple [int , ...], granularity : Granularity
12459) -> Tuple [int , ...]:
12560 """Get the block size based on the input shape and granularity type.
12661
12762 Args:
12863 input_shape: The input tensor shape possibly more than 2 dimensions
129- granularity_type : The granularity type of the quantization
64+ granularity : The granularity type of the quantization
13065 """
131- if isinstance (granularity_type , PerTensor ):
66+ if isinstance (granularity , PerTensor ):
13267 return input_shape
133- elif isinstance (granularity_type , PerAxis ):
68+ elif isinstance (granularity , PerAxis ):
13469 block_size = list (input_shape )
135- block_size [granularity_type .axis ] = 1
70+ block_size [granularity .axis ] = 1
13671 return tuple (block_size )
137- elif isinstance (granularity_type , PerRow ):
72+ elif isinstance (granularity , PerRow ):
13873 return (1 ,) * (len (input_shape ) - 1 ) + (input_shape [- 1 ],)
139- raise ValueError (f"Unsupported GranularityType : { granularity_type } " )
74+ raise ValueError (f"Unsupported Granularity : { granularity } " )
14075
14176
14277ABC : Any = ABCMeta ("ABC" , (object ,), {}) # compatible with Python 2 *and* 3:
@@ -146,7 +81,7 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
14681 """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
14782
14883 Args:
149- `granularity_type ` and `block_size`: The granularity of the quantization,
84+ `granularity ` and `block_size`: The granularity of the quantization,
15085 must specify at least one, if both are specified `block_size` takes precedence
15186 Current supported granularity type are `PerTensor` and `PerAxis`
15287 other args: please see `:class:torchao.dtypes.AffineQuantizedTensor`
@@ -158,7 +93,7 @@ def __init__(
15893 self ,
15994 mapping_type : MappingType ,
16095 target_dtype : torch .dtype ,
161- granularity_type : GranularityType ,
96+ granularity : Granularity ,
16297 quant_min : Optional [int ] = None ,
16398 quant_max : Optional [int ] = None ,
16499 eps : Optional [float ] = None ,
@@ -168,11 +103,11 @@ def __init__(
168103 zero_point_domain : Optional [ZeroPointDomain ] = ZeroPointDomain .INT ,
169104 ):
170105 super ().__init__ ()
171- assert granularity_type is not None , "granularity_type is None"
106+ assert granularity is not None , "granularity is None"
172107
173108 self .mapping_type = mapping_type
174109 self .target_dtype = target_dtype
175- self .granularity_type = granularity_type
110+ self .granularity = granularity
176111 self .quant_min = quant_min
177112 self .quant_max = quant_max
178113 self .eps = eps
@@ -202,8 +137,8 @@ def forward(self, input: torch.Tensor):
202137 return input
203138
204139 input_detached = input .detach ()
205- assert self .granularity_type is not None , "granularity_type is None"
206- block_size = get_block_size (input_detached .shape , self .granularity_type )
140+ assert self .granularity is not None , "granularity is None"
141+ block_size = get_block_size (input_detached .shape , self .granularity )
207142
208143 shape_for_reduction , reduction_dims = _get_reduction_params (
209144 block_size , input_detached .size ()
0 commit comments