Skip to content

Commit b1c40ad

Browse files
committed
Add Int4CPULayout and update int4 woq
1 parent d4ca98f commit b1c40ad

File tree

14 files changed

+436
-87
lines changed

14 files changed

+436
-87
lines changed

.github/workflows/regression_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
gpu-arch-version: "12.1"
4141
- name: CUDA Nightly
4242
runs-on: linux.g5.12xlarge.nvidia.gpu
43-
torch-spec: '--pre torch==2.6.0.dev20241101 --index-url https://download.pytorch.org/whl/nightly/cu121'
43+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
4444
gpu-arch-type: "cuda"
4545
gpu-arch-version: "12.1"
4646

@@ -61,7 +61,7 @@ jobs:
6161
gpu-arch-version: ""
6262
- name: CPU Nightly
6363
runs-on: linux.4xlarge
64-
torch-spec: '--pre torch==2.6.0.dev20241101 --index-url https://download.pytorch.org/whl/nightly/cpu'
64+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
6565
gpu-arch-type: "cpu"
6666
gpu-arch-version: ""
6767

test/dtypes/test_affine_quantized.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
float8_weight_only,
1212
)
1313
from torchao.quantization.quant_primitives import MappingType
14-
from torchao.dtypes import SemiSparseLayout
14+
from torchao.dtypes import SemiSparseLayout, Int4CPULayout
1515
from torch.testing._internal import common_utils
16-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
16+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6
1717

1818
import torch
1919
import unittest
@@ -22,15 +22,18 @@
2222
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
2323

2424

25-
def get_quantization_functions(do_sparse: bool, do_int4: bool):
25+
def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
2626
base_functions = [
2727
int8_weight_only(),
2828
int8_dynamic_activation_int4_weight(),
2929
int8_dynamic_activation_int8_weight(),
3030
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
3131
]
3232
if do_int4:
33-
base_functions.append(int4_weight_only(group_size=32))
33+
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
34+
base_functions.append(int4_weight_only(group_size=32, layout=Int4CPULayout()))
35+
else:
36+
base_functions.append(int4_weight_only(group_size=32))
3437

3538
if do_sparse:
3639
base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
@@ -139,23 +142,24 @@ class TestAffineQuantizedBasic(TestCase):
139142
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
140143
COMMON_DTYPES = [torch.bfloat16]
141144

142-
@common_utils.parametrize("apply_quant", get_quantization_functions(False, True))
143145
@common_utils.parametrize("device", COMMON_DEVICES)
144146
@common_utils.parametrize("dtype", COMMON_DTYPES)
145-
def test_flatten_unflatten(self, apply_quant, device, dtype):
146-
l = torch.nn.Linear(128, 256, dtype=dtype, device=device)
147-
ql = apply_quant(l)
148-
lp_tensor = ql.weight
149-
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
150-
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
151-
outer_size = lp_tensor.size()
152-
outer_stride = lp_tensor.stride()
153-
reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
154-
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
155-
ref = ql(*example_inputs)
156-
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
157-
reconstruct_res = ql(*example_inputs)
158-
self.assertEqual(reconstruct_res, ref)
147+
def test_flatten_unflatten(self, device, dtype):
148+
apply_quant_list = get_quantization_functions(False, True, device)
149+
for apply_quant in apply_quant_list:
150+
l = torch.nn.Linear(128, 256, dtype=dtype, device=device)
151+
ql = apply_quant(l)
152+
lp_tensor = ql.weight
153+
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
154+
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
155+
outer_size = lp_tensor.size()
156+
outer_stride = lp_tensor.stride()
157+
reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
158+
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
159+
ref = ql(*example_inputs)
160+
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
161+
reconstruct_res = ql(*example_inputs)
162+
self.assertEqual(reconstruct_res, ref)
159163

160164
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
161165
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

test/integration/test_integration.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torchao.quantization.dynamic_quant import (
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
22-
from torchao.dtypes import TensorCoreTiledLayout
22+
from torchao.dtypes import TensorCoreTiledLayout, Int4CPULayout
2323
from torchao.quantization.quant_api import (
2424
int4_weight_only,
2525
int8_weight_only,
@@ -93,6 +93,7 @@
9393
is_fbcode,
9494
benchmark_model
9595
)
96+
from torchao.dtypes.utils import is_device
9697

9798
logger = logging.getLogger("INFO")
9899

@@ -133,7 +134,10 @@ def _int8da_int8w_api(mod):
133134
change_linear_weights_to_int8_dqtensors(mod)
134135

135136
def _int4wo_api(mod):
136-
if TORCH_VERSION_AT_LEAST_2_4:
137+
if is_device(next(mod.parameters()).device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
138+
quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False)
139+
unwrap_tensor_subclass(mod)
140+
elif TORCH_VERSION_AT_LEAST_2_4:
137141
quantize_(mod, int4_weight_only(), set_inductor_config=False)
138142
if not TORCH_VERSION_AT_LEAST_2_5:
139143
unwrap_tensor_subclass(mod)
@@ -925,10 +929,16 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
925929
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
926930
if dtype != torch.bfloat16:
927931
self.skipTest(f"Fails for {dtype}")
932+
layout_list = []
933+
if device == 'cpu' and TORCH_VERSION_AT_LEAST_2_6:
934+
layout_list.append(Int4CPULayout())
935+
else:
936+
for inner_k_tiles in [4, 2]:
937+
layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles))
928938
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
929939
for groupsize in [64, 32]:
930-
for inner_k_tiles in [4, 2]:
931-
kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
940+
for layout in layout_list:
941+
kwargs = {"groupsize": groupsize, "layout": layout}
932942

933943
def api(mod):
934944
kwargs_copy = kwargs.copy()

test/quantization/test_quant_primitives.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
TORCH_VERSION_AT_LEAST_2_6,
3434
is_fbcode,
3535
)
36+
from torchao.dtypes.utils import is_device
3637

3738
_SEED = 1234
3839
torch.manual_seed(_SEED)
@@ -102,7 +103,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
102103
.reshape_as(w)
103104
)
104105
if TORCH_VERSION_AT_LEAST_2_5:
105-
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
106+
if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
107+
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
106108

107109
return w_int4x8
108110

@@ -524,8 +526,10 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
524526
groupsize = 128
525527

526528
if TORCH_VERSION_AT_LEAST_2_5:
527-
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
528-
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
529+
input_tmp = input
530+
if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
531+
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
532+
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_tmp, scales, zeros, n_bit, groupsize)
529533
else:
530534
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
531535
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)

torchao/dtypes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
SemiSparseLayout,
2222
TensorCoreTiledLayout,
2323
UintxLayout,
24+
Int4CPULayout,
2425
)
2526
from .utils import (
2627
Layout,
@@ -48,4 +49,5 @@
4849
"UintxLayout",
4950
"MarlinQQQTensor",
5051
"MarlinQQQLayout",
52+
"Int4CPULayout",
5153
]

torchao/dtypes/uintx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from .tensor_core_tiled_layout import (
1414
TensorCoreTiledLayout,
15+
Int4CPULayout,
1516
)
1617
from .uintx_layout import (
1718
UintxLayout,
@@ -23,5 +24,6 @@
2324
"MarlinSparseLayout",
2425
"SemiSparseLayout",
2526
"TensorCoreTiledLayout",
27+
"Int4CPULayout",
2628
"MarlinQQQLayout",
2729
]

0 commit comments

Comments
 (0)