Skip to content

Commit d491087

Browse files
author
Peter Y. Yeh
committed
update skip_if_rocm import
lint
1 parent f52d14a commit d491087

File tree

14 files changed

+44
-48
lines changed

14 files changed

+44
-48
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import unittest
33

44
import torch
5-
from test_utils import skip_if_rocm
65
from torch.testing._internal import common_utils
76
from torch.testing._internal.common_utils import (
87
TestCase,
@@ -22,6 +21,7 @@
2221
TORCH_VERSION_AT_LEAST_2_5,
2322
TORCH_VERSION_AT_LEAST_2_6,
2423
is_sm_at_least_89,
24+
skip_if_rocm,
2525
)
2626

2727

test/dtypes/test_floatx.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import unittest
33

44
import torch
5-
from test_utils import skip_if_rocm
65
from torch.testing._internal.common_utils import (
76
TestCase,
87
instantiate_parametrized_tests,
@@ -28,7 +27,7 @@
2827
fpx_weight_only,
2928
quantize_,
3029
)
31-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode
30+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, skip_if_rocm
3231

3332
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
3433
_Floatx_DTYPES = [(3, 2), (2, 2)]

test/float8/test_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
TORCH_VERSION_AT_LEAST_2_5,
1919
is_sm_at_least_89,
2020
is_sm_at_least_90,
21+
skip_if_rocm,
2122
)
2223

2324
if not TORCH_VERSION_AT_LEAST_2_5:
2425
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2526

2627

27-
from test_utils import skip_if_rocm
28-
2928
from torchao.float8.config import (
3029
CastConfig,
3130
Float8LinearConfig,

test/hqq/test_hqq_affine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import unittest
22

33
import torch
4-
from test_utils import skip_if_rocm
54

65
from torchao.quantization import (
76
MappingType,
@@ -11,6 +10,7 @@
1110
)
1211
from torchao.utils import (
1312
TORCH_VERSION_AT_LEAST_2_3,
13+
skip_if_rocm,
1414
)
1515

1616
cuda_available = torch.cuda.is_available()

test/integration/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
benchmark_model,
8181
is_fbcode,
8282
is_sm_at_least_90,
83+
skip_if_rocm,
8384
unwrap_tensor_subclass,
8485
)
8586

@@ -90,7 +91,6 @@
9091
except ModuleNotFoundError:
9192
has_gemlite = False
9293

93-
from test_utils import skip_if_rocm
9494

9595
logger = logging.getLogger("INFO")
9696

test/kernel/test_galore_downproj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
import torch
1010
from galore_test_utils import make_data
11-
from test_utils import skip_if_rocm
1211

1312
from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk
1413
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
14+
from torchao.utils import skip_if_rocm
1515

1616
torch.manual_seed(0)
1717

test/prototype/test_awq.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import torch
66

77
from torchao.quantization import quantize_
8-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
8+
from torchao.utils import (
9+
TORCH_VERSION_AT_LEAST_2_3,
10+
TORCH_VERSION_AT_LEAST_2_5,
11+
skip_if_rocm,
12+
)
913

1014
if TORCH_VERSION_AT_LEAST_2_3:
1115
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
1216

13-
from test_utils import skip_if_rocm
14-
1517

1618
class ToyLinearModel(torch.nn.Module):
1719
def __init__(self, m=512, n=256, k=128):

test/prototype/test_low_bit_optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
TORCH_VERSION_AT_LEAST_2_4,
3131
TORCH_VERSION_AT_LEAST_2_5,
3232
get_available_devices,
33+
skip_if_rocm,
3334
)
3435

3536
try:
@@ -42,7 +43,6 @@
4243
except ImportError:
4344
lpmm = None
4445

45-
from test_utils import skip_if_rocm
4646

4747
_DEVICES = get_available_devices()
4848

test/prototype/test_splitk.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
except ImportError:
1414
triton_available = False
1515

16-
from test_utils import skip_if_rocm
1716

18-
from torchao.utils import skip_if_compute_capability_less_than
17+
from torchao.utils import skip_if_compute_capability_less_than, skip_if_rocm
1918

2019

2120
@unittest.skipIf(not triton_available, "Triton is required but not available")

test/quantization/test_galore_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
dequantize_blockwise,
1414
quantize_blockwise,
1515
)
16-
from test_utils import skip_if_rocm
1716

1817
from torchao.prototype.galore.kernels import (
1918
triton_dequant_blockwise,
2019
triton_quantize_blockwise,
2120
)
21+
from torchao.utils import skip_if_rocm
2222

2323
SEED = 0
2424
torch.manual_seed(SEED)

0 commit comments

Comments
 (0)