diff --git a/CHANGELOG.md b/CHANGELOG.md index 5229fd565ab71..c606fdcde05e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,6 +64,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595)) +- Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677)) + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 8754842bc4a22..88bed79904cf3 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -78,6 +78,8 @@ def any_lightning_module_function_or_hook(self): Defaults to `./mlflow` if `tracking_uri` is not provided. Has no effect if `tracking_uri` is provided. prefix: A string to put at the beginning of metric keys. + artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate + default. Raises: ImportError: @@ -93,6 +95,7 @@ def __init__( tags: Optional[Dict[str, Any]] = None, save_dir: Optional[str] = './mlruns', prefix: str = '', + artifact_location: Optional[str] = None, ): if mlflow is None: raise ImportError( @@ -109,6 +112,8 @@ def __init__( self._run_id = None self.tags = tags self._prefix = prefix + self._artifact_location = artifact_location + self._mlflow_client = MlflowClient(tracking_uri) @property @@ -129,7 +134,10 @@ def experiment(self) -> MlflowClient: self._experiment_id = expt.experiment_id else: log.warning(f'Experiment with name {self._experiment_name} not found. Creating it.') - self._experiment_id = self._mlflow_client.create_experiment(name=self._experiment_name) + self._experiment_id = self._mlflow_client.create_experiment( + name=self._experiment_name, + artifact_location=self._artifact_location, + ) if self._run_id is None: run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags) diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index d2673f48b871b..35bad766798b1 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -199,3 +199,31 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir): with pytest.warns(RuntimeWarning, match=f'Discard {key}={value}'): logger.log_hyperparams(params) + + +@mock.patch('pytorch_lightning.loggers.mlflow.time') +@mock.patch('pytorch_lightning.loggers.mlflow.mlflow') +@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') +def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): + """ + Test that the logger calls methods on the mlflow experiment correctly. + """ + time.return_value = 1 + + logger = MLFlowLogger('test', save_dir=tmpdir, artifact_location='my_artifact_location') + logger._mlflow_client.get_experiment_by_name.return_value = None + + params = {'test': 'test_param'} + logger.log_hyperparams(params) + + logger.experiment.log_param.assert_called_once_with(logger.run_id, 'test', 'test_param') + + metrics = {'some_metric': 10} + logger.log_metrics(metrics) + + logger.experiment.log_metric.assert_called_once_with(logger.run_id, 'some_metric', 10, 1000, None) + + logger._mlflow_client.create_experiment.assert_called_once_with( + name='test', + artifact_location='my_artifact_location', + )