|
14 | 14 | import torch |
15 | 15 | import torch.nn as nn |
16 | 16 |
|
17 | | -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 |
| 17 | +from torchao.utils import ( |
| 18 | + TORCH_VERSION_AT_LEAST_2_5, |
| 19 | + is_sm_at_least_89, |
| 20 | + is_sm_at_least_90, |
| 21 | +) |
18 | 22 |
|
19 | 23 | if not TORCH_VERSION_AT_LEAST_2_5: |
20 | 24 | pytest.skip("Unsupported PyTorch version", allow_module_level=True) |
|
60 | 64 | torch.manual_seed(0) |
61 | 65 |
|
62 | 66 |
|
63 | | -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) |
64 | | -is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) |
65 | | - |
66 | | - |
67 | 67 | def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: |
68 | 68 | assert torch.all(a._scale == b._scale).item(), "scales are not identical" |
69 | 69 | assert torch.all(a._data == b._data).item(), "data is not identical" |
@@ -219,7 +219,7 @@ def test_axiswise_reshape(self): |
219 | 219 | ], |
220 | 220 | ) |
221 | 221 | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") |
222 | | - @unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0") |
| 222 | + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") |
223 | 223 | def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): |
224 | 224 | a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") |
225 | 225 | b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") |
@@ -333,7 +333,9 @@ def _test_linear_impl( |
333 | 333 | # verify initialization flags got updated |
334 | 334 | assert m_fp8.is_amax_initialized, "Amax was not properly initialized" |
335 | 335 |
|
336 | | - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) |
| 336 | + @pytest.mark.parametrize( |
| 337 | + "emulate", [True, False] if is_sm_at_least_89() else [True] |
| 338 | + ) |
337 | 339 | @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) |
338 | 340 | @pytest.mark.parametrize( |
339 | 341 | "scaling_type_input", |
@@ -415,7 +417,9 @@ def test_linear_from_recipe( |
415 | 417 | config, |
416 | 418 | ) |
417 | 419 |
|
418 | | - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) |
| 420 | + @pytest.mark.parametrize( |
| 421 | + "emulate", [True, False] if is_sm_at_least_89() else [True] |
| 422 | + ) |
419 | 423 | @pytest.mark.parametrize( |
420 | 424 | "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] |
421 | 425 | ) |
@@ -462,7 +466,9 @@ def test_autocast_outputs( |
462 | 466 | @pytest.mark.parametrize( |
463 | 467 | "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] |
464 | 468 | ) |
465 | | - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) |
| 469 | + @pytest.mark.parametrize( |
| 470 | + "emulate", [True, False] if is_sm_at_least_89() else [True] |
| 471 | + ) |
466 | 472 | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") |
467 | 473 | def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): |
468 | 474 | m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) |
@@ -523,18 +529,33 @@ def test_repr(self): |
523 | 529 | s = m.__repr__() |
524 | 530 | assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s |
525 | 531 |
|
526 | | - @unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available") |
| 532 | + @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") |
527 | 533 | def test_inference_mode(self): |
528 | 534 | x = torch.randn(32, 32, device="cuda") |
529 | 535 | m = nn.Sequential(nn.Linear(32, 32)).cuda() |
530 | 536 | m = convert_to_float8_training(m) |
531 | 537 | with torch.inference_mode(mode=True): |
532 | 538 | m(x) |
533 | 539 |
|
| 540 | + @unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available") |
| 541 | + def test_quantize(self): |
| 542 | + x = torch.randn(32, 32, device="cuda") |
| 543 | + m = nn.Sequential(nn.Linear(32, 32)).cuda() |
| 544 | + m = convert_to_float8_training(m) |
| 545 | + assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear" |
| 546 | + from torchao.quantization.quant_api import float8_weight_only, quantize_ |
| 547 | + |
| 548 | + quantize_(m, float8_weight_only()) |
| 549 | + assert ( |
| 550 | + m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn |
| 551 | + ), "Post quantization dtype should be torch.float8_e4m3fn" |
| 552 | + with torch.no_grad(): |
| 553 | + m(x) |
| 554 | + |
534 | 555 |
|
535 | 556 | class TestScaledMM: |
536 | 557 | @unittest.skipIf( |
537 | | - not is_cuda_8_9, |
| 558 | + not is_sm_at_least_89(), |
538 | 559 | "CUDA not available", |
539 | 560 | ) |
540 | 561 | @pytest.mark.parametrize( |
@@ -576,10 +597,10 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): |
576 | 597 | if base_dtype in {torch.bfloat16, torch.float16}: |
577 | 598 | atol, rtol = 7e-2, 7e-2 |
578 | 599 | else: |
579 | | - atol, rtol = 2e-3, 2e-3 |
| 600 | + atol, rtol = 3e-3, 3e-3 |
580 | 601 | torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) |
581 | 602 |
|
582 | | - @unittest.skipIf(not is_cuda_8_9, "CUDA not available") |
| 603 | + @unittest.skipIf(not is_sm_at_least_89(), "CUDA not available") |
583 | 604 | def test_different_configs_error(self): |
584 | 605 | x_fp32 = torch.randn(16, 16, device="cuda") |
585 | 606 | x_scale = torch.tensor(1.0, device="cuda") |
@@ -615,7 +636,7 @@ def test_different_configs_error(self): |
615 | 636 | a @ b |
616 | 637 |
|
617 | 638 | @unittest.skipIf( |
618 | | - not is_cuda_8_9, |
| 639 | + not is_sm_at_least_89(), |
619 | 640 | "CUDA not available", |
620 | 641 | ) |
621 | 642 | @pytest.mark.parametrize( |
|
0 commit comments