Skip to content

Commit fd69fc4

Browse files
committed
Move and rename GranularityType -> Granularity
Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI ghstack-source-id: e9c8552 Pull Request resolved: #1038
1 parent 85c7e9a commit fd69fc4

File tree

3 files changed

+88
-86
lines changed

3 files changed

+88
-86
lines changed

torchao/prototype/awq/core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,19 @@
99
from torchao.dtypes import to_affine_quantized_intx
1010
from torchao.quantization.quant_primitives import (
1111
MappingType,
12+
Granularity,
1213
ZeroPointDomain,
1314
)
1415
from torchao.quantization.observer import (
15-
AffineQuantizedObserverBase, GranularityType
16+
AffineQuantizedObserverBase,
1617
)
1718

1819

1920
class AWQObserver(AffineQuantizedObserverBase):
2021
def __init__(self,
2122
weight: torch.Tensor,
2223
bias: torch.Tensor,
23-
quantization_granularity: GranularityType,
24+
quantization_granularity: Granularity,
2425
mapping_type: MappingType,
2526
target_dtype: torch.dtype,
2627
n_validation_examples: int,
@@ -40,7 +41,7 @@ def __init__(self,
4041
Args:
4142
weight: The weight tensor to be observed.
4243
bias: The bias tensor to be observed.
43-
quantization_granularity: Granularity type which specifies how many weights share the same scale/zero point
44+
quantization_granularity: Granularity which specifies how many weights share the same scale/zero point
4445
input_dtype: The data type of the input tensor.
4546
mapping_type: Always set to asymmetric
4647
target_dtype: The target data type of the quantized tensor
@@ -153,4 +154,4 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver):
153154
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
154155
observed_linear.weight = float_linear.weight
155156
observed_linear.bias = float_linear.bias
156-
return observed_linear
157+
return observed_linear

torchao/quantization/observer.py

Lines changed: 14 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3,87 +3,19 @@
33
_get_reduction_params,
44
choose_qparams_affine_with_min_max,
55
MappingType,
6+
Granularity,
67
ZeroPointDomain,
78
)
89
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
910

1011
from abc import ABCMeta, abstractmethod
11-
from dataclasses import dataclass
1212
from typing import Tuple, Optional, Any
1313
from functools import partial
1414
import logging
1515

1616
logger = logging.getLogger(__name__)
1717

1818

19-
@dataclass(frozen=True)
20-
class GranularityType:
21-
"""
22-
Base class for representing the granularity of quantization.
23-
24-
This class serves as a parent for specific granularity types used in
25-
quantization operations, such as per-tensor or per-axis quantization.
26-
"""
27-
pass
28-
29-
@dataclass(frozen=True)
30-
class PerTensor(GranularityType):
31-
"""
32-
Represents per-tensor granularity in quantization.
33-
34-
This granularity type calcualtes the quantization parameters
35-
based off the entire tensor.
36-
"""
37-
pass
38-
39-
@dataclass(frozen=True)
40-
class PerAxis(GranularityType):
41-
"""
42-
Represents per-axis granularity in quantization.
43-
44-
This granularity type calcualtes different quantization parameters
45-
along a specified axis of the tensor.
46-
47-
For example if the input tensor is shape [8, 16] and axis=0, then
48-
the quantization parameters are calculated for each row of the tensor.
49-
Giving a total of 8 quantization parameters.
50-
51-
52-
Attributes:
53-
axis (int): The axis along which reduction is performed.
54-
"""
55-
axis: int
56-
57-
@dataclass(frozen=True)
58-
59-
class PerGroup(GranularityType):
60-
"""
61-
Represents per-channel group granularity in quantization.
62-
63-
This granularity type calcualtes different quantization parameters
64-
for each group of <group_size> elements.
65-
66-
For example if the input tensor is shape [8, 16], and the group size is 4, then
67-
the input tensor is reshaped to [64, 4]
68-
quantization parameters are calculated for each group of 4 elements,
69-
giving a total of 64 quantization parameters.
70-
71-
Attributes:
72-
group_size (int): The size of each quantization group
73-
74-
"""
75-
group_size: int
76-
77-
class PerRow(GranularityType):
78-
"""
79-
Represents row-wise granularity in quantization.
80-
81-
This is a special case of per-axis quantization and is unique to Float8 matmuls
82-
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
83-
is quantized with a block_size of (1, weight.shape[1]).
84-
"""
85-
pass
86-
8719
# borrowed from torch.ao.quantization.observer
8820
class _PartialWrapper:
8921
def __init__(self, p):
@@ -120,23 +52,23 @@ def _with_args(cls_or_self, *args, **kwargs):
12052

12153

