3838 --accelerator dp \
3939 """
4040
41- ARGS_DP_AMP = ARGS_DP + """
42- --precision 16 \
43- """
44-
45- ARGS_DDP = ARGS_DEFAULT + """
46- --gpus 2 \
47- --accelerator ddp \
48- --precision 16 \
49- """
50-
51- ARGS_DDP_AMP = ARGS_DEFAULT + """
41+ ARGS_AMP = """
5242--precision 16 \
5343 """
5444
6151 ]
6252)
6353@pytest .mark .skipif (torch .cuda .device_count () < 2 , reason = "test requires multi-GPU machine" )
64- @pytest .mark .parametrize ('cli_args' , [ARGS_DP , ARGS_DP_AMP ])
54+ @pytest .mark .parametrize ('cli_args' , [ARGS_DP , ARGS_DP + ARGS_AMP ])
6555def test_examples_dp (tmpdir , import_cli , cli_args ):
6656
6757 module = importlib .import_module (import_cli )
@@ -72,24 +62,6 @@ def test_examples_dp(tmpdir, import_cli, cli_args):
7262 module .cli_main ()
7363
7464
75- # ToDo: fix this failing example
76- # @pytest.mark.parametrize('import_cli', [
77- # 'pl_examples.basic_examples.simple_image_classifier',
78- # 'pl_examples.basic_examples.backbone_image_classifier',
79- # 'pl_examples.basic_examples.autoencoder',
80- # ])
81- # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
82- # @pytest.mark.parametrize('cli_args', [ARGS_DDP, ARGS_DDP_AMP])
83- # def test_examples_ddp(tmpdir, import_cli, cli_args):
84- #
85- # module = importlib.import_module(import_cli)
86- # # update the temp dir
87- # cli_args = cli_args % {'tmpdir': tmpdir}
88- #
89- # with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
90- # module.cli_main()
91-
92-
9365@pytest .mark .parametrize (
9466 'import_cli' , [
9567 'pl_examples.basic_examples.simple_image_classifier' ,
0 commit comments