Skip to content

Commit f03b194

Browse files
committed
Revert "pin nightly to 2.5.0.dev20240709+cu121 (#505)"
This reverts commit cc871c5.
1 parent 6e7cf71 commit f03b194

File tree

3 files changed

+32
-26
lines changed

3 files changed

+32
-26
lines changed

.github/workflows/regression_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
gpu-arch-version: "12.1"
3434
- name: CUDA Nightly
3535
runs-on: linux.g5.12xlarge.nvidia.gpu
36-
torch-spec: '--pre torch==2.5.0.dev20240709+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
36+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
3737
gpu-arch-type: "cuda"
3838
gpu-arch-version: "12.1"
3939
- name: CPU 2.2.2
@@ -48,7 +48,7 @@ jobs:
4848
gpu-arch-version: ""
4949
- name: CPU Nightly
5050
runs-on: linux.4xlarge
51-
torch-spec: '--pre torch==2.5.0.dev20240709+cpu --index-url https://download.pytorch.org/whl/nightly/cpu'
51+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
5252
gpu-arch-type: "cpu"
5353
gpu-arch-version: ""
5454

test/integration/test_integration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
631631

632632
@parameterized.expand(COMMON_DEVICE_DTYPE)
633633
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
634+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
634635
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
635636
if dtype != torch.bfloat16:
636637
self.skipTest("Currently only supports bfloat16.")
@@ -641,6 +642,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
641642

642643
@parameterized.expand(COMMON_DEVICE_DTYPE)
643644
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
645+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
644646
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
645647
if dtype != torch.bfloat16:
646648
self.skipTest("Currently only supports bfloat16.")
@@ -821,6 +823,7 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):
821823

822824
@parameterized.expand(COMMON_DEVICE_DTYPE)
823825
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
826+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
824827
def test_int4_weight_only_quant_subclass_api(self, device, dtype):
825828
if dtype != torch.bfloat16:
826829
self.skipTest(f"Fails for {dtype}")
@@ -835,6 +838,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
835838

836839
@parameterized.expand(COMMON_DEVICE_DTYPE)
837840
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
841+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
838842
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
839843
if dtype != torch.bfloat16:
840844
self.skipTest(f"Fails for {dtype}")

test/test_ops.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
9595
TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES))
9696

9797
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
98+
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
9899
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str)
99100
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
100101
N, K = shape
@@ -107,14 +108,15 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
107108

108109
# TODO: Fix "test_aot_dispatch_dynamic" test failure
109110
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
111+
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
110112
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str)
111113
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
112114
test_utils = [
113115
"test_schema",
114116
"test_autograd_registration",
115117
"test_faketensor",
116118
]
117-
119+
118120
# TODO: Figure out why test fails unless torch >= 2.5
119121
if TORCH_VERSION_AFTER_2_5:
120122
test_utils.append("test_aot_dispatch_dynamic")
@@ -137,10 +139,10 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
137139
assert scales.shape == zeros.shape
138140

139141
midpoint = 2 ** (nbits - 1)
140-
142+
141143
#Convert fron u4 -> s4 and upcast to bfloat16
142144
q = q.sub(midpoint).to(dtype)
143-
145+
144146
# Dequantize
145147
q = q.reshape(-1, group_size)
146148
dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1)
@@ -152,18 +154,18 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
152154
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
153155
def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size):
154156
n, k = shape
155-
dtype = torch.bfloat16
157+
dtype = torch.bfloat16
156158

157159
device = "cuda"
158160

159161
t = torch.randn(n, k, dtype=dtype, device=device)
160162
scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype)
161-
163+
162164
# Quantize
163165
q = groupwise_affine_quantize_tensor_from_qparams(
164166
t, scales, zeros, n_bit=4, groupsize=group_size
165167
)
166-
168+
167169
# Pack to tensor core layout
168170
packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)
169171
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
@@ -174,7 +176,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
174176
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
175177
q, scales, zeros, n_bit=4, groupsize=group_size
176178
)
177-
179+
178180
# Dequantize by passing in an identity matrix as the activation
179181
a_eye = torch.eye(k, device=device, dtype=dtype)
180182
dq_id = torch.ops.aten._weight_int4pack_mm(
@@ -183,23 +185,23 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
183185
group_size,
184186
scales_and_zeros,
185187
).t()
186-
188+
187189
# Actual operation to test
188190
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles)
189-
191+
190192
# Compare results
191193
diff_ao_id = (dq_id - dq_ao).abs().max()
192194
diff_op_id = (dq_op - dq_id).abs().max()
193195
diff_op_ao = (dq_op - dq_ao).abs().max()
194-
196+
195197
# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
196198
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
197199
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
198200
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
199-
200-
# Test that the `dequant` kernel gives same results as identity matrix-based dequant
201+
202+
# Test that the `dequant` kernel gives same results as identity matrix-based dequant
201203
assert diff_op_id == 0
202-
204+
203205
# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
204206
assert diff_op_ao == diff_ao_id
205207

@@ -210,7 +212,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
210212
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
211213
def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size):
212214
n, k = shape
213-
dtype = torch.bfloat16
215+
dtype = torch.bfloat16
214216
device = "cuda"
215217

216218
# Quantize and pack
@@ -222,13 +224,13 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
222224

223225
packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)
224226
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
225-
227+
226228
# Unpack and dequantize
227229
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles)
228230
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
229231
unpacked, scales, zeros, n_bit=4, groupsize=group_size
230232
)
231-
233+
232234
# Dequantize by passing in an identity matrix as the activation
233235
a_eye = torch.eye(k, device=device, dtype=dtype)
234236
dq_id = torch.ops.aten._weight_int4pack_mm(
@@ -237,23 +239,23 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
237239
group_size,
238240
scales_and_zeros,
239241
).t()
240-
242+
241243
# Actual operation to test
242244
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles)
243-
245+
244246
# Compare results
245247
diff_ao_id = (dq_id - dq_ao).abs().max()
246248
diff_op_id = (dq_op - dq_id).abs().max()
247249
diff_op_ao = (dq_op - dq_ao).abs().max()
248-
250+
249251
# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
250252
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
251253
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
252254
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
253-
254-
# Test that the `dequant` kernel gives same results as identity matrix-based dequant
255+
256+
# Test that the `dequant` kernel gives same results as identity matrix-based dequant
255257
assert diff_op_id == 0
256-
258+
257259
# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
258260
assert diff_op_ao == diff_ao_id
259261

@@ -271,7 +273,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
271273
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)
272274
zeros = torch.randn_like(scales)
273275
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
274-
276+
275277
test_utils = [
276278
"test_schema",
277279
"test_autograd_registration",
@@ -287,4 +289,4 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
287289
)
288290

289291
if __name__ == "__main__":
290-
run_tests()
292+
run_tests()

0 commit comments

Comments
 (0)