Skip to content

Commit f292646

Browse files
committed
RUpdate is_sm_ -> is_sm_at_least
ruff
1 parent 73957a5 commit f292646

File tree

13 files changed

+102
-60
lines changed

13 files changed

+102
-60
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
int8_weight_only,
1818
)
1919
from torchao.quantization.quant_primitives import MappingType
20-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_sm_89
20+
from torchao.utils import (
21+
TORCH_VERSION_AT_LEAST_2_5,
22+
TORCH_VERSION_AT_LEAST_2_6,
23+
is_sm_at_least_89,
24+
)
2125

2226

2327
def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
@@ -40,7 +44,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu
4044
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
4145
)
4246

43-
if is_sm_89():
47+
if is_sm_at_least_89():
4448
base_functions.append(float8_weight_only())
4549

4650
return base_functions

test/dtypes/test_affine_quantized_float.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
choose_qparams_affine,
3939
)
4040
from torchao.utils import (
41-
is_sm_89,
42-
is_sm_90,
41+
is_sm_at_least_89,
42+
is_sm_at_least_90,
4343
)
4444

4545
random.seed(0)
@@ -60,12 +60,14 @@ def forward(self, x):
6060

6161
class TestAffineQuantizedFloat8Compile(InductorTestCase):
6262
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
63-
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
63+
@unittest.skipIf(
64+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
65+
)
6466
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
6567
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
6668
@common_utils.parametrize("compile", [True, False])
6769
@common_utils.parametrize(
68-
"granularity", [PerTensor(), PerRow()] if is_sm_90() else [PerTensor()]
70+
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
6971
)
7072
# Inputs are (M,..), K, N
7173
@common_utils.parametrize(
@@ -135,20 +137,26 @@ def test_fp8_linear_variants(
135137
compute_error(output_original, output_quantized) > 20
136138
), f"Quantization error is too high got a SQNR of {error}"
137139

138-
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
140+
@unittest.skipIf(
141+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
142+
)
139143
def test_invalid_granularity(self):
140144
with pytest.raises(ValueError, match="Invalid granularity specification"):
141145
float8_dynamic_activation_float8_weight(granularity="invalid")
142146

143-
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
147+
@unittest.skipIf(
148+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
149+
)
144150
def test_mismatched_granularity(self):
145151
with pytest.raises(
146152
ValueError,
147153
match="Different granularities for activation and weight are not supported",
148154
):
149155
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
150156

151-
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
157+
@unittest.skipIf(
158+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
159+
)
152160
def test_unsupported_granularity(self):
153161
class UnsupportedGranularity:
154162
pass
@@ -159,7 +167,9 @@ class UnsupportedGranularity:
159167
)
160168

161169
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
162-
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
170+
@unittest.skipIf(
171+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
172+
)
163173
def test_per_row_with_float32(self):
164174
with pytest.raises(
165175
AssertionError,
@@ -171,7 +181,9 @@ def test_per_row_with_float32(self):
171181
)
172182

173183
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
174-
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
184+
@unittest.skipIf(
185+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
186+
)
175187
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
176188
def test_serialization(self, mode: str):
177189
# Create and quantize the model
@@ -241,7 +253,9 @@ def test_serialization(self, mode: str):
241253
), f"Scales do not match for {layer_name}"
242254

243255
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
244-
@unittest.skipIf(not is_sm_89(), "Requires GPU with compute capability >= 8.9")
256+
@unittest.skipIf(
257+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
258+
)
245259
def test_fp8_weight_dimension_warning(self):
246260
# Create model with incompatible dimensions (not multiples of 16)
247261
model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights

test/float8/test_base.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
import torch
1515
import torch.nn as nn
1616

17-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90
17+
from torchao.utils import (
18+
TORCH_VERSION_AT_LEAST_2_5,
19+
is_sm_at_least_89,
20+
is_sm_at_least_90,
21+
)
1822

