Skip to content

Commit 2a2817c

Browse files
committed
Move and rename GranularityType -> Granularity
Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI ghstack-source-id: b12bce5 Pull Request resolved: #1038
1 parent 85c7e9a commit 2a2817c

File tree

14 files changed

+123
-110
lines changed

14 files changed

+123
-110
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,15 @@
2626
float8_weight_only,
2727
quantize_,
2828
)
29-
from torchao.quantization.observer import PerRow, PerTensor
3029
from torchao.quantization.quant_api import (
3130
float8_static_activation_float8_weight,
3231
)
33-
from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine
32+
from torchao.quantization.quant_primitives import (
33+
MappingType,
34+
PerRow,
35+
PerTensor,
36+
choose_qparams_affine,
37+
)
3438

3539
random.seed(0)
3640
torch.manual_seed(0)

test/quantization/test_observer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111

1212
from torchao.quantization.observer import (
1313
AffineQuantizedMinMaxObserver,
14-
PerAxis,
15-
PerTensor,
1614
)
1715
from torchao.quantization.quant_api import (
1816
insert_observers_,
1917
)
2018
from torchao.quantization.quant_primitives import (
2119
MappingType,
20+
PerAxis,
21+
PerTensor,
2222
)
2323

2424

@@ -42,7 +42,7 @@ def test_min_max_per_tensor_affine(self):
4242
obs = AffineQuantizedMinMaxObserver(
4343
MappingType.ASYMMETRIC,
4444
torch.uint8,
45-
granularity_type=PerTensor(),
45+
granularity=PerTensor(),
4646
eps=torch.finfo(torch.float32).eps,
4747
scale_dtype=torch.float,
4848
zero_point_dtype=torch.int,
@@ -54,7 +54,7 @@ def test_min_max_per_channel_affine(self):
5454
obs = AffineQuantizedMinMaxObserver(
5555
MappingType.ASYMMETRIC,
5656
torch.uint8,
57-
granularity_type=PerAxis(axis=0),
57+
granularity=PerAxis(axis=0),
5858
eps=torch.finfo(torch.float32).eps,
5959
scale_dtype=torch.float,
6060
zero_point_dtype=torch.int,
@@ -68,7 +68,7 @@ def test_block_size_calc_success(self):
6868
obs = AffineQuantizedMinMaxObserver(
6969
MappingType.SYMMETRIC,
7070
torch.float8_e4m3fn,
71-
granularity_type=PerTensor(),
71+
granularity=PerTensor(),
7272
eps=torch.finfo(torch.float32).eps,
7373
scale_dtype=torch.float,
7474
zero_point_dtype=torch.int,
@@ -87,7 +87,7 @@ def test_block_size_calc_success(self):
8787
obs = AffineQuantizedMinMaxObserver(
8888
MappingType.SYMMETRIC,
8989
torch.float8_e4m3fn,
90-
granularity_type=PerAxis(1),
90+
granularity=PerAxis(1),
9191
eps=torch.finfo(torch.float32).eps,
9292
scale_dtype=torch.float,
9393
zero_point_dtype=torch.int,
@@ -102,7 +102,7 @@ def test_block_size_row_errors(self):
102102
obs = AffineQuantizedMinMaxObserver(
103103
MappingType.SYMMETRIC,
104104
torch.float8_e4m3fn,
105-
granularity_type=PerAxis(0),
105+
granularity=PerAxis(0),
106106
eps=torch.finfo(torch.float32).eps,
107107
scale_dtype=torch.float,
108108
zero_point_dtype=torch.int,
@@ -121,7 +121,7 @@ def test_block_size_row_errors(self):
121121
obs = AffineQuantizedMinMaxObserver(
122122
MappingType.SYMMETRIC,
123123
torch.float8_e4m3fn,
124-
granularity_type=PerAxis(1),
124+
granularity=PerAxis(1),
125125
eps=torch.finfo(torch.float32).eps,
126126
scale_dtype=torch.float,
127127
zero_point_dtype=torch.int,
@@ -149,7 +149,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
149149
input_observer = AffineQuantizedMinMaxObserver(
150150
MappingType.SYMMETRIC,
151151
torch.float8_e4m3fn,
152-
granularity_type=PerTensor(),
152+
granularity=PerTensor(),
153153
eps=torch.finfo(torch.float32).eps,
154154
scale_dtype=torch.float,
155155
zero_point_dtype=torch.int,
@@ -159,7 +159,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
159159
weight_observer = AffineQuantizedMinMaxObserver(
160160
MappingType.SYMMETRIC,
161161
torch.float8_e4m3fn,
162-
granularity_type=PerTensor(),
162+
granularity=PerTensor(),
163163
eps=torch.finfo(torch.float32).eps,
164164
scale_dtype=torch.float,
165165
zero_point_dtype=torch.int,

torchao/_models/llama/eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
float8_dynamic_activation_float8_weight,
2525
float8_static_activation_float8_weight,
2626
)
27-
from torchao.quantization.observer import PerRow, PerTensor
2827
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
2928
from torchao._models.llama.model import prepare_inputs_for_model
29+
from torchao.quantization.quant_primitives import PerRow, PerTensor
3030

3131
from tokenizer import get_tokenizer
3232
import time
@@ -255,4 +255,4 @@ def run_evaluation(
255255
args.calibration_limit,
256256
args.calibration_seq_length,
257257
args.pad_calibration_inputs,
258-
)
258+
)

torchao/_models/llama/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def main(
216216
float8_weight_only,
217217
float8_dynamic_activation_float8_weight,
218218
)
219-
from torchao.quantization.observer import PerTensor, PerRow
219+
from torchao.quantization.quant_primitives import PerTensor, PerRow
220220
if "int8wo" in quantization:
221221
quantize_(model, int8_weight_only())
222222
if "int8dq" in quantization:

torchao/prototype/awq/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
from torchao.quantization.quant_primitives import (
55
MappingType,
6+
PerGroup,
67
ZeroPointDomain,
78
_DTYPE_TO_QVALUE_BOUNDS,
89
)
910
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
10-
from torchao.quantization.observer import PerGroup
1111
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
1212
from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType
1313
from torchao.dtypes import(

torchao/prototype/awq/core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,19 @@
99
from torchao.dtypes import to_affine_quantized_intx
1010
from torchao.quantization.quant_primitives import (
1111
MappingType,
12+
Granularity,
1213
ZeroPointDomain,
1314
)
1415
from torchao.quantization.observer import (
15-
AffineQuantizedObserverBase, GranularityType
16+
AffineQuantizedObserverBase,
1617
)
1718

1819

1920
class AWQObserver(AffineQuantizedObserverBase):
2021
def __init__(self,
2122
weight: torch.Tensor,
2223
bias: torch.Tensor,
23-
quantization_granularity: GranularityType,
24+
quantization_granularity: Granularity,
2425
mapping_type: MappingType,
2526
target_dtype: torch.dtype,
2627
n_validation_examples: int,
@@ -40,7 +41,7 @@ def __init__(self,
4041
Args:
4142
weight: The weight tensor to be observed.
4243
bias: The bias tensor to be observed.
43-
quantization_granularity: Granularity type which specifies how many weights share the same scale/zero point
44+
quantization_granularity: Granularity which specifies how many weights share the same scale/zero point
4445
input_dtype: The data type of the input tensor.
4546
mapping_type: Always set to asymmetric
4647
target_dtype: The target data type of the quantized tensor
@@ -153,4 +154,4 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver):
153154
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
154155
observed_linear.weight = float_linear.weight
155156
observed_linear.bias = float_linear.bias
156-
return observed_linear
157+
return observed_linear

torchao/quantization/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ change_linear_weights_to_int8_dqtensors(model)
137137
```python
138138
# for torch 2.4+
139139
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
140-
from torchao.quantization.observer import PerTensor
140+
from torchao.quantization.quant_api import PerTensor
141141
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
142142
```
143143

torchao/quantization/autoquant.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
1414
from torch.utils._python_dispatch import return_and_correct_aliasing
1515
from .quant_primitives import (
16+
PerAxis,
17+
PerRow,
18+
PerTensor,
1619
safe_int_mm,
1720
)
1821
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
1922
from torchao.quantization.utils import quantize_activation_per_token_absmax
20-
from torchao.quantization.observer import PerAxis, PerTensor, PerRow
2123
from torchao.float8.inference import Float8MMConfig
2224

2325
import torch.nn.functional as F

torchao/quantization/observer.py

Lines changed: 17 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3,87 +3,22 @@
33
_get_reduction_params,
44
choose_qparams_affine_with_min_max,
55
MappingType,
6+
Granularity,
7+
PerAxis,
8+
PerRow,
9+
PerTensor,
610
ZeroPointDomain,
711
)
812
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
913

1014
from abc import ABCMeta, abstractmethod
11-
from dataclasses import dataclass
1215
from typing import Tuple, Optional, Any
1316
from functools import partial
1417
import logging
1518

1619
logger = logging.getLogger(__name__)
1720

1821

19-
@dataclass(frozen=True)
20-
class GranularityType:
21-
"""
22-
Base class for representing the granularity of quantization.
23-
24-
This class serves as a parent for specific granularity types used in
25-
quantization operations, such as per-tensor or per-axis quantization.
26-
"""
27-
pass
28-
29-
@dataclass(frozen=True)
30-
class PerTensor(GranularityType):
31-
"""
32-
Represents per-tensor granularity in quantization.
33-
34-
This granularity type calcualtes the quantization parameters
35-
based off the entire tensor.
36-
"""
37-
pass
38-
39-
@dataclass(frozen=True)
40-
class PerAxis(GranularityType):
41-
"""
42-
Represents per-axis granularity in quantization.
43-
44-
This granularity type calcualtes different quantization parameters
45-
along a specified axis of the tensor.
46-
47-
For example if the input tensor is shape [8, 16] and axis=0, then
48-
the quantization parameters are calculated for each row of the tensor.
49-
Giving a total of 8 quantization parameters.
50-
51-
52-
Attributes:
53-
axis (int): The axis along which reduction is performed.
54-
"""
55-
axis: int
56-
57-
@dataclass(frozen=True)
58-
59-
class PerGroup(GranularityType):
60-
"""
61-
Represents per-channel group granularity in quantization.
62-
63-
This granularity type calcualtes different quantization parameters
64-
for each group of <group_size> elements.
65-
66-
For example if the input tensor is shape [8, 16], and the group size is 4, then
67-
the input tensor is reshaped to [64, 4]
68-
quantization parameters are calculated for each group of 4 elements,
69-
giving a total of 64 quantization parameters.
70-
71-
Attributes:
72-
group_size (int): The size of each quantization group
73-
74-
"""
75-
group_size: int
76-
77-
class PerRow(GranularityType):
78-
"""
79-
Represents row-wise granularity in quantization.
80-
81-
This is a special case of per-axis quantization and is unique to Float8 matmuls
82-
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
83-
is quantized with a block_size of (1, weight.shape[1]).
84-
"""
85-
pass
86-
8722
# borrowed from torch.ao.quantization.observer
8823
class _PartialWrapper:
8924
def __init__(self, p):
@@ -120,23 +55,23 @@ def _with_args(cls_or_self, *args, **kwargs):
12055

12156

12257
def get_block_size(
123-
input_shape: Tuple[int, ...], granularity_type: GranularityType
58+
input_shape: Tuple[int, ...], granularity: Granularity
12459
) -> Tuple[int, ...]:
12560
"""Get the block size based on the input shape and granularity type.
12661
12762
Args:
12863
input_shape: The input tensor shape possibly more than 2 dimensions
129-
granularity_type: The granularity type of the quantization
64+
granularity: The granularity type of the quantization
13065
"""
131-
if isinstance(granularity_type, PerTensor):
66+
if isinstance(granularity, PerTensor):
13267
return input_shape
133-
elif isinstance(granularity_type, PerAxis):
68+
elif isinstance(granularity, PerAxis):
13469
block_size = list(input_shape)
135-
block_size[granularity_type.axis] = 1
70+
block_size[granularity.axis] = 1
13671
return tuple(block_size)
137-
elif isinstance(granularity_type, PerRow):
72+
elif isinstance(granularity, PerRow):
13873
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
139-
raise ValueError(f"Unsupported GranularityType: {granularity_type}")
74+
raise ValueError(f"Unsupported Granularity: {granularity}")
14075

14176

14277
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
@@ -146,7 +81,7 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
14681
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
14782
14883
Args:
149-
`granularity_type` and `block_size`: The granularity of the quantization,
84+
`granularity` and `block_size`: The granularity of the quantization,
15085
must specify at least one, if both are specified `block_size` takes precedence
15186
Current supported granularity type are `PerTensor` and `PerAxis`
15287
other args: please see `:class:torchao.dtypes.AffineQuantizedTensor`
@@ -158,7 +93,7 @@ def __init__(
15893
self,
15994
mapping_type: MappingType,
16095
target_dtype: torch.dtype,
161-
granularity_type: GranularityType,
96+
granularity: Granularity,
16297
quant_min: Optional[int] = None,
16398
quant_max: Optional[int] = None,
16499
eps: Optional[float] = None,
@@ -168,11 +103,11 @@ def __init__(
168103
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
169104
):
170105
super().__init__()
171-
assert granularity_type is not None, "granularity_type is None"
106+
assert granularity is not None, "granularity is None"
172107

173108
self.mapping_type = mapping_type
174109
self.target_dtype = target_dtype
175-
self.granularity_type = granularity_type
110+
self.granularity = granularity
176111
self.quant_min = quant_min
177112
self.quant_max = quant_max
178113
self.eps = eps
@@ -202,8 +137,8 @@ def forward(self, input: torch.Tensor):
202137
return input
203138

204139
input_detached = input.detach()
205-
assert self.granularity_type is not None, "granularity_type is None"
206-
block_size = get_block_size(input_detached.shape, self.granularity_type)
140+
assert self.granularity is not None, "granularity is None"
141+
block_size = get_block_size(input_detached.shape, self.granularity)
207142

208143
shape_for_reduction, reduction_dims = _get_reduction_params(
209144
block_size, input_detached.size()

torchao/quantization/quant_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454

5555
from .quant_primitives import (
5656
MappingType,
57+
PerRow,
58+
PerTensor,
5759
ZeroPointDomain,
5860
)
5961
from .weight_only import WeightOnlyInt8QuantLinear
@@ -71,7 +73,7 @@
7173
)
7274
from torchao.float8.inference import Float8MMConfig
7375

74-
from torchao.quantization.observer import PerTensor, PerRow, get_block_size
76+
from torchao.quantization.observer import get_block_size
7577

7678
logger = logging.getLogger(__name__)
7779

0 commit comments

Comments
 (0)