Skip to content

Commit 5642f44

Browse files
committed
Update on "Add generic fake quantized linear for QAT"
**Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned]
2 parents fbc0259 + 756cb8d commit 5642f44

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

test/integration/test_integration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,7 @@ def test_shape_logger(self):
11261126
class SmoothquantIntegrationTest(unittest.TestCase):
11271127
@torch.no_grad()
11281128
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1129+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported")
11291130
def test_non_dynamically_quantizable_linear(self):
11301131
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
11311132
self.skipTest("test requires SM capability of at least (8, 0).")

test/quantization/test_qat.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
groupwise_affine_quantize_tensor,
5959
)
6060
from torchao.utils import (
61+
TORCH_VERSION_AT_LEAST_2_3,
6162
TORCH_VERSION_AT_LEAST_2_4,
6263
TORCH_VERSION_AT_LEAST_2_5,
6364
)
@@ -753,14 +754,15 @@ def test_fake_quantize_config_dtype(self):
753754
with self.assertRaisesRegex(ValueError, msg):
754755
FakeQuantizeConfig(torch.float32, "per_token")
755756
# OK
756-
FakeQuantizeConfig(torch.uint1, "per_token")
757-
FakeQuantizeConfig(torch.uint2, "per_token")
758-
FakeQuantizeConfig(torch.uint3, "per_token")
759-
FakeQuantizeConfig(torch.uint4, "per_token")
760-
FakeQuantizeConfig(torch.uint5, "per_token")
761-
FakeQuantizeConfig(torch.uint6, "per_token")
762-
FakeQuantizeConfig(torch.uint7, "per_token")
763-
FakeQuantizeConfig(torch.uint8, "per_token")
757+
if TORCH_VERSION_AT_LEAST_2_3:
758+
FakeQuantizeConfig(torch.uint1, "per_token")
759+
FakeQuantizeConfig(torch.uint2, "per_token")
760+
FakeQuantizeConfig(torch.uint3, "per_token")
761+
FakeQuantizeConfig(torch.uint4, "per_token")
762+
FakeQuantizeConfig(torch.uint5, "per_token")
763+
FakeQuantizeConfig(torch.uint6, "per_token")
764+
FakeQuantizeConfig(torch.uint7, "per_token")
765+
FakeQuantizeConfig(torch.uint8, "per_token")
764766
FakeQuantizeConfig(TorchAODType.INT1, "per_token")
765767
FakeQuantizeConfig(TorchAODType.INT2, "per_token")
766768
FakeQuantizeConfig(TorchAODType.INT3, "per_token")

0 commit comments

Comments
 (0)