@@ -201,18 +201,28 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
201201 logger .log_hyperparams (params )
202202
203203
204+ @mock .patch ('pytorch_lightning.loggers.mlflow.time' )
204205@mock .patch ('pytorch_lightning.loggers.mlflow.mlflow' )
205206@mock .patch ('pytorch_lightning.loggers.mlflow.MlflowClient' )
206- def test_mlflow_logger_with_artifact_location (client , mlflow , tmpdir ):
207+ def test_mlflow_logger_experiment_calls (client , mlflow , time , tmpdir ):
207208 """
208- Test that the logger raises warning with special characters not accepted by MLFlow .
209+ Test that the logger calls methods on the mlflow experiment correctly .
209210 """
211+ time .return_value = 1
212+
210213 logger = MLFlowLogger ('test' , save_dir = tmpdir , artifact_location = 'my_artifact_location' )
211214 logger ._mlflow_client .get_experiment_by_name .return_value = None
212215
216+ params = {'test' : 'test_param' }
217+ logger .log_hyperparams (params )
218+
219+ logger .experiment .log_param .assert_called_once_with (logger .run_id , 'test' , 'test_param' )
220+
213221 metrics = {'some_metric' : 10 }
214222 logger .log_metrics (metrics )
215223
224+ logger .experiment .log_metric .assert_called_once_with (logger .run_id , 'some_metric' , 10 , 1000 , None )
225+
216226 logger ._mlflow_client .create_experiment .assert_called_once_with (
217227 name = 'test' ,
218228 artifact_location = 'my_artifact_location' ,
0 commit comments