|
37 | 37 | from torchao.testing.utils import skip_if_rocm |
38 | 38 |
|
39 | 39 |
|
40 | | -@skip_if_rocm("ROCm enablement in progress") |
| 40 | +@skip_if_rocm("ROCm not supported") |
41 | 41 | def test_valid_scaled_grouped_mm_2d_3d(): |
42 | 42 | out_dtype = torch.bfloat16 |
43 | 43 | device = "cuda" |
@@ -91,6 +91,7 @@ def test_valid_scaled_grouped_mm_2d_3d(): |
91 | 91 | assert torch.equal(b_t.grad, ref_b_t.grad) |
92 | 92 |
|
93 | 93 |
|
| 94 | +@skip_if_rocm("ROCm not supported") |
94 | 95 | @pytest.mark.parametrize("m", [16, 17]) |
95 | 96 | @pytest.mark.parametrize("k", [16, 18]) |
96 | 97 | @pytest.mark.parametrize("n", [32, 33]) |
@@ -219,6 +220,7 @@ def compute_reference_forward( |
219 | 220 | return output_ref |
220 | 221 |
|
221 | 222 |
|
| 223 | +@skip_if_rocm("ROCm not supported") |
222 | 224 | @pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)]) |
223 | 225 | @pytest.mark.parametrize("num_experts", (1, 8, 16)) |
224 | 226 | def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts): |
@@ -249,6 +251,7 @@ def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts): |
249 | 251 | assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}" |
250 | 252 |
|
251 | 253 |
|
| 254 | +@skip_if_rocm("ROCm not supported") |
252 | 255 | @pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)]) |
253 | 256 | @pytest.mark.parametrize("num_experts", (1, 8, 16)) |
254 | 257 | def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts): |
|
0 commit comments