Skip to content

Commit 928addc

Browse files
committed
Add Int4CPULayout and update int4 woq
1 parent 2e338a8 commit 928addc

File tree

14 files changed

+449
-92
lines changed

14 files changed

+449
-92
lines changed

.github/workflows/regression_test.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ jobs:
7070
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
7171
gpu-arch-type: "cuda"
7272
gpu-arch-version: "12.1"
73+
- name: CUDA Nightly
74+
runs-on: linux.g5.12xlarge.nvidia.gpu
75+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
76+
gpu-arch-type: "cuda"
77+
gpu-arch-version: "12.1"
78+
7379
- name: CPU 2.3
7480
runs-on: linux.4xlarge
7581
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'
@@ -85,6 +91,11 @@ jobs:
8591
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cpu'
8692
gpu-arch-type: "cpu"
8793
gpu-arch-version: ""
94+
- name: CPU Nightly
95+
runs-on: linux.4xlarge
96+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
97+
gpu-arch-type: "cpu"
98+
gpu-arch-version: ""
8899

89100
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
90101
with:

test/dtypes/test_affine_quantized.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
run_tests,
99
)
1010

11-
from torchao.dtypes import SemiSparseLayout
11+
from torchao.dtypes import SemiSparseLayout, Int4CPULayout
1212
from torchao.quantization import (
1313
float8_weight_only,
1414
int4_weight_only,
@@ -17,20 +17,23 @@
1717
int8_weight_only,
1818
)
1919
from torchao.quantization.quant_primitives import MappingType
20-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
20+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6
2121

2222
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
2323

2424

25-
def get_quantization_functions(do_sparse: bool, do_int4: bool):
25+
def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
2626
base_functions = [
2727
int8_weight_only(),
2828
int8_dynamic_activation_int4_weight(),
2929
int8_dynamic_activation_int8_weight(),
3030
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
3131
]
3232
if do_int4:
33-
base_functions.append(int4_weight_only(group_size=32))
33+
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
34+
base_functions.append(int4_weight_only(group_size=32, layout=Int4CPULayout()))
35+
else:
36+
base_functions.append(int4_weight_only(group_size=32))
3437

