Skip to content

Commit 83e2f10

Browse files
committed
Update on "Add generic fake quantized linear for QAT"
**Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned]
2 parents c0ed9ed + f9a2f4c commit 83e2f10

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

torchao/quantization/quant_primitives.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,19 @@ class TorchAODType(Enum):
9898
torch.int16: (-(2**15), 2**15 - 1),
9999
torch.int32: (-(2**31), 2**31 - 1),
100100
}
101+
_DTYPE_TO_BIT_WIDTH: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {
102+
TorchAODType.INT1: 1,
103+
TorchAODType.INT2: 2,
104+
TorchAODType.INT3: 3,
105+
TorchAODType.INT4: 4,
106+
TorchAODType.INT5: 5,
107+
TorchAODType.INT6: 6,
108+
TorchAODType.INT7: 7,
109+
torch.int8: 8,
110+
torch.int16: 16,
111+
torch.int32: 32,
112+
}
113+
101114
_SUB_BYTE_UINT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {}
102115
_SUB_BYTE_INT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {}
103116

@@ -123,26 +136,17 @@ class TorchAODType(Enum):
123136
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS)
124137
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS)
125138

126-
_DTYPE_TO_BIT_WIDTH: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {
127-
torch.uint1: 1,
128-
torch.uint2: 2,
129-
torch.uint3: 3,
130-
torch.uint4: 4,
131-
torch.uint5: 5,
132-
torch.uint6: 6,
133-
torch.uint7: 7,
134-
torch.uint8: 8,
135-
TorchAODType.INT1: 1,
136-
TorchAODType.INT2: 2,
137-
TorchAODType.INT3: 3,
138-
TorchAODType.INT4: 4,
139-
TorchAODType.INT5: 5,
140-
TorchAODType.INT6: 6,
141-
TorchAODType.INT7: 7,
142-
torch.int8: 8,
143-
torch.int16: 16,
144-
torch.int32: 32,
145-
}
139+
_DTYPE_TO_BIT_WIDTH.update({
140+
torch.uint1: 1,
141+
torch.uint2: 2,
142+
torch.uint3: 3,
143+
torch.uint4: 4,
144+
torch.uint5: 5,
145+
torch.uint6: 6,
146+
torch.uint7: 7,
147+
torch.uint8: 8,
148+
})
149+
146150
assert _DTYPE_TO_BIT_WIDTH.keys() == _DTYPE_TO_QVALUE_BOUNDS.keys()
147151

148152
_ONES_TABLE = [_n_ones(i) for i in range(8)]

0 commit comments

Comments
 (0)