Skip to content

Commit 620703b

Browse files
committed
Deduplicate tests
1 parent 8d1c423 commit 620703b

File tree

2 files changed

+6
-47
lines changed

2 files changed

+6
-47
lines changed

.azure-pipelines/gpu-tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ jobs:
108108
bash pl_examples/run_examples.sh --trainer.gpus=1
109109
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=ddp
110110
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=ddp --trainer.precision=16
111+
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=dp
112+
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=dp --trainer.precision=16
111113
env:
112114
PL_USE_MOCKED_MNIST: "1"
113115
displayName: 'Testing: examples'

pl_examples/test_examples.py

Lines changed: 4 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,70 +11,27 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
import importlib
16-
import platform
1714
from unittest import mock
1815

1916
import pytest
20-
import torch
2117

2218
from pl_examples import _DALI_AVAILABLE
19+
from tests.helpers.runif import RunIf
2320

2421
ARGS_DEFAULT = (
2522
"--trainer.default_root_dir %(tmpdir)s "
2623
"--trainer.max_epochs 1 "
2724
"--trainer.limit_train_batches 2 "
2825
"--trainer.limit_val_batches 2 "
26+
"--trainer.limit_test_batches 2 "
27+
"--trainer.limit_predict_batches 2 "
2928
"--data.batch_size 32 "
3029
)
3130
ARGS_GPU = ARGS_DEFAULT + "--trainer.gpus 1 "
32-
ARGS_DP = ARGS_DEFAULT + "--trainer.gpus 2 --trainer.accelerator dp "
33-
ARGS_AMP = "--trainer.precision 16 "
34-
35-
36-
@pytest.mark.parametrize(
37-
"import_cli",
38-
[
39-
"pl_examples.basic_examples.simple_image_classifier",
40-
"pl_examples.basic_examples.backbone_image_classifier",
41-
"pl_examples.basic_examples.autoencoder",
42-
],
43-
)
44-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
45-
@pytest.mark.parametrize("cli_args", [ARGS_DP, ARGS_DP + ARGS_AMP])
46-
def test_examples_dp(tmpdir, import_cli, cli_args):
47-
48-
module = importlib.import_module(import_cli)
49-
# update the temp dir
50-
cli_args = cli_args % {"tmpdir": tmpdir}
51-
52-
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
53-
module.cli_main()
54-
55-
56-
@pytest.mark.parametrize(
57-
"import_cli",
58-
[
59-
"pl_examples.basic_examples.simple_image_classifier",
60-
"pl_examples.basic_examples.backbone_image_classifier",
61-
"pl_examples.basic_examples.autoencoder",
62-
],
63-
)
64-
@pytest.mark.parametrize("cli_args", [ARGS_DEFAULT])
65-
def test_examples_cpu(tmpdir, import_cli, cli_args):
66-
67-
module = importlib.import_module(import_cli)
68-
# update the temp dir
69-
cli_args = cli_args % {"tmpdir": tmpdir}
70-
71-
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
72-
module.cli_main()
7331

7432

7533
@pytest.mark.skipif(not _DALI_AVAILABLE, reason="Nvidia DALI required")
76-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
77-
@pytest.mark.skipif(platform.system() != "Linux", reason="Only applies to Linux platform.")
34+
@RunIf(min_gpus=1, skip_windows=True)
7835
@pytest.mark.parametrize("cli_args", [ARGS_GPU])
7936
def test_examples_mnist_dali(tmpdir, cli_args):
8037
from pl_examples.basic_examples.dali_image_classifier import cli_main

0 commit comments

Comments
 (0)