|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import os |
15 | | -from unittest.mock import patch |
| 15 | +from unittest.mock import patch, DEFAULT |
16 | 16 |
|
17 | 17 | import pytest |
18 | 18 |
|
@@ -99,6 +99,37 @@ def test_comet_logger_experiment_name(comet): |
99 | 99 | comet_experiment().set_name.assert_called_once_with(experiment_name) |
100 | 100 |
|
101 | 101 |
|
| 102 | +@patch('pytorch_lightning.loggers.comet.comet_ml') |
| 103 | +def test_comet_logger_manual_experiment_key(comet): |
| 104 | + """Test that Comet Logger respects manually set COMET_EXPERIMENT_KEY.""" |
| 105 | + |
| 106 | + api_key = "key" |
| 107 | + experiment_key = "96346da91469407a85641afe5766b554" |
| 108 | + |
| 109 | + instantation_environ = {} |
| 110 | + |
| 111 | + def save_os_environ(*args, **kwargs): |
| 112 | + nonlocal instantation_environ |
| 113 | + instantation_environ = os.environ.copy() |
| 114 | + |
| 115 | + return DEFAULT |
| 116 | + |
| 117 | + # Test api_key given |
| 118 | + with patch.dict(os.environ, {"COMET_EXPERIMENT_KEY": experiment_key}): |
| 119 | + with patch('pytorch_lightning.loggers.comet.CometExperiment', side_effect=save_os_environ) as comet_experiment: |
| 120 | + logger = CometLogger(api_key=api_key) |
| 121 | + |
| 122 | + assert logger.version == experiment_key |
| 123 | + |
| 124 | + assert logger._experiment is None |
| 125 | + |
| 126 | + _ = logger.experiment |
| 127 | + |
| 128 | + comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) |
| 129 | + |
| 130 | + assert instantation_environ["COMET_EXPERIMENT_KEY"] == experiment_key |
| 131 | + |
| 132 | + |
102 | 133 | @patch('pytorch_lightning.loggers.comet.CometOfflineExperiment') |
103 | 134 | @patch('pytorch_lightning.loggers.comet.comet_ml') |
104 | 135 | def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch): |
|
0 commit comments