@@ -21,7 +21,7 @@ def test_amp_single_gpu(tmpdir, backend):
2121 max_epochs = 1 ,
2222 gpus = 1 ,
2323 distributed_backend = backend ,
24- precision = 16
24+ precision = 16 ,
2525 )
2626
2727 model = EvalModelTemplate ()
@@ -39,18 +39,15 @@ def test_amp_multi_gpu(tmpdir, backend):
3939 tutils .set_random_master_port ()
4040
4141 model = EvalModelTemplate ()
42-
43- trainer_options = dict (
42+ # tutils.run_model_test(trainer_options, model)
43+ trainer = Trainer (
4444 default_root_dir = tmpdir ,
4545 max_epochs = 1 ,
4646 # gpus=2,
4747 gpus = '0, 1' , # test init with gpu string
4848 distributed_backend = backend ,
4949 precision = 16 ,
5050 )
51-
52- # tutils.run_model_test(trainer_options, model)
53- trainer = Trainer (** trainer_options )
5451 result = trainer .fit (model )
5552 assert result
5653
@@ -66,17 +63,15 @@ def test_multi_gpu_wandb(tmpdir, backend):
6663 model = EvalModelTemplate ()
6764 logger = WandbLogger (name = 'utest' )
6865
69- trainer_options = dict (
66+ # tutils.run_model_test(trainer_options, model)
67+ trainer = Trainer (
7068 default_root_dir = tmpdir ,
7169 max_epochs = 1 ,
7270 gpus = 2 ,
7371 distributed_backend = backend ,
7472 precision = 16 ,
7573 logger = logger ,
76-
7774 )
78- # tutils.run_model_test(trainer_options, model)
79- trainer = Trainer (** trainer_options )
8075 result = trainer .fit (model )
8176 assert result
8277 trainer .test (model )
@@ -106,6 +101,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
106101 precision = 16 ,
107102 checkpoint_callback = checkpoint ,
108103 logger = logger ,
104+ default_root_dir = tmpdir ,
109105 )
110106 trainer .is_slurm_managing_tasks = True
111107 result = trainer .fit (model )
0 commit comments