|
21 | 21 | from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE |
22 | 22 |
|
23 | 23 |
|
24 | | -def skipif_args( |
25 | | - min_gpus: int = 0, |
26 | | - min_torch: Optional[str] = None, |
27 | | - quant_available: bool = False, |
28 | | -) -> dict: |
29 | | - """ Creating aggregated arguments for standard pytest skipif, sot the usecase is:: |
30 | | -
|
31 | | - @pytest.mark.skipif(**create_skipif(min_torch="99")) |
32 | | - def test_any_func(...): |
33 | | - ... |
| 24 | +class SkipIf: |
| 25 | + """ |
| 26 | + SkipIf wrapper for simple marking specific cases, fully compatible with pytest.mark:: |
34 | 27 |
|
35 | | - >>> from pprint import pprint |
36 | | - >>> pprint(skipif_args(min_torch="99", min_gpus=0)) |
37 | | - {'condition': True, 'reason': 'Required: [torch>=99]'} |
38 | | - >>> pprint(skipif_args(min_torch="0.0", min_gpus=0)) # doctest: +NORMALIZE_WHITESPACE |
39 | | - {'condition': False, 'reason': 'Conditions satisfied, going ahead with the test.'} |
| 28 | + @SkipIf(min_torch="0.0") |
| 29 | + @pytest.mark.parametrize("arg1", [1, 2.0]) |
| 30 | + def test_wrapper(arg1): |
| 31 | + assert arg1 > 0.0 |
40 | 32 | """ |
41 | | - conditions = [] |
42 | | - reasons = [] |
43 | 33 |
|
44 | | - if min_gpus: |
45 | | - conditions.append(torch.cuda.device_count() < min_gpus) |
46 | | - reasons.append(f"GPUs>={min_gpus}") |
| 34 | + def __new__(self, *args, min_gpus: int = 0, min_torch: Optional[str] = None, quantization: bool = False, **kwargs): |
| 35 | + """ |
| 36 | + Args: |
| 37 | + args: native pytest.mark.skipif arguments |
| 38 | + min_gpus: min number of gpus required to run test |
| 39 | + min_torch: minimum pytorch version to run test |
| 40 | + quantization: if `torch.quantization` package is required to run test |
| 41 | + kwargs: native pytest.mark.skipif keyword arguments |
| 42 | + """ |
| 43 | + conditions = [] |
| 44 | + reasons = [] |
47 | 45 |
|
48 | | - if min_torch: |
49 | | - torch_version = LooseVersion(get_distribution("torch").version) |
50 | | - conditions.append(torch_version < LooseVersion(min_torch)) |
51 | | - reasons.append(f"torch>={min_torch}") |
| 46 | + if min_gpus: |
| 47 | + conditions.append(torch.cuda.device_count() < min_gpus) |
| 48 | + reasons.append(f"GPUs>={min_gpus}") |
52 | 49 |
|
53 | | - if quant_available: |
54 | | - _miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines |
55 | | - conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default) |
56 | | - reasons.append("PyTorch quantization is available") |
| 50 | + if min_torch: |
| 51 | + torch_version = LooseVersion(get_distribution("torch").version) |
| 52 | + conditions.append(torch_version < LooseVersion(min_torch)) |
| 53 | + reasons.append(f"torch>={min_torch}") |
57 | 54 |
|
58 | | - if not any(conditions): |
59 | | - return dict(condition=False, reason="Conditions satisfied, going ahead with the test.") |
| 55 | + if quantization: |
| 56 | + _miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines |
| 57 | + conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default) |
| 58 | + reasons.append("missing PyTorch quantization") |
60 | 59 |
|
61 | | - reasons = [rs for cond, rs in zip(conditions, reasons) if cond] |
62 | | - return dict( |
63 | | - condition=any(conditions), |
64 | | - reason=f"Required: [{' + '.join(reasons)}]", |
65 | | - ) |
| 60 | + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] |
| 61 | + return pytest.mark.skipif( |
| 62 | + *args, |
| 63 | + condition=any(conditions), |
| 64 | + reason=f"Requires: [{' + '.join(reasons)}]", |
| 65 | + **kwargs, |
| 66 | + ) |
66 | 67 |
|
67 | 68 |
|
68 | | -@pytest.mark.skipif(**skipif_args(min_torch="99")) |
| 69 | +@SkipIf(min_torch="99") |
69 | 70 | def test_always_skip(): |
70 | 71 | exit(1) |
71 | 72 |
|
72 | 73 |
|
73 | | -@pytest.mark.skipif(**skipif_args(min_torch="0.0")) |
74 | | -def test_always_pass(): |
75 | | - assert True |
| 74 | +@pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0]) |
| 75 | +@SkipIf(min_torch="0.0") |
| 76 | +def test_wrapper(arg1): |
| 77 | + assert arg1 > 0.0 |
0 commit comments