Skip to content

Commit cf3234c

Browse files
committed
Add decorator for custom op and inductor decomp registration
Summary: This PR adds a decorator to register custom op and also an inductor dcomposition. The goal is for torch.export path to be able to see high level ops like quantize_affine instead of breaking down the op, this is because some backends like xnnpack wants to work with these higher level ops. Test Plan: regression tests: `python test/quantization/test_quant_api.py` `python test/integration/test_integration.py` also need to check performance with `python tutorials/quantize_vit/run_vit_b_quant.py` Reviewers: Subscribers: Tasks: Tags:
1 parent e6460c2 commit cf3234c

File tree

7 files changed

+92
-94
lines changed

7 files changed

+92
-94
lines changed

test/integration/test_integration.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
choose_qparams_affine,
3838
quantize_affine,
3939
dequantize_affine,
40-
MappingType,
4140
)
4241
from torchao.quantization.utils import (
4342
dequantize_per_channel,
@@ -1436,7 +1435,7 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype):
14361435
api(model)
14371436
size2 = torchao.utils.get_model_size_in_bytes(model)
14381437
self.assertTrue(size2 < size)
1439-
1438+
14401439

14411440

14421441

test/quantization/test_quant_api.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@
2222
from torchao.dtypes import (
2323
AffineQuantizedTensor,
2424
)
25-
from torchao.quantization.quant_primitives import (
26-
MappingType,
27-
ZeroPointDomain,
28-
)
2925
from torchao.quantization.subclass import (
3026
LinearActQuantizedTensor,
3127
Int8WeightOnlyQuantizedLinearWeight,

torchao/dtypes/aqt.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
choose_qparams_affine,
77
quantize_affine,
88
dequantize_affine,
9-
ZeroPointDomain,
10-
MappingType,
119
int_scaled_matmul,
1210
)
1311
from torchao.quantization.utils import (
@@ -98,12 +96,12 @@ class AffineQuantizedTensor(torch.Tensor):
9896
shape (torch.Size): the shape for the Tensor
9997
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
10098
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
101-
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
99+
zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" or "float"
102100
if zero_point is in integer domain, zero point is added to the quantized integer value during
103101
quantization
104102
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
105103
value during quantization
106-
default is ZeroPointDomain.INT
104+
default is "int"
107105
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object
108106
dtype: dtype for external representation of the tensor, e.g. torch.float32
109107
"""
@@ -116,7 +114,7 @@ def __new__(
116114
shape: torch.Size,
117115
quant_min: Optional[int] = None,
118116
quant_max: Optional[int] = None,
119-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
117+
zero_point_domain: str = "int",
120118
dtype=None,
121119
strides=None,
122120
):
@@ -138,7 +136,7 @@ def __init__(
138136
shape: torch.Size,
139137
quant_min: Optional[int] = None,
140138
quant_max: Optional[int] = None,
141-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
139+
zero_point_domain: str = "int",
142140
dtype=None,
143141
strides=None,
144142
):
@@ -184,7 +182,7 @@ def __tensor_unflatten__(
184182
def from_float(
185183
cls,
186184
input_float: torch.Tensor,
187-
mapping_type: MappingType,
185+
mapping_type: str,
188186
block_size: Tuple[int, ...],
189187
target_dtype: torch.dtype,
190188
quant_min: Optional[int] = None,
@@ -193,7 +191,7 @@ def from_float(
193191
scale_dtype: Optional[torch.dtype] = None,
194192
zero_point_dtype: Optional[torch.dtype] = None,
195193
preserve_zero: bool = True,
196-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
194+
zero_point_domain: str = "int",
197195
extended_layout: str = "plain",
198196
# TODO: this is only for "tensor_core_tiled", need to figure out
199197
# the proper API for this arg
@@ -520,7 +518,7 @@ def get_plain(self):
520518
target_dtype = torch.int32
521519
quant_min = 0
522520
quant_max = 15
523-
zero_point_domain = ZeroPointDomain.FLOAT
521+
zero_point_domain = "int"
524522
assert len(block_size) == 2 and block_size[0] == 1
525523
groupsize = block_size[-1]
526524
dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero)
@@ -597,7 +595,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
597595
weight_is_uint4 and
598596
weight_qtensor.dtype == torch.bfloat16 and
599597
len(weight_qtensor.shape) == 2 and
600-
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
598+
weight_qtensor.zero_point_domain == "float" and
601599
weight_qtensor.extended_layout == "tensor_core_tiled"
602600
):
603601
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
@@ -640,7 +638,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
640638
len(weight_qtensor.block_size) == 2 and
641639
weight_qtensor.block_size[0] == 1 and
642640
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
643-
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
641+
weight_qtensor.zero_point_domain == "int" and
644642
weight_qtensor.extended_layout == "plain"
645643
):
646644
# TODO: enable cpu and mps efficient path

torchao/quantization/quant_api.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@
3131
to_linear_act_quantized,
3232
)
3333

34-
from .quant_primitives import (
35-
MappingType,
36-
ZeroPointDomain,
37-
)
3834
from .weight_only import WeightOnlyInt8QuantLinear
3935
from .unified import Quantizer, TwoStepQuantizer
4036
from .GPTQ import (
@@ -272,15 +268,15 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Union[str, Callable[
272268
273269
# weight settings
274270
groupsize = 32
275-
mapping_type = MappingType.ASYMMETRIC
271+
mapping_type = "asymmetric"
276272
block_size = (1, groupsize)
277273
target_dtype = torch.int32
278274
quant_min = 0
279275
quant_max = 15
280276
eps = 1e-6
281277
preserve_zero = False
282278
zero_point_dtype = torch.bfloat16
283-
zero_point_domain = ZeroPointDomain.FLOAT
279+
zero_point_domain = "float"
284280
285281
apply_weight_quant = lambda x: to_affine_quantized(
286282
x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
@@ -321,7 +317,7 @@ def apply_8da4w_quant(weight):
321317
from torchao.dtypes import to_affine_quantized
322318

323319
# weight settings
324-
mapping_type = MappingType.SYMMETRIC
320+
mapping_type = "symmetric"
325321
block_size = (1, groupsize)
326322
target_dtype = torch.int8
327323
eps = torch.finfo(torch.float32).eps
@@ -338,7 +334,7 @@ def get_per_token_block_size(x):
338334
return block_size
339335

340336
# input settings
341-
input_mapping_type = MappingType.ASYMMETRIC
337+
input_mapping_type = "asymmetric"
342338
input_target_dtype = torch.int8
343339
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
344340

@@ -363,15 +359,15 @@ def apply_int4wo_quant(weight):
363359
# avoid circular dep
364360
from torchao.dtypes import to_affine_quantized
365361

366-
mapping_type = MappingType.ASYMMETRIC
362+
mapping_type = "asymmetric"
367363
block_size = (1, groupsize)
368364
target_dtype = torch.int32
369365
quant_min = 0
370366
quant_max = 15
371367
eps = 1e-6
372368
preserve_zero = False
373369
zero_point_dtype = torch.bfloat16
374-
zero_point_domain = ZeroPointDomain.FLOAT
370+
zero_point_domain = "float"
375371
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled", inner_k_tiles=inner_k_tiles)
376372

377373
return apply_int4wo_quant
@@ -385,7 +381,7 @@ def apply_int8wo_quant(weight):
385381
# avoid circular dep
386382
from torchao.dtypes import to_affine_quantized
387383

388-
mapping_type = MappingType.SYMMETRIC
384+
mapping_type = "symmetric"
389385
target_dtype = torch.int8
390386
eps = torch.finfo(torch.float32).eps
391387
zero_point_dtype = torch.int64
@@ -407,7 +403,7 @@ def apply_int8dyn_quant(weight):
407403
# avoid circular dep
408404
from torchao.dtypes import to_affine_quantized
409405
# weight settings
410-
mapping_type = MappingType.SYMMETRIC
406+
mapping_type = "symmetric"
411407
def get_weight_block_size(x):
412408
return (1, x.shape[1])
413409
target_dtype = torch.int8
@@ -421,7 +417,7 @@ def get_per_token_block_size(x):
421417
block_size[i] = 1
422418
return block_size
423419

424-
input_mapping_type = MappingType.SYMMETRIC
420+
input_mapping_type = "symmetric"
425421
input_target_dtype = torch.int8
426422
input_eps = 1e-5
427423
input_quant_min = -127

0 commit comments

Comments
 (0)