Skip to content

Commit a813b12

Browse files
committed
Update hardware check conditions
1 parent ed76e9c commit a813b12

File tree

11 files changed

+53
-73
lines changed

11 files changed

+53
-73
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
int8_weight_only,
1818
)
1919
from torchao.quantization.quant_primitives import MappingType
20-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6
21-
22-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
20+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_sm_89
2321

2422

2523
def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
@@ -42,7 +40,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu
4240
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
4341
)
4442

45-
if is_cuda_8_9:
43+
if is_sm_89():
4644
base_functions.append(float8_weight_only())
4745

4846
return base_functions

test/dtypes/test_affine_quantized_float.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,14 @@
3737
MappingType,
3838
choose_qparams_affine,
3939
)
40+
from torchao.utils import (
41+
is_sm_89,
42+
is_sm_90,
43+
)
4044

4145
random.seed(0)
4246
torch.manual_seed(0)
4347

44-
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
45-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
46-
4748

4849
class ToyLinearModel(torch.nn.Module):
4950
def __init__(self, in_features, out_features):
@@ -59,12 +60,12 @@ def forward(self, x):
5960

6061
class TestAffineQuantizedFloat8Compile(InductorTestCase):
6162
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
62-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
63+
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
6364
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
6465
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
6566
@common_utils.parametrize("compile", [True, False])
6667
@common_utils.parametrize(
67-
"granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()]
68+
"granularity", [PerTensor(), PerRow()] if is_sm_90() else [PerTensor()]
6869
)
6970
# Inputs are (M,..), K, N
7071
@common_utils.parametrize(
@@ -134,20 +135,20 @@ def test_fp8_linear_variants(
134135
compute_error(output_original, output_quantized) > 20
135136
), f"Quantization error is too high got a SQNR of {error}"
136137

137-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
138+
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
138139
def test_invalid_granularity(self):
139140
with pytest.raises(ValueError, match="Invalid granularity specification"):
140141
float8_dynamic_activation_float8_weight(granularity="invalid")
141142

142-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
143+
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
143144
def test_mismatched_granularity(self):
144145
with pytest.raises(
145146
ValueError,
146147
match="Different granularities for activation and weight are not supported",
147148
):
148149
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
149150

150-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
151+
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
151152
def test_unsupported_granularity(self):
152153
class UnsupportedGranularity:
153154
pass
@@ -158,7 +159,7 @@ class UnsupportedGranularity:
158159
)
159160

160161
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
161-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
162+
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
162163
def test_per_row_with_float32(self):
163164
with pytest.raises(
164165
AssertionError,
@@ -170,7 +171,7 @@ def test_per_row_with_float32(self):
170171
)
171172

172173
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
173-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
174+
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
174175
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
175176
def test_serialization(self, mode: str):
176177
# Create and quantize the model
@@ -240,7 +241,7 @@ def test_serialization(self, mode: str):
240241
), f"Scales do not match for {layer_name}"
241242

242243
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
243-
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
244+
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
244245
def test_fp8_weight_dimension_warning(self):
245246
# Create model with incompatible dimensions (not multiples of 16)
246247
model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights

test/float8/test_base.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
import torch.nn as nn
1616

17-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
17+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90
1818

1919
if not TORCH_VERSION_AT_LEAST_2_5:
2020
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -60,10 +60,6 @@
6060
torch.manual_seed(0)
6161

6262

63-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
64-
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
65-
66-
6763
def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
6864
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
6965
assert torch.all(a._data == b._data).item(), "data is not identical"
@@ -219,7 +215,7 @@ def test_axiswise_reshape(self):
219215
],
220216
)
221217
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
222-
@unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0")
218+
@unittest.skipIf(not is_sm_90(), "Requires CUDA capability >= 9.0")
223219
def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
224220
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
225221
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")
@@ -333,7 +329,7 @@ def _test_linear_impl(
333329
# verify initialization flags got updated
334330
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"
335331

336-
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
332+
@pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True])
337333
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
338334
@pytest.mark.parametrize(
339335
"scaling_type_input",
@@ -415,7 +411,7 @@ def test_linear_from_recipe(
415411
config,
416412
)
417413

418-
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
414+
@pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True])
419415
@pytest.mark.parametrize(
420416
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
421417
)
@@ -462,7 +458,7 @@ def test_autocast_outputs(
462458
@pytest.mark.parametrize(
463459
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
464460
)
465-
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
461+
@pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True])
466462
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
467463
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
468464
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
@@ -523,7 +519,7 @@ def test_repr(self):
523519
s = m.__repr__()
524520
assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s
525521

