5050 SaveConfigCallback ,
5151)
5252from pytorch_lightning .utilities .exceptions import MisconfigurationException
53- from pytorch_lightning .utilities .imports import _TORCHVISION_AVAILABLE
53+ from pytorch_lightning .utilities .imports import _TORCH_GREATER_EQUAL_1_8 , _TORCHVISION_AVAILABLE
5454from tests .helpers import BoringDataModule , BoringModel
5555from tests .helpers .runif import RunIf
5656from tests .helpers .utils import no_warning_call
@@ -576,21 +576,17 @@ def on_fit_start(self):
576576 raise MisconfigurationException ("Error on fit start" )
577577
578578
579+ @RunIf (skip_windows = True )
579580@pytest .mark .parametrize ("logger" , (False , True ))
580- @pytest .mark .parametrize (
581- "trainer_kwargs" ,
582- (
583- # dict(strategy="ddp_spawn")
584- # dict(strategy="ddp")
585- # the previous accl_conn will choose singleDeviceStrategy for both strategy=ddp/ddp_spawn
586- # TODO revisit this test as it never worked with DDP or DDPSpawn
587- dict (strategy = "single_device" ),
588- pytest .param ({"tpu_cores" : 1 }, marks = RunIf (tpu = True )),
589- ),
590- )
591- def test_cli_distributed_save_config_callback (tmpdir , logger , trainer_kwargs ):
581+ @pytest .mark .parametrize ("strategy" , ("ddp_spawn" , "ddp" ))
582+ def test_cli_distributed_save_config_callback (tmpdir , logger , strategy ):
583+ if _TORCH_GREATER_EQUAL_1_8 :
584+ from torch .multiprocessing import ProcessRaisedException
585+ else :
586+ ProcessRaisedException = Exception
587+
592588 with mock .patch ("sys.argv" , ["any.py" , "fit" ]), pytest .raises (
593- MisconfigurationException , match = r"Error on fit start"
589+ ( MisconfigurationException , ProcessRaisedException ) , match = r"Error on fit start"
594590 ):
595591 LightningCLI (
596592 EarlyExitTestModel ,
@@ -599,7 +595,9 @@ def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
599595 "logger" : logger ,
600596 "max_steps" : 1 ,
601597 "max_epochs" : 1 ,
602- ** trainer_kwargs ,
598+ "strategy" : strategy ,
599+ "accelerator" : "auto" ,
600+ "devices" : 1 ,
603601 },
604602 )
605603 if logger :
0 commit comments