1- import atexit
21import inspect
32import os
43import pickle
2019from pytorch_lightning .loggers .base import DummyExperiment
2120from tests .base import EvalModelTemplate
2221from tests .loggers .test_comet import _patch_comet_atexit
22+ from tests .loggers .test_mlflow import mock_mlflow_run_creation
2323
2424
2525def _get_logger_args (logger_class , save_dir ):
@@ -34,27 +34,31 @@ def _get_logger_args(logger_class, save_dir):
3434
3535
3636def test_loggers_fit_test_all (tmpdir , monkeypatch ):
37- _patch_comet_atexit (monkeypatch )
37+ """ Verify that basic functionality of all loggers. """
38+
39+ _test_loggers_fit_test (tmpdir , TensorBoardLogger )
40+
3841 with mock .patch ('pytorch_lightning.loggers.comet.comet_ml' ), \
3942 mock .patch ('pytorch_lightning.loggers.comet.CometOfflineExperiment' ):
43+ _patch_comet_atexit (monkeypatch )
4044 _test_loggers_fit_test (tmpdir , CometLogger )
4145
42- _test_loggers_fit_test (tmpdir , MLFlowLogger )
46+ with mock .patch ('pytorch_lightning.loggers.mlflow.mlflow' ), \
47+ mock .patch ('pytorch_lightning.loggers.mlflow.MlflowClient' ):
48+ _test_loggers_fit_test (tmpdir , MLFlowLogger )
4349
4450 with mock .patch ('pytorch_lightning.loggers.neptune.neptune' ):
4551 _test_loggers_fit_test (tmpdir , NeptuneLogger )
4652
47- _test_loggers_fit_test ( tmpdir , TensorBoardLogger )
48- _test_loggers_fit_test (tmpdir , TestTubeLogger )
53+ with mock . patch ( 'pytorch_lightning.loggers.test_tube.Experiment' ):
54+ _test_loggers_fit_test (tmpdir , TestTubeLogger )
4955
5056 with mock .patch ('pytorch_lightning.loggers.wandb.wandb' ):
5157 _test_loggers_fit_test (tmpdir , WandbLogger )
5258
5359
5460def _test_loggers_fit_test (tmpdir , logger_class ):
55- """Verify that basic functionality of all loggers."""
5661 os .environ ['PL_DEV_DEBUG' ] = '0'
57-
5862 model = EvalModelTemplate ()
5963
6064 class StoreHistoryLogger (logger_class ):
@@ -78,6 +82,13 @@ def log_metrics(self, metrics, step):
7882 logger .experiment .id = 'foo'
7983 logger .experiment .project_name = 'bar'
8084
85+ if logger_class == TestTubeLogger :
86+ logger .experiment .version = 'foo'
87+ logger .experiment .name = 'bar'
88+
89+ if logger_class == MLFlowLogger :
90+ logger = mock_mlflow_run_creation (logger , experiment_id = "foo" , run_id = "bar" )
91+
8192 trainer = Trainer (
8293 max_epochs = 1 ,
8394 logger = logger ,
@@ -109,21 +120,27 @@ def log_metrics(self, metrics, step):
109120
110121
111122def test_loggers_save_dir_and_weights_save_path_all (tmpdir , monkeypatch ):
112- _patch_comet_atexit (monkeypatch )
123+ """ Test the combinations of save_dir, weights_save_path and default_root_dir. """
124+
125+ _test_loggers_save_dir_and_weights_save_path (tmpdir , TensorBoardLogger )
126+
113127 with mock .patch ('pytorch_lightning.loggers.comet.comet_ml' ), \
114128 mock .patch ('pytorch_lightning.loggers.comet.CometOfflineExperiment' ):
129+ _patch_comet_atexit (monkeypatch )
115130 _test_loggers_save_dir_and_weights_save_path (tmpdir , CometLogger )
116131
117- _test_loggers_save_dir_and_weights_save_path (tmpdir , TensorBoardLogger )
118- _test_loggers_save_dir_and_weights_save_path (tmpdir , MLFlowLogger )
119- _test_loggers_save_dir_and_weights_save_path (tmpdir , TestTubeLogger )
132+ with mock .patch ('pytorch_lightning.loggers.mlflow.mlflow' ), \
133+ mock .patch ('pytorch_lightning.loggers.mlflow.MlflowClient' ):
134+ _test_loggers_save_dir_and_weights_save_path (tmpdir , MLFlowLogger )
135+
136+ with mock .patch ('pytorch_lightning.loggers.test_tube.Experiment' ):
137+ _test_loggers_save_dir_and_weights_save_path (tmpdir , TestTubeLogger )
120138
121139 with mock .patch ('pytorch_lightning.loggers.wandb.wandb' ):
122140 _test_loggers_save_dir_and_weights_save_path (tmpdir , WandbLogger )
123141
124142
125143def _test_loggers_save_dir_and_weights_save_path (tmpdir , logger_class ):
126- """ Test the combinations of save_dir, weights_save_path and default_root_dir. """
127144
128145 class TestLogger (logger_class ):
129146 # for this test it does not matter what these attributes are
@@ -255,18 +272,24 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_
255272 assert pl_module .logger .experiment .something (foo = "bar" ) is None
256273
257274
258- @pytest .mark .skipif (platform .system () == "Windows" , reason = "Distributed training is not supported on Windows" )
259275@pytest .mark .parametrize ("logger_class" , [
260- TensorBoardLogger ,
276+ CometLogger ,
261277 MLFlowLogger ,
262- # NeptuneLogger, # TODO: fix: https://github.com/PyTorchLightning/pytorch-lightning/pull/3256
278+ NeptuneLogger ,
279+ TensorBoardLogger ,
263280 TestTubeLogger ,
264281])
265- @mock . patch ( 'pytorch_lightning.loggers.neptune.neptune' )
266- def test_logger_created_on_rank_zero_only (neptune , tmpdir , monkeypatch , logger_class ):
282+ @pytest . mark . skipif ( platform . system () == "Windows" , reason = "Distributed training is not supported on Windows" )
283+ def test_logger_created_on_rank_zero_only (tmpdir , monkeypatch , logger_class ):
267284 """ Test that loggers get replaced by dummy loggers on global rank > 0"""
268285 _patch_comet_atexit (monkeypatch )
286+ try :
287+ _test_logger_created_on_rank_zero_only (tmpdir , logger_class )
288+ except (ImportError , ModuleNotFoundError ):
289+ pytest .xfail (f"multi-process test requires { logger_class .__class__ } dependencies to be installed." )
290+
269291
292+ def _test_logger_created_on_rank_zero_only (tmpdir , logger_class ):
270293 logger_args = _get_logger_args (logger_class , tmpdir )
271294 logger = logger_class (** logger_args )
272295 model = EvalModelTemplate ()
0 commit comments