Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
gpu-arch-version: "12.1"
- name: CUDA Nightly
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch==2.5.0.dev20240709+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CPU 2.2.2
Expand All @@ -48,7 +48,7 @@ jobs:
gpu-arch-version: ""
- name: CPU Nightly
runs-on: linux.4xlarge
torch-spec: '--pre torch==2.5.0.dev20240709+cpu --index-url https://download.pytorch.org/whl/nightly/cpu'
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""

Expand Down
4 changes: 4 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
from torchao.quantization.quant_api import int4_weight_only
import torch
import unittest
from torchao.utils import (
TORCH_VERSION_AFTER_2_5,
)


class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_tensor_core_layout_transpose(self):
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
shape = t.shape
Expand Down
5 changes: 5 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Currently only supports bfloat16.")
Expand All @@ -641,6 +642,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Currently only supports bfloat16.")
Expand Down Expand Up @@ -821,6 +823,7 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_api(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand All @@ -835,6 +838,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand Down Expand Up @@ -1484,6 +1488,7 @@ def test_get_model_size_autoquant(self, device, dtype):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_get_model_size_aqt(self, api, test_device, test_dtype):
if test_dtype != torch.bfloat16:
self.skipTest(f"{api} in {test_dtype} is not supported yet")
Expand Down
53 changes: 29 additions & 24 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str)
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
N, K = shape
Expand All @@ -107,14 +108,15 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):

# TODO: Fix "test_aot_dispatch_dynamic" test failure
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str)
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
]

# TODO: Figure out why test fails unless torch >= 2.5
if TORCH_VERSION_AFTER_2_5:
test_utils.append("test_aot_dispatch_dynamic")
Expand All @@ -137,10 +139,10 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
assert scales.shape == zeros.shape

midpoint = 2 ** (nbits - 1)

#Convert fron u4 -> s4 and upcast to bfloat16
q = q.sub(midpoint).to(dtype)

# Dequantize
q = q.reshape(-1, group_size)
dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1)
Expand All @@ -149,21 +151,22 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size):
n, k = shape
dtype = torch.bfloat16
dtype = torch.bfloat16

device = "cuda"

t = torch.randn(n, k, dtype=dtype, device=device)
scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype)

# Quantize
q = groupwise_affine_quantize_tensor_from_qparams(
t, scales, zeros, n_bit=4, groupsize=group_size
)

# Pack to tensor core layout
packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
Expand All @@ -174,7 +177,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
q, scales, zeros, n_bit=4, groupsize=group_size
)

# Dequantize by passing in an identity matrix as the activation
a_eye = torch.eye(k, device=device, dtype=dtype)
dq_id = torch.ops.aten._weight_int4pack_mm(
Expand All @@ -183,34 +186,35 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
group_size,
scales_and_zeros,
).t()

# Actual operation to test
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles)

# Compare results
diff_ao_id = (dq_id - dq_ao).abs().max()
diff_op_id = (dq_op - dq_id).abs().max()
diff_op_ao = (dq_op - dq_ao).abs().max()

# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
# Test that the `dequant` kernel gives same results as identity matrix-based dequant

# Test that the `dequant` kernel gives same results as identity matrix-based dequant
assert diff_op_id == 0

# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
assert diff_op_ao == diff_ao_id

assert diff_op_ao < 1e-1

# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size):
n, k = shape
dtype = torch.bfloat16
dtype = torch.bfloat16
device = "cuda"

# Quantize and pack
Expand All @@ -222,13 +226,13 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap

packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)

# Unpack and dequantize
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles)
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
unpacked, scales, zeros, n_bit=4, groupsize=group_size
)

# Dequantize by passing in an identity matrix as the activation
a_eye = torch.eye(k, device=device, dtype=dtype)
dq_id = torch.ops.aten._weight_int4pack_mm(
Expand All @@ -237,29 +241,30 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
group_size,
scales_and_zeros,
).t()

# Actual operation to test
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles)

# Compare results
diff_ao_id = (dq_id - dq_ao).abs().max()
diff_op_id = (dq_op - dq_id).abs().max()
diff_op_ao = (dq_op - dq_ao).abs().max()

# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
# Test that the `dequant` kernel gives same results as identity matrix-based dequant

# Test that the `dequant` kernel gives same results as identity matrix-based dequant
assert diff_op_id == 0

# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
assert diff_op_ao == diff_ao_id

assert diff_op_ao < 1e-1

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size):
n, k = shape
Expand All @@ -271,7 +276,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)
zeros = torch.randn_like(scales)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)

test_utils = [
"test_schema",
"test_autograd_registration",
Expand All @@ -287,4 +292,4 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
)

if __name__ == "__main__":
run_tests()
run_tests()