2222from tests .base import EvalModelTemplate , BoringModel
2323
2424
25+ def get_warnings (recwarn ):
26+ warnings_text = '\n ' .join (str (w .message ) for w in recwarn .list )
27+ recwarn .clear ()
28+ return warnings_text
29+
30+
2531@mock .patch ('pytorch_lightning.loggers.wandb.wandb' )
26- def test_wandb_logger_init (wandb ):
32+ def test_wandb_logger_init (wandb , recwarn ):
2733 """Verify that basic functionality of wandb logger works.
2834 Wandb doesn't work well with pytest so we have to mock it out here."""
2935
@@ -34,6 +40,9 @@ def test_wandb_logger_init(wandb):
3440 wandb .init .assert_called_once ()
3541 wandb .init ().log .assert_called_once_with ({'acc' : 1.0 }, step = None )
3642
43+ # mock wandb step
44+ wandb .init ().step = 0
45+
3746 # test wandb.init not called if there is a W&B run
3847 wandb .init ().log .reset_mock ()
3948 wandb .init .reset_mock ()
@@ -49,15 +58,28 @@ def test_wandb_logger_init(wandb):
4958 logger .log_metrics ({'acc' : 1.0 }, step = 3 )
5059 wandb .init ().log .assert_called_with ({'acc' : 1.0 }, step = 6 )
5160
61+ # log hyper parameters
5262 logger .log_hyperparams ({'test' : None , 'nested' : {'a' : 1 }, 'b' : [2 , 3 , 4 ]})
5363 wandb .init ().config .update .assert_called_once_with (
5464 {'test' : 'None' , 'nested/a' : 1 , 'b' : [2 , 3 , 4 ]},
5565 allow_val_change = True ,
5666 )
5767
68+ # watch a model
5869 logger .watch ('model' , 'log' , 10 )
5970 wandb .init ().watch .assert_called_once_with ('model' , log = 'log' , log_freq = 10 )
6071
72+ # verify warning for logging at a previous step
73+ assert 'Trying to log at a previous step' not in get_warnings (recwarn )
74+ # current step from wandb should be 6 (last logged step)
75+ logger .experiment .step = 6
76+ # logging at step 2 should raise a warning (step_offset is still 3)
77+ logger .log_metrics ({'acc' : 1.0 }, step = 2 )
78+ assert 'Trying to log at a previous step' in get_warnings (recwarn )
79+ # logging again at step 2 should not display again the same warning
80+ logger .log_metrics ({'acc' : 1.0 }, step = 2 )
81+ assert 'Trying to log at a previous step' not in get_warnings (recwarn )
82+
6183 assert logger .name == wandb .init ().project_name ()
6284 assert logger .version == wandb .init ().id
6385
@@ -71,6 +93,7 @@ def test_wandb_pickle(wandb, tmpdir):
7193 class Experiment :
7294 """ """
7395 id = 'the_id'
96+ step = 0
7497
7598 def project_name (self ):
7699 return 'the_project_name'
@@ -108,8 +131,11 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
108131 assert logger .name is None
109132
110133 # mock return values of experiment
134+ wandb .run = None
135+ wandb .init ().step = 0
111136 logger .experiment .id = '1'
112137 logger .experiment .project_name .return_value = 'project'
138+ logger .experiment .step = 0
113139
114140 for _ in range (2 ):
115141 _ = logger .experiment
0 commit comments