|
11 | 11 | float8_weight_only, |
12 | 12 | ) |
13 | 13 | from torchao.quantization.quant_primitives import MappingType |
14 | | -from torchao.dtypes import SemiSparseLayout |
| 14 | +from torchao.dtypes import SemiSparseLayout, Int4CPULayout |
15 | 15 | from torch.testing._internal import common_utils |
16 | | -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 |
| 16 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | import unittest |
|
22 | 22 | is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) |
23 | 23 |
|
24 | 24 |
|
25 | | -def get_quantization_functions(do_sparse: bool, do_int4: bool): |
| 25 | +def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"): |
26 | 26 | base_functions = [ |
27 | 27 | int8_weight_only(), |
28 | 28 | int8_dynamic_activation_int4_weight(), |
29 | 29 | int8_dynamic_activation_int8_weight(), |
30 | 30 | int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC), |
31 | 31 | ] |
32 | 32 | if do_int4: |
33 | | - base_functions.append(int4_weight_only(group_size=32)) |
| 33 | + if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: |
| 34 | + base_functions.append(int4_weight_only(group_size=32, layout=Int4CPULayout())) |
| 35 | + else: |
| 36 | + base_functions.append(int4_weight_only(group_size=32)) |
34 | 37 |
|
35 | 38 | if do_sparse: |
36 | 39 | base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())) |
@@ -139,23 +142,24 @@ class TestAffineQuantizedBasic(TestCase): |
139 | 142 | COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) |
140 | 143 | COMMON_DTYPES = [torch.bfloat16] |
141 | 144 |
|
142 | | - @common_utils.parametrize("apply_quant", get_quantization_functions(False, True)) |
143 | 145 | @common_utils.parametrize("device", COMMON_DEVICES) |
144 | 146 | @common_utils.parametrize("dtype", COMMON_DTYPES) |
145 | | - def test_flatten_unflatten(self, apply_quant, device, dtype): |
146 | | - l = torch.nn.Linear(128, 256, dtype=dtype, device=device) |
147 | | - ql = apply_quant(l) |
148 | | - lp_tensor = ql.weight |
149 | | - tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() |
150 | | - tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict} |
151 | | - outer_size = lp_tensor.size() |
152 | | - outer_stride = lp_tensor.stride() |
153 | | - reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) |
154 | | - example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) |
155 | | - ref = ql(*example_inputs) |
156 | | - ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) |
157 | | - reconstruct_res = ql(*example_inputs) |
158 | | - self.assertEqual(reconstruct_res, ref) |
| 147 | + def test_flatten_unflatten(self, device, dtype): |
| 148 | + apply_quant_list = get_quantization_functions(False, True, device) |
| 149 | + for apply_quant in apply_quant_list: |
| 150 | + l = torch.nn.Linear(128, 256, dtype=dtype, device=device) |
| 151 | + ql = apply_quant(l) |
| 152 | + lp_tensor = ql.weight |
| 153 | + tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() |
| 154 | + tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict} |
| 155 | + outer_size = lp_tensor.size() |
| 156 | + outer_stride = lp_tensor.stride() |
| 157 | + reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) |
| 158 | + example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) |
| 159 | + ref = ql(*example_inputs) |
| 160 | + ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) |
| 161 | + reconstruct_res = ql(*example_inputs) |
| 162 | + self.assertEqual(reconstruct_res, ref) |
159 | 163 |
|
160 | 164 | common_utils.instantiate_parametrized_tests(TestAffineQuantized) |
161 | 165 | common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic) |
|
0 commit comments