1923
if not TORCH_VERSION_AT_LEAST_2_5:
2024
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -215,7 +219,7 @@ def test_axiswise_reshape(self):
215219
],
216220
)
217221
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
218-
@unittest.skipIf(not is_sm_90(), "Requires CUDA capability >= 9.0")
222+
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
219223
def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
220224
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
221225
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")
@@ -329,7 +333,9 @@ def _test_linear_impl(
329333
# verify initialization flags got updated
330334
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"
331335

332-
@pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True])
336+
@pytest.mark.parametrize(
337+
"emulate", [True, False] if is_sm_at_least_89() else [True]
338+
)
333339
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
334340
@pytest.mark.parametrize(
335341
"scaling_type_input",
@@ -411,7 +417,9 @@ def test_linear_from_recipe(
411417
config,
412418
)
413419

414-
@pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True])
420+
@pytest.mark.parametrize(
421+
"emulate", [True, False] if is_sm_at_least_89() else [True]
422+
)
415423
@pytest.mark.parametrize(
416424
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
417425
)
@@ -458,7 +466,9 @@ def test_autocast_outputs(
458466
@pytest.mark.parametrize(
459467
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
460468
)
461-
@pytest.mark.parametrize("emulate", [True, False] if is_sm_89() else [True])
469+
@pytest.mark.parametrize(
470+
"emulate", [True, False] if is_sm_at_least_89() else [True]
471+
)
462472
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
463473
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
464474
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
@@ -519,7 +529,7 @@ def test_repr(self):
519529
s = m.__repr__()
520530
assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s
521531

522-
@unittest.skipIf(not is_sm_89(), "CUDA 8.9 not available")
532+
@unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available")
523533
def test_inference_mode(self):
524534
x = torch.randn(32, 32, device="cuda")
525535
m = nn.Sequential(nn.Linear(32, 32)).cuda()
@@ -530,7 +540,7 @@ def test_inference_mode(self):
530540

531541
class TestScaledMM:
532542
@unittest.skipIf(
533-
not is_sm_89(),
543+
not is_sm_at_least_89(),
534544
"CUDA not available",
535545
)
536546
@pytest.mark.parametrize(
@@ -575,7 +585,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
575585
atol, rtol = 2e-3, 2e-3
576586
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
577587

578-
@unittest.skipIf(not is_sm_89(), "CUDA not available")
588+
@unittest.skipIf(not is_sm_at_least_89(), "CUDA not available")
579589
def test_different_configs_error(self):
580590
x_fp32 = torch.randn(16, 16, device="cuda")
581591
x_scale = torch.tensor(1.0, device="cuda")
@@ -611,7 +621,7 @@ def test_different_configs_error(self):
611621
a @ b
612622

613623
@unittest.skipIf(
614-
not is_sm_89(),
624+
not is_sm_at_least_89(),
615625
"CUDA not available",
616626
)
617627
@pytest.mark.parametrize(

test/float8/test_compile.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111

1212
import pytest
1313

14-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90
14+
from torchao.utils import (
15+
TORCH_VERSION_AT_LEAST_2_5,
16+
is_sm_at_least_89,
17+
is_sm_at_least_90,
18+
)
1519

1620
if not TORCH_VERSION_AT_LEAST_2_5:
1721
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -95,7 +99,7 @@ def _test_compile_base(
9599
"scaling_type_grad_output",
96100
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
97101
)
98-
@pytest.mark.parametrize("emulate", [False, True] if is_sm_89() else [True])
102+
@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True])
99103
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
100104
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
101105
def test_eager_only(
@@ -122,7 +126,7 @@ def test_eager_only(
122126

123127

124128
@pytest.mark.parametrize("fullgraph", [True])
125-
@pytest.mark.parametrize("emulate", [False, True] if is_sm_89() else [True])
129+
@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True])
126130
@pytest.mark.parametrize(
127131
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
128132
)
@@ -173,7 +177,7 @@ def test_aot_eager(
173177
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
174178
)
175179
@unittest.skipIf(
176-
not torch.cuda.is_available() or not is_sm_89(),
180+
not torch.cuda.is_available() or not is_sm_at_least_89(),
177181
"CUDA with float8 support not available",
178182
)
179183
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@@ -211,7 +215,9 @@ def test_inductor_from_config_params(
211215
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
212216
],
213217
)
214-
@unittest.skipIf(not is_sm_90(), "CUDA with capability 9.0 or greater not available")
218+
@unittest.skipIf(
219+
not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available"
220+
)
215221
def test_inductor_from_recipe(recipe_name):
216222
torch._dynamo.reset()
217223
config = recipe_name_to_linear_config(recipe_name)
@@ -249,7 +255,7 @@ def forward(self, x):
249255