3538
if do_sparse:
3639
base_functions.append(
@@ -152,30 +155,28 @@ class TestAffineQuantizedBasic(TestCase):
152155
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
153156
COMMON_DTYPES = [torch.bfloat16]
154157

155-
@common_utils.parametrize("apply_quant", get_quantization_functions(False, True))
156158
@common_utils.parametrize("device", COMMON_DEVICES)
157159
@common_utils.parametrize("dtype", COMMON_DTYPES)
158-
def test_flatten_unflatten(self, apply_quant, device, dtype):
159-
if device == "cpu":
160-
self.skipTest(f"Temporarily skipping for {device}")
161-
162-
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
163-
ql = apply_quant(linear)
164-
lp_tensor = ql.weight
165-
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
166-
tensor_data_dict = {
167-
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
168-
}
169-
outer_size = lp_tensor.size()
170-
outer_stride = lp_tensor.stride()
171-
reconstructed = type(lp_tensor).__tensor_unflatten__(
172-
tensor_data_dict, tensor_attributes, outer_size, outer_stride
173-
)
174-
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
175-
ref = ql(*example_inputs)
176-
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
177-
reconstruct_res = ql(*example_inputs)
178-
self.assertEqual(reconstruct_res, ref)
160+
def test_flatten_unflatten(self, device, dtype):
161+
apply_quant_list = get_quantization_functions(False, True, device)
162+
for apply_quant in apply_quant_list:
163+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
164+
ql = apply_quant(linear)
165+
lp_tensor = ql.weight
166+
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
167+
tensor_data_dict = {
168+
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
169+
}
170+
outer_size = lp_tensor.size()
171+
outer_stride = lp_tensor.stride()
172+
reconstructed = type(lp_tensor).__tensor_unflatten__(
173+
tensor_data_dict, tensor_attributes, outer_size, outer_stride
174+
)
175+
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
176+
ref = ql(*example_inputs)
177+
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
178+
reconstruct_res = ql(*example_inputs)
179+
self.assertEqual(reconstruct_res, ref)
179180

180181

181182
common_utils.instantiate_parametrized_tests(TestAffineQuantized)

test/integration/test_integration.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torchao.quantization.dynamic_quant import (
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
22-
from torchao.dtypes import TensorCoreTiledLayout
22+
from torchao.dtypes import TensorCoreTiledLayout, Int4CPULayout
2323
from torchao.quantization.quant_api import (
2424
int4_weight_only,
2525
int8_weight_only,
@@ -93,6 +93,7 @@
9393
is_fbcode,
9494
benchmark_model
9595
)
96+
from torchao.dtypes.utils import is_device
9697

9798
logger = logging.getLogger("INFO")
9899

@@ -133,7 +134,10 @@ def _int8da_int8w_api(mod):
133134
change_linear_weights_to_int8_dqtensors(mod)
134135

135136
def _int4wo_api(mod):
136-
if TORCH_VERSION_AT_LEAST_2_4:
137+
if is_device(next(mod.parameters()).device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
138+
quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False)
139+
unwrap_tensor_subclass(mod)
140+
elif TORCH_VERSION_AT_LEAST_2_4:
137141
quantize_(mod, int4_weight_only(), set_inductor_config=False)
138142
if not TORCH_VERSION_AT_LEAST_2_5:
139143
unwrap_tensor_subclass(mod)
@@ -935,10 +939,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
935939
self.skipTest(f"Temporarily skipping for {device}")
936940
if dtype != torch.bfloat16:
937941
self.skipTest(f"Fails for {dtype}")
942+
layout_list = []
943+
if device == 'cpu' and TORCH_VERSION_AT_LEAST_2_6:
944+
layout_list.append(Int4CPULayout())
945+
else:
946+
for inner_k_tiles in [4, 2]:
947+
layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles))
938948
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
939949
for groupsize in [64, 32]:
940-
for inner_k_tiles in [4, 2]:
941-
kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
950+
for layout in layout_list:
951+
kwargs = {"groupsize": groupsize, "layout": layout}
942952

943953
def api(mod):
944954
kwargs_copy = kwargs.copy()

test/quantization/test_quant_primitives.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
TORCH_VERSION_AT_LEAST_2_6,
3434
is_fbcode,
3535
)
36+
from torchao.dtypes.utils import is_device
3637

3738
_SEED = 1234
3839
torch.manual_seed(_SEED)
@@ -102,7 +103,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
102103
.reshape_as(w)
103104
)
104105
if TORCH_VERSION_AT_LEAST_2_5:
105-
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
106+
if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
107+
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
106108

107109
return w_int4x8
108110

@@ -524,8 +526,10 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
524526
groupsize = 128
525527

526528
if TORCH_VERSION_AT_LEAST_2_5:
527-
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
528-
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
529+
input_tmp = input
530+
if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
531+
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
532+
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_tmp, scales, zeros, n_bit, groupsize)
529533
else:
530534
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
531535
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)

torchao/dtypes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
SemiSparseLayout,
2222
TensorCoreTiledLayout,
2323
UintxLayout,
24+
Int4CPULayout,
2425
)
2526
from .utils import (
2627
Layout,
@@ -48,4 +49,5 @@
4849
"UintxLayout",
4950
"MarlinQQQTensor",
5051
"MarlinQQQLayout",
52+
"Int4CPULayout",
5153
]

torchao/dtypes/uintx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from .tensor_core_tiled_layout import (
1414
TensorCoreTiledLayout,
15+
Int4CPULayout,
1516
)
1617
from .uintx_layout import (
1718
UintxLayout,
@@ -23,5 +24,6 @@
2324
"MarlinSparseLayout",
2425
"SemiSparseLayout",
2526
"TensorCoreTiledLayout",
27+
"Int4CPULayout",
2628
"MarlinQQQLayout",
2729
]

0 commit comments

Comments
 (0)