|
27 | 27 | get_groupwise_affine_qparams, |
28 | 28 | groupwise_affine_quantize_tensor, |
29 | 29 | ) |
30 | | -from torchao.utils import TORCH_VERSION_AFTER_2_4 |
| 30 | +from torchao.utils import ( |
| 31 | + TORCH_VERSION_AFTER_2_4, |
| 32 | + TORCH_VERSION_AFTER_2_5, |
| 33 | +) |
31 | 34 |
|
32 | 35 |
|
33 | 36 | # TODO: put this in a common test utils file |
@@ -366,6 +369,8 @@ def _assert_close_4w(self, val, ref): |
366 | 369 |
|
367 | 370 | @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") |
368 | 371 | @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") |
| 372 | + # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 |
| 373 | + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") |
369 | 374 | def test_qat_4w_primitives(self): |
370 | 375 | n_bit = 4 |
371 | 376 | group_size = 32 |
@@ -411,6 +416,8 @@ def test_qat_4w_primitives(self): |
411 | 416 |
|
412 | 417 | @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") |
413 | 418 | @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") |
| 419 | + # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 |
| 420 | + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") |
414 | 421 | def test_qat_4w_linear(self): |
415 | 422 | from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear |
416 | 423 | from torchao.quantization.GPTQ import WeightOnlyInt4Linear |
@@ -439,6 +446,8 @@ def test_qat_4w_linear(self): |
439 | 446 |
|
440 | 447 | @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") |
441 | 448 | @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") |
| 449 | + # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 |
| 450 | + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") |
442 | 451 | def test_qat_4w_quantizer(self): |
443 | 452 | from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer |
444 | 453 | from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer |
|
0 commit comments