|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | - |
15 | | -import importlib |
16 | | -import platform |
17 | 14 | from unittest import mock |
18 | 15 |
|
19 | 16 | import pytest |
20 | | -import torch |
21 | 17 |
|
22 | 18 | from pl_examples import _DALI_AVAILABLE |
| 19 | +from tests.helpers.runif import RunIf |
23 | 20 |
|
24 | 21 | ARGS_DEFAULT = ( |
25 | 22 | "--trainer.default_root_dir %(tmpdir)s " |
26 | 23 | "--trainer.max_epochs 1 " |
27 | 24 | "--trainer.limit_train_batches 2 " |
28 | 25 | "--trainer.limit_val_batches 2 " |
| 26 | + "--trainer.limit_test_batches 2 " |
| 27 | + "--trainer.limit_predict_batches 2 " |
29 | 28 | "--data.batch_size 32 " |
30 | 29 | ) |
31 | 30 | ARGS_GPU = ARGS_DEFAULT + "--trainer.gpus 1 " |
32 | | -ARGS_DP = ARGS_DEFAULT + "--trainer.gpus 2 --trainer.accelerator dp " |
33 | | -ARGS_AMP = "--trainer.precision 16 " |
34 | | - |
35 | | - |
36 | | -@pytest.mark.parametrize( |
37 | | - "import_cli", |
38 | | - [ |
39 | | - "pl_examples.basic_examples.simple_image_classifier", |
40 | | - "pl_examples.basic_examples.backbone_image_classifier", |
41 | | - "pl_examples.basic_examples.autoencoder", |
42 | | - ], |
43 | | -) |
44 | | -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") |
45 | | -@pytest.mark.parametrize("cli_args", [ARGS_DP, ARGS_DP + ARGS_AMP]) |
46 | | -def test_examples_dp(tmpdir, import_cli, cli_args): |
47 | | - |
48 | | - module = importlib.import_module(import_cli) |
49 | | - # update the temp dir |
50 | | - cli_args = cli_args % {"tmpdir": tmpdir} |
51 | | - |
52 | | - with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): |
53 | | - module.cli_main() |
54 | | - |
55 | | - |
56 | | -@pytest.mark.parametrize( |
57 | | - "import_cli", |
58 | | - [ |
59 | | - "pl_examples.basic_examples.simple_image_classifier", |
60 | | - "pl_examples.basic_examples.backbone_image_classifier", |
61 | | - "pl_examples.basic_examples.autoencoder", |
62 | | - ], |
63 | | -) |
64 | | -@pytest.mark.parametrize("cli_args", [ARGS_DEFAULT]) |
65 | | -def test_examples_cpu(tmpdir, import_cli, cli_args): |
66 | | - |
67 | | - module = importlib.import_module(import_cli) |
68 | | - # update the temp dir |
69 | | - cli_args = cli_args % {"tmpdir": tmpdir} |
70 | | - |
71 | | - with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): |
72 | | - module.cli_main() |
73 | 31 |
|
74 | 32 |
|
75 | 33 | @pytest.mark.skipif(not _DALI_AVAILABLE, reason="Nvidia DALI required") |
76 | | -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") |
77 | | -@pytest.mark.skipif(platform.system() != "Linux", reason="Only applies to Linux platform.") |
| 34 | +@RunIf(min_gpus=1, skip_windows=True) |
78 | 35 | @pytest.mark.parametrize("cli_args", [ARGS_GPU]) |
79 | 36 | def test_examples_mnist_dali(tmpdir, cli_args): |
80 | 37 | from pl_examples.basic_examples.dali_image_classifier import cli_main |
|
0 commit comments