Skip to content

Commit 164c8dc

Browse files
committed
Update test
1 parent a94060a commit 164c8dc

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

tests/loggers/test_mlflow.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,18 +201,28 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
201201
logger.log_hyperparams(params)
202202

203203

204+
@mock.patch('pytorch_lightning.loggers.mlflow.time')
204205
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
205206
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
206-
def test_mlflow_logger_with_artifact_location(client, mlflow, tmpdir):
207+
def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
207208
"""
208-
Test that the logger raises warning with special characters not accepted by MLFlow.
209+
Test that the logger calls methods on the mlflow experiment correctly.
209210
"""
211+
time.return_value = 1
212+
210213
logger = MLFlowLogger('test', save_dir=tmpdir, artifact_location='my_artifact_location')
211214
logger._mlflow_client.get_experiment_by_name.return_value = None
212215

216+
params = {'test': 'test_param'}
217+
logger.log_hyperparams(params)
218+
219+
logger.experiment.log_param.assert_called_once_with(logger.run_id, 'test', 'test_param')
220+
213221
metrics = {'some_metric': 10}
214222
logger.log_metrics(metrics)
215223

224+
logger.experiment.log_metric.assert_called_once_with(logger.run_id, 'some_metric', 10, 1000, None)
225+
216226
logger._mlflow_client.create_experiment.assert_called_once_with(
217227
name='test',
218228
artifact_location='my_artifact_location',

0 commit comments

Comments
 (0)