12254
def get_block_size(
123-
input_shape: Tuple[int, ...], granularity_type: GranularityType
55+
input_shape: Tuple[int, ...], granularity: Granularity
12456
) -> Tuple[int, ...]:
12557
"""Get the block size based on the input shape and granularity type.
12658
12759
Args:
12860
input_shape: The input tensor shape possibly more than 2 dimensions
129-
granularity_type: The granularity type of the quantization
61+
granularity: The granularity type of the quantization
13062
"""
131-
if isinstance(granularity_type, PerTensor):
63+
if isinstance(granularity, PerTensor):
13264
return input_shape
133-
elif isinstance(granularity_type, PerAxis):
65+
elif isinstance(granularity, PerAxis):
13466
block_size = list(input_shape)
135-
block_size[granularity_type.axis] = 1
67+
block_size[granularity.axis] = 1
13668
return tuple(block_size)
137-
elif isinstance(granularity_type, PerRow):
69+
elif isinstance(granularity, PerRow):
13870
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
139-
raise ValueError(f"Unsupported GranularityType: {granularity_type}")
71+
raise ValueError(f"Unsupported Granularity: {granularity}")
14072

14173

14274
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
@@ -146,7 +78,7 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
14678
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
14779
14880
Args:
149-
`granularity_type` and `block_size`: The granularity of the quantization,
81+
`granularity` and `block_size`: The granularity of the quantization,
15082
must specify at least one, if both are specified `block_size` takes precedence
15183
Current supported granularity type are `PerTensor` and `PerAxis`
15284
other args: please see `:class:torchao.dtypes.AffineQuantizedTensor`
@@ -158,7 +90,7 @@ def __init__(
15890
self,
15991
mapping_type: MappingType,
16092
target_dtype: torch.dtype,
161-
granularity_type: GranularityType,
93+
granularity: Granularity,
16294
quant_min: Optional[int] = None,
16395
quant_max: Optional[int] = None,
16496
eps: Optional[float] = None,
@@ -168,11 +100,11 @@ def __init__(
168100
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
169101
):
170102
super().__init__()
171-
assert granularity_type is not None, "granularity_type is None"
103+
assert granularity is not None, "granularity is None"
172104

173105
self.mapping_type = mapping_type
174106
self.target_dtype = target_dtype
175-
self.granularity_type = granularity_type
107+
self.granularity = granularity
176108
self.quant_min = quant_min
177109
self.quant_max = quant_max
178110
self.eps = eps
@@ -202,8 +134,8 @@ def forward(self, input: torch.Tensor):
202134
return input
203135

204136
input_detached = input.detach()
205-
assert self.granularity_type is not None, "granularity_type is None"
206-
block_size = get_block_size(input_detached.shape, self.granularity_type)
137+
assert self.granularity is not None, "granularity is None"
138+
block_size = get_block_size(input_detached.shape, self.granularity)
207139

208140
shape_for_reduction, reduction_dims = _get_reduction_params(
209141
block_size, input_detached.size()

torchao/quantization/quant_primitives.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from dataclasses import dataclass
78
from enum import Enum, auto
89
from typing import List, Optional, Tuple, Dict, Callable, Union
910
import torch, math
@@ -64,6 +65,74 @@ class ZeroPointDomain(Enum):
6465
INT = auto()
6566
FLOAT = auto()
6667

68+
@dataclass(frozen=True)
69+
class Granularity:
70+
"""
71+
Base class for representing the granularity of quantization.
72+
73+
This class serves as a parent for specific granularity types used in
74+
quantization operations, such as per-tensor or per-axis quantization.
75+
"""
76+
pass
77+
78+
@dataclass(frozen=True)
79+
class PerTensor(Granularity):
80+
"""
81+
Represents per-tensor granularity in quantization.
82+
83+
This granularity type calcualtes the quantization parameters
84+
based off the entire tensor.
85+
"""
86+
pass
87+
88+
@dataclass(frozen=True)
89+
class PerAxis(Granularity):
90+
"""
91+
Represents per-axis granularity in quantization.
92+
93+
This granularity type calcualtes different quantization parameters
94+
along a specified axis of the tensor.
95+
96+
For example if the input tensor is shape [8, 16] and axis=0, then
97+
the quantization parameters are calculated for each row of the tensor.
98+
Giving a total of 8 quantization parameters.
99+
100+
101+
Attributes:
102+
axis (int): The axis along which reduction is performed.
103+
"""
104+
axis: int
105+
106+
@dataclass(frozen=True)
107+
108+
class PerGroup(Granularity):
109+
"""
110+
Represents per-channel group granularity in quantization.
111+
112+
This granularity type calcualtes different quantization parameters
113+
for each group of <group_size> elements.
114+
115+
For example if the input tensor is shape [8, 16], and the group size is 4, then
116+
the input tensor is reshaped to [64, 4]
117+
quantization parameters are calculated for each group of 4 elements,
118+
giving a total of 64 quantization parameters.
119+
120+
Attributes:
121+
group_size (int): The size of each quantization group
122+
123+
"""
124+
group_size: int
125+
126+
class PerRow(Granularity):
127+
"""
128+
Represents row-wise granularity in quantization.
129+
130+
This is a special case of per-axis quantization and is unique to Float8 matmuls
131+
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
132+
is quantized with a block_size of (1, weight.shape[1]).
133+
"""
134+
pass
135+
67136
if TORCH_VERSION_AT_LEAST_2_5:
68137
torch.serialization.add_safe_globals([MappingType, ZeroPointDomain])
69138

0 commit comments

Comments
 (0)