Skip to content

Commit fbc0259

Browse files
committed
Update on "Add generic fake quantized linear for QAT"
**Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned]
2 parents 83e2f10 + 5b4feb0 commit fbc0259

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

torchao/quantization/quant_primitives.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,24 @@ class TorchAODType(Enum):
106106
TorchAODType.INT5: 5,
107107
TorchAODType.INT6: 6,
108108
TorchAODType.INT7: 7,
109+
torch.uint8: 8,
109110
torch.int8: 8,
110111
torch.int16: 16,
111112
torch.int32: 32,
112113
}
113114

114115
_SUB_BYTE_UINT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {}
115-
_SUB_BYTE_INT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {}
116+
_SUB_BYTE_INT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {
117+
TorchAODType.INT1: (-(2**0), 2**0 - 1),
118+
TorchAODType.INT2: (-(2**1), 2**1 - 1),
119+
TorchAODType.INT3: (-(2**2), 2**2 - 1),
120+
TorchAODType.INT4: (-(2**3), 2**3 - 1),
121+
TorchAODType.INT5: (-(2**4), 2**4 - 1),
122+
TorchAODType.INT6: (-(2**5), 2**5 - 1),
123+
TorchAODType.INT7: (-(2**6), 2**6 - 1),
124+
}
116125

126+
# torch.uintX available only in PyTorch 2.3+
117127
if TORCH_VERSION_AT_LEAST_2_3:
118128
_SUB_BYTE_UINT_BOUNDS = {
119129
torch.uint1: (0, 2**1-1),
@@ -124,18 +134,6 @@ class TorchAODType(Enum):
124134
torch.uint6: (0, 2**6-1),
125135
torch.uint7: (0, 2**7-1),
126136
}
127-
_SUB_BYTE_INT_BOUNDS = {
128-
TorchAODType.INT1: (-(2**0), 2**0 - 1),
129-
TorchAODType.INT2: (-(2**1), 2**1 - 1),
130-
TorchAODType.INT3: (-(2**2), 2**2 - 1),
131-
TorchAODType.INT4: (-(2**3), 2**3 - 1),
132-
TorchAODType.INT5: (-(2**4), 2**4 - 1),
133-
TorchAODType.INT6: (-(2**5), 2**5 - 1),
134-
TorchAODType.INT7: (-(2**6), 2**6 - 1),
135-
}
136-
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS)
137-
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS)
138-
139137
_DTYPE_TO_BIT_WIDTH.update({
140138
torch.uint1: 1,
141139
torch.uint2: 2,
@@ -144,9 +142,10 @@ class TorchAODType(Enum):
144142
torch.uint5: 5,
145143
torch.uint6: 6,
146144
torch.uint7: 7,
147-
torch.uint8: 8,
148145
})
149146

147+
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS)
148+
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS)
150149
assert _DTYPE_TO_BIT_WIDTH.keys() == _DTYPE_TO_QVALUE_BOUNDS.keys()
151150

152151
_ONES_TABLE = [_n_ones(i) for i in range(8)]

0 commit comments

Comments
 (0)