250256
# TODO(future): figure out why the test below fails on CUDA capability 8.9
251257
@unittest.skipIf(
252-
not torch.cuda.is_available() or not is_sm_90(),
258+
not torch.cuda.is_available() or not is_sm_at_least_90(),
253259
"CUDA with capability 9.0 or greater not available",
254260
)
255261
def test_float8_with_graph_break_in_the_middle(self):
@@ -265,7 +271,7 @@ def test_float8_with_graph_break_in_the_middle(self):
265271
torch.testing.assert_close(y_eager, y_compiled)
266272

267273
@unittest.skipIf(
268-
not torch.cuda.is_available() or not is_sm_89(),
274+
not torch.cuda.is_available() or not is_sm_at_least_89(),
269275
"CUDA with float8 support not available",
270276
)
271277
def test_float8_graph_input(self):
@@ -289,7 +295,7 @@ def to_float(x):
289295
torch.testing.assert_close(y2_eager, y2_compiled)
290296

291297
@unittest.skipIf(
292-
not torch.cuda.is_available() or not is_sm_89(),
298+
not torch.cuda.is_available() or not is_sm_at_least_89(),
293299
"CUDA with float8 support not available",
294300
)
295301
def test_float8_graph_output(self):
@@ -319,7 +325,7 @@ def test_float8_graph_output(self):
319325

320326

321327
@unittest.skipIf(
322-
not torch.cuda.is_available() or not is_sm_89(),
328+
not torch.cuda.is_available() or not is_sm_at_least_89(),
323329
"CUDA with float8 support not available",
324330
)
325331
def test_sync_amax_func():
@@ -360,7 +366,7 @@ def __exit__(self, *args):
360366

361367

362368
@unittest.skipIf(
363-
not torch.cuda.is_available() or not is_sm_89(),
369+
not torch.cuda.is_available() or not is_sm_at_least_89(),
364370
"CUDA with float8 support not available",
365371
)
366372
def test_sync_amax_func_cuda_graph_success():
@@ -392,7 +398,7 @@ def test_sync_amax_func_cuda_graph_success():
392398

393399

394400
@unittest.skipIf(
395-
not is_sm_89(),
401+
not is_sm_at_least_89(),
396402
"CUDA not available",
397403
)
398404
@pytest.mark.parametrize(

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88

9-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89
9+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89
1010

1111
if not TORCH_VERSION_AT_LEAST_2_5:
1212
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -40,7 +40,7 @@
4040
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
4141
from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp
4242

43-
if not is_sm_89():
43+
if not is_sm_at_least_89():
4444
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
4545

4646

test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89
6+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89
77

88
if not TORCH_VERSION_AT_LEAST_2_5:
99
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -30,7 +30,7 @@
3030
from torchao.float8.float8_tensor import GemmInputRole
3131
from torchao.testing.float8.fsdp2_utils import check_parity_fp8_comm_only
3232

33-
if not is_sm_89():
33+
if not is_sm_at_least_89():
3434
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
3535

3636

test/float8/test_numerics_integration.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111

1212
import pytest
1313

14-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89, is_sm_90
14+
from torchao.utils import (
15+
TORCH_VERSION_AT_LEAST_2_5,
16+
is_sm_at_least_89,
17+
is_sm_at_least_90,
18+
)
1519

1620
if not TORCH_VERSION_AT_LEAST_2_5:
1721
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -173,7 +177,9 @@ def _test_impl(self, config: Float8LinearConfig) -> None:
173177
"scaling_type_grad_output",
174178
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
175179
)
176-
@pytest.mark.skipif(not is_sm_89(), reason="requires SM89 compatible machine")
180+
@pytest.mark.skipif(
181+
not is_sm_at_least_89(), reason="requires SM89 compatible machine"
182+
)
177183
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
178184
def test_encoder_fw_bw_from_config_params(
179185
self,
@@ -196,7 +202,9 @@ def test_encoder_fw_bw_from_config_params(
196202
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
197203
],
198204
)
199-
@pytest.mark.skipif(not is_sm_90(), reason="requires SM90 compatible machine")
205+
@pytest.mark.skipif(
206+
not is_sm_at_least_90(), reason="requires SM90 compatible machine"
207+
)
200208
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
201209
def test_encoder_fw_bw_from_recipe(
202210
self,

0 commit comments

Comments
 (0)