526-
@unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available")
522+
@unittest.skipIf(not is_sm_89(), "CUDA 8.9 not available")
527523
def test_inference_mode(self):
528524
x = torch.randn(32, 32, device="cuda")
529525
m = nn.Sequential(nn.Linear(32, 32)).cuda()
@@ -534,7 +530,7 @@ def test_inference_mode(self):
534530

535531
class TestScaledMM:
536532
@unittest.skipIf(
537-
not is_cuda_8_9,
533+
not is_sm_89(),
538534
"CUDA not available",
539535
)
540536
@pytest.mark.parametrize(
@@ -579,7 +575,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
579575
atol, rtol = 2e-3, 2e-3
580576
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
581577

582-
@unittest.skipIf(not is_cuda_8_9, "CUDA not available")
578+
@unittest.skipIf(not is_sm_89(), "CUDA not available")
583579
def test_different_configs_error(self):
584580
x_fp32 = torch.randn(16, 16, device="cuda")
585581
x_scale = torch.tensor(1.0, device="cuda")
@@ -615,7 +611,7 @@ def test_different_configs_error(self):
615611
a @ b
616612

617613
@unittest.skipIf(
618-
not is_cuda_8_9,
614+
not is_sm_89(),
619615
"CUDA not available",
620616
)
621617
@pytest.mark.parametrize(

test/float8/test_compile.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import pytest
1313

14-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
14+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90
1515

1616
if not TORCH_VERSION_AT_LEAST_2_5:
1717
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -46,10 +46,6 @@
4646
from torchao.float8.float8_utils import e4m3_dtype
4747
from torchao.testing.float8.test_utils import get_test_float8_linear_config
4848

49-
# TODO(future PR): standardize IS_H100 with the rest of the codebase
50-
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
51-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
52-
5349

5450
def _test_compile_base(
5551
backend: str,
@@ -99,7 +95,7 @@ def _test_compile_base(
9995
"scaling_type_grad_output",
10096
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
10197
)
102-
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
98+
@pytest.mark.parametrize("emulate", [False, True] if is_sm_89() else [True])
10399
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
104100
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
105101
def test_eager_only(
@@ -126,7 +122,7 @@ def test_eager_only(
126122

127123

128124
@pytest.mark.parametrize("fullgraph", [True])
129-
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
125+
@pytest.mark.parametrize("emulate", [False, True] if is_sm_89() else [True])
130126
@pytest.mark.parametrize(
131127
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
132128
)
@@ -177,7 +173,7 @@ def test_aot_eager(
177173
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
178174
)
179175
@unittest.skipIf(
180-
not torch.cuda.is_available() or not is_cuda_8_9,
176+
not torch.cuda.is_available() or not is_sm_89(),
181177
"CUDA with float8 support not available",
182178
)
183179
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@@ -215,7 +211,7 @@ def test_inductor_from_config_params(
215211
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
216212
],
217213
)
218-
@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available")
214+
@unittest.skipIf(not is_sm_90(), "CUDA with capability 9.0 or greater not available")
219215
def test_inductor_from_recipe(recipe_name):
220216
torch._dynamo.reset()
221217
config = recipe_name_to_linear_config(recipe_name)
@@ -253,7 +249,7 @@ def forward(self, x):
253249

254250
# TODO(future): figure out why the test below fails on CUDA capability 8.9
255251
@unittest.skipIf(
256-
not torch.cuda.is_available() or not is_H100,
252+
not torch.cuda.is_available() or not is_sm_90(),
257253
"CUDA with capability 9.0 or greater not available",
258254
)
259255
def test_float8_with_graph_break_in_the_middle(self):
@@ -269,7 +265,7 @@ def test_float8_with_graph_break_in_the_middle(self):
269265
torch.testing.assert_close(y_eager, y_compiled)
270266

271267
@unittest.skipIf(
272-
not torch.cuda.is_available() or not is_cuda_8_9,
268+
not torch.cuda.is_available() or not is_sm_89(),
273269
"CUDA with float8 support not available",
274270
)
275271
def test_float8_graph_input(self):
@@ -293,7 +289,7 @@ def to_float(x):
293289
torch.testing.assert_close(y2_eager, y2_compiled)
294290

295291
@unittest.skipIf(
296-
not torch.cuda.is_available() or not is_cuda_8_9,
292+
not torch.cuda.is_available() or not is_sm_89(),
297293
"CUDA with float8 support not available",
298294
)
299295
def test_float8_graph_output(self):
@@ -323,7 +319,7 @@ def test_float8_graph_output(self):
323319

324320

325321
@unittest.skipIf(
326-
not torch.cuda.is_available() or not is_cuda_8_9,
322+
not torch.cuda.is_available() or not is_sm_89(),
327323
"CUDA with float8 support not available",
328324
)
329325
def test_sync_amax_func():
@@ -364,7 +360,7 @@ def __exit__(self, *args):
364360

365361

366362
@unittest.skipIf(
367-
not torch.cuda.is_available() or not is_cuda_8_9,
363+
not torch.cuda.is_available() or not is_sm_89(),
368364
"CUDA with float8 support not available",
369365
)
370366
def test_sync_amax_func_cuda_graph_success():
@@ -396,7 +392,7 @@ def test_sync_amax_func_cuda_graph_success():
396392

397393

398394
@unittest.skipIf(
399-
not is_cuda_8_9,
395+
not is_sm_89(),
400396
"CUDA not available",
401397
)
402398
@pytest.mark.parametrize(

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88

9-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
9+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89
1010

1111
if not TORCH_VERSION_AT_LEAST_2_5:
1212
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -40,8 +40,7 @@
4040
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
4141
from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp
4242

43-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
44-
if not is_cuda_8_9:
43+
if not is_sm_89():
4544
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
4645

4746

test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
6+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89
77

88
if not TORCH_VERSION_AT_LEAST_2_5:
99
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -30,8 +30,7 @@
3030
from torchao.float8.float8_tensor import GemmInputRole
3131
from torchao.testing.float8.fsdp2_utils import check_parity_fp8_comm_only
3232

33-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
34-
if not is_cuda_8_9:
33+
if not is_sm_89():
3534
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
3635

3736

test/float8/test_numerics_integration.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import pytest
1313

14-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
14+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90
1515

1616
if not TORCH_VERSION_AT_LEAST_2_5:
1717
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -34,9 +34,6 @@
3434
from torchao.float8.float8_utils import IS_ROCM, compute_error
3535
from torchao.testing.float8.test_utils import get_test_float8_linear_config
3636

37-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
38-
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
39-
4037
torch.manual_seed(0)
4138

4239

@@ -176,7 +173,7 @@ def _test_impl(self, config: Float8LinearConfig) -> None:
176173
"scaling_type_grad_output",
177174
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
178175
)
179-
@pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine")
176+
@pytest.mark.skipif(not is_sm_89(), reason="requires SM89 compatible machine")
180177
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
181178
def test_encoder_fw_bw_from_config_params(
182179
self,
@@ -199,7 +196,7 @@ def test_encoder_fw_bw_from_config_params(
199196
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
200197
],
201198
)
202-
@pytest.mark.skipif(not is_cuda_9_0, reason="requires SM90 compatible machine")
199+
@pytest.mark.skipif(not is_sm_90(), reason="requires SM90 compatible machine")
203200
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
204201
def test_encoder_fw_bw_from_recipe(
205202
self,

0 commit comments

Comments
 (0)