Skip to content

Commit ade3086

Browse files
Bordakaushikb11
authored andcommitted
add skipif warpper (Lightning-AI#6258)
1 parent e736666 commit ade3086

File tree

3 files changed

+49
-50
lines changed

3 files changed

+49
-50
lines changed

tests/callbacks/test_quantization.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,12 @@
2222
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2323
from tests.helpers.datamodules import RegressDataModule
2424
from tests.helpers.simple_models import RegressionModel
25-
from tests.helpers.skipif import skipif_args
25+
from tests.helpers.skipif import SkipIf
2626

2727

28-
@pytest.mark.parametrize(
29-
"observe",
30-
['average', pytest.param('histogram', marks=pytest.mark.skipif(**skipif_args(min_torch="1.5")))]
31-
)
28+
@pytest.mark.parametrize("observe", ['average', pytest.param('histogram', marks=SkipIf(min_torch="1.5"))])
3229
@pytest.mark.parametrize("fuse", [True, False])
33-
@pytest.mark.skipif(**skipif_args(quant_available=True))
30+
@SkipIf(quantization=True)
3431
def test_quantization(tmpdir, observe, fuse):
3532
"""Parity test for quant model"""
3633
seed_everything(42)
@@ -65,7 +62,7 @@ def test_quantization(tmpdir, observe, fuse):
6562
assert torch.allclose(org_score, quant_score, atol=0.45)
6663

6764

68-
@pytest.mark.skipif(**skipif_args(quant_available=True))
65+
@SkipIf(quantization=True)
6966
def test_quantize_torchscript(tmpdir):
7067
"""Test converting to torchscipt """
7168
dm = RegressDataModule()
@@ -81,7 +78,7 @@ def test_quantize_torchscript(tmpdir):
8178
tsmodel(tsmodel.quant(batch[0]))
8279

8380

84-
@pytest.mark.skipif(**skipif_args(quant_available=True))
81+
@SkipIf(quantization=True)
8582
def test_quantization_exceptions(tmpdir):
8683
"""Test wrong fuse layers"""
8784
with pytest.raises(MisconfigurationException, match='Unsupported qconfig'):
@@ -124,7 +121,7 @@ def custom_trigger_last(trainer):
124121
(custom_trigger_last, 2),
125122
]
126123
)
127-
@pytest.mark.skipif(**skipif_args(quant_available=True))
124+
@SkipIf(quantization=True)
128125
def test_quantization_triggers(tmpdir, trigger_fn, expected_count):
129126
"""Test how many times the quant is called"""
130127
dm = RegressDataModule()

tests/core/test_results.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytorch_lightning.core.step_result import Result
2727
from pytorch_lightning.trainer.states import TrainerState
2828
from tests.helpers import BoringDataModule, BoringModel
29-
from tests.helpers.skipif import skipif_args
29+
from tests.helpers.skipif import SkipIf
3030

3131

3232
def _setup_ddp(rank, worldsize):
@@ -72,7 +72,7 @@ def test_result_reduce_ddp(result_cls):
7272
pytest.param(5, False, 0, id='nested_list_predictions'),
7373
pytest.param(6, False, 0, id='dict_list_predictions'),
7474
pytest.param(7, True, 0, id='write_dict_predictions'),
75-
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**skipif_args(min_gpus=1)))
75+
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=SkipIf(min_gpus=1))
7676
]
7777
)
7878
def test_result_obj_predictions(tmpdir, test_option, do_train, gpus):

tests/helpers/skipif.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,55 +21,57 @@
2121
from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE
2222

2323

24-
def skipif_args(
25-
min_gpus: int = 0,
26-
min_torch: Optional[str] = None,
27-
quant_available: bool = False,
28-
) -> dict:
29-
""" Creating aggregated arguments for standard pytest skipif, sot the usecase is::
30-
31-
@pytest.mark.skipif(**create_skipif(min_torch="99"))
32-
def test_any_func(...):
33-
...
24+
class SkipIf:
25+
"""
26+
SkipIf wrapper for simple marking specific cases, fully compatible with pytest.mark::
3427
35-
>>> from pprint import pprint
36-
>>> pprint(skipif_args(min_torch="99", min_gpus=0))
37-
{'condition': True, 'reason': 'Required: [torch>=99]'}
38-
>>> pprint(skipif_args(min_torch="0.0", min_gpus=0)) # doctest: +NORMALIZE_WHITESPACE
39-
{'condition': False, 'reason': 'Conditions satisfied, going ahead with the test.'}
28+
@SkipIf(min_torch="0.0")
29+
@pytest.mark.parametrize("arg1", [1, 2.0])
30+
def test_wrapper(arg1):
31+
assert arg1 > 0.0
4032
"""
41-
conditions = []
42-
reasons = []
4333

44-
if min_gpus:
45-
conditions.append(torch.cuda.device_count() < min_gpus)
46-
reasons.append(f"GPUs>={min_gpus}")
34+
def __new__(self, *args, min_gpus: int = 0, min_torch: Optional[str] = None, quantization: bool = False, **kwargs):
35+
"""
36+
Args:
37+
args: native pytest.mark.skipif arguments
38+
min_gpus: min number of gpus required to run test
39+
min_torch: minimum pytorch version to run test
40+
quantization: if `torch.quantization` package is required to run test
41+
kwargs: native pytest.mark.skipif keyword arguments
42+
"""
43+
conditions = []
44+
reasons = []
4745

48-
if min_torch:
49-
torch_version = LooseVersion(get_distribution("torch").version)
50-
conditions.append(torch_version < LooseVersion(min_torch))
51-
reasons.append(f"torch>={min_torch}")
46+
if min_gpus:
47+
conditions.append(torch.cuda.device_count() < min_gpus)
48+
reasons.append(f"GPUs>={min_gpus}")
5249

53-
if quant_available:
54-
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
55-
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
56-
reasons.append("PyTorch quantization is available")
50+
if min_torch:
51+
torch_version = LooseVersion(get_distribution("torch").version)
52+
conditions.append(torch_version < LooseVersion(min_torch))
53+
reasons.append(f"torch>={min_torch}")
5754

58-
if not any(conditions):
59-
return dict(condition=False, reason="Conditions satisfied, going ahead with the test.")
55+
if quantization:
56+
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
57+
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
58+
reasons.append("missing PyTorch quantization")
6059

61-
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
62-
return dict(
63-
condition=any(conditions),
64-
reason=f"Required: [{' + '.join(reasons)}]",
65-
)
60+
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
61+
return pytest.mark.skipif(
62+
*args,
63+
condition=any(conditions),
64+
reason=f"Requires: [{' + '.join(reasons)}]",
65+
**kwargs,
66+
)
6667

6768

68-
@pytest.mark.skipif(**skipif_args(min_torch="99"))
69+
@SkipIf(min_torch="99")
6970
def test_always_skip():
7071
exit(1)
7172

7273

73-
@pytest.mark.skipif(**skipif_args(min_torch="0.0"))
74-
def test_always_pass():
75-
assert True
74+
@pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0])
75+
@SkipIf(min_torch="0.0")
76+
def test_wrapper(arg1):
77+
assert arg1 > 0.0

0 commit comments

Comments
 (0)