|
8 | 8 | run_tests, |
9 | 9 | ) |
10 | 10 |
|
11 | | -from torchao.dtypes import SemiSparseLayout |
| 11 | +from torchao.dtypes import SemiSparseLayout, Int4CPULayout |
12 | 12 | from torchao.quantization import ( |
13 | 13 | float8_weight_only, |
14 | 14 | int4_weight_only, |
|
17 | 17 | int8_weight_only, |
18 | 18 | ) |
19 | 19 | from torchao.quantization.quant_primitives import MappingType |
20 | | -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 |
| 20 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 |
21 | 21 |
|
22 | 22 | is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) |
23 | 23 |
|
24 | 24 |
|
25 | | -def get_quantization_functions(do_sparse: bool, do_int4: bool): |
| 25 | +def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"): |
26 | 26 | base_functions = [ |
27 | 27 | int8_weight_only(), |
28 | 28 | int8_dynamic_activation_int4_weight(), |
29 | 29 | int8_dynamic_activation_int8_weight(), |
30 | 30 | int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC), |
31 | 31 | ] |
32 | 32 | if do_int4: |
33 | | - base_functions.append(int4_weight_only(group_size=32)) |
| 33 | + if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: |
| 34 | + base_functions.append(int4_weight_only(group_size=32, layout=Int4CPULayout())) |
| 35 | + else: |
| 36 | + base_functions.append(int4_weight_only(group_size=32)) |
34 | 37 |
|
35 | 38 | if do_sparse: |
36 | 39 | base_functions.append( |
@@ -152,30 +155,28 @@ class TestAffineQuantizedBasic(TestCase): |
152 | 155 | COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) |
153 | 156 | COMMON_DTYPES = [torch.bfloat16] |
154 | 157 |
|
155 | | - @common_utils.parametrize("apply_quant", get_quantization_functions(False, True)) |
156 | 158 | @common_utils.parametrize("device", COMMON_DEVICES) |
157 | 159 | @common_utils.parametrize("dtype", COMMON_DTYPES) |
158 | | - def test_flatten_unflatten(self, apply_quant, device, dtype): |
159 | | - if device == "cpu": |
160 | | - self.skipTest(f"Temporarily skipping for {device}") |
161 | | - |
162 | | - linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) |
163 | | - ql = apply_quant(linear) |
164 | | - lp_tensor = ql.weight |
165 | | - tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() |
166 | | - tensor_data_dict = { |
167 | | - name: getattr(lp_tensor, name) for name in tensor_data_name_dict |
168 | | - } |
169 | | - outer_size = lp_tensor.size() |
170 | | - outer_stride = lp_tensor.stride() |
171 | | - reconstructed = type(lp_tensor).__tensor_unflatten__( |
172 | | - tensor_data_dict, tensor_attributes, outer_size, outer_stride |
173 | | - ) |
174 | | - example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) |
175 | | - ref = ql(*example_inputs) |
176 | | - ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) |
177 | | - reconstruct_res = ql(*example_inputs) |
178 | | - self.assertEqual(reconstruct_res, ref) |
| 160 | + def test_flatten_unflatten(self, device, dtype): |
| 161 | + apply_quant_list = get_quantization_functions(False, True, device) |
| 162 | + for apply_quant in apply_quant_list: |
| 163 | + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) |
| 164 | + ql = apply_quant(linear) |
| 165 | + lp_tensor = ql.weight |
| 166 | + tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() |
| 167 | + tensor_data_dict = { |
| 168 | + name: getattr(lp_tensor, name) for name in tensor_data_name_dict |
| 169 | + } |
| 170 | + outer_size = lp_tensor.size() |
| 171 | + outer_stride = lp_tensor.stride() |
| 172 | + reconstructed = type(lp_tensor).__tensor_unflatten__( |
| 173 | + tensor_data_dict, tensor_attributes, outer_size, outer_stride |
| 174 | + ) |
| 175 | + example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) |
| 176 | + ref = ql(*example_inputs) |
| 177 | + ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) |
| 178 | + reconstruct_res = ql(*example_inputs) |
| 179 | + self.assertEqual(reconstruct_res, ref) |
179 | 180 |
|
180 | 181 |
|
181 | 182 | common_utils.instantiate_parametrized_tests(TestAffineQuantized) |
|
0 commit comments